airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From criccom...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-1577] Add token support to DatabricksHook
Date Fri, 08 Sep 2017 18:24:29 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/master ea9ab96cb -> c2c51518e


[AIRFLOW-1577] Add token support to DatabricksHook

Closes #2579 from andrewmchen/token


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/c2c51518
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/c2c51518
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/c2c51518

Branch: refs/heads/master
Commit: c2c51518e8fc7ead22dc8c007289981e805827cf
Parents: ea9ab96
Author: Andrew Chen <andrewmnchen@gmail.com>
Authored: Fri Sep 8 11:24:14 2017 -0700
Committer: Chris Riccomini <criccomini@apache.org>
Committed: Fri Sep 8 11:24:14 2017 -0700

----------------------------------------------------------------------
 airflow/contrib/hooks/databricks_hook.py        | 21 ++++++++++-
 .../contrib/operators/databricks_operator.py    |  4 ++-
 tests/contrib/hooks/test_databricks_hook.py     | 37 +++++++++++++++++++-
 3 files changed, 59 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c2c51518/airflow/contrib/hooks/databricks_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py
index 0cd5d0f..18e20c4 100644
--- a/airflow/contrib/hooks/databricks_hook.py
+++ b/airflow/contrib/hooks/databricks_hook.py
@@ -20,6 +20,7 @@ from airflow import __version__
 from airflow.exceptions import AirflowException
 from airflow.hooks.base_hook import BaseHook
 from requests import exceptions as requests_exceptions
+from requests.auth import AuthBase
 
 
 try:
@@ -99,7 +100,12 @@ class DatabricksHook(BaseHook):
         url = 'https://{host}/{endpoint}'.format(
             host=self._parse_host(self.databricks_conn.host),
             endpoint=endpoint)
-        auth = (self.databricks_conn.login, self.databricks_conn.password)
+        if 'token' in self.databricks_conn.extra_dejson:
+            logging.info('Using token auth.')
+            auth = _TokenAuth(self.databricks_conn.extra_dejson['token'])
+        else:
+            logging.info('Using basic auth.')
+            auth = (self.databricks_conn.login, self.databricks_conn.password)
         if method == 'GET':
             request_func = requests.get
         elif method == 'POST':
@@ -200,3 +206,16 @@ class RunState:
 
     def __repr__(self):
         return str(self.__dict__)
+
+
+class _TokenAuth(AuthBase):
+    """
+    Helper class for requests Auth field. AuthBase requires you to implement the __call__
+    magic function.
+    """
+    def __init__(self, token):
+        self.token = token
+
+    def __call__(self, r):
+        r.headers['Authorization'] = 'Bearer ' + self.token
+        return r

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c2c51518/airflow/contrib/operators/databricks_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py
index 9c995df..1aa1441 100644
--- a/airflow/contrib/operators/databricks_operator.py
+++ b/airflow/contrib/operators/databricks_operator.py
@@ -131,7 +131,9 @@ class DatabricksSubmitRunOperator(BaseOperator):
         This field will be templated.
     :type timeout_seconds: int32
     :param databricks_conn_id: The name of the Airflow connection to use.
-        By default and in the common case this will be ``databricks_default``.
+        By default and in the common case this will be ``databricks_default``. To use
+        token based authentication, provide the key ``token`` in the extra field for the
+        connection.
     :type databricks_conn_id: string
     :param polling_period_seconds: Controls the rate which we poll for the result of
         this run. By default the operator will poll every 30 seconds.

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c2c51518/tests/contrib/hooks/test_databricks_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_databricks_hook.py b/tests/contrib/hooks/test_databricks_hook.py
index 6c789f9..56288a1 100644
--- a/tests/contrib/hooks/test_databricks_hook.py
+++ b/tests/contrib/hooks/test_databricks_hook.py
@@ -13,10 +13,11 @@
 # limitations under the License.
 #
 
+import json
 import unittest
 
 from airflow import __version__
-from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT
+from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT,
_TokenAuth
 from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.utils import db
@@ -45,6 +46,7 @@ HOST = 'xx.cloud.databricks.com'
 HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
 LOGIN = 'login'
 PASSWORD = 'password'
+TOKEN = 'token'
 USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)}
 RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1'
 LIFE_CYCLE_STATE = 'PENDING'
@@ -203,6 +205,39 @@ class DatabricksHookTest(unittest.TestCase):
             headers=USER_AGENT_HEADER,
             timeout=self.hook.timeout_seconds)
 
+
+class DatabricksHookTokenTest(unittest.TestCase):
+    """
+    Tests for DatabricksHook when auth is done with token.
+    """
+    @db.provide_session
+    def setUp(self, session=None):
+        conn = session.query(Connection) \
+            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
+            .first()
+        conn.extra = json.dumps({'token': TOKEN})
+        session.commit()
+
+        self.hook = DatabricksHook()
+
+    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+    def test_submit_run(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
+        status_code_mock = mock.PropertyMock(return_value=200)
+        type(mock_requests.post.return_value).status_code = status_code_mock
+        json = {
+          'notebook_task': NOTEBOOK_TASK,
+          'new_cluster': NEW_CLUSTER
+        }
+        run_id = self.hook.submit_run(json)
+
+        self.assertEquals(run_id, '1')
+        args = mock_requests.post.call_args
+        kwargs = args[1]
+        self.assertEquals(kwargs['auth'].token, TOKEN)
+
+
 class RunStateTest(unittest.TestCase):
     def test_is_terminal_true(self):
         terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']


Mime
View raw message