airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] morgendave closed pull request #4101: [AIRFLOW-3272] Add base grpc hook
Date Tue, 06 Nov 2018 00:53:42 GMT
morgendave closed pull request #4101: [AIRFLOW-3272] Add base grpc hook
URL: https://github.com/apache/incubator-airflow/pull/4101
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/grpc_hook.py b/airflow/contrib/hooks/grpc_hook.py
new file mode 100644
index 0000000000..b260847f19
--- /dev/null
+++ b/airflow/contrib/hooks/grpc_hook.py
@@ -0,0 +1,118 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import grpc
+from google import auth as google_auth
+from google.auth import jwt as google_auth_jwt
+from google.auth.transport import grpc as google_auth_transport_grpc
+from google.auth.transport import requests as google_auth_transport_requests
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.exceptions import AirflowConfigException
+
+
+class GrpcHook(BaseHook):
+    """
+    General interaction with gRPC servers.
+    :param grpc_conn_id: The connection ID to use when fetching connection info.
+    :type grpc_conn_id: str
+    :param interceptors: a list of gRPC interceptor objects which would be applied
+        to the connected gRPC channle. None by default.
+    :type interceptors: a list of gRPC interceptors based on or extends the four
+        official gRPC interceptors, eg, UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
+        StreamUnaryClientInterceptor, StreamStreamClientInterceptor.
+    ::param custom_connection_func: The customized connection function to return gRPC channel.
+    :type custom_connection_func: python callable objects that accept the connection as
+        its only arg. Could be partial or lambda.
+    """
+
+    def __init__(self, grpc_conn_id, interceptors=None, custom_connection_func=None):
+        self.grpc_conn_id = grpc_conn_id
+        self.conn = self.get_connection(self.grpc_conn_id)
+        self.extras = self.conn.extra_dejson
+        self.interceptors = interceptors if interceptors else []
+        self.custom_connection_func = custom_connection_func
+
+    def get_conn(self):
+        if "://" in self.conn.host:
+            base_url = self.conn.host
+        else:
+            # schema defaults to HTTP
+            schema = self.conn.schema if self.conn.schema else "http"
+            base_url = schema + "://" + self.conn.host
+
+        if self.conn.port:
+            base_url = base_url + ":" + str(self.conn.port) + "/"
+
+        auth_type = self._get_field("auth_type")
+
+        if auth_type == "NO_AUTH":
+            channel = grpc.insecure_channel(base_url)
+        elif auth_type == "SSL" or auth_type == "TLS":
+            credential_file_name = self._get_field("credential_pem_file")
+            creds = grpc.ssl_channel_credentials(open(credential_file_name).read())
+            channel = grpc.secure_channel(base_url, creds)
+        elif auth_type == "JWT_GOOGLE":
+            credentials, _ = google_auth.default()
+            jwt_creds = google_auth_jwt.OnDemandCredentials.from_signing_credentials(
+                credentials)
+            channel = google_auth_transport_grpc.secure_authorized_channel(
+                jwt_creds, None, base_url)
+        elif auth_type == "OATH_GOOGLE":
+            scopes = self._get_field("scopes").split(",")
+            credentials, _ = google_auth.default(scopes=scopes)
+            request = google_auth_transport_requests.Request()
+            channel = google_auth_transport_grpc.secure_authorized_channel(
+                credentials, request, base_url)
+        elif auth_type == "CUSTOM":
+            if not self.custom_connection_func:
+                raise AirflowConfigException(
+                    "Customized connection function not set, not able to establish a channel")
+            channel = self.custom_connection_func(self.conn)
+
+        if self.interceptors:
+            for interceptor in self.interceptors:
+                channel = grpc.intercept_channel(channel,
+                                                 interceptor)
+
+        return channel
+
+    def run(self, stub_class, call_func, streaming=False, data={}):
+        with self.get_conn() as channel:
+            stub = stub_class(channel)
+            try:
+                response = stub.call_func(**data)
+                if not streaming:
+                    return response
+
+                for single_response in response:
+                    yield single_response
+            except grpc.FutureTimeoutError:
+                self.log.exception(
+                    "Timeout when calling the grpc service: %s, method: %s" %
+                    (stub_class.__name__, call_func.__name__))
+
+    def _get_field(self, field_name, default=None):
+        """
+        Fetches a field from extras, and returns it. This is some Airflow
+        magic. The grpc hook type adds custom UI elements
+        to the hook page, which allow admins to specify scopes, credential pem files, etc.
+        They get formatted as shown below.
+        """
+        full_field_name = 'extra__grpc__{}'.format(field_name)
+        if full_field_name in self.extras:
+            return self.extras[full_field_name]
+        else:
+            return default
diff --git a/airflow/models.py b/airflow/models.py
index fa33609852..37ac495451 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -667,6 +667,7 @@ class Connection(Base, LoggingMixin):
         ('azure_data_lake', 'Azure Data Lake'),
         ('cassandra', 'Cassandra',),
         ('qubole', 'Qubole'),
