airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] kaxil closed pull request #3884: [AIRFLOW-3035] Allow custom 'job_error_states' in dataproc ops
Date Wed, 12 Sep 2018 21:42:15 GMT
kaxil closed pull request #3884: [AIRFLOW-3035] Allow custom 'job_error_states' in dataproc
ops
URL: https://github.com/apache/incubator-airflow/pull/3884
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/gcp_dataproc_hook.py b/airflow/contrib/hooks/gcp_dataproc_hook.py
index f39325d15e..ca86f08795 100644
--- a/airflow/contrib/hooks/gcp_dataproc_hook.py
+++ b/airflow/contrib/hooks/gcp_dataproc_hook.py
@@ -28,7 +28,8 @@
 
 
 class _DataProcJob(LoggingMixin):
-    def __init__(self, dataproc_api, project_id, job, region='global'):
+    def __init__(self, dataproc_api, project_id, job, region='global',
+                 job_error_states=None):
         self.dataproc_api = dataproc_api
         self.project_id = project_id
         self.region = region
@@ -37,6 +38,7 @@ def __init__(self, dataproc_api, project_id, job, region='global'):
             region=self.region,
             body=job).execute()
         self.job_id = self.job['reference']['jobId']
+        self.job_error_states = job_error_states
         self.log.info(
             'DataProc job %s is %s',
             self.job_id, str(self.job['status']['state'])
@@ -49,7 +51,6 @@ def wait_for_done(self):
                 region=self.region,
                 jobId=self.job_id).execute(num_retries=5)
             if 'ERROR' == self.job['status']['state']:
-                print(str(self.job))
                 self.log.error('DataProc job %s has errors', self.job_id)
                 self.log.error(self.job['status']['details'])
                 self.log.debug(str(self.job))
@@ -57,7 +58,6 @@ def wait_for_done(self):
                               self.job['driverOutputResourceUri'])
                 return False
             if 'CANCELLED' == self.job['status']['state']:
-                print(str(self.job))
                 self.log.warning('DataProc job %s is cancelled', self.job_id)
                 if 'details' in self.job['status']:
                     self.log.warning(self.job['status']['details'])
@@ -76,10 +76,15 @@ def wait_for_done(self):
             time.sleep(5)
 
     def raise_error(self, message=None):
-        if 'ERROR' == self.job['status']['state']:
-            if message is None:
-                message = "Google DataProc job has error"
-            raise Exception(message + ": " + str(self.job['status']['details']))
+        job_state = self.job['status']['state']
+        # We always consider ERROR to be an error state.
+        if ((self.job_error_states and job_state in self.job_error_states)
+                or 'ERROR' == job_state):
+            ex_message = message or ("Google DataProc job has state: %s" % job_state)
+            ex_details = (str(self.job['status']['details'])
+                          if 'details' in self.job['status']
+                          else "No details available")
+            raise Exception(ex_message + ": " + ex_details)
 
     def get(self):
         return self.job
@@ -222,10 +227,11 @@ def get_cluster(self, project_id, region, cluster_name):
             clusterName=cluster_name
         ).execute(num_retries=5)
 
-    def submit(self, project_id, job, region='global'):
-        submitted = _DataProcJob(self.get_conn(), project_id, job, region)
+    def submit(self, project_id, job, region='global', job_error_states=None):
+        submitted = _DataProcJob(self.get_conn(), project_id, job, region,
+                                 job_error_states=job_error_states)
         if not submitted.wait_for_done():
