airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From criccom...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-1271] Add Google CloudML Training Operator
Date Thu, 06 Jul 2017 18:46:17 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/master d231dce37 -> 0fc45045a


[AIRFLOW-1271] Add Google CloudML Training Operator

Closes #2408 from leomzhong/cloudml_training


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/0fc45045
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/0fc45045
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/0fc45045

Branch: refs/heads/master
Commit: 0fc45045a27a0b1867410613d6c0edba820e3abf
Parents: d231dce
Author: Ming Zhong <leomzhong@gmail.com>
Authored: Thu Jul 6 11:46:13 2017 -0700
Committer: Chris Riccomini <criccomini@apache.org>
Committed: Thu Jul 6 11:46:13 2017 -0700

----------------------------------------------------------------------
 airflow/contrib/hooks/gcp_cloudml_hook.py       |  82 ++++-----
 airflow/contrib/operators/cloudml_operator.py   | 148 ++++++++++++++-
 tests/contrib/hooks/test_gcp_cloudml_hook.py    | 111 +++++++++++-
 .../contrib/operators/test_cloudml_operator.py  | 179 ++++++++++++++-----
 4 files changed, 428 insertions(+), 92 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/airflow/contrib/hooks/gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py
index 3af8508..6f634b2 100644
--- a/airflow/contrib/hooks/gcp_cloudml_hook.py
+++ b/airflow/contrib/hooks/gcp_cloudml_hook.py
@@ -62,30 +62,37 @@ class CloudMLHook(GoogleCloudBaseHook):
         credentials = GoogleCredentials.get_application_default()
         return build('ml', 'v1', credentials=credentials)
 
-    def create_job(self, project_name, job):
+    def create_job(self, project_name, job, use_existing_job_fn=None):
         """
-        Creates and executes a CloudML job.
-
-        Returns the job object if the job was created and finished
-        successfully, or raises an error otherwise.
-
-        Raises:
-            apiclient.errors.HttpError: if the job cannot be created
-            successfully
-
-        project_name is the name of the project to use, such as
-        'my-project'
-
-        job is the complete Cloud ML Job object that should be provided to the
-        Cloud ML API, such as
-
-        {
-          'jobId': 'my_job_id',
-          'trainingInput': {
-            'scaleTier': 'STANDARD_1',
-            ...
-          }
-        }
+        Launches a CloudML job and wait for it to reach a terminal state.
+
+        :param project_name: The Google Cloud project name within which CloudML
+            job will be launched.
+        :type project_name: string
+
+        :param job: CloudML Job object that should be provided to the CloudML
+            API, such as:
+            {
+              'jobId': 'my_job_id',
+              'trainingInput': {
+                'scaleTier': 'STANDARD_1',
+                ...
+              }
+            }
+        :type job: dict
+
+        :param use_existing_job_fn: In case that a CloudML job with the same
+            job_id already exist, this method (if provided) will decide whether
+            we should use this existing job, continue waiting for it to finish
+            and returning the job object. It should accepts a CloudML job
+            object, and returns a boolean value indicating whether it is OK to
+            reuse the existing job. If 'use_existing_job_fn' is not provided,
+            we by default reuse the existing CloudML job.
+        :type use_existing_job_fn: function
+
+        :return: The CloudML job object if the job successfully reach a
+            terminal state (which might be FAILED or CANCELLED state).
+        :rtype: dict
         """
         request = self._cloudml.projects().jobs().create(
             parent='projects/{}'.format(project_name),
@@ -94,29 +101,24 @@ class CloudMLHook(GoogleCloudBaseHook):
 
         try:
             request.execute()
-            return self._wait_for_job_done(project_name, job_id)
         except errors.HttpError as e:
+            # 409 means there is an existing job with the same job ID.
             if e.resp.status == 409:
-                existing_job = self._get_job(project_name, job_id)
+                if use_existing_job_fn is not None:
+                    existing_job = self._get_job(project_name, job_id)
+                    if not use_existing_job_fn(existing_job):
+                        logging.error(
+                            'Job with job_id {} already exist, but it does '
+                            'not match our expectation: {}'.format(
+                                job_id, existing_job))
+                        raise
                 logging.info(
-                    'Job with job_id {} already exist: {}.'.format(
-                        job_id,
-                        existing_job))
-
-                if existing_job.get('predictionInput', None) == \
-                        job['predictionInput']:
-                    return self._wait_for_job_done(project_name, job_id)
-                else:
-                    logging.error(
-                        'Job with job_id {} already exists, but the '
-                        'predictionInput mismatch: {}'
-                        .format(job_id, existing_job))
-                    raise ValueError(
-                        'Found a existing job with job_id {}, but with '
-                        'different predictionInput.'.format(job_id))
+                    'Job with job_id {} already exist. Will waiting for it to '
+                    'finish'.format(job_id))
             else:
                 logging.error('Failed to create CloudML job: {}'.format(e))
                 raise