+        ('grpc', 'GRPC Connection'),
     ]
 
     def __init__(
@@ -807,6 +808,9 @@ def get_hook(self):
             elif self.conn_type == 'cassandra':
                 from airflow.contrib.hooks.cassandra_hook import CassandraHook
                 return CassandraHook(cassandra_conn_id=self.conn_id)
+            elif self.conn_type == 'grpc':
+                from airflow.contrib.hooks.grpc_hook import GrpcHook
+                return GrpcHook(grpc_conn_id=self.conn_id)
         except Exception:
             pass
 
diff --git a/airflow/www/views.py b/airflow/www/views.py
index 206caea7b9..cd32953d86 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -2980,6 +2980,9 @@ class ConnectionModelView(wwwutils.SuperUserMixin, AirflowModelView):
         'extra__google_cloud_platform__key_path',
         'extra__google_cloud_platform__keyfile_dict',
         'extra__google_cloud_platform__scope',
+        'extra__grpc__auth_type',
+        'extra__grpc__credential_pem_file',
+        'extra__grpc__scopes',
     )
     verbose_name = "Connection"
     verbose_name_plural = "Connections"
@@ -3002,6 +3005,9 @@ class ConnectionModelView(wwwutils.SuperUserMixin, AirflowModelView):
         'extra__google_cloud_platform__key_path': StringField('Keyfile Path'),
         'extra__google_cloud_platform__keyfile_dict': PasswordField('Keyfile JSON'),
         'extra__google_cloud_platform__scope': StringField('Scopes (comma separated)'),
+        'extra__grpc__auth_type': StringField('Authentication Type'),
+        'extra__grpc__credential_pem_file': StringField('Credential Pem File Path'),
+        'extra__grpc__scopes': StringField('Scopes (comma separated)'),
     }
     form_choices = {
         'conn_type': models.Connection._types
@@ -3009,7 +3015,7 @@ class ConnectionModelView(wwwutils.SuperUserMixin, AirflowModelView):
 
     def on_model_change(self, form, model, is_created):
         formdata = form.data
-        if formdata['conn_type'] in ['jdbc', 'google_cloud_platform']:
+        if formdata['conn_type'] in ['jdbc', 'google_cloud_platform', 'grpc']:
             extra = {
                 key: formdata[key]
                 for key in self.form_extra_fields.keys() if key in formdata}
diff --git a/airflow/www_rbac/views.py b/airflow/www_rbac/views.py
index b47cb0ec14..6e678ec8f5 100644
--- a/airflow/www_rbac/views.py
+++ b/airflow/www_rbac/views.py
@@ -1945,7 +1945,10 @@ class ConnectionModelView(AirflowModelView):
                     'extra__google_cloud_platform__project',
                     'extra__google_cloud_platform__key_path',
                     'extra__google_cloud_platform__keyfile_dict',
-                    'extra__google_cloud_platform__scope']
+                    'extra__google_cloud_platform__scope',
+                    'extra__grpc__auth_type',
+                    'extra__grpc__credential_pem_file',
+                    'extra__grpc__scopes']
     list_columns = ['conn_id', 'conn_type', 'host', 'port', 'is_encrypted',
                     'is_extra_encrypted']
     add_columns = edit_columns = ['conn_id', 'conn_type', 'host', 'schema',
@@ -1966,7 +1969,7 @@ def action_muldelete(self, items):
 
     def process_form(self, form, is_created):
         formdata = form.data
-        if formdata['conn_type'] in ['jdbc', 'google_cloud_platform']:
+        if formdata['conn_type'] in ['jdbc', 'google_cloud_platform', 'grpc']:
             extra = {
                 key: formdata[key]
                 for key in self.extra_fields if key in formdata}
diff --git a/setup.py b/setup.py
index 8c6c927153..ae2ec56487 100644
--- a/setup.py
+++ b/setup.py
@@ -307,6 +307,7 @@ def do_setup():
             'funcsigs==1.0.0',
             'future>=0.16.0, <0.17',
             'gitpython>=2.0.2',
+            'grpcio>=1.15.0',
             'gunicorn>=19.4.0, <20.0',
             'iso8601>=0.1.12',
             'jinja2>=2.7.3, <2.9.0',


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message