-            submitted.raise_error('DataProcTask has errors')
+            submitted.raise_error()
 
     def create_job_template(self, task_id, cluster_name, job_type, properties):
         return _DataProcJobBuilder(self.project_id, task_id, cluster_name,
diff --git a/airflow/contrib/operators/dataproc_operator.py b/airflow/contrib/operators/dataproc_operator.py
index 61f3895bc0..49e24a3df2 100644
--- a/airflow/contrib/operators/dataproc_operator.py
+++ b/airflow/contrib/operators/dataproc_operator.py
@@ -702,6 +702,13 @@ class DataProcPigOperator(BaseOperator):
     :type delegate_to: str
     :param region: The specified region where the dataproc cluster is created.
     :type region: str
+    :param job_error_states: Job states that should be considered error states.
+        Any states in this list will result in an error being raised and failure of the
+        task. Eg, if the ``CANCELLED`` state should also be considered a task failure,
+        pass in ``['ERROR', 'CANCELLED']``. Possible values are currently only
+        ``'ERROR'`` and ``'CANCELLED'``, but could change in the future. Defaults to
+        ``['ERROR']``.
+    :type job_error_states: list
     :var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
         This is useful for identifying or linking to the job in the Google Cloud Console
         Dataproc UI, as the actual "jobId" submitted to the Dataproc API is appended with
@@ -725,6 +732,7 @@ def __init__(
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
             region='global',
+            job_error_states=['ERROR'],
             *args,
             **kwargs):
 
@@ -739,6 +747,7 @@ def __init__(
         self.dataproc_properties = dataproc_pig_properties
         self.dataproc_jars = dataproc_pig_jars
         self.region = region
+        self.job_error_states = job_error_states
 
     def execute(self, context):
         hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
@@ -757,7 +766,7 @@ def execute(self, context):
         job_to_submit = job.build()
         self.dataproc_job_id = job_to_submit["job"]["reference"]["jobId"]
 
-        hook.submit(hook.project_id, job_to_submit, self.region)
+        hook.submit(hook.project_id, job_to_submit, self.region, self.job_error_states)
 
 
 class DataProcHiveOperator(BaseOperator):
@@ -790,6 +799,13 @@ class DataProcHiveOperator(BaseOperator):
     :type delegate_to: str
     :param region: The specified region where the dataproc cluster is created.
     :type region: str
+    :param job_error_states: Job states that should be considered error states.
+        Any states in this list will result in an error being raised and failure of the
+        task. Eg, if the ``CANCELLED`` state should also be considered a task failure,
+        pass in ``['ERROR', 'CANCELLED']``. Possible values are currently only
+        ``'ERROR'`` and ``'CANCELLED'``, but could change in the future. Defaults to
+        ``['ERROR']``.
+    :type job_error_states: list
     :var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
         This is useful for identifying or linking to the job in the Google Cloud Console
         Dataproc UI, as the actual "jobId" submitted to the Dataproc API is appended with
@@ -813,6 +829,7 @@ def __init__(
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
             region='global',
+            job_error_states=['ERROR'],
             *args,
             **kwargs):
 
@@ -827,6 +844,7 @@ def __init__(
         self.dataproc_properties = dataproc_hive_properties
         self.dataproc_jars = dataproc_hive_jars
         self.region = region
+        self.job_error_states = job_error_states
 
     def execute(self, context):
         hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
@@ -846,7 +864,7 @@ def execute(self, context):
         job_to_submit = job.build()
         self.dataproc_job_id = job_to_submit["job"]["reference"]["jobId"]
 
-        hook.submit(hook.project_id, job_to_submit, self.region)
+        hook.submit(hook.project_id, job_to_submit, self.region, self.job_error_states)
 
 
 class DataProcSparkSqlOperator(BaseOperator):
@@ -880,6 +898,13 @@ class DataProcSparkSqlOperator(BaseOperator):
     :type delegate_to: str
     :param region: The specified region where the dataproc cluster is created.
     :type region: str
+    :param job_error_states: Job states that should be considered error states.
+        Any states in this list will result in an error being raised and failure of the
+        task. Eg, if the ``CANCELLED`` state should also be considered a task failure,
+        pass in ``['ERROR', 'CANCELLED']``. Possible values are currently only
+        ``'ERROR'`` and ``'CANCELLED'``, but could change in the future. Defaults to
+        ``['ERROR']``.
+    :type job_error_states: list
     :var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
         This is useful for identifying or linking to the job in the Google Cloud Console
         Dataproc UI, as the actual "jobId" submitted to the Dataproc API is appended with
@@ -903,6 +928,7 @@ def __init__(
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
             region='global',
+            job_error_states=['ERROR'],
             *args,
             **kwargs):
 
@@ -917,6 +943,7 @@ def __init__(
         self.dataproc_properties = dataproc_spark_properties
         self.dataproc_jars = dataproc_spark_jars
         self.region = region
+        self.job_error_states = job_error_states
 
     def execute(self, context):
         hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
@@ -936,7 +963,7 @@ def execute(self, context):
         job_to_submit = job.build()
         self.dataproc_job_id = job_to_submit["job"]["reference"]["jobId"]
 
-        hook.submit(hook.project_id, job_to_submit, self.region)
+        hook.submit(hook.project_id, job_to_submit, self.region, self.job_error_states)
 
 
 class DataProcSparkOperator(BaseOperator):
@@ -977,6 +1004,13 @@ class DataProcSparkOperator(BaseOperator):
     :type delegate_to: str
     :param region: The specified region where the dataproc cluster is created.
     :type region: str
+    :param job_error_states: Job states that should be considered error states.
+        Any states in this list will result in an error being raised and failure of the
+        task. Eg, if the ``CANCELLED`` state should also be considered a task failure,
+        pass in ``['ERROR', 'CANCELLED']``. Possible values are currently only
+        ``'ERROR'`` and ``'CANCELLED'``, but could change in the future. Defaults to
+        ``['ERROR']``.
+    :type job_error_states: list
     :var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
         This is useful for identifying or linking to the job in the Google Cloud Console
         Dataproc UI, as the actual "jobId" submitted to the Dataproc API is appended with
@@ -1002,6 +1036,7 @@ def __init__(
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
             region='global',
+            job_error_states=['ERROR'],
             *args,
             **kwargs):
 
@@ -1018,6 +1053,7 @@ def __init__(
         self.dataproc_properties = dataproc_spark_properties
         self.dataproc_jars = dataproc_spark_jars
         self.region = region
+        self.job_error_states = job_error_states
 
     def execute(self, context):
         hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
@@ -1035,7 +1071,7 @@ def execute(self, context):
         job_to_submit = job.build()
         self.dataproc_job_id = job_to_submit["job"]["reference"]["jobId"]
 
-        hook.submit(hook.project_id, job_to_submit, self.region)
+        hook.submit(hook.project_id, job_to_submit, self.region, self.job_error_states)
 
 
 class DataProcHadoopOperator(BaseOperator):
@@ -1076,6 +1112,13 @@ class DataProcHadoopOperator(BaseOperator):
     :type delegate_to: str
     :param region: The specified region where the dataproc cluster is created.
     :type region: str
+    :param job_error_states: Job states that should be considered error states.
+        Any states in this list will result in an error being raised and failure of the
+        task. Eg, if the ``CANCELLED`` state should also be considered a task failure,
+        pass in ``['ERROR', 'CANCELLED']``. Possible values are currently only
+        ``'ERROR'`` and ``'CANCELLED'``, but could change in the future. Defaults to
+        ``['ERROR']``.
+    :type job_error_states: list
     :var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
         This is useful for identifying or linking to the job in the Google Cloud Console
         Dataproc UI, as the actual "jobId" submitted to the Dataproc API is appended with
@@ -1101,6 +1144,7 @@ def __init__(
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
             region='global',
+            job_error_states=['ERROR'],
             *args,
             **kwargs):
 
@@ -1117,6 +1161,7 @@ def __init__(
         self.dataproc_properties = dataproc_hadoop_properties
         self.dataproc_jars = dataproc_hadoop_jars
         self.region = region
+        self.job_error_states = job_error_states
 
     def execute(self, context):
         hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
@@ -1134,11 +1179,10 @@ def execute(self, context):
         job_to_submit = job.build()
         self.dataproc_job_id = job_to_submit["job"]["reference"]["jobId"]
 
-        hook.submit(hook.project_id, job_to_submit, self.region)
+        hook.submit(hook.project_id, job_to_submit, self.region, self.job_error_states)
 
 
 class DataProcPySparkOperator(BaseOperator):
-    # TODO Add docs around dataproc_job_id.
     """
     Start a PySpark Job on a Cloud DataProc cluster.
 
@@ -1176,6 +1220,13 @@ class DataProcPySparkOperator(BaseOperator):
     :type delegate_to: str
     :param region: The specified region where the dataproc cluster is created.
     :type region: str
+    :param job_error_states: Job states that should be considered error states.
+        Any states in this list will result in an error being raised and failure of the
+        task. Eg, if the ``CANCELLED`` state should also be considered a task failure,
+        pass in ``['ERROR', 'CANCELLED']``. Possible values are currently only
+        ``'ERROR'`` and ``'CANCELLED'``, but could change in the future. Defaults to
+        ``['ERROR']``.
+    :type job_error_states: list
     :var dataproc_job_id: The actual "jobId" as submitted to the Dataproc API.
         This is useful for identifying or linking to the job in the Google Cloud Console
         Dataproc UI, as the actual "jobId" submitted to the Dataproc API is appended with
@@ -1228,6 +1279,7 @@ def __init__(
             gcp_conn_id='google_cloud_default',
             delegate_to=None,
             region='global',
+            job_error_states=['ERROR'],
             *args,
             **kwargs):
 
@@ -1244,6 +1296,7 @@ def __init__(
         self.dataproc_properties = dataproc_pyspark_properties
         self.dataproc_jars = dataproc_pyspark_jars
         self.region = region
+        self.job_error_states = job_error_states
 
     def execute(self, context):
         hook = DataProcHook(
@@ -1274,7 +1327,7 @@ def execute(self, context):
         job_to_submit = job.build()
         self.dataproc_job_id = job_to_submit["job"]["reference"]["jobId"]
 
-        hook.submit(hook.project_id, job_to_submit, self.region)
+        hook.submit(hook.project_id, job_to_submit, self.region, self.job_error_states)
 
 
 class DataprocWorkflowTemplateBaseOperator(BaseOperator):
diff --git a/tests/contrib/hooks/test_gcp_dataproc_hook.py b/tests/contrib/hooks/test_gcp_dataproc_hook.py
index f2629ff148..e22b27a8e7 100644
--- a/tests/contrib/hooks/test_gcp_dataproc_hook.py
+++ b/tests/contrib/hooks/test_gcp_dataproc_hook.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,6 +19,7 @@
 #
 
 import unittest
+from airflow.contrib.hooks.gcp_dataproc_hook import _DataProcJob
 from airflow.contrib.hooks.gcp_dataproc_hook import DataProcHook
 
 try:
@@ -37,9 +38,11 @@
 BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}'
 DATAPROC_STRING = 'airflow.contrib.hooks.gcp_dataproc_hook.{}'
 
+
 def mock_init(self, gcp_conn_id, delegate_to=None):
     pass
 
+
 class DataProcHookTest(unittest.TestCase):
     def setUp(self):
         with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
@@ -48,6 +51,48 @@ def setUp(self):
 
     @mock.patch(DATAPROC_STRING.format('_DataProcJob'))
     def test_submit(self, job_mock):
-      with mock.patch(DATAPROC_STRING.format('DataProcHook.get_conn', return_value=None)):
-        self.dataproc_hook.submit(PROJECT_ID, JOB)
-        job_mock.assert_called_once_with(mock.ANY, PROJECT_ID, JOB, REGION)
+        with mock.patch(DATAPROC_STRING.format('DataProcHook.get_conn',
+                                               return_value=None)):
+            self.dataproc_hook.submit(PROJECT_ID, JOB)
+            job_mock.assert_called_once_with(mock.ANY, PROJECT_ID, JOB, REGION,
+                                             job_error_states=mock.ANY)
+
+
+class DataProcJobTest(unittest.TestCase):
+    @mock.patch(DATAPROC_STRING.format('_DataProcJob.__init__'), return_value=None)
+    def test_raise_error_default_job_error_states(self, mock_init):
+        job = _DataProcJob()
+        job.job = {'status': {'state': 'ERROR'}}
+        job.job_error_states = None
+        with self.assertRaises(Exception) as cm:
+            job.raise_error()
+        self.assertIn('ERROR', str(cm.exception))
+
+    @mock.patch(DATAPROC_STRING.format('_DataProcJob.__init__'), return_value=None)
+    def test_raise_error_custom_job_error_states(self, mock_init):
+        job = _DataProcJob()
+        job.job = {'status': {'state': 'CANCELLED'}}
+        job.job_error_states = ['ERROR', 'CANCELLED']
+        with self.assertRaises(Exception) as cm:
+            job.raise_error()
+        self.assertIn('CANCELLED', str(cm.exception))
+
+    @mock.patch(DATAPROC_STRING.format('_DataProcJob.__init__'), return_value=None)
+    def test_raise_error_fallback_job_error_states(self, mock_init):
+        job = _DataProcJob()
+        job.job = {'status': {'state': 'ERROR'}}
+        job.job_error_states = ['CANCELLED']
+        with self.assertRaises(Exception) as cm:
+            job.raise_error()
+        self.assertIn('ERROR', str(cm.exception))
+
+    @mock.patch(DATAPROC_STRING.format('_DataProcJob.__init__'), return_value=None)
+    def test_raise_error_with_state_done(self, mock_init):
+        job = _DataProcJob()
+        job.job = {'status': {'state': 'DONE'}}
+        job.job_error_states = None
+        try:
+            job.raise_error()
+            # Pass test
+        except Exception:
+            self.fail("raise_error() should not raise Exception when job=%s" % job.job)
diff --git a/tests/contrib/operators/test_dataproc_operator.py b/tests/contrib/operators/test_dataproc_operator.py
index 4095ca2068..60c1268ee7 100644
--- a/tests/contrib/operators/test_dataproc_operator.py
+++ b/tests/contrib/operators/test_dataproc_operator.py
@@ -465,7 +465,7 @@ def test_cluster_name_log_no_sub(self):
                 dag=self.dag
             )
             with patch.object(dataproc_task.log, 'info') as mock_info:
-                with self.assertRaises(TypeError) as _:
+                with self.assertRaises(TypeError):
                     dataproc_task.execute(None)
                 mock_info.assert_called_with('Deleting cluster: %s', CLUSTER_NAME)
 
@@ -504,7 +504,7 @@ def test_hook_correct_region():
 
             dataproc_task.execute(None)
             mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY,
-                                                                  REGION)
+                                                                  REGION, mock.ANY)
 
     @staticmethod
     def test_dataproc_job_id_is_set():
@@ -528,7 +528,7 @@ def test_hook_correct_region():
 
             dataproc_task.execute(None)
             mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY,
-                                                                  REGION)
+                                                                  REGION, mock.ANY)
 
     @staticmethod
     def test_dataproc_job_id_is_set():
@@ -553,7 +553,7 @@ def test_hook_correct_region():
 
             dataproc_task.execute(None)
             mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY,
-                                                                  REGION)
+                                                                  REGION, mock.ANY)
 
     @staticmethod
     def test_dataproc_job_id_is_set():
@@ -578,7 +578,7 @@ def test_hook_correct_region():
 
             dataproc_task.execute(None)
             mock_hook.return_value.submit.assert_called_once_with(mock.ANY, mock.ANY,
-                                                                  REGION)
+                                                                  REGION, mock.ANY)
 
     @staticmethod
     def test_dataproc_job_id_is_set():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message