airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From art...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-1028] Databricks Operator for Airflow
Date Thu, 06 Apr 2017 15:30:37 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/master 5a6f18f1c -> 53ca50845


[AIRFLOW-1028] Databricks Operator for Airflow

Add DatabricksSubmitRun Operator

In this PR, we contribute a DatabricksSubmitRun operator and a
Databricks hook. This operator enables easy integration of Airflow
with Databricks. In addition to the operator, we have created a
databricks_default connection, an example_dag using this
DatabricksSubmitRunOperator, and matching documentation.

Closes #2202 from andrewmchen/databricks-operator-
squashed


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/53ca5084
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/53ca5084
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/53ca5084

Branch: refs/heads/master
Commit: 53ca5084561fd5c13996609f2eda6baf717249b5
Parents: 5a6f18f
Author: Andrew Chen <andrewchen@databricks.com>
Authored: Thu Apr 6 08:30:01 2017 -0700
Committer: Arthur Wiedmer <arthur.wiedmer@gmail.com>
Committed: Thu Apr 6 08:30:33 2017 -0700

----------------------------------------------------------------------
 .../example_dags/example_databricks_operator.py |  82 +++++++
 airflow/contrib/hooks/databricks_hook.py        | 202 +++++++++++++++++
 .../contrib/operators/databricks_operator.py    | 211 +++++++++++++++++
 airflow/exceptions.py                           |   2 +-
 airflow/models.py                               |   1 +
 airflow/utils/db.py                             |   4 +
 docs/code.rst                                   |   1 +
 docs/integration.rst                            |  13 ++
 setup.py                                        |   2 +
 tests/contrib/hooks/databricks_hook.py          | 226 +++++++++++++++++++
 tests/contrib/operators/databricks_operator.py  | 185 +++++++++++++++
 11 files changed, 928 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/contrib/example_dags/example_databricks_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/example_dags/example_databricks_operator.py b/airflow/contrib/example_dags/example_databricks_operator.py
