airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From criccom...@apache.org
Subject [2/3] incubator-airflow git commit: [AIRFLOW-1567][Airflow-1567] Renamed cloudml hook and operator to mlengine
Date Wed, 06 Sep 2017 16:51:29 GMT
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/mlengine_operator_utils.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/mlengine_operator_utils.py b/airflow/contrib/operators/mlengine_operator_utils.py
new file mode 100644
index 0000000..5fda6ae
--- /dev/null
+++ b/airflow/contrib/operators/mlengine_operator_utils.py
@@ -0,0 +1,245 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the 'License'); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import json
+import os
+import re
+
+import dill
+
+from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
+from airflow.contrib.operators.mlengine_operator import MLEngineBatchPredictionOperator
+from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
+from airflow.exceptions import AirflowException
+from airflow.operators.python_operator import PythonOperator
+from six.moves.urllib.parse import urlsplit
+
+def create_evaluate_ops(task_prefix,
+                        data_format,
+                        input_paths,
+                        prediction_path,
+                        metric_fn_and_keys,
+                        validate_fn,
+                        batch_prediction_job_id=None,
+                        project_id=None,
+                        region=None,
+                        dataflow_options=None,
+                        model_uri=None,
+                        model_name=None,
+                        version_name=None,
+                        dag=None):
+    """
+    Creates Operators needed for model evaluation and returns.
+
+    It gets prediction over inputs via Cloud ML Engine BatchPrediction API by
+    calling MLEngineBatchPredictionOperator, then summarize and validate
+    the result via Cloud Dataflow using DataFlowPythonOperator.
+
+    For details and pricing about Batch prediction, please refer to the website
+    https://cloud.google.com/ml-engine/docs/how-tos/batch-predict
+    and for Cloud Dataflow, https://cloud.google.com/dataflow/docs/
+
+    It returns three chained operators for prediction, summary, and validation,
+    named as <prefix>-prediction, <prefix>-summary, and <prefix>-validation,
+    respectively.
+    (<prefix> should contain only alphanumeric characters or hyphen.)
+
+    The upstream and downstream can be set accordingly like:
+      pred, _, val = create_evaluate_ops(...)
+      pred.set_upstream(upstream_op)
+      ...
+      downstream_op.set_upstream(val)
+
+    Callers will provide two python callables, metric_fn and validate_fn, in
+    order to customize the evaluation behavior as they wish.
+    - metric_fn receives a dictionary per instance derived from json in the
+      batch prediction result. The keys might vary depending on the model.
+      It should return a tuple of metrics.
+    - validation_fn receives a dictionary of the averaged metrics that metric_fn
+      generated over all instances.
+      The key/value of the dictionary matches to what's given by
+      metric_fn_and_keys arg.
+      The dictionary contains an additional metric, 'count' to represent the
+      total number of instances received for evaluation.
+      The function would raise an exception to mark the task as failed, in a
+      case the validation result is not okay to proceed (i.e. to set the trained
+      version as default).
+
+    Typical examples are like this:
+
+    def get_metric_fn_and_keys():
+        import math  # imports should be outside of the metric_fn below.
+        def error_and_squared_error(inst):
+            label = float(inst['input_label'])
+            classes = float(inst['classes'])  # 0 or 1
+            err = abs(classes-label)
+            squared_err = math.pow(classes-label, 2)
+            return (err, squared_err)  # returns a tuple.
+        return error_and_squared_error, ['err', 'mse']  # key order must match.
+
+    def validate_err_and_count(summary):
+        if summary['err'] > 0.2:
+            raise ValueError('Too high err>0.2; summary=%s' % summary)
+        if summary['mse'] > 0.05:
+            raise ValueError('Too high mse>0.05; summary=%s' % summary)
+        if summary['count'] < 1000:
+            raise ValueError('Too few instances<1000; summary=%s' % summary)
+        return summary
+
+    For the details on the other BatchPrediction-related arguments (project_id,
+    job_id, region, data_format, input_paths, prediction_path, model_uri),
+    please refer to MLEngineBatchPredictionOperator too.
+
+    :param task_prefix: a prefix for the tasks. Only alphanumeric characters and
+        hyphen are allowed (no underscores), since this will be used as dataflow
+        job name, which doesn't allow other characters.
+    :type task_prefix: string
+
+    :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP'
+    :type data_format: string
+
+    :param input_paths: a list of input paths to be sent to BatchPrediction.
+    :type input_paths: list of strings
+
+    :param prediction_path: GCS path to put the prediction results in.
+    :type prediction_path: string
+
+    :param metric_fn_and_keys: a tuple of metric_fn and metric_keys:
+        - metric_fn is a function that accepts a dictionary (for an instance),
+          and returns a tuple of metric(s) that it calculates.
+        - metric_keys is a list of strings to denote the key of each metric.
+    :type metric_fn_and_keys: tuple of a function and a list of strings
+
+    :param validate_fn: a function to validate whether the averaged metric(s) is
+        good enough to push the model.
+    :type validate_fn: function
+
+    :param batch_prediction_job_id: the id to use for the Cloud ML Batch
+        prediction job. Passed directly to the MLEngineBatchPredictionOperator as
+        the job_id argument.
+    :type batch_prediction_job_id: string
+
+    :param project_id: the Google Cloud Platform project id in which to execute
+        Cloud ML Batch Prediction and Dataflow jobs. If None, then the `dag`'s
+        `default_args['project_id']` will be used.
+    :type project_id: string
+
+    :param region: the Google Cloud Platform region in which to execute Cloud ML
+        Batch Prediction and Dataflow jobs. If None, then the `dag`'s
+        `default_args['region']` will be used.
+    :type region: string
+
+    :param dataflow_options: options to run Dataflow jobs. If None, then the
+        `dag`'s `default_args['dataflow_default_options']` will be used.
+    :type dataflow_options: dictionary
+
+    :param model_uri: GCS path of the model exported by Tensorflow using
+        tensorflow.estimator.export_savedmodel(). It cannot be used with
+        model_name or version_name below. See MLEngineBatchPredictionOperator for
+        more detail.
+    :type model_uri: string
+
+    :param model_name: Used to indicate a model to use for prediction. Can be
+        used in combination with version_name, but cannot be used together with
+        model_uri. See MLEngineBatchPredictionOperator for more detail. If None,
+        then the `dag`'s `default_args['model_name']` will be used.
+    :type model_name: string
+
+    :param version_name: Used to indicate a model version to use for prediciton,
+        in combination with model_name. Cannot be used together with model_uri.
+        See MLEngineBatchPredictionOperator for more detail. If None, then the
+        `dag`'s `default_args['version_name']` will be used.
+    :type version_name: string
+
+    :param dag: The `DAG` to use for all Operators.
+    :type dag: airflow.DAG
+
+    :returns: a tuple of three operators, (prediction, summary, validation)
+    :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
+                  PythonOperator)
+    """
+
+    # Verify that task_prefix doesn't have any special characters except hyphen
+    # '-', which is the only allowed non-alphanumeric character by Dataflow.
+    if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix):
+        raise AirflowException(
+            "Malformed task_id for DataFlowPythonOperator (only alphanumeric "
+            "and hyphens are allowed but got: " + task_prefix)
+
+    metric_fn, metric_keys = metric_fn_and_keys
+    if not callable(metric_fn):
+        raise AirflowException("`metric_fn` param must be callable.")
+    if not callable(validate_fn):
+        raise AirflowException("`validate_fn` param must be callable.")
+
+    if dag is not None and dag.default_args is not None:
+        default_args = dag.default_args
+        project_id = project_id or default_args.get('project_id')
+        region = region or default_args.get('region')
+        model_name = model_name or default_args.get('model_name')
+        version_name = version_name or default_args.get('version_name')
+        dataflow_options = dataflow_options or \
+            default_args.get('dataflow_default_options')
+
+    evaluate_prediction = MLEngineBatchPredictionOperator(
+        task_id=(task_prefix + "-prediction"),
+        project_id=project_id,
+        job_id=batch_prediction_job_id,
+        region=region,
+        data_format=data_format,
+        input_paths=input_paths,
+        output_path=prediction_path,
+        uri=model_uri,
+        model_name=model_name,
+        version_name=version_name,
+        dag=dag)
+
+    metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True))
+    evaluate_summary = DataFlowPythonOperator(
+        task_id=(task_prefix + "-summary"),
+        py_options=["-m"],
+        py_file="airflow.contrib.operators.mlengine_prediction_summary",
+        dataflow_default_options=dataflow_options,
+        options={
+            "prediction_path": prediction_path,
+            "metric_fn_encoded": metric_fn_encoded,
+            "metric_keys": ','.join(metric_keys)
+        },
+        dag=dag)
+    evaluate_summary.set_upstream(evaluate_prediction)
+
+    def apply_validate_fn(*args, **kwargs):
+        prediction_path = kwargs["templates_dict"]["prediction_path"]
+        scheme, bucket, obj, _, _ = urlsplit(prediction_path)
+        if scheme != "gs" or not bucket or not obj:
+            raise ValueError("Wrong format prediction_path: %s",
+                             prediction_path)
+        summary = os.path.join(obj.strip("/"),
+                               "prediction.summary.json")
+        gcs_hook = GoogleCloudStorageHook()
+        summary = json.loads(gcs_hook.download(bucket, summary))
+        return validate_fn(summary)
+
+    evaluate_validation = PythonOperator(
+        task_id=(task_prefix + "-validation"),
+        python_callable=apply_validate_fn,
+        provide_context=True,
+        templates_dict={"prediction_path": prediction_path},
+        dag=dag)
+    evaluate_validation.set_upstream(evaluate_summary)
+
+    return evaluate_prediction, evaluate_summary, evaluate_validation

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/mlengine_prediction_summary.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/mlengine_prediction_summary.py b/airflow/contrib/operators/mlengine_prediction_summary.py
new file mode 100644
index 0000000..1f4d540
--- /dev/null
+++ b/airflow/contrib/operators/mlengine_prediction_summary.py
@@ -0,0 +1,177 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the 'License'); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A template called by DataFlowPythonOperator to summarize BatchPrediction.
+
+It accepts a user function to calculate the metric(s) per instance in
+the prediction results, then aggregates to output as a summary.
+
+Args:
+  --prediction_path:
+      The GCS folder that contains BatchPrediction results, containing
+      prediction.results-NNNNN-of-NNNNN files in the json format.
+      Output will be also stored in this folder, as 'prediction.summary.json'.
+
+  --metric_fn_encoded:
+      An encoded function that calculates and returns a tuple of metric(s)
+      for a given instance (as a dictionary). It should be encoded
+      via base64.b64encode(dill.dumps(fn, recurse=True)).
+
+  --metric_keys:
+      A comma-separated key(s) of the aggregated metric(s) in the summary
+      output. The order and the size of the keys must match to the output
+      of metric_fn.
+      The summary will have an additional key, 'count', to represent the
+      total number of instances, so the keys shouldn't include 'count'.
+
+# Usage example:
+def get_metric_fn():
+    import math  # all imports must be outside of the function to be passed.
+    def metric_fn(inst):
+        label = float(inst["input_label"])
+        classes = float(inst["classes"])
+        prediction = float(inst["scores"][1])
+        log_loss = math.log(1 + math.exp(
+            -(label * 2 - 1) * math.log(prediction / (1 - prediction))))
+        squared_err = (classes-label)**2
+        return (log_loss, squared_err)
+    return metric_fn
+metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True))
+
+airflow.contrib.operators.DataFlowPythonOperator(
+    task_id="summary-prediction",
+    py_options=["-m"],
+    py_file="airflow.contrib.operators.mlengine_prediction_summary",
+    options={
+        "prediction_path": prediction_path,
+        "metric_fn_encoded": metric_fn_encoded,
+        "metric_keys": "log_loss,mse"
+    },
+    dataflow_default_options={
+        "project": "xxx", "region": "us-east1",
+        "staging_location": "gs://yy", "temp_location": "gs://zz",
+    })
+    >> dag
+
+# When the input file is like the following:
+{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
+{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
+{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
+{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
+
+# The output file will be:
+{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
+
+# To test outside of the dag:
+subprocess.check_call(["python",
+                       "-m",
+                       "airflow.contrib.operators.mlengine_prediction_summary",
+                       "--prediction_path=gs://...",
+                       "--metric_fn_encoded=" + metric_fn_encoded,
+                       "--metric_keys=log_loss,mse",
+                       "--runner=DataflowRunner",
+                       "--staging_location=gs://...",
+                       "--temp_location=gs://...",
+                       ])
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import base64
+import json
+import logging
+import os
+
+import apache_beam as beam
+import dill
+
+
+class JsonCoder(object):
+    def encode(self, x):
+        return json.dumps(x)
+
+    def decode(self, x):
+        return json.loads(x)
+
+
+@beam.ptransform_fn
+def MakeSummary(pcoll, metric_fn, metric_keys):  # pylint: disable=invalid-name
+    return (
+        pcoll
+        | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn)
+        | "PairWith1" >> beam.Map(lambda tup: tup + (1,))
+        | "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn(
+            *([sum] * (len(metric_keys) + 1))))
+        | "AverageAndMakeDict" >> beam.Map(
+            lambda tup: dict(
+                [(name, tup[i]/tup[-1]) for i, name in enumerate(metric_keys)] +
+                [("count", tup[-1])])))
+
+
+def run(argv=None):
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--prediction_path", required=True,
+        help=(
+            "The GCS folder that contains BatchPrediction results, containing "
+            "prediction.results-NNNNN-of-NNNNN files in the json format. "
+            "Output will be also stored in this folder, as a file"
+            "'prediction.summary.json'."))
+    parser.add_argument(
+        "--metric_fn_encoded", required=True,
+        help=(
+            "An encoded function that calculates and returns a tuple of "
+            "metric(s) for a given instance (as a dictionary). It should be "
+            "encoded via base64.b64encode(dill.dumps(fn, recurse=True))."))
+    parser.add_argument(
+        "--metric_keys", required=True,
+        help=(
+            "A comma-separated keys of the aggregated metric(s) in the summary "
+            "output. The order and the size of the keys must match to the "
+            "output of metric_fn. The summary will have an additional key, "
+            "'count', to represent the total number of instances, so this flag "
+            "shouldn't include 'count'."))
+    known_args, pipeline_args = parser.parse_known_args(argv)
+
+    metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded))
+    if not callable(metric_fn):
+        raise ValueError("--metric_fn_encoded must be an encoded callable.")
+    metric_keys = known_args.metric_keys.split(",")
+
+    with beam.Pipeline(
+        options=beam.pipeline.PipelineOptions(pipeline_args)) as p:
+        # This is apache-beam ptransform's convention
+        # pylint: disable=no-value-for-parameter
+        _ = (p
+             | "ReadPredictionResult" >> beam.io.ReadFromText(
+                 os.path.join(known_args.prediction_path,
+                              "prediction.results-*-of-*"),
+                 coder=JsonCoder())
+             | "Summary" >> MakeSummary(metric_fn, metric_keys)
+             | "Write" >> beam.io.WriteToText(
+                 os.path.join(known_args.prediction_path,
+                              "prediction.summary.json"),
+                 shard_name_template='',  # without trailing -NNNNN-of-NNNNN.
+                 coder=JsonCoder()))
+        # pylint: enable=no-value-for-parameter
+
+
+if __name__ == "__main__":
+    logging.getLogger().setLevel(logging.INFO)
+    run()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/hooks/test_gcp_cloudml_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py
deleted file mode 100644
index f56018d..0000000
--- a/tests/contrib/hooks/test_gcp_cloudml_hook.py
+++ /dev/null
@@ -1,413 +0,0 @@
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import mock
-import unittest
-
-try:  # python 2
-    from urlparse import urlparse, parse_qsl
-except ImportError:  # python 3
-    from urllib.parse import urlparse, parse_qsl
-
-from airflow.contrib.hooks import gcp_cloudml_hook as hook
-from apiclient import errors
-from apiclient.discovery import build
-from apiclient.http import HttpMockSequence
-from oauth2client.contrib.gce import HttpAccessTokenRefreshError
-
-cml_available = True
-try:
-    hook.CloudMLHook().get_conn()
-except HttpAccessTokenRefreshError:
-    cml_available = False
-
-
-class _TestCloudMLHook(object):
-
-    def __init__(self, test_cls, responses, expected_requests):
-        """
-        Init method.
-
-        Usage example:
-        with _TestCloudMLHook(self, responses, expected_requests) as hook:
-            self.run_my_test(hook)
-
-        Args:
-          test_cls: The caller's instance used for test communication.
-          responses: A list of (dict_response, response_content) tuples.
-          expected_requests: A list of (uri, http_method, body) tuples.
-        """
-
-        self._test_cls = test_cls
-        self._responses = responses
-        self._expected_requests = [
-            self._normalize_requests_for_comparison(x[0], x[1], x[2])
-            for x in expected_requests]
-        self._actual_requests = []
-
-    def _normalize_requests_for_comparison(self, uri, http_method, body):
-        parts = urlparse(uri)
-        return (
-            parts._replace(query=set(parse_qsl(parts.query))),
-            http_method,
-            body)
-
-    def __enter__(self):
-        http = HttpMockSequence(self._responses)
-        native_request_method = http.request
-
-        # Collecting requests to validate at __exit__.
-        def _request_wrapper(*args, **kwargs):
-            self._actual_requests.append(args + (kwargs['body'],))
-            return native_request_method(*args, **kwargs)
-
-        http.request = _request_wrapper
-        service_mock = build('ml', 'v1', http=http)
-        with mock.patch.object(
-                hook.CloudMLHook, 'get_conn', return_value=service_mock):
-            return hook.CloudMLHook()
-
-    def __exit__(self, *args):
-        # Propogating exceptions here since assert will silence them.
-        if any(args):
-            return None
-        self._test_cls.assertEquals(
-            [self._normalize_requests_for_comparison(x[0], x[1], x[2])
-                for x in self._actual_requests],
-            self._expected_requests)
-
-
-class TestCloudMLHook(unittest.TestCase):
-
-    def setUp(self):
-        pass
-
-    _SKIP_IF = unittest.skipIf(not cml_available,
-                               'CloudML is not available to run tests')
-
-    _SERVICE_URI_PREFIX = 'https://ml.googleapis.com/v1/'
-
-    @_SKIP_IF
-    def test_create_version(self):
-        project = 'test-project'
-        model_name = 'test-model'
-        version = 'test-version'
-        operation_name = 'projects/{}/operations/test-operation'.format(
-            project)
-
-        response_body = {'name': operation_name, 'done': True}
-        succeeded_response = ({'status': '200'}, json.dumps(response_body))
-
-        expected_requests = [
-            ('{}projects/{}/models/{}/versions?alt=json'.format(
-                self._SERVICE_URI_PREFIX, project, model_name), 'POST',
-             '"{}"'.format(version)),
-            ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
-             'GET', None),
-        ]
-
-        with _TestCloudMLHook(
-                self,
-                responses=[succeeded_response] * 2,
-                expected_requests=expected_requests) as cml_hook:
-            create_version_response = cml_hook.create_version(
-                project_id=project, model_name=model_name,
-                version_spec=version)
-            self.assertEquals(create_version_response, response_body)
-
-    @_SKIP_IF
-    def test_set_default_version(self):
-        project = 'test-project'
-        model_name = 'test-model'
-        version = 'test-version'
-        operation_name = 'projects/{}/operations/test-operation'.format(
-            project)
-
-        response_body = {'name': operation_name, 'done': True}
-        succeeded_response = ({'status': '200'}, json.dumps(response_body))
-
-        expected_requests = [
-            ('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format(
-                self._SERVICE_URI_PREFIX, project, model_name, version),
-                'POST', '{}'),
-        ]
-
-        with _TestCloudMLHook(
-                self,
-                responses=[succeeded_response],
-                expected_requests=expected_requests) as cml_hook:
-            set_default_version_response = cml_hook.set_default_version(
-                project_id=project, model_name=model_name,
-                version_name=version)
-            self.assertEquals(set_default_version_response, response_body)
-
-    @_SKIP_IF
-    def test_list_versions(self):
-        project = 'test-project'
-        model_name = 'test-model'
-        operation_name = 'projects/{}/operations/test-operation'.format(
-            project)
-
-        # This test returns the versions one at a time.
-        versions = ['ver_{}'.format(ix) for ix in range(3)]
-
-        response_bodies = [
-            {
-                'name': operation_name,
-                'nextPageToken': ix,
-                'versions': [ver]
-            } for ix, ver in enumerate(versions)]
-        response_bodies[-1].pop('nextPageToken')
-        responses = [({'status': '200'}, json.dumps(body))
-                     for body in response_bodies]
-
-        expected_requests = [
-            ('{}projects/{}/models/{}/versions?alt=json&pageSize=100'.format(
-                self._SERVICE_URI_PREFIX, project, model_name), 'GET',
-             None),
-        ] + [
-            ('{}projects/{}/models/{}/versions?alt=json&pageToken={}'
-             '&pageSize=100'.format(
-                self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET',
-             None) for ix in range(len(versions) - 1)
-        ]
-
-        with _TestCloudMLHook(
-                self,
-                responses=responses,
-                expected_requests=expected_requests) as cml_hook:
-            list_versions_response = cml_hook.list_versions(
-                project_id=project, model_name=model_name)
-            self.assertEquals(list_versions_response, versions)
-
-    @_SKIP_IF
-    def test_delete_version(self):
-        project = 'test-project'
-        model_name = 'test-model'
-        version = 'test-version'
-        operation_name = 'projects/{}/operations/test-operation'.format(
-            project)
-
-        not_done_response_body = {'name': operation_name, 'done': False}
-        done_response_body = {'name': operation_name, 'done': True}
-        not_done_response = (
-            {'status': '200'}, json.dumps(not_done_response_body))
-        succeeded_response = (
-            {'status': '200'}, json.dumps(done_response_body))
-
-        expected_requests = [
-            (
-                '{}projects/{}/models/{}/versions/{}?alt=json'.format(
-                    self._SERVICE_URI_PREFIX, project, model_name, version),
-                'DELETE',
-                None),
-            ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
-             'GET', None),
-        ]
-
-        with _TestCloudMLHook(
-                self,
-                responses=[not_done_response, succeeded_response],
-                expected_requests=expected_requests) as cml_hook:
-            delete_version_response = cml_hook.delete_version(
-                project_id=project, model_name=model_name,
-                version_name=version)
-            self.assertEquals(delete_version_response, done_response_body)
-
-    @_SKIP_IF
-    def test_create_model(self):
-        project = 'test-project'
-        model_name = 'test-model'
-        model = {
-            'name': model_name,
-        }
-        response_body = {}
-        succeeded_response = ({'status': '200'}, json.dumps(response_body))
-
-        expected_requests = [
-            ('{}projects/{}/models?alt=json'.format(
-                self._SERVICE_URI_PREFIX, project), 'POST',
-             json.dumps(model)),
-        ]
-
-        with _TestCloudMLHook(
-                self,
-                responses=[succeeded_response],
-                expected_requests=expected_requests) as cml_hook:
-            create_model_response = cml_hook.create_model(
-                project_id=project, model=model)
-            self.assertEquals(create_model_response, response_body)
-
-    @_SKIP_IF
-    def test_get_model(self):
-        project = 'test-project'
-        model_name = 'test-model'
-        response_body = {'model': model_name}
-        succeeded_response = ({'status': '200'}, json.dumps(response_body))
-
-        expected_requests = [
-            ('{}projects/{}/models/{}?alt=json'.format(
-                self._SERVICE_URI_PREFIX, project, model_name), 'GET',
-             None),
-        ]
-
-        with _TestCloudMLHook(
-                self,
-                responses=[succeeded_response],
-                expected_requests=expected_requests) as cml_hook:
-            get_model_response = cml_hook.get_model(
-                project_id=project, model_name=model_name)
-            self.assertEquals(get_model_response, response_body)
-
-    @_SKIP_IF
-    def test_create_cloudml_job(self):
-        project = 'test-project'
-        job_id = 'test-job-id'
-        my_job = {
-            'jobId': job_id,
-            'foo': 4815162342,
-            'state': 'SUCCEEDED',
-        }
-        response_body = json.dumps(my_job)
-        succeeded_response = ({'status': '200'}, response_body)
-        queued_response = ({'status': '200'}, json.dumps({
-            'jobId': job_id,
-            'state': 'QUEUED',
-        }))
-
-        create_job_request = ('{}projects/{}/jobs?alt=json'.format(
-            self._SERVICE_URI_PREFIX, project), 'POST', response_body)
-        ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
-            self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
-        expected_requests = [
-            create_job_request,
-            ask_if_done_request,
-            ask_if_done_request,
-        ]
-        responses = [succeeded_response,
-                     queued_response, succeeded_response]
-
-        with _TestCloudMLHook(
-                self,
-                responses=responses,
-                expected_requests=expected_requests) as cml_hook:
-            create_job_response = cml_hook.create_job(
-                project_id=project, job=my_job)
-            self.assertEquals(create_job_response, my_job)
-
-    @_SKIP_IF
-    def test_create_cloudml_job_reuse_existing_job_by_default(self):
-        project = 'test-project'
-        job_id = 'test-job-id'
-        my_job = {
-            'jobId': job_id,
-            'foo': 4815162342,
-            'state': 'SUCCEEDED',
-        }
-        response_body = json.dumps(my_job)
-        job_already_exist_response = ({'status': '409'}, json.dumps({}))
-        succeeded_response = ({'status': '200'}, response_body)
-
-        create_job_request = ('{}projects/{}/jobs?alt=json'.format(
-            self._SERVICE_URI_PREFIX, project), 'POST', response_body)
-        ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
-            self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
-        expected_requests = [
-            create_job_request,
-            ask_if_done_request,
-        ]
-        responses = [job_already_exist_response, succeeded_response]
-
-        # By default, 'create_job' reuse the existing job.
-        with _TestCloudMLHook(
-                self,
-                responses=responses,
-                expected_requests=expected_requests) as cml_hook:
-            create_job_response = cml_hook.create_job(
-                project_id=project, job=my_job)
-            self.assertEquals(create_job_response, my_job)
-
-    @_SKIP_IF
-    def test_create_cloudml_job_check_existing_job(self):
-        project = 'test-project'
-        job_id = 'test-job-id'
-        my_job = {
-            'jobId': job_id,
-            'foo': 4815162342,
-            'state': 'SUCCEEDED',
-            'someInput': {
-                'input': 'someInput'
-            }
-        }
-        different_job = {
-            'jobId': job_id,
-            'foo': 4815162342,
-            'state': 'SUCCEEDED',
-            'someInput': {
-                'input': 'someDifferentInput'
-            }
-        }
-
-        my_job_response_body = json.dumps(my_job)
-        different_job_response_body = json.dumps(different_job)
-        job_already_exist_response = ({'status': '409'}, json.dumps({}))
-        different_job_response = ({'status': '200'},
-                                  different_job_response_body)
-
-        create_job_request = ('{}projects/{}/jobs?alt=json'.format(
-            self._SERVICE_URI_PREFIX, project), 'POST', my_job_response_body)
-        ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
-            self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
-        expected_requests = [
-            create_job_request,
-            ask_if_done_request,
-        ]
-
-        # Returns a different job (with different 'someInput' field) will
-        # cause 'create_job' request to fail.
-        responses = [job_already_exist_response, different_job_response]
-
-        def check_input(existing_job):
-            return existing_job.get('someInput', None) == \
-                my_job['someInput']
-        with _TestCloudMLHook(
-                self,
-                responses=responses,
-                expected_requests=expected_requests) as cml_hook:
-            with self.assertRaises(errors.HttpError):
-                cml_hook.create_job(
-                    project_id=project, job=my_job,
-                    use_existing_job_fn=check_input)
-
-        my_job_response = ({'status': '200'}, my_job_response_body)
-        expected_requests = [
-            create_job_request,
-            ask_if_done_request,
-            ask_if_done_request,
-        ]
-        responses = [
-            job_already_exist_response,
-            my_job_response,
-            my_job_response]
-        with _TestCloudMLHook(
-                self,
-                responses=responses,
-                expected_requests=expected_requests) as cml_hook:
-            create_job_response = cml_hook.create_job(
-                project_id=project, job=my_job,
-                use_existing_job_fn=check_input)
-            self.assertEquals(create_job_response, my_job)
-
-
-if __name__ == '__main__':
-    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/hooks/test_gcp_mlengine_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_gcp_mlengine_hook.py b/tests/contrib/hooks/test_gcp_mlengine_hook.py
new file mode 100644
index 0000000..372d47c
--- /dev/null
+++ b/tests/contrib/hooks/test_gcp_mlengine_hook.py
@@ -0,0 +1,413 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import mock
+import unittest
+
+try:  # python 2
+    from urlparse import urlparse, parse_qsl
+except ImportError:  # python 3
+    from urllib.parse import urlparse, parse_qsl
+
+from airflow.contrib.hooks import gcp_mlengine_hook as hook
+from apiclient import errors
+from apiclient.discovery import build
+from apiclient.http import HttpMockSequence
+from oauth2client.contrib.gce import HttpAccessTokenRefreshError
+
+cml_available = True
+try:
+    hook.MLEngineHook().get_conn()
+except HttpAccessTokenRefreshError:
+    cml_available = False
+
+
+class _TestMLEngineHook(object):
+
+    def __init__(self, test_cls, responses, expected_requests):
+        """
+        Init method.
+
+        Usage example:
+        with _TestMLEngineHook(self, responses, expected_requests) as hook:
+            self.run_my_test(hook)
+
+        Args:
+          test_cls: The caller's instance used for test communication.
+          responses: A list of (dict_response, response_content) tuples.
+          expected_requests: A list of (uri, http_method, body) tuples.
+        """
+
+        self._test_cls = test_cls
+        self._responses = responses
+        self._expected_requests = [
+            self._normalize_requests_for_comparison(x[0], x[1], x[2])
+            for x in expected_requests]
+        self._actual_requests = []
+
+    def _normalize_requests_for_comparison(self, uri, http_method, body):
+        parts = urlparse(uri)
+        return (
+            parts._replace(query=set(parse_qsl(parts.query))),
+            http_method,
+            body)
+
+    def __enter__(self):
+        http = HttpMockSequence(self._responses)
+        native_request_method = http.request
+
+        # Collecting requests to validate at __exit__.
+        def _request_wrapper(*args, **kwargs):
+            self._actual_requests.append(args + (kwargs['body'],))
+            return native_request_method(*args, **kwargs)
+
+        http.request = _request_wrapper
+        service_mock = build('ml', 'v1', http=http)
+        with mock.patch.object(
+                hook.MLEngineHook, 'get_conn', return_value=service_mock):
+            return hook.MLEngineHook()
+
+    def __exit__(self, *args):
+        # Propogating exceptions here since assert will silence them.
+        if any(args):
+            return None
+        self._test_cls.assertEquals(
+            [self._normalize_requests_for_comparison(x[0], x[1], x[2])
+                for x in self._actual_requests],
+            self._expected_requests)
+
+
+class TestMLEngineHook(unittest.TestCase):
+
+    def setUp(self):
+        pass
+
+    _SKIP_IF = unittest.skipIf(not cml_available,
+                               'MLEngine is not available to run tests')
+
+    _SERVICE_URI_PREFIX = 'https://ml.googleapis.com/v1/'
+
+    @_SKIP_IF
+    def test_create_version(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        version = 'test-version'
+        operation_name = 'projects/{}/operations/test-operation'.format(
+            project)
+
+        response_body = {'name': operation_name, 'done': True}
+        succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+        expected_requests = [
+            ('{}projects/{}/models/{}/versions?alt=json'.format(
+                self._SERVICE_URI_PREFIX, project, model_name), 'POST',
+             '"{}"'.format(version)),
+            ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
+             'GET', None),
+        ]
+
+        with _TestMLEngineHook(
+                self,
+                responses=[succeeded_response] * 2,
+                expected_requests=expected_requests) as cml_hook:
+            create_version_response = cml_hook.create_version(
+                project_id=project, model_name=model_name,
+                version_spec=version)
+            self.assertEquals(create_version_response, response_body)
+
+    @_SKIP_IF
+    def test_set_default_version(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        version = 'test-version'
+        operation_name = 'projects/{}/operations/test-operation'.format(
+            project)
+
+        response_body = {'name': operation_name, 'done': True}
+        succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+        expected_requests = [
+            ('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format(
+                self._SERVICE_URI_PREFIX, project, model_name, version),
+                'POST', '{}'),
+        ]
+
+        with _TestMLEngineHook(
+                self,
+                responses=[succeeded_response],
+                expected_requests=expected_requests) as cml_hook:
+            set_default_version_response = cml_hook.set_default_version(
+                project_id=project, model_name=model_name,
+                version_name=version)
+            self.assertEquals(set_default_version_response, response_body)
+
+    @_SKIP_IF
+    def test_list_versions(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        operation_name = 'projects/{}/operations/test-operation'.format(
+            project)
+
+        # This test returns the versions one at a time.
+        versions = ['ver_{}'.format(ix) for ix in range(3)]
+
+        response_bodies = [
+            {
+                'name': operation_name,
+                'nextPageToken': ix,
+                'versions': [ver]
+            } for ix, ver in enumerate(versions)]
+        response_bodies[-1].pop('nextPageToken')
+        responses = [({'status': '200'}, json.dumps(body))
+                     for body in response_bodies]
+
+        expected_requests = [
+            ('{}projects/{}/models/{}/versions?alt=json&pageSize=100'.format(
+                self._SERVICE_URI_PREFIX, project, model_name), 'GET',
+             None),
+        ] + [
+            ('{}projects/{}/models/{}/versions?alt=json&pageToken={}'
+             '&pageSize=100'.format(
+                self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET',
+             None) for ix in range(len(versions) - 1)
+        ]
+
+        with _TestMLEngineHook(
+                self,
+                responses=responses,
+                expected_requests=expected_requests) as cml_hook:
+            list_versions_response = cml_hook.list_versions(
+                project_id=project, model_name=model_name)
+            self.assertEquals(list_versions_response, versions)
+
+    @_SKIP_IF
+    def test_delete_version(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        version = 'test-version'
+        operation_name = 'projects/{}/operations/test-operation'.format(
+            project)
+
+        not_done_response_body = {'name': operation_name, 'done': False}
+        done_response_body = {'name': operation_name, 'done': True}
+        not_done_response = (
+            {'status': '200'}, json.dumps(not_done_response_body))
+        succeeded_response = (
+            {'status': '200'}, json.dumps(done_response_body))
+
+        expected_requests = [
+            (
+                '{}projects/{}/models/{}/versions/{}?alt=json'.format(
+                    self._SERVICE_URI_PREFIX, project, model_name, version),
+                'DELETE',
+                None),
+            ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name),
+             'GET', None),
+        ]
+
+        with _TestMLEngineHook(
+                self,
+                responses=[not_done_response, succeeded_response],
+                expected_requests=expected_requests) as cml_hook:
+            delete_version_response = cml_hook.delete_version(
+                project_id=project, model_name=model_name,
+                version_name=version)
+            self.assertEquals(delete_version_response, done_response_body)
+
+    @_SKIP_IF
+    def test_create_model(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        model = {
+            'name': model_name,
+        }
+        response_body = {}
+        succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+        expected_requests = [
+            ('{}projects/{}/models?alt=json'.format(
+                self._SERVICE_URI_PREFIX, project), 'POST',
+             json.dumps(model)),
+        ]
+
+        with _TestMLEngineHook(
+                self,
+                responses=[succeeded_response],
+                expected_requests=expected_requests) as cml_hook:
+            create_model_response = cml_hook.create_model(
+                project_id=project, model=model)
+            self.assertEquals(create_model_response, response_body)
+
+    @_SKIP_IF
+    def test_get_model(self):
+        project = 'test-project'
+        model_name = 'test-model'
+        response_body = {'model': model_name}
+        succeeded_response = ({'status': '200'}, json.dumps(response_body))
+
+        expected_requests = [
+            ('{}projects/{}/models/{}?alt=json'.format(
+                self._SERVICE_URI_PREFIX, project, model_name), 'GET',
+             None),
+        ]
+
+        with _TestMLEngineHook(
+                self,
+                responses=[succeeded_response],
+                expected_requests=expected_requests) as cml_hook:
+            get_model_response = cml_hook.get_model(
+                project_id=project, model_name=model_name)
+            self.assertEquals(get_model_response, response_body)
+
+    @_SKIP_IF
+    def test_create_mlengine_job(self):
+        project = 'test-project'
+        job_id = 'test-job-id'
+        my_job = {
+            'jobId': job_id,
+            'foo': 4815162342,
+            'state': 'SUCCEEDED',
+        }
+        response_body = json.dumps(my_job)
+        succeeded_response = ({'status': '200'}, response_body)
+        queued_response = ({'status': '200'}, json.dumps({
+            'jobId': job_id,
+            'state': 'QUEUED',
+        }))
+
+        create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project), 'POST', response_body)
+        ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+        expected_requests = [
+            create_job_request,
+            ask_if_done_request,
+            ask_if_done_request,
+        ]
+        responses = [succeeded_response,
+                     queued_response, succeeded_response]
+
+        with _TestMLEngineHook(
+                self,
+                responses=responses,
+                expected_requests=expected_requests) as cml_hook:
+            create_job_response = cml_hook.create_job(
+                project_id=project, job=my_job)
+            self.assertEquals(create_job_response, my_job)
+
+    @_SKIP_IF
+    def test_create_mlengine_job_reuse_existing_job_by_default(self):
+        project = 'test-project'
+        job_id = 'test-job-id'
+        my_job = {
+            'jobId': job_id,
+            'foo': 4815162342,
+            'state': 'SUCCEEDED',
+        }
+        response_body = json.dumps(my_job)
+        job_already_exist_response = ({'status': '409'}, json.dumps({}))
+        succeeded_response = ({'status': '200'}, response_body)
+
+        create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project), 'POST', response_body)
+        ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+        expected_requests = [
+            create_job_request,
+            ask_if_done_request,
+        ]
+        responses = [job_already_exist_response, succeeded_response]
+
+        # By default, 'create_job' reuse the existing job.
+        with _TestMLEngineHook(
+                self,
+                responses=responses,
+                expected_requests=expected_requests) as cml_hook:
+            create_job_response = cml_hook.create_job(
+                project_id=project, job=my_job)
+            self.assertEquals(create_job_response, my_job)
+
+    @_SKIP_IF
+    def test_create_mlengine_job_check_existing_job(self):
+        project = 'test-project'
+        job_id = 'test-job-id'
+        my_job = {
+            'jobId': job_id,
+            'foo': 4815162342,
+            'state': 'SUCCEEDED',
+            'someInput': {
+                'input': 'someInput'
+            }
+        }
+        different_job = {
+            'jobId': job_id,
+            'foo': 4815162342,
+            'state': 'SUCCEEDED',
+            'someInput': {
+                'input': 'someDifferentInput'
+            }
+        }
+
+        my_job_response_body = json.dumps(my_job)
+        different_job_response_body = json.dumps(different_job)
+        job_already_exist_response = ({'status': '409'}, json.dumps({}))
+        different_job_response = ({'status': '200'},
+                                  different_job_response_body)
+
+        create_job_request = ('{}projects/{}/jobs?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project), 'POST', my_job_response_body)
+        ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format(
+            self._SERVICE_URI_PREFIX, project, job_id), 'GET', None)
+        expected_requests = [
+            create_job_request,
+            ask_if_done_request,
+        ]
+
+        # Returns a different job (with different 'someInput' field) will
+        # cause 'create_job' request to fail.
+        responses = [job_already_exist_response, different_job_response]
+
+        def check_input(existing_job):
+            return existing_job.get('someInput', None) == \
+                my_job['someInput']
+        with _TestMLEngineHook(
+                self,
+                responses=responses,
+                expected_requests=expected_requests) as cml_hook:
+            with self.assertRaises(errors.HttpError):
+                cml_hook.create_job(
+                    project_id=project, job=my_job,
+                    use_existing_job_fn=check_input)
+
+        my_job_response = ({'status': '200'}, my_job_response_body)
+        expected_requests = [
+            create_job_request,
+            ask_if_done_request,
+            ask_if_done_request,
+        ]
+        responses = [
+            job_already_exist_response,
+            my_job_response,
+            my_job_response]
+        with _TestMLEngineHook(
+                self,
+                responses=responses,
+                expected_requests=expected_requests) as cml_hook:
+            create_job_response = cml_hook.create_job(
+                project_id=project, job=my_job,
+                use_existing_job_fn=check_input)
+            self.assertEquals(create_job_response, my_job)
+
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/operators/test_cloudml_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator.py b/tests/contrib/operators/test_cloudml_operator.py
deleted file mode 100644
index dc2366e..0000000
--- a/tests/contrib/operators/test_cloudml_operator.py
+++ /dev/null
@@ -1,373 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import datetime
-from apiclient import errors
-import httplib2
-import unittest
-
-from airflow import configuration, DAG
-from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
-from airflow.contrib.operators.cloudml_operator import CloudMLTrainingOperator
-
-from mock import ANY
-from mock import patch
-
-DEFAULT_DATE = datetime.datetime(2017, 6, 6)
-
-
-class CloudMLBatchPredictionOperatorTest(unittest.TestCase):
-    INPUT_MISSING_ORIGIN = {
-        'dataFormat': 'TEXT',
-        'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
-        'outputPath': 'gs://legal-bucket/fake-output-path',
-        'region': 'us-east1',
-    }
-    SUCCESS_MESSAGE_MISSING_INPUT = {
-        'jobId': 'test_prediction',
-        'predictionOutput': {
-            'outputPath': 'gs://fake-output-path',
-            'predictionCount': 5000,
-            'errorCount': 0,
-            'nodeHours': 2.78
-        },
-        'state': 'SUCCEEDED'
-    }
-    BATCH_PREDICTION_DEFAULT_ARGS = {
-        'project_id': 'test-project',
-        'job_id': 'test_prediction',
-        'region': 'us-east1',
-        'data_format': 'TEXT',
-        'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
-        'output_path':
-            'gs://12_legal_bucket_underscore_number/legal-output-path',
-        'task_id': 'test-prediction'
-    }
-
-    def setUp(self):
-        super(CloudMLBatchPredictionOperatorTest, self).setUp()
-        configuration.load_test_config()
-        self.dag = DAG(
-            'test_dag',
-            default_args={
-                'owner': 'airflow',
-                'start_date': DEFAULT_DATE,
-                'end_date': DEFAULT_DATE,
-            },
-            schedule_interval='@daily')
-
-    def testSuccessWithModel(self):
-        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
-                as mock_hook:
-
-            input_with_model = self.INPUT_MISSING_ORIGIN.copy()
-            input_with_model['modelName'] = \
-                'projects/test-project/models/test_model'
-            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
-            success_message['predictionInput'] = input_with_model
-
-            hook_instance = mock_hook.return_value
-            hook_instance.get_job.side_effect = errors.HttpError(
-                resp=httplib2.Response({
-                    'status': 404
-                }), content=b'some bytes')
-            hook_instance.create_job.return_value = success_message
-
-            prediction_task = CloudMLBatchPredictionOperator(
-                job_id='test_prediction',
-                project_id='test-project',
-                region=input_with_model['region'],
-                data_format=input_with_model['dataFormat'],
-                input_paths=input_with_model['inputPaths'],
-                output_path=input_with_model['outputPath'],
-                model_name=input_with_model['modelName'].split('/')[-1],
-                dag=self.dag,
-                task_id='test-prediction')
-            prediction_output = prediction_task.execute(None)
-
-            mock_hook.assert_called_with('google_cloud_default', None)
-            hook_instance.create_job.assert_called_once_with(
-                'test-project',
-                {
-                    'jobId': 'test_prediction',
-                    'predictionInput': input_with_model
-                }, ANY)
-            self.assertEquals(
-                success_message['predictionOutput'],
-                prediction_output)
-
-    def testSuccessWithVersion(self):
-        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
-                as mock_hook:
-
-            input_with_version = self.INPUT_MISSING_ORIGIN.copy()
-            input_with_version['versionName'] = \
-                'projects/test-project/models/test_model/versions/test_version'
-            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
-            success_message['predictionInput'] = input_with_version
-
-            hook_instance = mock_hook.return_value
-            hook_instance.get_job.side_effect = errors.HttpError(
-                resp=httplib2.Response({
-                    'status': 404
-                }), content=b'some bytes')
-            hook_instance.create_job.return_value = success_message
-
-            prediction_task = CloudMLBatchPredictionOperator(
-                job_id='test_prediction', project_id='test-project',
-                region=input_with_version['region'],
-                data_format=input_with_version['dataFormat'],
-                input_paths=input_with_version['inputPaths'],
-                output_path=input_with_version['outputPath'],
-                model_name=input_with_version['versionName'].split('/')[-3],
-                version_name=input_with_version['versionName'].split('/')[-1],
-                dag=self.dag,
-                task_id='test-prediction')
-            prediction_output = prediction_task.execute(None)
-
-            mock_hook.assert_called_with('google_cloud_default', None)
-            hook_instance.create_job.assert_called_with(
-                'test-project',
-                {
-                    'jobId': 'test_prediction',
-                    'predictionInput': input_with_version
-                }, ANY)
-            self.assertEquals(
-                success_message['predictionOutput'],
-                prediction_output)
-
-    def testSuccessWithURI(self):
-        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
-                as mock_hook:
-
-            input_with_uri = self.INPUT_MISSING_ORIGIN.copy()
-            input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel'
-            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
-            success_message['predictionInput'] = input_with_uri
-
-            hook_instance = mock_hook.return_value
-            hook_instance.get_job.side_effect = errors.HttpError(
-                resp=httplib2.Response({
-                    'status': 404
-                }), content=b'some bytes')
-            hook_instance.create_job.return_value = success_message
-
-            prediction_task = CloudMLBatchPredictionOperator(
-                job_id='test_prediction',
-                project_id='test-project',
-                region=input_with_uri['region'],
-                data_format=input_with_uri['dataFormat'],
-                input_paths=input_with_uri['inputPaths'],
-                output_path=input_with_uri['outputPath'],
-                uri=input_with_uri['uri'],
-                dag=self.dag,
-                task_id='test-prediction')
-            prediction_output = prediction_task.execute(None)
-
-            mock_hook.assert_called_with('google_cloud_default', None)
-            hook_instance.create_job.assert_called_with(
-                'test-project',
-                {
-                    'jobId': 'test_prediction',
-                    'predictionInput': input_with_uri
-                }, ANY)
-            self.assertEquals(
-                success_message['predictionOutput'],
-                prediction_output)
-
-    def testInvalidModelOrigin(self):
-        # Test that both uri and model is given
-        task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
-        task_args['uri'] = 'gs://fake-uri/saved_model'
-        task_args['model_name'] = 'fake_model'
-        with self.assertRaises(ValueError) as context:
-            CloudMLBatchPredictionOperator(**task_args).execute(None)
-        self.assertEquals('Ambiguous model origin.', str(context.exception))
-
-        # Test that both uri and model/version is given
-        task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
-        task_args['uri'] = 'gs://fake-uri/saved_model'
-        task_args['model_name'] = 'fake_model'
-        task_args['version_name'] = 'fake_version'
-        with self.assertRaises(ValueError) as context:
-            CloudMLBatchPredictionOperator(**task_args).execute(None)
-        self.assertEquals('Ambiguous model origin.', str(context.exception))
-
-        # Test that a version is given without a model
-        task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
-        task_args['version_name'] = 'bare_version'
-        with self.assertRaises(ValueError) as context:
-            CloudMLBatchPredictionOperator(**task_args).execute(None)
-        self.assertEquals(
-            'Missing model origin.',
-            str(context.exception))
-
-        # Test that none of uri, model, model/version is given
-        task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
-        with self.assertRaises(ValueError) as context:
-            CloudMLBatchPredictionOperator(**task_args).execute(None)
-        self.assertEquals(
-            'Missing model origin.',
-            str(context.exception))
-
-    def testHttpError(self):
-        http_error_code = 403
-
-        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
-                as mock_hook:
-            input_with_model = self.INPUT_MISSING_ORIGIN.copy()
-            input_with_model['modelName'] = \
-                'projects/experimental/models/test_model'
-
-            hook_instance = mock_hook.return_value
-            hook_instance.create_job.side_effect = errors.HttpError(
-                resp=httplib2.Response({
-                    'status': http_error_code
-                }), content=b'Forbidden')
-
-            with self.assertRaises(errors.HttpError) as context:
-                prediction_task = CloudMLBatchPredictionOperator(
-                    job_id='test_prediction',
-                    project_id='test-project',
-                    region=input_with_model['region'],
-                    data_format=input_with_model['dataFormat'],
-                    input_paths=input_with_model['inputPaths'],
-                    output_path=input_with_model['outputPath'],
-                    model_name=input_with_model['modelName'].split('/')[-1],
-                    dag=self.dag,
-                    task_id='test-prediction')
-                prediction_task.execute(None)
-
-                mock_hook.assert_called_with('google_cloud_default', None)
-                hook_instance.create_job.assert_called_with(
-                    'test-project',
-                    {
-                        'jobId': 'test_prediction',
-                        'predictionInput': input_with_model
-                    }, ANY)
-
-            self.assertEquals(http_error_code, context.exception.resp.status)
-
-    def testFailedJobError(self):
-        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
-                as mock_hook:
-            hook_instance = mock_hook.return_value
-            hook_instance.create_job.return_value = {
-                'state': 'FAILED',
-                'errorMessage': 'A failure message'
-            }
-            task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy()
-            task_args['uri'] = 'a uri'
-
-            with self.assertRaises(RuntimeError) as context:
-                CloudMLBatchPredictionOperator(**task_args).execute(None)
-
-            self.assertEquals('A failure message', str(context.exception))
-
-
-class CloudMLTrainingOperatorTest(unittest.TestCase):
-    TRAINING_DEFAULT_ARGS = {
-        'project_id': 'test-project',
-        'job_id': 'test_training',
-        'package_uris': ['gs://some-bucket/package1'],
-        'training_python_module': 'trainer',
-        'training_args': '--some_arg=\'aaa\'',
-        'region': 'us-east1',
-        'scale_tier': 'STANDARD_1',
-        'task_id': 'test-training'
-    }
-    TRAINING_INPUT = {
-        'jobId': 'test_training',
-        'trainingInput': {
-            'scaleTier': 'STANDARD_1',
-            'packageUris': ['gs://some-bucket/package1'],
-            'pythonModule': 'trainer',
-            'args': '--some_arg=\'aaa\'',
-            'region': 'us-east1'
-        }
-    }
-
-    def testSuccessCreateTrainingJob(self):
-        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
-                as mock_hook:
-            success_response = self.TRAINING_INPUT.copy()
-            success_response['state'] = 'SUCCEEDED'
-            hook_instance = mock_hook.return_value
-            hook_instance.create_job.return_value = success_response
-
-            training_op = CloudMLTrainingOperator(**self.TRAINING_DEFAULT_ARGS)
-            training_op.execute(None)
-
-            mock_hook.assert_called_with(gcp_conn_id='google_cloud_default',
-                                         delegate_to=None)
-            # Make sure only 'create_job' is invoked on hook instance
-            self.assertEquals(len(hook_instance.mock_calls), 1)
-            hook_instance.create_job.assert_called_with(
-                'test-project', self.TRAINING_INPUT, ANY)
-
-    def testHttpError(self):
-        http_error_code = 403
-        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
-                as mock_hook:
-            hook_instance = mock_hook.return_value
-            hook_instance.create_job.side_effect = errors.HttpError(
-                resp=httplib2.Response({
-                    'status': http_error_code
-                }), content=b'Forbidden')
-
-            with self.assertRaises(errors.HttpError) as context:
-                training_op = CloudMLTrainingOperator(
-                    **self.TRAINING_DEFAULT_ARGS)
-                training_op.execute(None)
-
-            mock_hook.assert_called_with(
-                gcp_conn_id='google_cloud_default', delegate_to=None)
-            # Make sure only 'create_job' is invoked on hook instance
-            self.assertEquals(len(hook_instance.mock_calls), 1)
-            hook_instance.create_job.assert_called_with(
-                'test-project', self.TRAINING_INPUT, ANY)
-            self.assertEquals(http_error_code, context.exception.resp.status)
-
-    def testFailedJobError(self):
-        with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \
-                as mock_hook:
-            failure_response = self.TRAINING_INPUT.copy()
-            failure_response['state'] = 'FAILED'
-            failure_response['errorMessage'] = 'A failure message'
-            hook_instance = mock_hook.return_value
-            hook_instance.create_job.return_value = failure_response
-
-            with self.assertRaises(RuntimeError) as context:
-                training_op = CloudMLTrainingOperator(
-                    **self.TRAINING_DEFAULT_ARGS)
-                training_op.execute(None)
-
-            mock_hook.assert_called_with(
-                gcp_conn_id='google_cloud_default', delegate_to=None)
-            # Make sure only 'create_job' is invoked on hook instance
-            self.assertEquals(len(hook_instance.mock_calls), 1)
-            hook_instance.create_job.assert_called_with(
-                'test-project', self.TRAINING_INPUT, ANY)
-            self.assertEquals('A failure message', str(context.exception))
-
-
-if __name__ == '__main__':
-    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/operators/test_cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator_utils.py b/tests/contrib/operators/test_cloudml_operator_utils.py
deleted file mode 100644
index b2a5a30..0000000
--- a/tests/contrib/operators/test_cloudml_operator_utils.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import datetime
-import unittest
-
-from airflow import configuration, DAG
-from airflow.contrib.operators import cloudml_operator_utils
-from airflow.contrib.operators.cloudml_operator_utils import create_evaluate_ops
-from airflow.exceptions import AirflowException
-
-from mock import ANY
-from mock import patch
-
-DEFAULT_DATE = datetime.datetime(2017, 6, 6)
-
-
-class CreateEvaluateOpsTest(unittest.TestCase):
-
-    INPUT_MISSING_ORIGIN = {
-        'dataFormat': 'TEXT',
-        'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
-        'outputPath': 'gs://legal-bucket/fake-output-path',
-        'region': 'us-east1',
-        'versionName': 'projects/test-project/models/test_model/versions/test_version',
-    }
-    SUCCESS_MESSAGE_MISSING_INPUT = {
-        'jobId': 'eval_test_prediction',
-        'predictionOutput': {
-            'outputPath': 'gs://fake-output-path',
-            'predictionCount': 5000,
-            'errorCount': 0,
-            'nodeHours': 2.78
-        },
-        'state': 'SUCCEEDED'
-    }
-
-    def setUp(self):
-        super(CreateEvaluateOpsTest, self).setUp()
-        configuration.load_test_config()
-        self.dag = DAG(
-            'test_dag',
-            default_args={
-                'owner': 'airflow',
-                'start_date': DEFAULT_DATE,
-                'end_date': DEFAULT_DATE,
-                'project_id': 'test-project',
-                'region': 'us-east1',
-                'model_name': 'test_model',
-                'version_name': 'test_version',
-            },
-            schedule_interval='@daily')
-        self.metric_fn = lambda x: (0.1,)
-        self.metric_fn_encoded = cloudml_operator_utils.base64.b64encode(
-            cloudml_operator_utils.dill.dumps(self.metric_fn, recurse=True))
-
-    def testSuccessfulRun(self):
-        input_with_model = self.INPUT_MISSING_ORIGIN.copy()
-
-        pred, summary, validate = create_evaluate_ops(
-            task_prefix='eval-test',
-            batch_prediction_job_id='eval-test-prediction',
-            data_format=input_with_model['dataFormat'],
-            input_paths=input_with_model['inputPaths'],
-            prediction_path=input_with_model['outputPath'],
-            metric_fn_and_keys=(self.metric_fn, ['err']),
-            validate_fn=(lambda x: 'err=%.1f' % x['err']),
-            dag=self.dag)
-
-        with patch('airflow.contrib.operators.cloudml_operator.'
-                   'CloudMLHook') as mock_cloudml_hook:
-
-            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
-            success_message['predictionInput'] = input_with_model
-            hook_instance = mock_cloudml_hook.return_value
-            hook_instance.create_job.return_value = success_message
-            result = pred.execute(None)
-            mock_cloudml_hook.assert_called_with('google_cloud_default', None)
-            hook_instance.create_job.assert_called_once_with(
-                'test-project',
-                {
-                    'jobId': 'eval_test_prediction',
-                    'predictionInput': input_with_model,
-                },
-                ANY)
-            self.assertEqual(success_message['predictionOutput'], result)
-
-        with patch('airflow.contrib.operators.dataflow_operator.'
-                   'DataFlowHook') as mock_dataflow_hook:
-
-            hook_instance = mock_dataflow_hook.return_value
-            hook_instance.start_python_dataflow.return_value = None
-            summary.execute(None)
-            mock_dataflow_hook.assert_called_with(
-                gcp_conn_id='google_cloud_default', delegate_to=None)
-            hook_instance.start_python_dataflow.assert_called_once_with(
-                'eval-test-summary',
-                {
-                    'prediction_path': 'gs://legal-bucket/fake-output-path',
-                    'metric_keys': 'err',
-                    'metric_fn_encoded': self.metric_fn_encoded,
-                },
-                'airflow.contrib.operators.cloudml_prediction_summary',
-                ['-m'])
-
-        with patch('airflow.contrib.operators.cloudml_operator_utils.'
-                   'GoogleCloudStorageHook') as mock_gcs_hook:
-
-            hook_instance = mock_gcs_hook.return_value
-            hook_instance.download.return_value = '{"err": 0.9, "count": 9}'
-            result = validate.execute({})
-            hook_instance.download.assert_called_once_with(
-                'legal-bucket', 'fake-output-path/prediction.summary.json')
-            self.assertEqual('err=0.9', result)
-
-    def testFailures(self):
-        dag = DAG(
-            'test_dag',
-            default_args={
-                'owner': 'airflow',
-                'start_date': DEFAULT_DATE,
-                'end_date': DEFAULT_DATE,
-                'project_id': 'test-project',
-                'region': 'us-east1',
-            },
-            schedule_interval='@daily')
-
-        input_with_model = self.INPUT_MISSING_ORIGIN.copy()
-        other_params_but_models = {
-            'task_prefix': 'eval-test',
-            'batch_prediction_job_id': 'eval-test-prediction',
-            'data_format': input_with_model['dataFormat'],
-            'input_paths': input_with_model['inputPaths'],
-            'prediction_path': input_with_model['outputPath'],
-            'metric_fn_and_keys': (self.metric_fn, ['err']),
-            'validate_fn': (lambda x: 'err=%.1f' % x['err']),
-            'dag': dag,
-        }
-
-        with self.assertRaisesRegexp(ValueError, 'Missing model origin'):
-            _ = create_evaluate_ops(**other_params_but_models)
-
-        with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
-            _ = create_evaluate_ops(model_uri='abc', model_name='cde',
-                                    **other_params_but_models)
-
-        with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
-            _ = create_evaluate_ops(model_uri='abc', version_name='vvv',
-                                    **other_params_but_models)
-
-        with self.assertRaisesRegexp(AirflowException,
-                                     '`metric_fn` param must be callable'):
-            params = other_params_but_models.copy()
-            params['metric_fn_and_keys'] = (None, ['abc'])
-            _ = create_evaluate_ops(model_uri='gs://blah', **params)
-
-        with self.assertRaisesRegexp(AirflowException,
-                                     '`validate_fn` param must be callable'):
-            params = other_params_but_models.copy()
-            params['validate_fn'] = None
-            _ = create_evaluate_ops(model_uri='gs://blah', **params)
-
-
-if __name__ == '__main__':
-    unittest.main()


Mime
View raw message