Repository: incubator-airflow
Updated Branches:
refs/heads/master 86063ba4e -> af91e2ac0
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/operators/test_mlengine_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_mlengine_operator.py b/tests/contrib/operators/test_mlengine_operator.py
new file mode 100644
index 0000000..75b46a0
--- /dev/null
+++ b/tests/contrib/operators/test_mlengine_operator.py
@@ -0,0 +1,373 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import datetime
+from apiclient import errors
+import httplib2
+import unittest
+
+from airflow import configuration, DAG
+from airflow.contrib.operators.mlengine_operator import MLEngineBatchPredictionOperator
+from airflow.contrib.operators.mlengine_operator import MLEngineTrainingOperator
+
+from mock import ANY
+from mock import patch
+
+DEFAULT_DATE = datetime.datetime(2017, 6, 6)
+
+
+class MLEngineBatchPredictionOperatorTest(unittest.TestCase):
+ INPUT_MISSING_ORIGIN = {
+ 'dataFormat': 'TEXT',
+ 'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
+ 'outputPath': 'gs://legal-bucket/fake-output-path',
+ 'region': 'us-east1',
+ }
+ SUCCESS_MESSAGE_MISSING_INPUT = {
+ 'jobId': 'test_prediction',
+ 'predictionOutput': {
+ 'outputPath': 'gs://fake-output-path',
+ 'predictionCount': 5000,
+ 'errorCount': 0,
+ 'nodeHours': 2.78
+ },
+ 'state': 'SUCCEEDED'
+ }
+ BATCH_PREDICTION_DEFAULT_ARGS = {
+ 'project_id': 'test-project',
+ 'job_id': 'test_prediction',
+ 'region': 'us-east1',
+ 'data_format': 'TEXT',
+ 'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
+ 'output_path':
+ 'gs://12_legal_bucket_underscore_number/legal-output-path',
+ 'task_id': 'test-prediction'
+ }
+
+ def setUp(self):
+ super(MLEngineBatchPredictionOperatorTest, self).setUp()
+ configuration.load_test_config()
+ self.dag = DAG(
+ 'test_dag',
+ default_args={
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE,
+ 'end_date': DEFAULT_DATE,
+ },
+ schedule_interval='@daily')
+
+ def testSuccessWithModel(self):
+ with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \
+ as mock_hook:
+
+ input_with_model = self.INPUT_MISSING_ORIGIN.copy()
+ input_with_model['modelName'] = \
+ 'projects/test-project/models/test_model'
+ success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
+ success_message['predictionInput'] = input_with_model
+
+ hook_instance = mock_hook.return_value
+ hook_instance.get_job.side_effect = errors.HttpError(
+ resp=httplib2.Response({
+ 'status': 404
+ }), content=b'some bytes')
+ hook_instance.create_job.return_value = success_message
+
+ prediction_task = MLEngineBatchPredictionOperator(
+ job_id='test_prediction',
+ project_id='test-project',
+ region=input_with_model['region'],
+ data_format=input_with_model['dataFormat'],
+ input_paths=input_with_model['inputPaths'],
+ output_path=input_with_model['outputPath'],
+ model_name=input_with_model['modelName'].split('/')[-1],
+ dag=self.dag,
+ task_id='test-prediction')
+ prediction_output = prediction_task.execute(None)
+
+ mock_hook.assert_called_with('google_cloud_default', None)
+ hook_instance.create_job.assert_called_once_with(
+ 'test-project',
+ {
+ 'jobId': 'test_prediction',
+ 'predictionInput': input_with_model
+ }, ANY)
+ self.assertEquals(
+ success_message['predictionOutput'],
+ prediction_output)
+
+ def testSuccessWithVersion(self):
+ with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \
+ as mock_hook:
+
+ input_with_version = self.INPUT_MISSING_ORIGIN.copy()
+ input_with_version['versionName'] = \
+ 'projects/test-project/models/test_model/versions/test_version'
+ success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
+ success_message['predictionInput'] = input_with_version
+
+ hook_instance = mock_hook.return_value
+ hook_instance.get_job.side_effect = errors.HttpError(
+ resp=httplib2.Response({
+ 'status': 404
+ }), content=b'some bytes')
+ hook_instance.create_job.return_value = success_message
+
+ prediction_task = MLEngineBatchPredictionOperator(
+ job_id='test_prediction', project_id='test-project',
+ region=input_with_version['region'],
+ data_format=input_with_version['dataFormat'],
+ input_paths=input_with_version['inputPaths'],
+ output_path=input_with_version['outputPath'],
+ model_name=input_with_version['versionName'].split('/')[-3],
+ version_name=input_with_version['versionName'].split('/')[-1],
+ dag=self.dag,
+ task_id='test-prediction')
+ prediction_output = prediction_task.execute(None)
+
+ mock_hook.assert_called_with('google_cloud_default', None)
+ hook_instance.create_job.assert_called_with(
+ 'test-project',
+ {
+ 'jobId': 'test_prediction',
+ 'predictionInput': input_with_version
+ }, ANY)
+ self.assertEquals(
+ success_message['predictionOutput'],
+ prediction_output)
+
+ def testSuccessWithURI(self):
+ with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \
+ as mock_hook:
+
+ input_with_uri = self.INPUT_MISSING_ORIGIN.copy()
+ input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel'
+ success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
+ success_message['predictionInput'] = input_with_uri
+
+ hook_instance = mock_hook.return_value
+ hook_instance.get_job.side_effect = errors.HttpError(
+ resp=httplib2.Response({
+ 'status': 404
+ }), content=b'some bytes')
+ hook_instance.create_job.return_value = success_message
+
+ prediction_task = MLEngineBatchPredictionOperator(
+ job_id='test_prediction',
+ project_id='test-project',
+ region=input_with_uri['region'],
+ data_format=input_with_uri['dataFormat'],
+ input_paths=input_with_uri['inputPaths'],
+ output_path=input_with_uri['outputPath'],
+ uri=input_with_uri['uri'],
+ dag=self.dag,
+ task_id='test-prediction')
+ prediction_output = prediction_task.execute(None)
+
+ mock_hook.assert_called_with('google_cloud_default', None)
+ hook_instance.create_job.assert_called_with(
+ 'test-project',
+ {
+ 'jobId': 'test_prediction',
+ 'predictionInput': input_with_uri
+ }, ANY)
+ self.assertEquals(
+ success_message['predictionOutput'],
+ prediction_output)
+
+ def testInvalidModelOrigin(self):
+ # Test that both uri and model is given
+ task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
+ task_args['uri'] = 'gs://fake-uri/saved_model'
+ task_args['model_name'] = 'fake_model'
+ with self.assertRaises(ValueError) as context:
+ MLEngineBatchPredictionOperator(**task_args).execute(None)
+ self.assertEquals('Ambiguous model origin.', str(context.exception))
+
+ # Test that both uri and model/version is given
+ task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
+ task_args['uri'] = 'gs://fake-uri/saved_model'
+ task_args['model_name'] = 'fake_model'
+ task_args['version_name'] = 'fake_version'
+ with self.assertRaises(ValueError) as context:
+ MLEngineBatchPredictionOperator(**task_args).execute(None)
+ self.assertEquals('Ambiguous model origin.', str(context.exception))
+
+ # Test that a version is given without a model
+ task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
+ task_args['version_name'] = 'bare_version'
+ with self.assertRaises(ValueError) as context:
+ MLEngineBatchPredictionOperator(**task_args).execute(None)
+ self.assertEquals(
+ 'Missing model origin.',
+ str(context.exception))
+
+ # Test that none of uri, model, model/version is given
+ task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
+ with self.assertRaises(ValueError) as context:
+ MLEngineBatchPredictionOperator(**task_args).execute(None)
+ self.assertEquals(
+ 'Missing model origin.',
+ str(context.exception))
+
+ def testHttpError(self):
+ http_error_code = 403
+
+ with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \
+ as mock_hook:
+ input_with_model = self.INPUT_MISSING_ORIGIN.copy()
+ input_with_model['modelName'] = \
+ 'projects/experimental/models/test_model'
+
+ hook_instance = mock_hook.return_value
+ hook_instance.create_job.side_effect = errors.HttpError(
+ resp=httplib2.Response({
+ 'status': http_error_code
+ }), content=b'Forbidden')
+
+ with self.assertRaises(errors.HttpError) as context:
+ prediction_task = MLEngineBatchPredictionOperator(
+ job_id='test_prediction',
+ project_id='test-project',
+ region=input_with_model['region'],
+ data_format=input_with_model['dataFormat'],
+ input_paths=input_with_model['inputPaths'],
+ output_path=input_with_model['outputPath'],
+ model_name=input_with_model['modelName'].split('/')[-1],
+ dag=self.dag,
+ task_id='test-prediction')
+ prediction_task.execute(None)
+
+ mock_hook.assert_called_with('google_cloud_default', None)
+ hook_instance.create_job.assert_called_with(
+ 'test-project',
+ {
+ 'jobId': 'test_prediction',
+ 'predictionInput': input_with_model
+ }, ANY)
+
+ self.assertEquals(http_error_code, context.exception.resp.status)
+
+ def testFailedJobError(self):
+ with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \
+ as mock_hook:
+ hook_instance = mock_hook.return_value
+ hook_instance.create_job.return_value = {
+ 'state': 'FAILED',
+ 'errorMessage': 'A failure message'
+ }
+ task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
+ task_args['uri'] = 'a uri'
+
+ with self.assertRaises(RuntimeError) as context:
+ MLEngineBatchPredictionOperator(**task_args).execute(None)
+
+ self.assertEquals('A failure message', str(context.exception))
+
+
+class MLEngineTrainingOperatorTest(unittest.TestCase):
+ TRAINING_DEFAULT_ARGS = {
+ 'project_id': 'test-project',
+ 'job_id': 'test_training',
+ 'package_uris': ['gs://some-bucket/package1'],
+ 'training_python_module': 'trainer',
+ 'training_args': '--some_arg=\'aaa\'',
+ 'region': 'us-east1',
+ 'scale_tier': 'STANDARD_1',
+ 'task_id': 'test-training'
+ }
+ TRAINING_INPUT = {
+ 'jobId': 'test_training',
+ 'trainingInput': {
+ 'scaleTier': 'STANDARD_1',
+ 'packageUris': ['gs://some-bucket/package1'],
+ 'pythonModule': 'trainer',
+ 'args': '--some_arg=\'aaa\'',
+ 'region': 'us-east1'
+ }
+ }
+
+ def testSuccessCreateTrainingJob(self):
+ with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \
+ as mock_hook:
+ success_response = self.TRAINING_INPUT.copy()
+ success_response['state'] = 'SUCCEEDED'
+ hook_instance = mock_hook.return_value
+ hook_instance.create_job.return_value = success_response
+
+ training_op = MLEngineTrainingOperator(**self.TRAINING_DEFAULT_ARGS)
+ training_op.execute(None)
+
+ mock_hook.assert_called_with(gcp_conn_id='google_cloud_default',
+ delegate_to=None)
+ # Make sure only 'create_job' is invoked on hook instance
+ self.assertEquals(len(hook_instance.mock_calls), 1)
+ hook_instance.create_job.assert_called_with(
+ 'test-project', self.TRAINING_INPUT, ANY)
+
+ def testHttpError(self):
+ http_error_code = 403
+ with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \
+ as mock_hook:
+ hook_instance = mock_hook.return_value
+ hook_instance.create_job.side_effect = errors.HttpError(
+ resp=httplib2.Response({
+ 'status': http_error_code
+ }), content=b'Forbidden')
+
+ with self.assertRaises(errors.HttpError) as context:
+ training_op = MLEngineTrainingOperator(
+ **self.TRAINING_DEFAULT_ARGS)
+ training_op.execute(None)
+
+ mock_hook.assert_called_with(
+ gcp_conn_id='google_cloud_default', delegate_to=None)
+ # Make sure only 'create_job' is invoked on hook instance
+ self.assertEquals(len(hook_instance.mock_calls), 1)
+ hook_instance.create_job.assert_called_with(
+ 'test-project', self.TRAINING_INPUT, ANY)
+ self.assertEquals(http_error_code, context.exception.resp.status)
+
+ def testFailedJobError(self):
+ with patch('airflow.contrib.operators.mlengine_operator.MLEngineHook') \
+ as mock_hook:
+ failure_response = self.TRAINING_INPUT.copy()
+ failure_response['state'] = 'FAILED'
+ failure_response['errorMessage'] = 'A failure message'
+ hook_instance = mock_hook.return_value
+ hook_instance.create_job.return_value = failure_response
+
+ with self.assertRaises(RuntimeError) as context:
+ training_op = MLEngineTrainingOperator(
+ **self.TRAINING_DEFAULT_ARGS)
+ training_op.execute(None)
+
+ mock_hook.assert_called_with(
+ gcp_conn_id='google_cloud_default', delegate_to=None)
+ # Make sure only 'create_job' is invoked on hook instance
+ self.assertEquals(len(hook_instance.mock_calls), 1)
+ hook_instance.create_job.assert_called_with(
+ 'test-project', self.TRAINING_INPUT, ANY)
+ self.assertEquals('A failure message', str(context.exception))
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/operators/test_mlengine_operator_utils.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_mlengine_operator_utils.py b/tests/contrib/operators/test_mlengine_operator_utils.py
new file mode 100644
index 0000000..9909c02
--- /dev/null
+++ b/tests/contrib/operators/test_mlengine_operator_utils.py
@@ -0,0 +1,183 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import datetime
+import unittest
+
+from airflow import configuration, DAG
+from airflow.contrib.operators import mlengine_operator_utils
+from airflow.contrib.operators.mlengine_operator_utils import create_evaluate_ops
+from airflow.exceptions import AirflowException
+
+from mock import ANY
+from mock import patch
+
+DEFAULT_DATE = datetime.datetime(2017, 6, 6)
+
+
+class CreateEvaluateOpsTest(unittest.TestCase):
+
+ INPUT_MISSING_ORIGIN = {
+ 'dataFormat': 'TEXT',
+ 'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
+ 'outputPath': 'gs://legal-bucket/fake-output-path',
+ 'region': 'us-east1',
+ 'versionName': 'projects/test-project/models/test_model/versions/test_version',
+ }
+ SUCCESS_MESSAGE_MISSING_INPUT = {
+ 'jobId': 'eval_test_prediction',
+ 'predictionOutput': {
+ 'outputPath': 'gs://fake-output-path',
+ 'predictionCount': 5000,
+ 'errorCount': 0,
+ 'nodeHours': 2.78
+ },
+ 'state': 'SUCCEEDED'
+ }
+
+ def setUp(self):
+ super(CreateEvaluateOpsTest, self).setUp()
+ configuration.load_test_config()
+ self.dag = DAG(
+ 'test_dag',
+ default_args={
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE,
+ 'end_date': DEFAULT_DATE,
+ 'project_id': 'test-project',
+ 'region': 'us-east1',
+ 'model_name': 'test_model',
+ 'version_name': 'test_version',
+ },
+ schedule_interval='@daily')
+ self.metric_fn = lambda x: (0.1,)
+ self.metric_fn_encoded = mlengine_operator_utils.base64.b64encode(
+ mlengine_operator_utils.dill.dumps(self.metric_fn, recurse=True))
+
+ def testSuccessfulRun(self):
+ input_with_model = self.INPUT_MISSING_ORIGIN.copy()
+
+ pred, summary, validate = create_evaluate_ops(
+ task_prefix='eval-test',
+ batch_prediction_job_id='eval-test-prediction',
+ data_format=input_with_model['dataFormat'],
+ input_paths=input_with_model['inputPaths'],
+ prediction_path=input_with_model['outputPath'],
+ metric_fn_and_keys=(self.metric_fn, ['err']),
+ validate_fn=(lambda x: 'err=%.1f' % x['err']),
+ dag=self.dag)
+
+ with patch('airflow.contrib.operators.mlengine_operator.'
+ 'MLEngineHook') as mock_mlengine_hook:
+
+ success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
+ success_message['predictionInput'] = input_with_model
+ hook_instance = mock_mlengine_hook.return_value
+ hook_instance.create_job.return_value = success_message
+ result = pred.execute(None)
+ mock_mlengine_hook.assert_called_with('google_cloud_default', None)
+ hook_instance.create_job.assert_called_once_with(
+ 'test-project',
+ {
+ 'jobId': 'eval_test_prediction',
+ 'predictionInput': input_with_model,
+ },
+ ANY)
+ self.assertEqual(success_message['predictionOutput'], result)
+
+ with patch('airflow.contrib.operators.dataflow_operator.'
+ 'DataFlowHook') as mock_dataflow_hook:
+
+ hook_instance = mock_dataflow_hook.return_value
+ hook_instance.start_python_dataflow.return_value = None
+ summary.execute(None)
+ mock_dataflow_hook.assert_called_with(
+ gcp_conn_id='google_cloud_default', delegate_to=None)
+ hook_instance.start_python_dataflow.assert_called_once_with(
+ 'eval-test-summary',
+ {
+ 'prediction_path': 'gs://legal-bucket/fake-output-path',
+ 'metric_keys': 'err',
+ 'metric_fn_encoded': self.metric_fn_encoded,
+ },
+ 'airflow.contrib.operators.mlengine_prediction_summary',
+ ['-m'])
+
+ with patch('airflow.contrib.operators.mlengine_operator_utils.'
+ 'GoogleCloudStorageHook') as mock_gcs_hook:
+
+ hook_instance = mock_gcs_hook.return_value
+ hook_instance.download.return_value = '{"err": 0.9, "count": 9}'
+ result = validate.execute({})
+ hook_instance.download.assert_called_once_with(
+ 'legal-bucket', 'fake-output-path/prediction.summary.json')
+ self.assertEqual('err=0.9', result)
+
+ def testFailures(self):
+ dag = DAG(
+ 'test_dag',
+ default_args={
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE,
+ 'end_date': DEFAULT_DATE,
+ 'project_id': 'test-project',
+ 'region': 'us-east1',
+ },
+ schedule_interval='@daily')
+
+ input_with_model = self.INPUT_MISSING_ORIGIN.copy()
+ other_params_but_models = {
+ 'task_prefix': 'eval-test',
+ 'batch_prediction_job_id': 'eval-test-prediction',
+ 'data_format': input_with_model['dataFormat'],
+ 'input_paths': input_with_model['inputPaths'],
+ 'prediction_path': input_with_model['outputPath'],
+ 'metric_fn_and_keys': (self.metric_fn, ['err']),
+ 'validate_fn': (lambda x: 'err=%.1f' % x['err']),
+ 'dag': dag,
+ }
+
+ with self.assertRaisesRegexp(ValueError, 'Missing model origin'):
+ _ = create_evaluate_ops(**other_params_but_models)
+
+ with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
+ _ = create_evaluate_ops(model_uri='abc', model_name='cde',
+ **other_params_but_models)
+
+ with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
+ _ = create_evaluate_ops(model_uri='abc', version_name='vvv',
+ **other_params_but_models)
+
+ with self.assertRaisesRegexp(AirflowException,
+ '`metric_fn` param must be callable'):
+ params = other_params_but_models.copy()
+ params['metric_fn_and_keys'] = (None, ['abc'])
+ _ = create_evaluate_ops(model_uri='gs://blah', **params)
+
+ with self.assertRaisesRegexp(AirflowException,
+ '`validate_fn` param must be callable'):
+ params = other_params_but_models.copy()
+ params['validate_fn'] = None
+ _ = create_evaluate_ops(model_uri='gs://blah', **params)
+
+
+if __name__ == '__main__':
+ unittest.main()
|