new file mode 100644
index 0000000..abf6844
--- /dev/null
+++ b/airflow/contrib/example_dags/example_databricks_operator.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import airflow
+
+from airflow import DAG
+from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator
+
+# This is an example DAG which uses the DatabricksSubmitRunOperator.
+# In this example, we create two tasks which execute sequentially.
+# The first task is to run a notebook at the workspace path "/test"
+# and the second task is to run a JAR uploaded to DBFS. Both,
+# tasks use new clusters.
+#
+# Because we have set a downstream dependency on the notebook task,
+# the spark jar task will NOT run until the notebook task completes
+# successfully.
+#
+# The definition of a succesful run is if the run has a result_state of "SUCCESS".
+# For more information about the state of a run refer to
+# https://docs.databricks.com/api/latest/jobs.html#runstate
+
+args = {
+    'owner': 'airflow',
+    'email': ['airflow@example.com'],
+    'depends_on_past': False,
+    'start_date': airflow.utils.dates.days_ago(2)
+}
+
+dag = DAG(
+    dag_id='example_databricks_operator', default_args=args,
+    schedule_interval='@daily')
+
+new_cluster = {
+    'spark_version': '2.1.0-db3-scala2.11',
+    'node_type_id': 'r3.xlarge',
+    'aws_attributes': {
+        'availability': 'ON_DEMAND'
+    },
+    'num_workers': 8
+}
+
+notebook_task_params = {
+    'new_cluster': new_cluster,
+    'notebook_task': {
+        'notebook_path': '/Users/airflow@example.com/PrepareData',
+    },
+}
+# Example of using the JSON parameter to initialize the operator.
+notebook_task = DatabricksSubmitRunOperator(
+    task_id='notebook_task',
+    dag=dag,
+    json=notebook_task_params)
+
+# Example of using the named parameters of DatabricksSubmitRunOperator
+# to initialize the operator.
+spark_jar_task = DatabricksSubmitRunOperator(
+    task_id='spark_jar_task',
+    dag=dag,
+    new_cluster=new_cluster,
+    spark_jar_task={
+        'main_class_name': 'com.example.ProcessData'
+    },
+    libraries=[
+        {
+            'jar': 'dbfs:/lib/etl-0.1.jar'
+        }
+    ]
+)
+
+notebook_task.set_downstream(spark_jar_task)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/contrib/hooks/databricks_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py
new file mode 100644
index 0000000..0cd5d0f
--- /dev/null
+++ b/airflow/contrib/hooks/databricks_hook.py
@@ -0,0 +1,202 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import requests
+
+from airflow import __version__
+from airflow.exceptions import AirflowException
+from airflow.hooks.base_hook import BaseHook
+from requests import exceptions as requests_exceptions
+
+
+try:
+    from urllib import parse as urlparse
+except ImportError:
+    import urlparse
+
+
+SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/submit')
+GET_RUN_ENDPOINT = ('GET', 'api/2.0/jobs/runs/get')
+CANCEL_RUN_ENDPOINT = ('POST', 'api/2.0/jobs/runs/cancel')
+USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)}
+
+
+class DatabricksHook(BaseHook):
+    """
+    Interact with Databricks.
+    """
+    def __init__(
+            self,
+            databricks_conn_id='databricks_default',
+            timeout_seconds=180,
+            retry_limit=3):
+        """
+        :param databricks_conn_id: The name of the databricks connection to use.
+        :type databricks_conn_id: string
+        :param timeout_seconds: The amount of time in seconds the requests library
+            will wait before timing-out.
+        :type timeout_seconds: int
+        :param retry_limit: The number of times to retry the connection in case of
+            service outages.
+        :type retry_limit: int
+        """
+        self.databricks_conn_id = databricks_conn_id
+        self.databricks_conn = self.get_connection(databricks_conn_id)
+        self.timeout_seconds = timeout_seconds
+        assert retry_limit >= 1, 'Retry limit must be greater than equal to 1'
+        self.retry_limit = retry_limit
+
+    def _parse_host(self, host):
+        """
+        The purpose of this function is to be robust to improper connections
+        settings provided by users, specifically in the host field.
+
+
+        For example -- when users supply ``https://xx.cloud.databricks.com`` as the
+        host, we must strip out the protocol to get the host.
+        >>> h = DatabricksHook()
+        >>> assert h._parse_host('https://xx.cloud.databricks.com') == \
+            'xx.cloud.databricks.com'
+
+        In the case where users supply the correct ``xx.cloud.databricks.com`` as the
+        host, this function is a no-op.
+        >>> assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com'
+        """
+        urlparse_host = urlparse.urlparse(host).hostname
+        if urlparse_host:
+            # In this case, host = https://xx.cloud.databricks.com
+            return urlparse_host
+        else:
+            # In this case, host = xx.cloud.databricks.com
+            return host
+
+    def _do_api_call(self, endpoint_info, json):
+        """
+        Utility function to perform an API call with retries
+        :param endpoint_info: Tuple of method and endpoint
+        :type endpoint_info: (string, string)
+        :param json: Parameters for this API call.
+        :type json: dict
+        :return: If the api call returns a OK status code,
+            this function returns the response in JSON. Otherwise,
+            we throw an AirflowException.
+        :rtype: dict
+        """
+        method, endpoint = endpoint_info
+        url = 'https://{host}/{endpoint}'.format(
+            host=self._parse_host(self.databricks_conn.host),
+            endpoint=endpoint)
+        auth = (self.databricks_conn.login, self.databricks_conn.password)
+        if method == 'GET':
+            request_func = requests.get
+        elif method == 'POST':
+            request_func = requests.post
+        else:
+            raise AirflowException('Unexpected HTTP Method: ' + method)
+
+        for attempt_num in range(1, self.retry_limit+1):
+            try:
+                response = request_func(
+                    url,
+                    json=json,
+                    auth=auth,
+                    headers=USER_AGENT_HEADER,
+                    timeout=self.timeout_seconds)
+                if response.status_code == requests.codes.ok:
+                    return response.json()
+                else:
+                    # In this case, the user probably made a mistake.
+                    # Don't retry.
+                    raise AirflowException('Response: {0}, Status Code: {1}'.format(
+                        response.content, response.status_code))
+            except (requests_exceptions.ConnectionError,
+                    requests_exceptions.Timeout) as e:
+                logging.error(('Attempt {0} API Request to Databricks failed ' +
+                              'with reason: {1}').format(attempt_num, e))
+        raise AirflowException(('API requests to Databricks failed {} times. ' +
+                               'Giving up.').format(self.retry_limit))
+
+    def submit_run(self, json):
+        """
+        Utility function to call the ``api/2.0/jobs/runs/submit`` endpoint.
+
+        :param json: The data used in the body of the request to the ``submit`` endpoint.
+        :type json: dict
+        :return: the run_id as a string
+        :rtype: string
+        """
+        response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json)
+        return response['run_id']
+
+    def get_run_page_url(self, run_id):
+        json = {'run_id': run_id}
+        response = self._do_api_call(GET_RUN_ENDPOINT, json)
+        return response['run_page_url']
+
+    def get_run_state(self, run_id):
+        json = {'run_id': run_id}
+        response = self._do_api_call(GET_RUN_ENDPOINT, json)
+        state = response['state']
+        life_cycle_state = state['life_cycle_state']
+        # result_state may not be in the state if not terminal
+        result_state = state.get('result_state', None)
+        state_message = state['state_message']
+        return RunState(life_cycle_state, result_state, state_message)
+
+    def cancel_run(self, run_id):
+        json = {'run_id': run_id}
+        self._do_api_call(CANCEL_RUN_ENDPOINT, json)
+
+
+RUN_LIFE_CYCLE_STATES = [
+    'PENDING',
+    'RUNNING',
+    'TERMINATING',
+    'TERMINATED',
+    'SKIPPED',
+    'INTERNAL_ERROR'
+]
+
+
+class RunState:
+    """
+    Utility class for the run state concept of Databricks runs.
+    """
+    def __init__(self, life_cycle_state, result_state, state_message):
+        self.life_cycle_state = life_cycle_state
+        self.result_state = result_state
+        self.state_message = state_message
+
+    @property
+    def is_terminal(self):
+        if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES:
+            raise AirflowException(('Unexpected life cycle state: {}: If the state has '
+                            'been introduced recently, please check the Databricks user '
+                            'guide for troubleshooting information').format(
+                                self.life_cycle_state))
+        return self.life_cycle_state in ('TERMINATED', 'SKIPPED', 'INTERNAL_ERROR')
+
+    @property
+    def is_successful(self):
+        return self.result_state == 'SUCCESS'
+
+    def __eq__(self, other):
+        return self.life_cycle_state == other.life_cycle_state and \
+            self.result_state == other.result_state and \
+            self.state_message == other.state_message
+
+    def __repr__(self):
+        return str(self.__dict__)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/contrib/operators/databricks_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py
new file mode 100644
index 0000000..46b1659
--- /dev/null
+++ b/airflow/contrib/operators/databricks_operator.py
@@ -0,0 +1,211 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import time
+
+from airflow.exceptions import AirflowException
+from airflow.contrib.hooks.databricks_hook import DatabricksHook
+from airflow.models import BaseOperator
+
+LINE_BREAK = ('-' * 80)
+
+
+class DatabricksSubmitRunOperator(BaseOperator):
+    """
+    Submits an Spark job run to Databricks using the
+    `api/2.0/jobs/runs/submit
+    <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_
+    API endpoint.
+
+    There are two ways to instantiate this operator.
+
+    In the first way, you can take the JSON payload that you typically use
+    to call the ``api/2.0/jobs/runs/submit`` endpoint and pass it directly
+    to our ``DatabricksSubmitRunOperator`` through the ``json`` parameter.
+    For example ::
+        json = {
+          'new_cluster': {
+            'spark_version': '2.1.0-db3-scala2.11',
+            'num_workers': 2
+          },
+          'notebook_task': {
+            'notebook_path': '/Users/airflow@example.com/PrepareData',
+          },
+        }
+        notebook_run = DatabricksSubmitRunOperator(task_id='notebook_run', json=json)
+
+    Another way to accomplish the same thing is to use the named parameters
+    of the ``DatabricksSubmitRunOperator`` directly. Note that there is exactly
+    one named parameter for each top level parameter in the ``runs/submit``
+    endpoint. In this method, your code would look like this: ::
+        new_cluster = {
+          'spark_version': '2.1.0-db3-scala2.11',
+          'num_workers': 2
+        }
+        notebook_task = {
+          'notebook_path': '/Users/airflow@example.com/PrepareData',
+        }
+        notebook_run = DatabricksSubmitRunOperator(
+            task_id='notebook_run',
+            new_cluster=new_cluster,
+            notebook_task=notebook_task)
+
+    In the case where both the json parameter **AND** the named parameters
+    are provided, they will be merged together. If there are conflicts during the merge,
+    the named parameters will take precedence and override the top level ``json`` keys.
+
+    Currently the named parameters that ``DatabricksSubmitRunOperator`` supports are
+        - ``spark_jar_task``
+        - ``notebook_task``
+        - ``new_cluster``
+        - ``existing_cluster_id``
+        - ``libraries``
+        - ``run_name``
+        - ``timeout_seconds``
+
+    :param json: A JSON object containing API parameters which will be passed
+        directly to the ``api/2.0/jobs/runs/submit`` endpoint. The other named parameters
+        (i.e. ``spark_jar_task``, ``notebook_task``..) to this operator will
+        be merged with this json dictionary if they are provided.
+        If there are conflicts during the merge, the named parameters will
+        take precedence and override the top level json keys.
+        https://docs.databricks.com/api/latest/jobs.html#runs-submit
+    :type json: dict
+    :param spark_jar_task: The main class and parameters for the JAR task. Note that
+        the actual JAR is specified in the ``libraries``.
+        *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified.
+        https://docs.databricks.com/api/latest/jobs.html#jobssparkjartask
+    :type spark_jar_task: dict
+    :param notebook_task: The notebook path and parameters for the notebook task.
+        *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified.
+        https://docs.databricks.com/api/latest/jobs.html#jobsnotebooktask
+    :type notebook_task: dict
+    :param new_cluster: Specs for a new cluster on which this task will be run.
+        *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified.
+        https://docs.databricks.com/api/latest/jobs.html#jobsclusterspecnewcluster
+    :type new_cluster: dict
+    :param existing_cluster_id: ID for existing cluster on which to run this task.
+        *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified.
+    :type existing_cluster_id: string
+    :param libraries: Libraries which this run will use.
+        https://docs.databricks.com/api/latest/libraries.html#managedlibrarieslibrary
+    :type libraries: list of dicts
+    :param run_name: The run name used for this task.
+        By default this will be set to the Airflow ``task_id``. This ``task_id`` is a
+        required parameter of the superclass ``BaseOperator``.
+    :type run_name: string
+    :param timeout_seconds: The timeout for this run. By default a value of 0 is used
+        which means to have no timeout.
+    :type timeout_seconds: int32
+    :param databricks_conn_id: The name of the Airflow connection to use.
+        By default and in the common case this will be ``databricks_default``.
+    :type databricks_conn_id: string
+    :param polling_period_seconds: Controls the rate which we poll for the result of
+        this run. By default the operator will poll every 30 seconds.
+    :type polling_period_seconds: int
+    :param databricks_retry_limit: Amount of times retry if the Databricks backend is
+        unreachable. Its value must be greater than or equal to 1.
+    :type databricks_retry_limit: int
+    """
+    # Databricks brand color (blue) under white text
+    ui_color = '#1CB1C2'
+    ui_fgcolor = '#fff'
+
+    def __init__(
+            self,
+            json=None,
+            spark_jar_task=None,
+            notebook_task=None,
+            new_cluster=None,
+            existing_cluster_id=None,
+            libraries=None,
+            run_name=None,
+            timeout_seconds=None,
+            databricks_conn_id='databricks_default',
+            polling_period_seconds=30,
+            databricks_retry_limit=3,
+            **kwargs):
+        """
+        Creates a new ``DatabricksSubmitRunOperator``.
+        """
+        super(DatabricksSubmitRunOperator, self).__init__(**kwargs)
+        self.json = json or {}
+        self.databricks_conn_id = databricks_conn_id
+        self.polling_period_seconds = polling_period_seconds
+        self.databricks_retry_limit = databricks_retry_limit
+        if spark_jar_task is not None:
+            self.json['spark_jar_task'] = spark_jar_task
+        if notebook_task is not None:
+            self.json['notebook_task'] = notebook_task
+        if new_cluster is not None:
+            self.json['new_cluster'] = new_cluster
+        if existing_cluster_id is not None:
+            self.json['existing_cluster_id'] = existing_cluster_id
+        if libraries is not None:
+            self.json['libraries'] = libraries
+        if run_name is not None:
+            self.json['run_name'] = run_name
+        if timeout_seconds is not None:
+            self.json['timeout_seconds'] = timeout_seconds
+        if 'run_name' not in self.json:
+            self.json['run_name'] = run_name or kwargs['task_id']
+
+        # This variable will be used in case our task gets killed.
+        self.run_id = None
+
+    def _log_run_page_url(self, url):
+        logging.info('View run status, Spark UI, and logs at {}'.format(url))
+
+    def get_hook(self):
+        return DatabricksHook(
+            self.databricks_conn_id,
+            retry_limit=self.databricks_retry_limit)
+
+    def execute(self, context):
+        hook = self.get_hook()
+        self.run_id = hook.submit_run(self.json)
+        run_page_url = hook.get_run_page_url(self.run_id)
+        logging.info(LINE_BREAK)
+        logging.info('Run submitted with run_id: {}'.format(self.run_id))
+        self._log_run_page_url(run_page_url)
+        logging.info(LINE_BREAK)
+        while True:
+            run_state = hook.get_run_state(self.run_id)
+            if run_state.is_terminal:
+                if run_state.is_successful:
+                    logging.info('{} completed successfully.'.format(
+                        self.task_id))
+                    self._log_run_page_url(run_page_url)
+                    return
+                else:
+                    error_message = '{t} failed with terminal state: {s}'.format(
+                        t=self.task_id,
+                        s=run_state)
+                    raise AirflowException(error_message)
+            else:
+                logging.info('{t} in run state: {s}'.format(t=self.task_id,
+                                                            s=run_state))
+                self._log_run_page_url(run_page_url)
+                logging.info('Sleeping for {} seconds.'.format(
+                    self.polling_period_seconds))
+                time.sleep(self.polling_period_seconds)
+
+    def on_kill(self):
+        hook = self.get_hook()
+        hook.cancel_run(self.run_id)
+        logging.info('Task: {t} with run_id: {r} was requested to be cancelled.'.format(
+            t=self.task_id,
+            r=self.run_id))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/exceptions.py
----------------------------------------------------------------------
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 2231208..90d3e22 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -22,7 +22,7 @@ class AirflowException(Exception):
 
 class AirflowConfigException(AirflowException):
     pass
