airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From criccom...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-1359] Add Google CloudML utils for model evaluation
Date Fri, 14 Jul 2017 00:07:04 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/master e88ecff6a -> 194d1d6e5


[AIRFLOW-1359] Add Google CloudML utils for model evaluation

Closes #2407 from yk5/evaluate


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/194d1d6e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/194d1d6e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/194d1d6e

Branch: refs/heads/master
Commit: 194d1d6e5b89918f22267ae6a86455a0acc771df
Parents: e88ecff
Author: Younghee Kwon <youngheek@google.com>
Authored: Thu Jul 13 17:06:06 2017 -0700
Committer: Chris Riccomini <criccomini@apache.org>
Committed: Thu Jul 13 17:06:56 2017 -0700

----------------------------------------------------------------------
 .../contrib/operators/cloudml_operator_utils.py | 223 +++++++++++++++++++
 .../operators/cloudml_prediction_summary.py     | 177 +++++++++++++++
 .../operators/test_cloudml_operator_utils.py    | 179 +++++++++++++++
 3 files changed, 579 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/194d1d6e/airflow/contrib/operators/cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_operator_utils.py b/airflow/contrib/operators/cloudml_operator_utils.py
new file mode 100644
index 0000000..f4abb32
--- /dev/null
+++ b/airflow/contrib/operators/cloudml_operator_utils.py
@@ -0,0 +1,223 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the 'License'); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import json
+import os
+import re
+try:  # python 2
+    from urlparse import urlsplit
+except ImportError:  # python 3
+    from urllib.parse import urlsplit
+
+import dill
+
+from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
+from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator
+from airflow.contrib.operators.cloudml_operator import _normalize_cloudml_job_id
+from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
+from airflow.exceptions import AirflowException
+from airflow.operators.python_operator import PythonOperator
+
+
+def create_evaluate_ops(task_prefix,
+                        project_id,
+                        job_id,
+                        region,
+                        data_format,
+                        input_paths,
+                        prediction_path,
+                        metric_fn_and_keys,
+                        validate_fn,
+                        dataflow_options,
+                        model_uri=None,
+                        model_name=None,
+                        version_name=None,
+                        dag=None):
+    """
+    Creates Operators needed for model evaluation and returns.
+
+    It gets prediction over inputs via Cloud ML Engine BatchPrediction API by
+    calling CloudMLBatchPredictionOperator, then summarize and validate
+    the result via Cloud Dataflow using DataFlowPythonOperator.
+
+    For details and pricing about Batch prediction, please refer to the website
+    https://cloud.google.com/ml-engine/docs/how-tos/batch-predict
+    and for Cloud Dataflow, https://cloud.google.com/dataflow/docs/
+
+    It returns three chained operators for prediction, summary, and validation,
+    named as <prefix>-prediction, <prefix>-summary, and <prefix>-validation,
+    respectively.
+    (<prefix> should contain only alphanumeric characters or hyphen.)
+
+    The upstream and downstream can be set accordingly like:
+      pred, _, val = create_evaluate_ops(...)
+      pred.set_upstream(upstream_op)
+      ...
+      downstream_op.set_upstream(val)
+
+    Callers will provide two python callables, metric_fn and validate_fn, in
+    order to customize the evaluation behavior as they wish.
+    - metric_fn receives a dictionary per instance derived from json in the
+      batch prediction result. The keys might vary depending on the model.
+      It should return a tuple of metrics.
+    - validation_fn receives a dictionary of the averaged metrics that metric_fn
+      generated over all instances.
+      The key/value of the dictionary matches to what's given by
+      metric_fn_and_keys arg.
+      The dictionary contains an additional metric, 'count' to represent the
+      total number of instances received for evaluation.
+      The function would raise an exception to mark the task as failed, in a
+      case the validation result is not okay to proceed (i.e. to set the trained
+      version as default).
+
+    Typical examples are like this:
+
+    def get_metric_fn_and_keys():
+        import math  # imports should be outside of the metric_fn below.
+        def error_and_squared_error(inst):
+            label = float(inst['input_label'])
+            classes = float(inst['classes'])  # 0 or 1
+            err = abs(classes-label)
+            squared_err = math.pow(classes-label, 2)
+            return (err, squared_err)  # returns a tuple.
+        return error_and_squared_error, ['err', 'mse']  # key order must match.
+
+    def validate_err_and_count(summary):
+        if summary['err'] > 0.2:
+            raise ValueError('Too high err>0.2; summary=%s' % summary)
+        if summary['mse'] > 0.05:
+            raise ValueError('Too high mse>0.05; summary=%s' % summary)
+        if summary['count'] < 1000:
+            raise ValueError('Too few instances<1000; summary=%s' % summary)
+        return summary
+
+    For the details on the other BatchPrediction-related arguments (project_id,
+    job_id, region, data_format, input_paths, prediction_path, model_uri),
+    please refer to CloudMLBatchPredictionOperator too.
+
+    :param task_prefix: a prefix for the tasks. Only alphanumeric characters and
+        hyphen are allowed (no underscores), since this will be used as dataflow
+        job name, which doesn't allow other characters.
+    :type task_prefix: string
+
+    :param model_uri: GCS path of the model exported by Tensorflow using
+        tensorflow.estimator.export_savedmodel(). It cannot be used with
+        model_name or version_name below. See CloudMLBatchPredictionOperator for
+        more detail.
+    :type model_uri: string
+
+    :param model_name: Used to indicate a model to use for prediction. Can be
+        used in combination with version_name, but cannot be used together with
+        model_uri. See CloudMLBatchPredictionOperator for more detail.
+    :type model_name: string
+
+    :param version_name: Used to indicate a model version to use for prediciton,
+        in combination with model_name. Cannot be used together with model_uri.
+        See CloudMLBatchPredictionOperator for more detail.
+    :type version_name: string
+
+    :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP'
+    :type data_format: string
+
+    :param input_paths: a list of input paths to be sent to BatchPrediction.
+    :type input_paths: list of strings
+
+    :param prediction_path: GCS path to put the prediction results in.
+    :type prediction_path: string
+
+    :param metric_fn_and_keys: a tuple of metric_fn and metric_keys:
+        - metric_fn is a function that accepts a dictionary (for an instance),
+          and returns a tuple of metric(s) that it calculates.
+        - metric_keys is a list of strings to denote the key of each metric.
+    :type metric_fn_and_keys: tuple of a function and a list of strings
+
+    :param validate_fn: a function to validate whether the averaged metric(s) is
+        good enough to push the model.
+    :type validate_fn: function
+
+    :param dataflow_options: options to run Dataflow jobs.
+    :type dataflow_options: dictionary
+
+    :returns: a tuple of three operators, (prediction, summary, validation)
+    :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator,
+                  PythonOperator)
+    """
+
+    # Verify that task_prefix doesn't have any special characters except hyphen
+    # '-', which is the only allowed non-alphanumeric character by Dataflow.
+    if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix):
+        raise AirflowException(
+            "Malformed task_id for DataFlowPythonOperator (only alphanumeric "
+            "and hyphens are allowed but got: " + task_prefix)
+
+    metric_fn, metric_keys = metric_fn_and_keys
+    if not callable(metric_fn):
+        raise AirflowException("`metric_fn` param must be callable.")
+    if not callable(validate_fn):
+        raise AirflowException("`validate_fn` param must be callable.")
+
+    evaluate_prediction = CloudMLBatchPredictionOperator(
+        task_id=(task_prefix + "-prediction"),
+        project_id=project_id,
+        job_id=_normalize_cloudml_job_id(job_id),
+        region=region,
+        data_format=data_format,
+        input_paths=input_paths,
+        output_path=prediction_path,
+        uri=model_uri,
+        model_name=model_name,
+        version_name=version_name,
+        dag=dag)
+
+    metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True))
+    evaluate_summary = DataFlowPythonOperator(
+        task_id=(task_prefix + "-summary"),
+        py_options=["-m"],
+        py_file="airflow.contrib.operators.cloudml_prediction_summary",
+        dataflow_default_options=dataflow_options,
+        options={
+            "prediction_path": prediction_path,
+            "metric_fn_encoded": metric_fn_encoded,
+            "metric_keys": ','.join(metric_keys)
+        },
+        dag=dag)
+    # TODO: "options" is not template_field of DataFlowPythonOperator (not sure
+    # if intended or by mistake); consider fixing in the DataFlowPythonOperator.
+    evaluate_summary.template_fields.append("options")
+    evaluate_summary.set_upstream(evaluate_prediction)
+
+    def apply_validate_fn(*args, **kwargs):
+        prediction_path = kwargs["templates_dict"]["prediction_path"]
+        scheme, bucket, obj, _, _ = urlsplit(prediction_path)
+        if scheme != "gs" or not bucket or not obj:
+            raise ValueError("Wrong format prediction_path: %s",
+                             prediction_path)
+        summary = os.path.join(obj.strip("/"),
+                               "prediction.summary.json")
+        gcs_hook = GoogleCloudStorageHook()
+        summary = json.loads(gcs_hook.download(bucket, summary))
+        return validate_fn(summary)
+
+    evaluate_validation = PythonOperator(
+        task_id=(task_prefix + "-validation"),
+        python_callable=apply_validate_fn,
+        provide_context=True,
+        templates_dict={"prediction_path": prediction_path},
+        dag=dag)
+    evaluate_validation.set_upstream(evaluate_summary)
+
+    return evaluate_prediction, evaluate_summary, evaluate_validation

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/194d1d6e/airflow/contrib/operators/cloudml_prediction_summary.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cloudml_prediction_summary.py b/airflow/contrib/operators/cloudml_prediction_summary.py
new file mode 100644
index 0000000..3128dc3
--- /dev/null
+++ b/airflow/contrib/operators/cloudml_prediction_summary.py
@@ -0,0 +1,177 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the 'License'); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A template called by DataFlowPythonOperator to summarize BatchPrediction.
+
+It accepts a user function to calculate the metric(s) per instance in
+the prediction results, then aggregates to output as a summary.
+
+Args:
+  --prediction_path:
+      The GCS folder that contains BatchPrediction results, containing
+      prediction.results-NNNNN-of-NNNNN files in the json format.
+      Output will be also stored in this folder, as 'prediction.summary.json'.
+
+  --metric_fn_encoded:
+      An encoded function that calculates and returns a tuple of metric(s)
+      for a given instance (as a dictionary). It should be encoded
+      via base64.b64encode(dill.dumps(fn, recurse=True)).
+
+  --metric_keys:
+      A comma-separated key(s) of the aggregated metric(s) in the summary
+      output. The order and the size of the keys must match to the output
+      of metric_fn.
+      The summary will have an additional key, 'count', to represent the
+      total number of instances, so the keys shouldn't include 'count'.
+
+# Usage example:
+def get_metric_fn():
+    import math  # all imports must be outside of the function to be passed.
+    def metric_fn(inst):
+        label = float(inst["input_label"])
+        classes = float(inst["classes"])
+        prediction = float(inst["scores"][1])
+        log_loss = math.log(1 + math.exp(
+            -(label * 2 - 1) * math.log(prediction / (1 - prediction))))
+        squared_err = (classes-label)**2
+        return (log_loss, squared_err)
+    return metric_fn
+metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True))
+
+airflow.contrib.operators.DataFlowPythonOperator(
+    task_id="summary-prediction",
+    py_options=["-m"],
+    py_file="airflow.contrib.operators.cloudml_prediction_summary",
+    options={
+        "prediction_path": prediction_path,
+        "metric_fn_encoded": metric_fn_encoded,
+        "metric_keys": "log_loss,mse"
+    },
+    dataflow_default_options={
+        "project": "xxx", "region": "us-east1",
+        "staging_location": "gs://yy", "temp_location": "gs://zz",
+    })
+    >> dag
+
+# When the input file is like the following:
+{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]}
+{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]}
+{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]}
+{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]}
+
+# The output file will be:
+{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25}
+
+# To test outside of the dag:
+subprocess.check_call(["python",
+                       "-m",
+                       "airflow.contrib.operators.cloudml_prediction_summary",
+                       "--prediction_path=gs://...",
+                       "--metric_fn_encoded=" + metric_fn_encoded,
+                       "--metric_keys=log_loss,mse",
+                       "--runner=DataflowRunner",
+                       "--staging_location=gs://...",
+                       "--temp_location=gs://...",
+                       ])
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import base64
+import json
+import logging
+import os
+
+import apache_beam as beam
+import dill
+
+
+class JsonCoder(object):
+    def encode(self, x):
+        return json.dumps(x)
+
+    def decode(self, x):
+        return json.loads(x)
+
+
+@beam.ptransform_fn
+def MakeSummary(pcoll, metric_fn, metric_keys):  # pylint: disable=invalid-name
+    return (
+        pcoll
+        | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn)
+        | "PairWith1" >> beam.Map(lambda tup: tup + (1,))
+        | "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn(
+            *([sum] * (len(metric_keys) + 1))))
+        | "AverageAndMakeDict" >> beam.Map(
+            lambda tup: dict(
+                [(name, tup[i]/tup[-1]) for i, name in enumerate(metric_keys)] +
+                [("count", tup[-1])])))
+
+
+def run(argv=None):
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--prediction_path", required=True,
+        help=(
+            "The GCS folder that contains BatchPrediction results, containing "
+            "prediction.results-NNNNN-of-NNNNN files in the json format. "
+            "Output will be also stored in this folder, as a file"
+            "'prediction.summary.json'."))
+    parser.add_argument(
+        "--metric_fn_encoded", required=True,
+        help=(
+            "An encoded function that calculates and returns a tuple of "
+            "metric(s) for a given instance (as a dictionary). It should be "
+            "encoded via base64.b64encode(dill.dumps(fn, recurse=True))."))
+    parser.add_argument(
+        "--metric_keys", required=True,
+        help=(
+            "A comma-separated keys of the aggregated metric(s) in the summary "
+            "output. The order and the size of the keys must match to the "
+            "output of metric_fn. The summary will have an additional key, "
+            "'count', to represent the total number of instances, so this flag "
+            "shouldn't include 'count'."))
+    known_args, pipeline_args = parser.parse_known_args(argv)
+
+    metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded))
+    if not callable(metric_fn):
+        raise ValueError("--metric_fn_encoded must be an encoded callable.")
+    metric_keys = known_args.metric_keys.split(",")
+
+    with beam.Pipeline(
+        options=beam.pipeline.PipelineOptions(pipeline_args)) as p:
+        # This is apache-beam ptransform's convention
+        # pylint: disable=no-value-for-parameter
+        _ = (p
+             | "ReadPredictionResult" >> beam.io.ReadFromText(
+                 os.path.join(known_args.prediction_path,
+                              "prediction.results-*-of-*"),
+                 coder=JsonCoder())
+             | "Summary" >> MakeSummary(metric_fn, metric_keys)
+             | "Write" >> beam.io.WriteToText(
+                 os.path.join(known_args.prediction_path,
+                              "prediction.summary.json"),
+                 shard_name_template='',  # without trailing -NNNNN-of-NNNNN.
+                 coder=JsonCoder()))
+        # pylint: enable=no-value-for-parameter
+
+
+if __name__ == "__main__":
+    logging.getLogger().setLevel(logging.INFO)
+    run()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/194d1d6e/tests/contrib/operators/test_cloudml_operator_utils.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cloudml_operator_utils.py b/tests/contrib/operators/test_cloudml_operator_utils.py
new file mode 100644
index 0000000..91a9f77
--- /dev/null
+++ b/tests/contrib/operators/test_cloudml_operator_utils.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import datetime
+import unittest
+
+from airflow import configuration, DAG
+from airflow.contrib.operators import cloudml_operator_utils
+from airflow.contrib.operators.cloudml_operator_utils import create_evaluate_ops
+from airflow.exceptions import AirflowException
+
+from mock import ANY
+from mock import patch
+
+DEFAULT_DATE = datetime.datetime(2017, 6, 6)
+
+
+class CreateEvaluateOpsTest(unittest.TestCase):
+
+    INPUT_MISSING_ORIGIN = {
+        'dataFormat': 'TEXT',
+        'inputPaths': ['gs://legal-bucket/fake-input-path/*'],
+        'outputPath': 'gs://legal-bucket/fake-output-path',
+        'region': 'us-east1',
+    }
+    SUCCESS_MESSAGE_MISSING_INPUT = {
+        'jobId': 'eval_test_prediction',
+        'predictionOutput': {
+            'outputPath': 'gs://fake-output-path',
+            'predictionCount': 5000,
+            'errorCount': 0,
+            'nodeHours': 2.78
+        },
+        'state': 'SUCCEEDED'
+    }
+
+    def setUp(self):
+        super(CreateEvaluateOpsTest, self).setUp()
+        configuration.load_test_config()
+        self.dag = DAG(
+            'test_dag',
+            default_args={
+                'owner': 'airflow',
+                'start_date': DEFAULT_DATE,
+                'end_date': DEFAULT_DATE,
+            },
+            schedule_interval='@daily')
+        self.metric_fn = lambda x: (0.1,)
+        self.metric_fn_encoded = cloudml_operator_utils.base64.b64encode(
+            cloudml_operator_utils.dill.dumps(self.metric_fn, recurse=True))
+
+
+    def testSuccessfulRun(self):
+        input_with_model = self.INPUT_MISSING_ORIGIN.copy()
+        input_with_model['modelName'] = (
+            'projects/test-project/models/test_model')
+
+        pred, summary, validate = create_evaluate_ops(
+            task_prefix='eval-test',
+            project_id='test-project',
+            job_id='eval-test-prediction',
+            region=input_with_model['region'],
+            data_format=input_with_model['dataFormat'],
+            input_paths=input_with_model['inputPaths'],
+            prediction_path=input_with_model['outputPath'],
+            model_name=input_with_model['modelName'].split('/')[-1],
+            metric_fn_and_keys=(self.metric_fn, ['err']),
+            validate_fn=(lambda x: 'err=%.1f' % x['err']),
+            dataflow_options=None,
+            dag=self.dag)
+
+        with patch('airflow.contrib.operators.cloudml_operator.'
+                   'CloudMLHook') as mock_cloudml_hook:
+
+            success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy()
+            success_message['predictionInput'] = input_with_model
+            hook_instance = mock_cloudml_hook.return_value
+            hook_instance.create_job.return_value = success_message
+            result = pred.execute(None)
+            mock_cloudml_hook.assert_called_with('google_cloud_default', None)
+            hook_instance.create_job.assert_called_once_with(
+                'test-project',
+                {
+                    'jobId': 'eval_test_prediction',
+                    'predictionInput': input_with_model
+                }, ANY)
+            self.assertEqual(success_message['predictionOutput'], result)
+
+        with patch('airflow.contrib.operators.dataflow_operator.'
+                   'DataFlowHook') as mock_dataflow_hook:
+
+            hook_instance = mock_dataflow_hook.return_value
+            hook_instance.start_python_dataflow.return_value = None
+            summary.execute(None)
+            mock_dataflow_hook.assert_called_with(
+                gcp_conn_id='google_cloud_default', delegate_to=None)
+            hook_instance.start_python_dataflow.assert_called_once_with(
+                'eval-test-summary',
+                {
+                    'prediction_path': 'gs://legal-bucket/fake-output-path',
+                    'metric_keys': 'err',
+                    'metric_fn_encoded': self.metric_fn_encoded,
+                },
+                'airflow.contrib.operators.cloudml_prediction_summary',
+                ['-m'])
+
+        with patch('airflow.contrib.operators.cloudml_operator_utils.'
+                   'GoogleCloudStorageHook') as mock_gcs_hook:
+
+            hook_instance = mock_gcs_hook.return_value
+            hook_instance.download.return_value = '{"err": 0.9, "count": 9}'
+            result = validate.execute({})
+            hook_instance.download.assert_called_once_with(
+                'legal-bucket', 'fake-output-path/prediction.summary.json')
+            self.assertEqual('err=0.9', result)
+
+    def testFailures(self):
+        input_with_model = self.INPUT_MISSING_ORIGIN.copy()
+        input_with_model['modelName'] = (
+            'projects/test-project/models/test_model')
+
+        other_params_but_models = {
+            'task_prefix': 'eval-test',
+            'project_id': 'test-project',
+            'job_id': 'eval-test-prediction',
+            'region': input_with_model['region'],
+            'data_format': input_with_model['dataFormat'],
+            'input_paths': input_with_model['inputPaths'],
+            'prediction_path': input_with_model['outputPath'],
+            'metric_fn_and_keys': (self.metric_fn, ['err']),
+            'validate_fn': (lambda x: 'err=%.1f' % x['err']),
+            'dataflow_options': None,
+            'dag': self.dag,
+        }
+
+        with self.assertRaisesRegexp(ValueError, 'Missing model origin'):
+            _ = create_evaluate_ops(**other_params_but_models)
+
+        with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
+            _ = create_evaluate_ops(model_uri='abc', model_name='cde',
+                                    **other_params_but_models)
+
+        with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'):
+            _ = create_evaluate_ops(model_uri='abc', version_name='vvv',
+                                    **other_params_but_models)
+
+        with self.assertRaisesRegexp(AirflowException,
+                                     '`metric_fn` param must be callable'):
+            params = other_params_but_models.copy()
+            params['metric_fn_and_keys'] = (None, ['abc'])
+            _ = create_evaluate_ops(model_uri='gs://blah', **params)
+
+        with self.assertRaisesRegexp(AirflowException,
+                                     '`validate_fn` param must be callable'):
+            params = other_params_but_models.copy()
+            params['validate_fn'] = None
+            _ = create_evaluate_ops(model_uri='gs://blah', **params)
+
+
+if __name__ == '__main__':
+    unittest.main()


Mime
View raw message