+        return self._wait_for_job_done(project_name, job_id)
 
     def _get_job(self, project_name, job_id):
         """

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/airflow/contrib/operators/cloudml_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py
index 871cc73..3ad6f5a 100644
--- a/airflow/contrib/operators/cloudml_operator.py
+++ b/airflow/contrib/operators/cloudml_operator.py
@@ -18,8 +18,9 @@ import logging
 import re
 
 from airflow import settings
-from airflow.operators import BaseOperator
 from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook
+from airflow.exceptions import AirflowException
+from airflow.operators import BaseOperator
 from airflow.utils.decorators import apply_defaults
 from apiclient import errors
 
@@ -239,10 +240,14 @@ class CloudMLBatchPredictionOperator(BaseOperator):
     def execute(self, context):
         hook = CloudMLHook(self.gcp_conn_id, self.delegate_to)
 
+        def check_existing_job(existing_job):
+            return existing_job.get('predictionInput', None) == \
+                self.prediction_job_request['predictionInput']
         try:
             finished_prediction_job = hook.create_job(
                 self.project_id,
-                self.prediction_job_request)
+                self.prediction_job_request,
+                check_existing_job)
         except errors.HttpError:
             raise
 
@@ -406,3 +411,142 @@ class CloudMLVersionOperator(BaseOperator):
                                        self._version['name'])
         else:
             raise ValueError('Unknown operation: {}'.format(self._operation))
+
+
+class CloudMLTrainingOperator(BaseOperator):
+    """
+    Operator for launching a CloudML training job.
+
+    :param project_name: The Google Cloud project name within which CloudML
+        training job should run. This field could be templated.
+    :type project_name: string
+
+    :param job_id: A unique templated id for the submitted Google CloudML
+        training job.
+    :type job_id: string
+
+    :param package_uris: A list of package locations for CloudML training job,
+        which should include the main training program + any additional
+        dependencies.
+    :type package_uris: string
+
+    :param training_python_module: The Python module name to run within CloudML
+        training job after installing 'package_uris' packages.
+    :type training_python_module: string
+
+    :param training_args: A list of templated command line arguments to pass to
+        the CloudML training program.
+    :type training_args: string
+
+    :param region: The Google Compute Engine region to run the CloudML training
+        job in. This field could be templated.
+    :type region: string
+
+    :param scale_tier: Resource tier for CloudML training job.
+    :type scale_tier: string
+
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :type gcp_conn_id: string
+
+    :param delegate_to: The account to impersonate, if any.
+        For this to work, the service account making the request must have
+        domain-wide delegation enabled.
+    :type delegate_to: string
+
+    :param mode: Can be one of 'DRY_RUN'/'CLOUD'. In 'DRY_RUN' mode, no real
+        training job will be launched, but the CloudML training job request
+        will be printed out. In 'CLOUD' mode, a real CloudML training job
+        creation request will be issued.
+    :type mode: string
+    """
+
+    template_fields = [
+        '_project_name',
+        '_job_id',
+        '_package_uris',
+        '_training_python_module',
+        '_training_args',
+        '_region',
+        '_scale_tier',
+    ]
+
+    @apply_defaults
+    def __init__(self,
+                 project_name,
+                 job_id,
+                 package_uris,
+                 training_python_module,
+                 training_args,
+                 region,
+                 scale_tier=None,
+                 gcp_conn_id='google_cloud_default',
+                 delegate_to=None,
+                 mode='PRODUCTION',
+                 *args,
+                 **kwargs):
+        super(CloudMLTrainingOperator, self).__init__(*args, **kwargs)
+        self._project_name = project_name
+        self._job_id = job_id
+        self._package_uris = package_uris
+        self._training_python_module = training_python_module
+        self._training_args = training_args
+        self._region = region
+        self._scale_tier = scale_tier
+        self._gcp_conn_id = gcp_conn_id
+        self._delegate_to = delegate_to
+        self._mode = mode
+
+        if not self._project_name:
+            raise AirflowException('Google Cloud project name is required.')
+        if not self._job_id:
+            raise AirflowException(
+                'An unique job id is required for Google CloudML training '
+                'job.')
+        if not package_uris:
+            raise AirflowException(
+                'At least one python package is required for CloudML '
+                'Training job.')
+        if not training_python_module:
+            raise AirflowException(
+                'Python module name to run after installing required '
+                'packages is required.')
+        if not self._region:
+            raise AirflowException('Google Compute Engine region is required.')
+
+    def execute(self, context):
+        job_id = _normalize_cloudml_job_id(self._job_id)
+        training_request = {
+            'jobId': job_id,
+            'trainingInput': {
+                'scaleTier': self._scale_tier,
+                'packageUris': self._package_uris,
+                'pythonModule': self._training_python_module,
+                'region': self._region,
+                'args': self._training_args,
+            }
+        }
+
+        if self._mode == 'DRY_RUN':
+            logging.info('In dry_run mode.')
+            logging.info(
+                'CloudML Training job request is: {}'.format(training_request))
+            return
+
+        hook = CloudMLHook(
+            gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to)
+
+        # Helper method to check if the existing job's training input is the
+        # same as the request we get here.
+        def check_existing_job(existing_job):
+            return existing_job.get('trainingInput', None) == \
+                training_request['trainingInput']
+        try:
+            finished_training_job = hook.create_job(
+                self._project_name, training_request, check_existing_job)
+        except errors.HttpError:
+            raise
+
+        if finished_training_job['state'] != 'SUCCEEDED':
+            logging.error('CloudML training job failed: {}'.format(
+                str(finished_training_job)))
+            raise RuntimeError(finished_training_job['errorMessage'])

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/tests/contrib/hooks/test_gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py
index e34e05f..53aba41 100644
--- a/tests/contrib/hooks/test_gcp_cloudml_hook.py
+++ b/tests/contrib/hooks/test_gcp_cloudml_hook.py
@@ -20,6 +20,7 @@ except ImportError:  # python 3
     from urllib.parse import urlparse, parse_qsl
 
 from airflow.contrib.hooks import gcp_cloudml_hook as hook
+from apiclient import errors
 from apiclient.discovery import build
 from apiclient.http import HttpMockSequence
 from oauth2client.contrib.gce import HttpAccessTokenRefreshError
@@ -137,8 +138,8 @@ class TestCloudMLHook(unittest.TestCase):
 
         expected_requests = [
             ('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format(
-                self._SERVICE_URI_PREFIX, project, model_name, version), 'POST',
-             '{}'),
+                self._SERVICE_URI_PREFIX, project, model_name, version),
+                'POST', '{}'),
         ]
 
         with _TestCloudMLHook(
@@ -175,7 +176,8 @@ class TestCloudMLHook(unittest.TestCase):
                 self._SERVICE_URI_PREFIX, project, model_name), 'GET',
              None),
         ] + [
-            ('{}projects/{}/models/{}/versions?alt=json&pageToken={}&pageSize=100'.format(
+            ('{}projects/{}/models/{}/versions?alt=json&pageToken={}'
+             '&pageSize=100'.format(
                 self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET',
              None) for ix in range(len(versions) - 1)
         ]
@@ -303,6 +305,109 @@ class TestCloudMLHook(unittest.TestCase):
                 project_name=project, job=my_job)
             self.assertEquals(create_job_response, my_job)
 
+    @_SKIP_IF
+    def test_create_cloudml_job_reuse_existing_job_by_default(self):
+        project = 'test-project'
+        job_id = 'test-job-id'
+        my_job = {
+            'jobId': job_id,
+            'foo': 4815162342,
+            'state': 'SUCCEEDED',
+        }
+        response_body = json.dumps(my_job)
+        job_already_exist_response = ({'status': '409'}, json.dumps({}))
+        succeeded_response = ({'status': '200'}, response_body)
+
+        create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project), 'POST', response_body)
+        ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+        expected_requests = [
+            create_job_request,
+            ask_if_done_request,
+        ]
+        responses = [job_already_exist_response, succeeded_response]
+
+        # By default, 'create_job' reuse the existing job.
+        with _TestCloudMLHook(
+                self,
+                responses=responses,
+                expected_requests=expected_requests) as cml_hook:
+            create_job_response = cml_hook.create_job(
+                project_name=project, job=my_job)
+            self.assertEquals(create_job_response, my_job)
+
+    @_SKIP_IF
+    def test_create_cloudml_job_check_existing_job(self):
+        project = 'test-project'
+        job_id = 'test-job-id'
+        my_job = {
+            'jobId': job_id,
+            'foo': 4815162342,
+            'state': 'SUCCEEDED',
+            'someInput': {
+                'input': 'someInput'
+            }
+        }
+        different_job = {
+            'jobId': job_id,
+            'foo': 4815162342,
+            'state': 'SUCCEEDED',
+            'someInput': {
+                'input': 'someDifferentInput'
+            }
+        }
+
+        my_job_response_body = json.dumps(my_job)
+        different_job_response_body = json.dumps(different_job)
+        job_already_exist_response = ({'status': '409'}, json.dumps({}))
+        different_job_response = ({'status': '200'},
+                                  different_job_response_body)
+
+        create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project), 'POST', my_job_response_body)
+        ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+        expected_requests = [
+            create_job_request,
+            ask_if_done_request,
+        ]
+
+        # Returns a different job (with different 'someInput' field) will
+        # cause 'create_job' request to fail.
+        responses = [job_already_exist_response, different_job_response]
+
+        def check_input(existing_job):
+            return existing_job.get('someInput', None) == \
+                my_job['someInput']
+        with _TestCloudMLHook(
+                self,
+                responses=responses,
+                expected_requests=expected_requests) as cml_hook:
+            with self.assertRaises(errors.HttpError):
+                cml_hook.create_job(
+                    project_name=project, job=my_job,
+                    use_existing_job_fn=check_input)
+
+        my_job_response = ({'status': '200'}, my_job_response_body)
+        expected_requests = [
+            create_job_request,
+            ask_if_done_request,
+            ask_if_done_request,
+        ]
+        responses = [
+            job_already_exist_response,
+            my_job_response,
+            my_job_response]
+        with _TestCloudMLHook(
+                self,
+                responses=responses,
+                expected_requests=expected_requests) as cml_hook:
+            create_job_response = cml_hook.create_job(
+                project_name=project, job=my_job,
+                use_existing_job_fn=check_input)
+            self.assertEquals(create_job_response, my_job)
+
 
 if __name__ == '__main__':
     unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0fc45045/tests/contrib/operators/test_cloudml_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator.py b/tests/contrib/operators/test_cloudml_operator.py
index b76a0c6..dc8c204 100644
--- a/tests/contrib/operators/test_cloudml_operator.py
+++ b/tests/contrib/operators/test_cloudml_operator.py
@@ -26,41 +26,41 @@ import unittest
 
 from airflow import configuration, DAG
 from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
+from airflow.contrib.operators.cloudml_operator import CloudMLTrainingOperator
 
+from mock import ANY
 from mock import patch
 
 DEFAULT_DATE = datetime.datetime(2017, 6, 6)
 
-INPUT_MISSING_ORIGIN = {
-    'dataFormat': 'TEXT',
-    'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
-    'outputPath': 'gs://legal-bucket/fake-output-path',
-    'region': 'us-east1',
-}
-
-SUCCESS_MESSAGE_MISSING_INPUT = {
-    'jobId': 'test_prediction',
-    'predictionOutput': {
-        'outputPath': 'gs://fake-output-path',
-        'predictionCount': 5000,
-        'errorCount': 0,
-        'nodeHours': 2.78
-    },
-    'state': 'SUCCEEDED'
-}
-
-DEFAULT_ARGS = {
-    'project_id': 'test-project',
-    'job_id': 'test_prediction',
-    'region': 'us-east1',
-    'data_format': 'TEXT',
-    'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
-    'output_path': 'gs://12_legal_bucket_underscore_number/legal-output-path',
-    'task_id': 'test-prediction'
-}
-
 
 class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
+    INPUT_MISSING_ORIGIN = {
+        'dataFormat': 'TEXT',
+        'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
+        'outputPath': 'gs://legal-bucket/fake-output-path',
+        'region': 'us-east1',
+    }
+    SUCCESS_MESSAGE_MISSING_INPUT = {
+        'jobId': 'test_prediction',
+        'predictionOutput': {
+            'outputPath': 'gs://fake-output-path',
+            'predictionCount': 5000,
+            'errorCount': 0,
+            'nodeHours': 2.78
+        },
+        'state': 'SUCCEEDED'
+    }
+    BATCH_PREDICTION_DEFAULT_ARGS = {
+        'project_id': 'test-project',
+        'job_id': 'test_prediction',
+        'region': 'us-east1',
+        'data_format': 'TEXT',
+        'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
+        'output_path':
+            'gs://12_legal_bucket_underscore_number/legal-output-path',
+        'task_id': 'test-prediction'
+    }
 
     def setUp(self):
         super(CloudMLBatchPredictionOperatorTest, self).setUp()
@@ -78,10 +78,10 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
         with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
                 as mock_hook:
 
-            input_with_model = INPUT_MISSING_ORIGIN.copy()
+            input_with_model = self.INPUT_MISSING_ORIGIN.copy()
             input_with_model['modelName'] = \
                 'projects/test-project/models/test_model'
-            success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
             success_message['predictionInput'] = input_with_model
 
             hook_instance = mock_hook.return_value
@@ -104,12 +104,12 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
             prediction_output = prediction_task.execute(None)
 
             mock_hook.assert_called_with('google_cloud_default', None)
-            hook_instance.create_job.assert_called_with(
+            hook_instance.create_job.assert_called_once_with(
                 'test-project',
                 {
                     'jobId': 'test_prediction',
                     'predictionInput': input_with_model
-                })
+                }, ANY)
             self.assertEquals(
                 success_message['predictionOutput'],
                 prediction_output)
@@ -118,10 +118,10 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
         with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
                 as mock_hook:
 
-            input_with_version = INPUT_MISSING_ORIGIN.copy()
+            input_with_version = self.INPUT_MISSING_ORIGIN.copy()
             input_with_version['versionName'] = \
                 'projects/test-project/models/test_model/versions/test_version'
-            success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
             success_message['predictionInput'] = input_with_version
 
             hook_instance = mock_hook.return_value
@@ -132,8 +132,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
             hook_instance.create_job.return_value = success_message
 
             prediction_task = CloudMLBatchPredictionOperator(
-                job_id='test_prediction',
-                project_id='test-project',
+                job_id='test_prediction', project_id='test-project',
                 region=input_with_version['region'],
                 data_format=input_with_version['dataFormat'],
                 input_paths=input_with_version['inputPaths'],
@@ -150,7 +149,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
                 {
                     'jobId': 'test_prediction',
                     'predictionInput': input_with_version
-                })
+                }, ANY)
             self.assertEquals(
                 success_message['predictionOutput'],
                 prediction_output)
@@ -159,9 +158,9 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
         with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
                 as mock_hook:
 
-            input_with_uri = INPUT_MISSING_ORIGIN.copy()
+            input_with_uri = self.INPUT_MISSING_ORIGIN.copy()
             input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel'
-            success_message = SUCCESS_MESSAGE_MISSING_INPUT.copy()
+            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
             success_message['predictionInput'] = input_with_uri
 
             hook_instance = mock_hook.return_value
@@ -189,14 +188,14 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
                 {
                     'jobId': 'test_prediction',
                     'predictionInput': input_with_uri
-                })
+                }, ANY)
             self.assertEquals(
                 success_message['predictionOutput'],
                 prediction_output)
 
     def testInvalidModelOrigin(self):
         # Test that both uri and model is given
-        task_args = DEFAULT_ARGS.copy()
+        task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
         task_args['uri'] = 'gs://fake-uri/saved_model'
         task_args['model_name'] = 'fake_model'
         with self.assertRaises(ValueError) as context:
@@ -204,7 +203,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
         self.assertEquals('Ambiguous model origin.', str(context.exception))
 
         # Test that both uri and model/version is given
-        task_args = DEFAULT_ARGS.copy()
+        task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
         task_args['uri'] = 'gs://fake-uri/saved_model'
         task_args['model_name'] = 'fake_model'
         task_args['version_name'] = 'fake_version'
@@ -213,7 +212,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
         self.assertEquals('Ambiguous model origin.', str(context.exception))
 
         # Test that a version is given without a model
-        task_args = DEFAULT_ARGS.copy()
+        task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
         task_args['version_name'] = 'bare_version'
         with self.assertRaises(ValueError) as context:
             CloudMLBatchPredictionOperator(**task_args).execute(None)
@@ -222,7 +221,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
             str(context.exception))
 
         # Test that none of uri, model, model/version is given
-        task_args = DEFAULT_ARGS.copy()
+        task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
         with self.assertRaises(ValueError) as context:
             CloudMLBatchPredictionOperator(**task_args).execute(None)
         self.assertEquals(
@@ -234,7 +233,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
 
         with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
                 as mock_hook:
-            input_with_model = INPUT_MISSING_ORIGIN.copy()
+            input_with_model = self.INPUT_MISSING_ORIGIN.copy()
             input_with_model['modelName'] = \
                 'projects/experimental/models/test_model'
 
@@ -263,7 +262,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
                     {
                         'jobId': 'test_prediction',
                         'predictionInput': input_with_model
-                    })
+                    }, ANY)
 
             self.assertEquals(http_error_code, context.exception.resp.status)
 
@@ -275,7 +274,7 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
                 'state': 'FAILED',
                 'errorMessage': 'A failure message'
             }
-            task_args = DEFAULT_ARGS.copy()
+            task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
             task_args['uri'] = 'a uri'
 
             with self.assertRaises(RuntimeError) as context:
@@ -284,5 +283,91 @@ class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
             self.assertEquals('A failure message', str(context.exception))
 
 
+class CloudMLTrainingOperatorTest(unittest.TestCase):
+    TRAINING_DEFAULT_ARGS = {
+        'project_name': 'test-project',
+        'job_id': 'test_training',
+        'package_uris': ['gs://some-bucket/package1'],
+        'training_python_module': 'trainer',
+        'training_args': '--some_arg=\'aaa\'',
+        'region': 'us-east1',
+        'scale_tier': 'STANDARD_1',
+        'task_id': 'test-training'
+    }
+    TRAINING_INPUT = {
+        'jobId': 'test_training',
+        'trainingInput': {
+            'scaleTier': 'STANDARD_1',
+            'packageUris': ['gs://some-bucket/package1'],
+            'pythonModule': 'trainer',
+            'args': '--some_arg=\'aaa\'',
+            'region': 'us-east1'
+        }
+    }
+
+    def testSuccessCreateTrainingJob(self):
+        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+                as mock_hook:
+            success_response = self.TRAINING_INPUT.copy()
+            success_response['state'] = 'SUCCEEDED'
+            hook_instance = mock_hook.return_value
+            hook_instance.create_job.return_value = success_response
+
+            training_op = CloudMLTrainingOperator(**self.TRAINING_DEFAULT_ARGS)
+            training_op.execute(None)
+
+            mock_hook.assert_called_with(gcp_conn_id='google_cloud_default',
+                                         delegate_to=None)
+            # Make sure only 'create_job' is invoked on hook instance
+            self.assertEquals(len(hook_instance.mock_calls), 1)
+            hook_instance.create_job.assert_called_with(
+                'test-project', self.TRAINING_INPUT, ANY)
+
+    def testHttpError(self):
+        http_error_code = 403
+        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+                as mock_hook:
+            hook_instance = mock_hook.return_value
+            hook_instance.create_job.side_effect = errors.HttpError(
+                resp=httplib2.Response({
+                    'status': http_error_code
+                }), content=b'Forbidden')
+
+            with self.assertRaises(errors.HttpError) as context:
+                training_op = CloudMLTrainingOperator(
+                    **self.TRAINING_DEFAULT_ARGS)
+                training_op.execute(None)
+
+            mock_hook.assert_called_with(
+                gcp_conn_id='google_cloud_default', delegate_to=None)
+            # Make sure only 'create_job' is invoked on hook instance
+            self.assertEquals(len(hook_instance.mock_calls), 1)
+            hook_instance.create_job.assert_called_with(
+                'test-project', self.TRAINING_INPUT, ANY)
+            self.assertEquals(http_error_code, context.exception.resp.status)
+
+    def testFailedJobError(self):
+        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
+                as mock_hook:
+            failure_response = self.TRAINING_INPUT.copy()
+            failure_response['state'] = 'FAILED'
+            failure_response['errorMessage'] = 'A failure message'
+            hook_instance = mock_hook.return_value
+            hook_instance.create_job.return_value = failure_response
+
+            with self.assertRaises(RuntimeError) as context:
+                training_op = CloudMLTrainingOperator(
+                    **self.TRAINING_DEFAULT_ARGS)
+                training_op.execute(None)
+
+            mock_hook.assert_called_with(
+                gcp_conn_id='google_cloud_default', delegate_to=None)
+            # Make sure only 'create_job' is invoked on hook instance
+            self.assertEquals(len(hook_instance.mock_calls), 1)
+            hook_instance.create_job.assert_called_with(
+                'test-project', self.TRAINING_INPUT, ANY)
+            self.assertEquals('A failure message', str(context.exception))
+
+
 if __name__ == '__main__':
     unittest.main()


Mime
View raw message