-    
+
 
 class AirflowSensorTimeout(AirflowException):
     pass

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 95e2255..42b621d 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -543,6 +543,7 @@ class Connection(Base):
         ('jira', 'JIRA',),
         ('redis', 'Redis',),
         ('wasb', 'Azure Blob Storage'),
+        ('databricks', 'Databricks',),
     ]
 
     def __init__(

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/airflow/utils/db.py
----------------------------------------------------------------------
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 7da9217..54254f6 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -249,6 +249,10 @@ def initdb():
                     ]
                 }
             '''))
+    merge_conn(
+        models.Connection(
+            conn_id='databricks_default', conn_type='databricks',
+            host='localhost'))
 
     # Known event types
     KET = models.KnownEventType

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/docs/code.rst
----------------------------------------------------------------------
diff --git a/docs/code.rst b/docs/code.rst
index 683e85f..c31061c 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -97,6 +97,7 @@ Community-contributed Operators
 
 .. autoclass:: airflow.contrib.operators.bigquery_operator.BigQueryOperator
 .. autoclass:: airflow.contrib.operators.bigquery_to_gcs.BigQueryToCloudStorageOperator
+.. autoclass:: airflow.contrib.operators.databricks_operator.DatabricksSubmitRunOperator
 .. autoclass:: airflow.contrib.operators.ecs_operator.ECSOperator
 .. autoclass:: airflow.contrib.operators.file_to_wasb.FileToWasbOperator
 .. autoclass:: airflow.contrib.operators.gcs_download_operator.GoogleCloudStorageDownloadOperator

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/docs/integration.rst
----------------------------------------------------------------------
diff --git a/docs/integration.rst b/docs/integration.rst
index 4a6b676..a6c9d7c 100644
--- a/docs/integration.rst
+++ b/docs/integration.rst
@@ -61,6 +61,19 @@ AWS: Amazon Webservices
 
 ---
 
+.. _Databricks:
+
+Databricks
+--------------------------
+`Databricks <https://databricks.com/>`_ has contributed an Airflow operator which enables
+submitting runs to the Databricks platform. Internally the operator talks to the
+``api/2.0/jobs/runs/submit`` `endpoint <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_.
+
+DatabricksSubmitRunOperator
+''''''''''''''''''''''''''''
+
+.. autoclass:: airflow.contrib.operators.databricks_operator.DatabricksSubmitRunOperator
+
 .. _GCP:
 
 GCP: Google Cloud Platform

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index ea60dca..6691208 100644
--- a/setup.py
+++ b/setup.py
@@ -116,6 +116,7 @@ crypto = ['cryptography>=0.9.3']
 dask = [
     'distributed>=1.15.2, <2'
     ]
+databricks = ['requests>=2.5.1, <3']
 datadog = ['datadog>=0.14.0']
 doc = [
     'sphinx>=1.2.3',
@@ -244,6 +245,7 @@ def do_setup():
             'cloudant': cloudant,
             'crypto': crypto,
             'dask': dask,
+            'databricks': databricks,
             'datadog': datadog,
             'devel': devel_minreq,
             'devel_hadoop': devel_hadoop,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/tests/contrib/hooks/databricks_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/databricks_hook.py b/tests/contrib/hooks/databricks_hook.py
new file mode 100644
index 0000000..6c789f9
--- /dev/null
+++ b/tests/contrib/hooks/databricks_hook.py
@@ -0,0 +1,226 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow import __version__
+from airflow.contrib.hooks.databricks_hook import DatabricksHook, RunState, SUBMIT_RUN_ENDPOINT
+from airflow.exceptions import AirflowException
+from airflow.models import Connection
+from airflow.utils import db
+from requests import exceptions as requests_exceptions
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+TASK_ID = 'databricks-operator'
+DEFAULT_CONN_ID = 'databricks_default'
+NOTEBOOK_TASK = {
+    'notebook_path': '/test'
+}
+NEW_CLUSTER = {
+    'spark_version': '2.0.x-scala2.10',
+    'node_type_id': 'r3.xlarge',
+    'num_workers': 1
+}
+RUN_ID = 1
+HOST = 'xx.cloud.databricks.com'
+HOST_WITH_SCHEME = 'https://xx.cloud.databricks.com'
+LOGIN = 'login'
+PASSWORD = 'password'
+USER_AGENT_HEADER = {'user-agent': 'airflow-{v}'.format(v=__version__)}
+RUN_PAGE_URL = 'https://XX.cloud.databricks.com/#jobs/1/runs/1'
+LIFE_CYCLE_STATE = 'PENDING'
+STATE_MESSAGE = 'Waiting for cluster'
+GET_RUN_RESPONSE = {
+    'run_page_url': RUN_PAGE_URL,
+    'state': {
+        'life_cycle_state': LIFE_CYCLE_STATE,
+        'state_message': STATE_MESSAGE
+    }
+}
+RESULT_STATE = None
+
+
+def submit_run_endpoint(host):
+    """
+    Utility function to generate the submit run endpoint given the host.
+    """
+    return 'https://{}/api/2.0/jobs/runs/submit'.format(host)
+
+
+def get_run_endpoint(host):
+    """
+    Utility function to generate the get run endpoint given the host.
+    """
+    return 'https://{}/api/2.0/jobs/runs/get'.format(host)
+
+def cancel_run_endpoint(host):
+    """
+    Utility function to generate the get run endpoint given the host.
+    """
+    return 'https://{}/api/2.0/jobs/runs/cancel'.format(host)
+
+class DatabricksHookTest(unittest.TestCase):
+    """
+    Tests for DatabricksHook.
+    """
+    @db.provide_session
+    def setUp(self, session=None):
+        conn = session.query(Connection) \
+            .filter(Connection.conn_id == DEFAULT_CONN_ID) \
+            .first()
+        conn.host = HOST
+        conn.login = LOGIN
+        conn.password = PASSWORD
+        session.commit()
+
+        self.hook = DatabricksHook()
+
+    def test_parse_host_with_proper_host(self):
+        host = self.hook._parse_host(HOST)
+        self.assertEquals(host, HOST)
+
+    def test_parse_host_with_scheme(self):
+        host = self.hook._parse_host(HOST_WITH_SCHEME)
+        self.assertEquals(host, HOST)
+
+    def test_init_bad_retry_limit(self):
+        with self.assertRaises(AssertionError):
+            DatabricksHook(retry_limit = 0)
+
+    @mock.patch('airflow.contrib.hooks.databricks_hook.logging')
+    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+    def test_do_api_call_with_error_retry(self, mock_requests, mock_logging):
+        for exception in [requests_exceptions.ConnectionError, requests_exceptions.Timeout]:
+            mock_requests.reset_mock()
+            mock_logging.reset_mock()
+            mock_requests.post.side_effect = exception()
+
+            with self.assertRaises(AirflowException):
+                self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+            self.assertEquals(len(mock_logging.error.mock_calls), self.hook.retry_limit)
+
+    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+    def test_do_api_call_with_bad_status_code(self, mock_requests):
+        mock_requests.codes.ok = 200
+        status_code_mock = mock.PropertyMock(return_value=500)
+        type(mock_requests.post.return_value).status_code = status_code_mock
+        with self.assertRaises(AirflowException):
+            self.hook._do_api_call(SUBMIT_RUN_ENDPOINT, {})
+
+    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+    def test_submit_run(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.post.return_value.json.return_value = {'run_id': '1'}
+        status_code_mock = mock.PropertyMock(return_value=200)
+        type(mock_requests.post.return_value).status_code = status_code_mock
+        json = {
+          'notebook_task': NOTEBOOK_TASK,
+          'new_cluster': NEW_CLUSTER
+        }
+        run_id = self.hook.submit_run(json)
+
+        self.assertEquals(run_id, '1')
+        mock_requests.post.assert_called_once_with(
+            submit_run_endpoint(HOST),
+            json={
+                'notebook_task': NOTEBOOK_TASK,
+                'new_cluster': NEW_CLUSTER,
+            },
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds)
+
+    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+    def test_get_run_page_url(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
+        status_code_mock = mock.PropertyMock(return_value=200)
+        type(mock_requests.get.return_value).status_code = status_code_mock
+
+        run_page_url = self.hook.get_run_page_url(RUN_ID)
+
+        self.assertEquals(run_page_url, RUN_PAGE_URL)
+        mock_requests.get.assert_called_once_with(
+            get_run_endpoint(HOST),
+            json={'run_id': RUN_ID},
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds)
+
+    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+    def test_get_run_state(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = GET_RUN_RESPONSE
+        status_code_mock = mock.PropertyMock(return_value=200)
+        type(mock_requests.get.return_value).status_code = status_code_mock
+
+        run_state = self.hook.get_run_state(RUN_ID)
+
+        self.assertEquals(run_state, RunState(
+            LIFE_CYCLE_STATE,
+            RESULT_STATE,
+            STATE_MESSAGE))
+        mock_requests.get.assert_called_once_with(
+            get_run_endpoint(HOST),
+            json={'run_id': RUN_ID},
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds)
+
+    @mock.patch('airflow.contrib.hooks.databricks_hook.requests')
+    def test_cancel_run(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.post.return_value.json.return_value = GET_RUN_RESPONSE
+        status_code_mock = mock.PropertyMock(return_value=200)
+        type(mock_requests.post.return_value).status_code = status_code_mock
+
+        self.hook.cancel_run(RUN_ID)
+
+        mock_requests.post.assert_called_once_with(
+            cancel_run_endpoint(HOST),
+            json={'run_id': RUN_ID},
+            auth=(LOGIN, PASSWORD),
+            headers=USER_AGENT_HEADER,
+            timeout=self.hook.timeout_seconds)
+
+class RunStateTest(unittest.TestCase):
+    def test_is_terminal_true(self):
+        terminal_states = ['TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']
+        for state in terminal_states:
+            run_state = RunState(state, '', '')
+            self.assertTrue(run_state.is_terminal)
+
+    def test_is_terminal_false(self):
+        non_terminal_states = ['PENDING', 'RUNNING', 'TERMINATING']
+        for state in non_terminal_states:
+            run_state = RunState(state, '', '')
+            self.assertFalse(run_state.is_terminal)
+
+    def test_is_terminal_with_nonexistent_life_cycle_state(self):
+        run_state = RunState('blah', '', '')
+        with self.assertRaises(AirflowException):
+            run_state.is_terminal
+
+    def test_is_successful(self):
+        run_state = RunState('TERMINATED', 'SUCCESS', '')
+        self.assertTrue(run_state.is_successful)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/53ca5084/tests/contrib/operators/databricks_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/databricks_operator.py b/tests/contrib/operators/databricks_operator.py
new file mode 100644
index 0000000..aab47fa
--- /dev/null
+++ b/tests/contrib/operators/databricks_operator.py
@@ -0,0 +1,185 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow.contrib.hooks.databricks_hook import RunState
+from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator
+from airflow.exceptions import AirflowException
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+TASK_ID = 'databricks-operator'
+DEFAULT_CONN_ID = 'databricks_default'
+NOTEBOOK_TASK = {
+    'notebook_path': '/test'
+}
+SPARK_JAR_TASK = {
+    'main_class_name': 'com.databricks.Test'
+}
+NEW_CLUSTER = {
+    'spark_version': '2.0.x-scala2.10',
+    'node_type_id': 'development-node',
+    'num_workers': 1
+}
+EXISTING_CLUSTER_ID = 'existing-cluster-id'
+RUN_NAME = 'run-name'
+RUN_ID = 1
+
+
+class DatabricksSubmitRunOperatorTest(unittest.TestCase):
+    def test_init_with_named_parameters(self):
+        """
+        Test the initializer with the named parameters.
+        """
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK)
+        expected = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': NOTEBOOK_TASK,
+          'run_name': TASK_ID
+        }
+        self.assertDictEqual(expected, op.json)
+
+    def test_init_with_json(self):
+        """
+        Test the initializer with json data.
+        """
+        json = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': NOTEBOOK_TASK
+        }
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+        expected = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': NOTEBOOK_TASK,
+          'run_name': TASK_ID
+        }
+        self.assertDictEqual(expected, op.json)
+
+    def test_init_with_specified_run_name(self):
+        """
+        Test the initializer with a specified run_name.
+        """
+        json = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': NOTEBOOK_TASK,
+          'run_name': RUN_NAME
+        }
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+        expected = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': NOTEBOOK_TASK,
+          'run_name': RUN_NAME
+        }
+        self.assertDictEqual(expected, op.json)
+
+    def test_init_with_merging(self):
+        """
+        Test the initializer when json and other named parameters are both
+        provided. The named parameters should override top level keys in the
+        json dict.
+        """
+        override_new_cluster = {'workers': 999}
+        json = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': NOTEBOOK_TASK,
+        }
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster)
+        expected = {
+          'new_cluster': override_new_cluster,
+          'notebook_task': NOTEBOOK_TASK,
+          'run_name': TASK_ID,
+        }
+        self.assertDictEqual(expected, op.json)
+
+    @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+    def test_exec_success(self, db_mock_class):
+        """
+        Test the execute function in case where the run is successful.
+        """
+        run = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': NOTEBOOK_TASK,
+        }
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.submit_run.return_value = 1
+        db_mock.get_run_state.return_value = RunState('TERMINATED', 'SUCCESS', '')
+
+        op.execute(None)
+
+        expected = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': NOTEBOOK_TASK,
+          'run_name': TASK_ID
+        }
+        db_mock_class.assert_called_once_with(
+                DEFAULT_CONN_ID,
+                retry_limit=op.databricks_retry_limit)
+        db_mock.submit_run.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        db_mock.get_run_state.assert_called_once_with(RUN_ID)
+        self.assertEquals(RUN_ID, op.run_id)
+
+    @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+    def test_exec_failure(self, db_mock_class):
+        """
+        Test the execute function in case where the run failed.
+        """
+        run = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': NOTEBOOK_TASK,
+        }
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
+        db_mock = db_mock_class.return_value
+        db_mock.submit_run.return_value = 1
+        db_mock.get_run_state.return_value = RunState('TERMINATED', 'FAILED', '')
+
+        with self.assertRaises(AirflowException):
+            op.execute(None)
+
+        expected = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': NOTEBOOK_TASK,
+          'run_name': TASK_ID,
+        }
+        db_mock_class.assert_called_once_with(
+                DEFAULT_CONN_ID,
+                retry_limit=op.databricks_retry_limit)
+        db_mock.submit_run.assert_called_once_with(expected)
+        db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
+        db_mock.get_run_state.assert_called_once_with(RUN_ID)
+        self.assertEquals(RUN_ID, op.run_id)
+
+    @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
+    def test_on_kill(self, db_mock_class):
+        run = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': NOTEBOOK_TASK,
+        }
+        op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
+        db_mock = db_mock_class.return_value
+        op.run_id = RUN_ID
+
+        op.on_kill()
+
+        db_mock.cancel_run.assert_called_once_with(RUN_ID)
+



Mime
View raw message