Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 11174200CFD for ; Wed, 6 Sep 2017 18:51:48 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 0FAB71609C5; Wed, 6 Sep 2017 16:51:48 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id B2E671609BB for ; Wed, 6 Sep 2017 18:51:45 +0200 (CEST) Received: (qmail 10512 invoked by uid 500); 6 Sep 2017 16:51:43 -0000 Mailing-List: contact commits-help@airflow.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@airflow.incubator.apache.org Delivered-To: mailing list commits@airflow.incubator.apache.org Received: (qmail 10503 invoked by uid 99); 6 Sep 2017 16:51:43 -0000 Received: from pnap-us-west-generic-nat.apache.org (HELO spamd2-us-west.apache.org) (209.188.14.142) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 06 Sep 2017 16:51:43 +0000 Received: from localhost (localhost [127.0.0.1]) by spamd2-us-west.apache.org (ASF Mail Server at spamd2-us-west.apache.org) with ESMTP id B74101A68AD for ; Wed, 6 Sep 2017 16:51:42 +0000 (UTC) X-Virus-Scanned: Debian amavisd-new at spamd2-us-west.apache.org X-Spam-Flag: NO X-Spam-Score: -4.222 X-Spam-Level: X-Spam-Status: No, score=-4.222 tagged_above=-999 required=6.31 tests=[KAM_ASCII_DIVIDERS=0.8, RCVD_IN_DNSWL_HI=-5, RCVD_IN_MSPIKE_H3=-0.01, RCVD_IN_MSPIKE_WL=-0.01, RP_MATCHES_RCVD=-0.001, SPF_PASS=-0.001] autolearn=disabled Received: from mx1-lw-eu.apache.org ([10.40.0.8]) by localhost (spamd2-us-west.apache.org [10.40.0.9]) (amavisd-new, port 10024) with ESMTP id 1NAJj_P57oGt for ; Wed, 6 Sep 2017 16:51:32 +0000 (UTC) Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx1-lw-eu.apache.org (ASF Mail Server at mx1-lw-eu.apache.org) with SMTP id D430661036 for ; Wed, 6 Sep 2017 16:51:29 +0000 (UTC) Received: (qmail 9950 invoked by uid 99); 6 Sep 2017 16:51:29 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 06 Sep 2017 16:51:29 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 00F5FF552A; Wed, 6 Sep 2017 16:51:28 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: criccomini@apache.org To: commits@airflow.incubator.apache.org Date: Wed, 06 Sep 2017 16:51:29 -0000 Message-Id: <2745fab8b3564e05844ab1128bbf4fa4@git.apache.org> In-Reply-To: References: X-Mailer: ASF-Git Admin Mailer Subject: [2/3] incubator-airflow git commit: [AIRFLOW-1567][Airflow-1567] Renamed cloudml hook and operator to mlengine archived-at: Wed, 06 Sep 2017 16:51:48 -0000 http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/mlengine_operator_utils.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/mlengine_operator_utils.py b/airflow/contrib/operators/mlengine_operator_utils.py new file mode 100644 index 0000000..5fda6ae --- /dev/null +++ b/airflow/contrib/operators/mlengine_operator_utils.py @@ -0,0 +1,245 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the 'License'); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import os +import re + +import dill + +from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook +from airflow.contrib.operators.mlengine_operator import MLEngineBatchPredictionOperator +from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator +from airflow.exceptions import AirflowException +from airflow.operators.python_operator import PythonOperator +from six.moves.urllib.parse import urlsplit + +def create_evaluate_ops(task_prefix, + data_format, + input_paths, + prediction_path, + metric_fn_and_keys, + validate_fn, + batch_prediction_job_id=None, + project_id=None, + region=None, + dataflow_options=None, + model_uri=None, + model_name=None, + version_name=None, + dag=None): + """ + Creates Operators needed for model evaluation and returns. + + It gets prediction over inputs via Cloud ML Engine BatchPrediction API by + calling MLEngineBatchPredictionOperator, then summarize and validate + the result via Cloud Dataflow using DataFlowPythonOperator. + + For details and pricing about Batch prediction, please refer to the website + https://cloud.google.com/ml-engine/docs/how-tos/batch-predict + and for Cloud Dataflow, https://cloud.google.com/dataflow/docs/ + + It returns three chained operators for prediction, summary, and validation, + named as -prediction, -summary, and -validation, + respectively. + ( should contain only alphanumeric characters or hyphen.) + + The upstream and downstream can be set accordingly like: + pred, _, val = create_evaluate_ops(...) + pred.set_upstream(upstream_op) + ... + downstream_op.set_upstream(val) + + Callers will provide two python callables, metric_fn and validate_fn, in + order to customize the evaluation behavior as they wish. + - metric_fn receives a dictionary per instance derived from json in the + batch prediction result. The keys might vary depending on the model. + It should return a tuple of metrics. + - validation_fn receives a dictionary of the averaged metrics that metric_fn + generated over all instances. + The key/value of the dictionary matches to what's given by + metric_fn_and_keys arg. + The dictionary contains an additional metric, 'count' to represent the + total number of instances received for evaluation. + The function would raise an exception to mark the task as failed, in a + case the validation result is not okay to proceed (i.e. to set the trained + version as default). + + Typical examples are like this: + + def get_metric_fn_and_keys(): + import math # imports should be outside of the metric_fn below. + def error_and_squared_error(inst): + label = float(inst['input_label']) + classes = float(inst['classes']) # 0 or 1 + err = abs(classes-label) + squared_err = math.pow(classes-label, 2) + return (err, squared_err) # returns a tuple. + return error_and_squared_error, ['err', 'mse'] # key order must match. + + def validate_err_and_count(summary): + if summary['err'] > 0.2: + raise ValueError('Too high err>0.2; summary=%s' % summary) + if summary['mse'] > 0.05: + raise ValueError('Too high mse>0.05; summary=%s' % summary) + if summary['count'] < 1000: + raise ValueError('Too few instances<1000; summary=%s' % summary) + return summary + + For the details on the other BatchPrediction-related arguments (project_id, + job_id, region, data_format, input_paths, prediction_path, model_uri), + please refer to MLEngineBatchPredictionOperator too. + + :param task_prefix: a prefix for the tasks. Only alphanumeric characters and + hyphen are allowed (no underscores), since this will be used as dataflow + job name, which doesn't allow other characters. + :type task_prefix: string + + :param data_format: either of 'TEXT', 'TF_RECORD', 'TF_RECORD_GZIP' + :type data_format: string + + :param input_paths: a list of input paths to be sent to BatchPrediction. + :type input_paths: list of strings + + :param prediction_path: GCS path to put the prediction results in. + :type prediction_path: string + + :param metric_fn_and_keys: a tuple of metric_fn and metric_keys: + - metric_fn is a function that accepts a dictionary (for an instance), + and returns a tuple of metric(s) that it calculates. + - metric_keys is a list of strings to denote the key of each metric. + :type metric_fn_and_keys: tuple of a function and a list of strings + + :param validate_fn: a function to validate whether the averaged metric(s) is + good enough to push the model. + :type validate_fn: function + + :param batch_prediction_job_id: the id to use for the Cloud ML Batch + prediction job. Passed directly to the MLEngineBatchPredictionOperator as + the job_id argument. + :type batch_prediction_job_id: string + + :param project_id: the Google Cloud Platform project id in which to execute + Cloud ML Batch Prediction and Dataflow jobs. If None, then the `dag`'s + `default_args['project_id']` will be used. + :type project_id: string + + :param region: the Google Cloud Platform region in which to execute Cloud ML + Batch Prediction and Dataflow jobs. If None, then the `dag`'s + `default_args['region']` will be used. + :type region: string + + :param dataflow_options: options to run Dataflow jobs. If None, then the + `dag`'s `default_args['dataflow_default_options']` will be used. + :type dataflow_options: dictionary + + :param model_uri: GCS path of the model exported by Tensorflow using + tensorflow.estimator.export_savedmodel(). It cannot be used with + model_name or version_name below. See MLEngineBatchPredictionOperator for + more detail. + :type model_uri: string + + :param model_name: Used to indicate a model to use for prediction. Can be + used in combination with version_name, but cannot be used together with + model_uri. See MLEngineBatchPredictionOperator for more detail. If None, + then the `dag`'s `default_args['model_name']` will be used. + :type model_name: string + + :param version_name: Used to indicate a model version to use for prediciton, + in combination with model_name. Cannot be used together with model_uri. + See MLEngineBatchPredictionOperator for more detail. If None, then the + `dag`'s `default_args['version_name']` will be used. + :type version_name: string + + :param dag: The `DAG` to use for all Operators. + :type dag: airflow.DAG + + :returns: a tuple of three operators, (prediction, summary, validation) + :rtype: tuple(DataFlowPythonOperator, DataFlowPythonOperator, + PythonOperator) + """ + + # Verify that task_prefix doesn't have any special characters except hyphen + # '-', which is the only allowed non-alphanumeric character by Dataflow. + if not re.match(r"^[a-zA-Z][-A-Za-z0-9]*$", task_prefix): + raise AirflowException( + "Malformed task_id for DataFlowPythonOperator (only alphanumeric " + "and hyphens are allowed but got: " + task_prefix) + + metric_fn, metric_keys = metric_fn_and_keys + if not callable(metric_fn): + raise AirflowException("`metric_fn` param must be callable.") + if not callable(validate_fn): + raise AirflowException("`validate_fn` param must be callable.") + + if dag is not None and dag.default_args is not None: + default_args = dag.default_args + project_id = project_id or default_args.get('project_id') + region = region or default_args.get('region') + model_name = model_name or default_args.get('model_name') + version_name = version_name or default_args.get('version_name') + dataflow_options = dataflow_options or \ + default_args.get('dataflow_default_options') + + evaluate_prediction = MLEngineBatchPredictionOperator( + task_id=(task_prefix + "-prediction"), + project_id=project_id, + job_id=batch_prediction_job_id, + region=region, + data_format=data_format, + input_paths=input_paths, + output_path=prediction_path, + uri=model_uri, + model_name=model_name, + version_name=version_name, + dag=dag) + + metric_fn_encoded = base64.b64encode(dill.dumps(metric_fn, recurse=True)) + evaluate_summary = DataFlowPythonOperator( + task_id=(task_prefix + "-summary"), + py_options=["-m"], + py_file="airflow.contrib.operators.mlengine_prediction_summary", + dataflow_default_options=dataflow_options, + options={ + "prediction_path": prediction_path, + "metric_fn_encoded": metric_fn_encoded, + "metric_keys": ','.join(metric_keys) + }, + dag=dag) + evaluate_summary.set_upstream(evaluate_prediction) + + def apply_validate_fn(*args, **kwargs): + prediction_path = kwargs["templates_dict"]["prediction_path"] + scheme, bucket, obj, _, _ = urlsplit(prediction_path) + if scheme != "gs" or not bucket or not obj: + raise ValueError("Wrong format prediction_path: %s", + prediction_path) + summary = os.path.join(obj.strip("/"), + "prediction.summary.json") + gcs_hook = GoogleCloudStorageHook() + summary = json.loads(gcs_hook.download(bucket, summary)) + return validate_fn(summary) + + evaluate_validation = PythonOperator( + task_id=(task_prefix + "-validation"), + python_callable=apply_validate_fn, + provide_context=True, + templates_dict={"prediction_path": prediction_path}, + dag=dag) + evaluate_validation.set_upstream(evaluate_summary) + + return evaluate_prediction, evaluate_summary, evaluate_validation http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/airflow/contrib/operators/mlengine_prediction_summary.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/mlengine_prediction_summary.py b/airflow/contrib/operators/mlengine_prediction_summary.py new file mode 100644 index 0000000..1f4d540 --- /dev/null +++ b/airflow/contrib/operators/mlengine_prediction_summary.py @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the 'License'); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A template called by DataFlowPythonOperator to summarize BatchPrediction. + +It accepts a user function to calculate the metric(s) per instance in +the prediction results, then aggregates to output as a summary. + +Args: + --prediction_path: + The GCS folder that contains BatchPrediction results, containing + prediction.results-NNNNN-of-NNNNN files in the json format. + Output will be also stored in this folder, as 'prediction.summary.json'. + + --metric_fn_encoded: + An encoded function that calculates and returns a tuple of metric(s) + for a given instance (as a dictionary). It should be encoded + via base64.b64encode(dill.dumps(fn, recurse=True)). + + --metric_keys: + A comma-separated key(s) of the aggregated metric(s) in the summary + output. The order and the size of the keys must match to the output + of metric_fn. + The summary will have an additional key, 'count', to represent the + total number of instances, so the keys shouldn't include 'count'. + +# Usage example: +def get_metric_fn(): + import math # all imports must be outside of the function to be passed. + def metric_fn(inst): + label = float(inst["input_label"]) + classes = float(inst["classes"]) + prediction = float(inst["scores"][1]) + log_loss = math.log(1 + math.exp( + -(label * 2 - 1) * math.log(prediction / (1 - prediction)))) + squared_err = (classes-label)**2 + return (log_loss, squared_err) + return metric_fn +metric_fn_encoded = base64.b64encode(dill.dumps(get_metric_fn(), recurse=True)) + +airflow.contrib.operators.DataFlowPythonOperator( + task_id="summary-prediction", + py_options=["-m"], + py_file="airflow.contrib.operators.mlengine_prediction_summary", + options={ + "prediction_path": prediction_path, + "metric_fn_encoded": metric_fn_encoded, + "metric_keys": "log_loss,mse" + }, + dataflow_default_options={ + "project": "xxx", "region": "us-east1", + "staging_location": "gs://yy", "temp_location": "gs://zz", + }) + >> dag + +# When the input file is like the following: +{"inputs": "1,x,y,z", "classes": 1, "scores": [0.1, 0.9]} +{"inputs": "0,o,m,g", "classes": 0, "scores": [0.7, 0.3]} +{"inputs": "1,o,m,w", "classes": 0, "scores": [0.6, 0.4]} +{"inputs": "1,b,r,b", "classes": 1, "scores": [0.2, 0.8]} + +# The output file will be: +{"log_loss": 0.43890510565304547, "count": 4, "mse": 0.25} + +# To test outside of the dag: +subprocess.check_call(["python", + "-m", + "airflow.contrib.operators.mlengine_prediction_summary", + "--prediction_path=gs://...", + "--metric_fn_encoded=" + metric_fn_encoded, + "--metric_keys=log_loss,mse", + "--runner=DataflowRunner", + "--staging_location=gs://...", + "--temp_location=gs://...", + ]) + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import base64 +import json +import logging +import os + +import apache_beam as beam +import dill + + +class JsonCoder(object): + def encode(self, x): + return json.dumps(x) + + def decode(self, x): + return json.loads(x) + + +@beam.ptransform_fn +def MakeSummary(pcoll, metric_fn, metric_keys): # pylint: disable=invalid-name + return ( + pcoll + | "ApplyMetricFnPerInstance" >> beam.Map(metric_fn) + | "PairWith1" >> beam.Map(lambda tup: tup + (1,)) + | "SumTuple" >> beam.CombineGlobally(beam.combiners.TupleCombineFn( + *([sum] * (len(metric_keys) + 1)))) + | "AverageAndMakeDict" >> beam.Map( + lambda tup: dict( + [(name, tup[i]/tup[-1]) for i, name in enumerate(metric_keys)] + + [("count", tup[-1])]))) + + +def run(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "--prediction_path", required=True, + help=( + "The GCS folder that contains BatchPrediction results, containing " + "prediction.results-NNNNN-of-NNNNN files in the json format. " + "Output will be also stored in this folder, as a file" + "'prediction.summary.json'.")) + parser.add_argument( + "--metric_fn_encoded", required=True, + help=( + "An encoded function that calculates and returns a tuple of " + "metric(s) for a given instance (as a dictionary). It should be " + "encoded via base64.b64encode(dill.dumps(fn, recurse=True)).")) + parser.add_argument( + "--metric_keys", required=True, + help=( + "A comma-separated keys of the aggregated metric(s) in the summary " + "output. The order and the size of the keys must match to the " + "output of metric_fn. The summary will have an additional key, " + "'count', to represent the total number of instances, so this flag " + "shouldn't include 'count'.")) + known_args, pipeline_args = parser.parse_known_args(argv) + + metric_fn = dill.loads(base64.b64decode(known_args.metric_fn_encoded)) + if not callable(metric_fn): + raise ValueError("--metric_fn_encoded must be an encoded callable.") + metric_keys = known_args.metric_keys.split(",") + + with beam.Pipeline( + options=beam.pipeline.PipelineOptions(pipeline_args)) as p: + # This is apache-beam ptransform's convention + # pylint: disable=no-value-for-parameter + _ = (p + | "ReadPredictionResult" >> beam.io.ReadFromText( + os.path.join(known_args.prediction_path, + "prediction.results-*-of-*"), + coder=JsonCoder()) + | "Summary" >> MakeSummary(metric_fn, metric_keys) + | "Write" >> beam.io.WriteToText( + os.path.join(known_args.prediction_path, + "prediction.summary.json"), + shard_name_template='', # without trailing -NNNNN-of-NNNNN. + coder=JsonCoder())) + # pylint: enable=no-value-for-parameter + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + run() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/hooks/test_gcp_cloudml_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py deleted file mode 100644 index f56018d..0000000 --- a/tests/contrib/hooks/test_gcp_cloudml_hook.py +++ /dev/null @@ -1,413 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import mock -import unittest - -try: # python 2 - from urlparse import urlparse, parse_qsl -except ImportError: # python 3 - from urllib.parse import urlparse, parse_qsl - -from airflow.contrib.hooks import gcp_cloudml_hook as hook -from apiclient import errors -from apiclient.discovery import build -from apiclient.http import HttpMockSequence -from oauth2client.contrib.gce import HttpAccessTokenRefreshError - -cml_available = True -try: - hook.CloudMLHook().get_conn() -except HttpAccessTokenRefreshError: - cml_available = False - - -class _TestCloudMLHook(object): - - def __init__(self, test_cls, responses, expected_requests): - """ - Init method. - - Usage example: - with _TestCloudMLHook(self, responses, expected_requests) as hook: - self.run_my_test(hook) - - Args: - test_cls: The caller's instance used for test communication. - responses: A list of (dict_response, response_content) tuples. - expected_requests: A list of (uri, http_method, body) tuples. - """ - - self._test_cls = test_cls - self._responses = responses - self._expected_requests = [ - self._normalize_requests_for_comparison(x[0], x[1], x[2]) - for x in expected_requests] - self._actual_requests = [] - - def _normalize_requests_for_comparison(self, uri, http_method, body): - parts = urlparse(uri) - return ( - parts._replace(query=set(parse_qsl(parts.query))), - http_method, - body) - - def __enter__(self): - http = HttpMockSequence(self._responses) - native_request_method = http.request - - # Collecting requests to validate at __exit__. - def _request_wrapper(*args, **kwargs): - self._actual_requests.append(args + (kwargs['body'],)) - return native_request_method(*args, **kwargs) - - http.request = _request_wrapper - service_mock = build('ml', 'v1', http=http) - with mock.patch.object( - hook.CloudMLHook, 'get_conn', return_value=service_mock): - return hook.CloudMLHook() - - def __exit__(self, *args): - # Propogating exceptions here since assert will silence them. - if any(args): - return None - self._test_cls.assertEquals( - [self._normalize_requests_for_comparison(x[0], x[1], x[2]) - for x in self._actual_requests], - self._expected_requests) - - -class TestCloudMLHook(unittest.TestCase): - - def setUp(self): - pass - - _SKIP_IF = unittest.skipIf(not cml_available, - 'CloudML is not available to run tests') - - _SERVICE_URI_PREFIX = 'https://ml.googleapis.com/v1/' - - @_SKIP_IF - def test_create_version(self): - project = 'test-project' - model_name = 'test-model' - version = 'test-version' - operation_name = 'projects/{}/operations/test-operation'.format( - project) - - response_body = {'name': operation_name, 'done': True} - succeeded_response = ({'status': '200'}, json.dumps(response_body)) - - expected_requests = [ - ('{}projects/{}/models/{}/versions?alt=json'.format( - self._SERVICE_URI_PREFIX, project, model_name), 'POST', - '"{}"'.format(version)), - ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name), - 'GET', None), - ] - - with _TestCloudMLHook( - self, - responses=[succeeded_response] * 2, - expected_requests=expected_requests) as cml_hook: - create_version_response = cml_hook.create_version( - project_id=project, model_name=model_name, - version_spec=version) - self.assertEquals(create_version_response, response_body) - - @_SKIP_IF - def test_set_default_version(self): - project = 'test-project' - model_name = 'test-model' - version = 'test-version' - operation_name = 'projects/{}/operations/test-operation'.format( - project) - - response_body = {'name': operation_name, 'done': True} - succeeded_response = ({'status': '200'}, json.dumps(response_body)) - - expected_requests = [ - ('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format( - self._SERVICE_URI_PREFIX, project, model_name, version), - 'POST', '{}'), - ] - - with _TestCloudMLHook( - self, - responses=[succeeded_response], - expected_requests=expected_requests) as cml_hook: - set_default_version_response = cml_hook.set_default_version( - project_id=project, model_name=model_name, - version_name=version) - self.assertEquals(set_default_version_response, response_body) - - @_SKIP_IF - def test_list_versions(self): - project = 'test-project' - model_name = 'test-model' - operation_name = 'projects/{}/operations/test-operation'.format( - project) - - # This test returns the versions one at a time. - versions = ['ver_{}'.format(ix) for ix in range(3)] - - response_bodies = [ - { - 'name': operation_name, - 'nextPageToken': ix, - 'versions': [ver] - } for ix, ver in enumerate(versions)] - response_bodies[-1].pop('nextPageToken') - responses = [({'status': '200'}, json.dumps(body)) - for body in response_bodies] - - expected_requests = [ - ('{}projects/{}/models/{}/versions?alt=json&pageSize=100'.format( - self._SERVICE_URI_PREFIX, project, model_name), 'GET', - None), - ] + [ - ('{}projects/{}/models/{}/versions?alt=json&pageToken={}' - '&pageSize=100'.format( - self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET', - None) for ix in range(len(versions) - 1) - ] - - with _TestCloudMLHook( - self, - responses=responses, - expected_requests=expected_requests) as cml_hook: - list_versions_response = cml_hook.list_versions( - project_id=project, model_name=model_name) - self.assertEquals(list_versions_response, versions) - - @_SKIP_IF - def test_delete_version(self): - project = 'test-project' - model_name = 'test-model' - version = 'test-version' - operation_name = 'projects/{}/operations/test-operation'.format( - project) - - not_done_response_body = {'name': operation_name, 'done': False} - done_response_body = {'name': operation_name, 'done': True} - not_done_response = ( - {'status': '200'}, json.dumps(not_done_response_body)) - succeeded_response = ( - {'status': '200'}, json.dumps(done_response_body)) - - expected_requests = [ - ( - '{}projects/{}/models/{}/versions/{}?alt=json'.format( - self._SERVICE_URI_PREFIX, project, model_name, version), - 'DELETE', - None), - ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name), - 'GET', None), - ] - - with _TestCloudMLHook( - self, - responses=[not_done_response, succeeded_response], - expected_requests=expected_requests) as cml_hook: - delete_version_response = cml_hook.delete_version( - project_id=project, model_name=model_name, - version_name=version) - self.assertEquals(delete_version_response, done_response_body) - - @_SKIP_IF - def test_create_model(self): - project = 'test-project' - model_name = 'test-model' - model = { - 'name': model_name, - } - response_body = {} - succeeded_response = ({'status': '200'}, json.dumps(response_body)) - - expected_requests = [ - ('{}projects/{}/models?alt=json'.format( - self._SERVICE_URI_PREFIX, project), 'POST', - json.dumps(model)), - ] - - with _TestCloudMLHook( - self, - responses=[succeeded_response], - expected_requests=expected_requests) as cml_hook: - create_model_response = cml_hook.create_model( - project_id=project, model=model) - self.assertEquals(create_model_response, response_body) - - @_SKIP_IF - def test_get_model(self): - project = 'test-project' - model_name = 'test-model' - response_body = {'model': model_name} - succeeded_response = ({'status': '200'}, json.dumps(response_body)) - - expected_requests = [ - ('{}projects/{}/models/{}?alt=json'.format( - self._SERVICE_URI_PREFIX, project, model_name), 'GET', - None), - ] - - with _TestCloudMLHook( - self, - responses=[succeeded_response], - expected_requests=expected_requests) as cml_hook: - get_model_response = cml_hook.get_model( - project_id=project, model_name=model_name) - self.assertEquals(get_model_response, response_body) - - @_SKIP_IF - def test_create_cloudml_job(self): - project = 'test-project' - job_id = 'test-job-id' - my_job = { - 'jobId': job_id, - 'foo': 4815162342, - 'state': 'SUCCEEDED', - } - response_body = json.dumps(my_job) - succeeded_response = ({'status': '200'}, response_body) - queued_response = ({'status': '200'}, json.dumps({ - 'jobId': job_id, - 'state': 'QUEUED', - })) - - create_job_request = ('{}projects/{}/jobs?alt=json'.format( - self._SERVICE_URI_PREFIX, project), 'POST', response_body) - ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format( - self._SERVICE_URI_PREFIX, project, job_id), 'GET', None) - expected_requests = [ - create_job_request, - ask_if_done_request, - ask_if_done_request, - ] - responses = [succeeded_response, - queued_response, succeeded_response] - - with _TestCloudMLHook( - self, - responses=responses, - expected_requests=expected_requests) as cml_hook: - create_job_response = cml_hook.create_job( - project_id=project, job=my_job) - self.assertEquals(create_job_response, my_job) - - @_SKIP_IF - def test_create_cloudml_job_reuse_existing_job_by_default(self): - project = 'test-project' - job_id = 'test-job-id' - my_job = { - 'jobId': job_id, - 'foo': 4815162342, - 'state': 'SUCCEEDED', - } - response_body = json.dumps(my_job) - job_already_exist_response = ({'status': '409'}, json.dumps({})) - succeeded_response = ({'status': '200'}, response_body) - - create_job_request = ('{}projects/{}/jobs?alt=json'.format( - self._SERVICE_URI_PREFIX, project), 'POST', response_body) - ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format( - self._SERVICE_URI_PREFIX, project, job_id), 'GET', None) - expected_requests = [ - create_job_request, - ask_if_done_request, - ] - responses = [job_already_exist_response, succeeded_response] - - # By default, 'create_job' reuse the existing job. - with _TestCloudMLHook( - self, - responses=responses, - expected_requests=expected_requests) as cml_hook: - create_job_response = cml_hook.create_job( - project_id=project, job=my_job) - self.assertEquals(create_job_response, my_job) - - @_SKIP_IF - def test_create_cloudml_job_check_existing_job(self): - project = 'test-project' - job_id = 'test-job-id' - my_job = { - 'jobId': job_id, - 'foo': 4815162342, - 'state': 'SUCCEEDED', - 'someInput': { - 'input': 'someInput' - } - } - different_job = { - 'jobId': job_id, - 'foo': 4815162342, - 'state': 'SUCCEEDED', - 'someInput': { - 'input': 'someDifferentInput' - } - } - - my_job_response_body = json.dumps(my_job) - different_job_response_body = json.dumps(different_job) - job_already_exist_response = ({'status': '409'}, json.dumps({})) - different_job_response = ({'status': '200'}, - different_job_response_body) - - create_job_request = ('{}projects/{}/jobs?alt=json'.format( - self._SERVICE_URI_PREFIX, project), 'POST', my_job_response_body) - ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format( - self._SERVICE_URI_PREFIX, project, job_id), 'GET', None) - expected_requests = [ - create_job_request, - ask_if_done_request, - ] - - # Returns a different job (with different 'someInput' field) will - # cause 'create_job' request to fail. - responses = [job_already_exist_response, different_job_response] - - def check_input(existing_job): - return existing_job.get('someInput', None) == \ - my_job['someInput'] - with _TestCloudMLHook( - self, - responses=responses, - expected_requests=expected_requests) as cml_hook: - with self.assertRaises(errors.HttpError): - cml_hook.create_job( - project_id=project, job=my_job, - use_existing_job_fn=check_input) - - my_job_response = ({'status': '200'}, my_job_response_body) - expected_requests = [ - create_job_request, - ask_if_done_request, - ask_if_done_request, - ] - responses = [ - job_already_exist_response, - my_job_response, - my_job_response] - with _TestCloudMLHook( - self, - responses=responses, - expected_requests=expected_requests) as cml_hook: - create_job_response = cml_hook.create_job( - project_id=project, job=my_job, - use_existing_job_fn=check_input) - self.assertEquals(create_job_response, my_job) - - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/hooks/test_gcp_mlengine_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_gcp_mlengine_hook.py b/tests/contrib/hooks/test_gcp_mlengine_hook.py new file mode 100644 index 0000000..372d47c --- /dev/null +++ b/tests/contrib/hooks/test_gcp_mlengine_hook.py @@ -0,0 +1,413 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import mock +import unittest + +try: # python 2 + from urlparse import urlparse, parse_qsl +except ImportError: # python 3 + from urllib.parse import urlparse, parse_qsl + +from airflow.contrib.hooks import gcp_mlengine_hook as hook +from apiclient import errors +from apiclient.discovery import build +from apiclient.http import HttpMockSequence +from oauth2client.contrib.gce import HttpAccessTokenRefreshError + +cml_available = True +try: + hook.MLEngineHook().get_conn() +except HttpAccessTokenRefreshError: + cml_available = False + + +class _TestMLEngineHook(object): + + def __init__(self, test_cls, responses, expected_requests): + """ + Init method. + + Usage example: + with _TestMLEngineHook(self, responses, expected_requests) as hook: + self.run_my_test(hook) + + Args: + test_cls: The caller's instance used for test communication. + responses: A list of (dict_response, response_content) tuples. + expected_requests: A list of (uri, http_method, body) tuples. + """ + + self._test_cls = test_cls + self._responses = responses + self._expected_requests = [ + self._normalize_requests_for_comparison(x[0], x[1], x[2]) + for x in expected_requests] + self._actual_requests = [] + + def _normalize_requests_for_comparison(self, uri, http_method, body): + parts = urlparse(uri) + return ( + parts._replace(query=set(parse_qsl(parts.query))), + http_method, + body) + + def __enter__(self): + http = HttpMockSequence(self._responses) + native_request_method = http.request + + # Collecting requests to validate at __exit__. + def _request_wrapper(*args, **kwargs): + self._actual_requests.append(args + (kwargs['body'],)) + return native_request_method(*args, **kwargs) + + http.request = _request_wrapper + service_mock = build('ml', 'v1', http=http) + with mock.patch.object( + hook.MLEngineHook, 'get_conn', return_value=service_mock): + return hook.MLEngineHook() + + def __exit__(self, *args): + # Propogating exceptions here since assert will silence them. + if any(args): + return None + self._test_cls.assertEquals( + [self._normalize_requests_for_comparison(x[0], x[1], x[2]) + for x in self._actual_requests], + self._expected_requests) + + +class TestMLEngineHook(unittest.TestCase): + + def setUp(self): + pass + + _SKIP_IF = unittest.skipIf(not cml_available, + 'MLEngine is not available to run tests') + + _SERVICE_URI_PREFIX = 'https://ml.googleapis.com/v1/' + + @_SKIP_IF + def test_create_version(self): + project = 'test-project' + model_name = 'test-model' + version = 'test-version' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + response_body = {'name': operation_name, 'done': True} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models/{}/versions?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name), 'POST', + '"{}"'.format(version)), + ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name), + 'GET', None), + ] + + with _TestMLEngineHook( + self, + responses=[succeeded_response] * 2, + expected_requests=expected_requests) as cml_hook: + create_version_response = cml_hook.create_version( + project_id=project, model_name=model_name, + version_spec=version) + self.assertEquals(create_version_response, response_body) + + @_SKIP_IF + def test_set_default_version(self): + project = 'test-project' + model_name = 'test-model' + version = 'test-version' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + response_body = {'name': operation_name, 'done': True} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name, version), + 'POST', '{}'), + ] + + with _TestMLEngineHook( + self, + responses=[succeeded_response], + expected_requests=expected_requests) as cml_hook: + set_default_version_response = cml_hook.set_default_version( + project_id=project, model_name=model_name, + version_name=version) + self.assertEquals(set_default_version_response, response_body) + + @_SKIP_IF + def test_list_versions(self): + project = 'test-project' + model_name = 'test-model' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + # This test returns the versions one at a time. + versions = ['ver_{}'.format(ix) for ix in range(3)] + + response_bodies = [ + { + 'name': operation_name, + 'nextPageToken': ix, + 'versions': [ver] + } for ix, ver in enumerate(versions)] + response_bodies[-1].pop('nextPageToken') + responses = [({'status': '200'}, json.dumps(body)) + for body in response_bodies] + + expected_requests = [ + ('{}projects/{}/models/{}/versions?alt=json&pageSize=100'.format( + self._SERVICE_URI_PREFIX, project, model_name), 'GET', + None), + ] + [ + ('{}projects/{}/models/{}/versions?alt=json&pageToken={}' + '&pageSize=100'.format( + self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET', + None) for ix in range(len(versions) - 1) + ] + + with _TestMLEngineHook( + self, + responses=responses, + expected_requests=expected_requests) as cml_hook: + list_versions_response = cml_hook.list_versions( + project_id=project, model_name=model_name) + self.assertEquals(list_versions_response, versions) + + @_SKIP_IF + def test_delete_version(self): + project = 'test-project' + model_name = 'test-model' + version = 'test-version' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + not_done_response_body = {'name': operation_name, 'done': False} + done_response_body = {'name': operation_name, 'done': True} + not_done_response = ( + {'status': '200'}, json.dumps(not_done_response_body)) + succeeded_response = ( + {'status': '200'}, json.dumps(done_response_body)) + + expected_requests = [ + ( + '{}projects/{}/models/{}/versions/{}?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name, version), + 'DELETE', + None), + ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name), + 'GET', None), + ] + + with _TestMLEngineHook( + self, + responses=[not_done_response, succeeded_response], + expected_requests=expected_requests) as cml_hook: + delete_version_response = cml_hook.delete_version( + project_id=project, model_name=model_name, + version_name=version) + self.assertEquals(delete_version_response, done_response_body) + + @_SKIP_IF + def test_create_model(self): + project = 'test-project' + model_name = 'test-model' + model = { + 'name': model_name, + } + response_body = {} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models?alt=json'.format( + self._SERVICE_URI_PREFIX, project), 'POST', + json.dumps(model)), + ] + + with _TestMLEngineHook( + self, + responses=[succeeded_response], + expected_requests=expected_requests) as cml_hook: + create_model_response = cml_hook.create_model( + project_id=project, model=model) + self.assertEquals(create_model_response, response_body) + + @_SKIP_IF + def test_get_model(self): + project = 'test-project' + model_name = 'test-model' + response_body = {'model': model_name} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models/{}?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name), 'GET', + None), + ] + + with _TestMLEngineHook( + self, + responses=[succeeded_response], + expected_requests=expected_requests) as cml_hook: + get_model_response = cml_hook.get_model( + project_id=project, model_name=model_name) + self.assertEquals(get_model_response, response_body) + + @_SKIP_IF + def test_create_mlengine_job(self): + project = 'test-project' + job_id = 'test-job-id' + my_job = { + 'jobId': job_id, + 'foo': 4815162342, + 'state': 'SUCCEEDED', + } + response_body = json.dumps(my_job) + succeeded_response = ({'status': '200'}, response_body) + queued_response = ({'status': '200'}, json.dumps({ + 'jobId': job_id, + 'state': 'QUEUED', + })) + + create_job_request = ('{}projects/{}/jobs?alt=json'.format( + self._SERVICE_URI_PREFIX, project), 'POST', response_body) + ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format( + self._SERVICE_URI_PREFIX, project, job_id), 'GET', None) + expected_requests = [ + create_job_request, + ask_if_done_request, + ask_if_done_request, + ] + responses = [succeeded_response, + queued_response, succeeded_response] + + with _TestMLEngineHook( + self, + responses=responses, + expected_requests=expected_requests) as cml_hook: + create_job_response = cml_hook.create_job( + project_id=project, job=my_job) + self.assertEquals(create_job_response, my_job) + + @_SKIP_IF + def test_create_mlengine_job_reuse_existing_job_by_default(self): + project = 'test-project' + job_id = 'test-job-id' + my_job = { + 'jobId': job_id, + 'foo': 4815162342, + 'state': 'SUCCEEDED', + } + response_body = json.dumps(my_job) + job_already_exist_response = ({'status': '409'}, json.dumps({})) + succeeded_response = ({'status': '200'}, response_body) + + create_job_request = ('{}projects/{}/jobs?alt=json'.format( + self._SERVICE_URI_PREFIX, project), 'POST', response_body) + ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format( + self._SERVICE_URI_PREFIX, project, job_id), 'GET', None) + expected_requests = [ + create_job_request, + ask_if_done_request, + ] + responses = [job_already_exist_response, succeeded_response] + + # By default, 'create_job' reuse the existing job. + with _TestMLEngineHook( + self, + responses=responses, + expected_requests=expected_requests) as cml_hook: + create_job_response = cml_hook.create_job( + project_id=project, job=my_job) + self.assertEquals(create_job_response, my_job) + + @_SKIP_IF + def test_create_mlengine_job_check_existing_job(self): + project = 'test-project' + job_id = 'test-job-id' + my_job = { + 'jobId': job_id, + 'foo': 4815162342, + 'state': 'SUCCEEDED', + 'someInput': { + 'input': 'someInput' + } + } + different_job = { + 'jobId': job_id, + 'foo': 4815162342, + 'state': 'SUCCEEDED', + 'someInput': { + 'input': 'someDifferentInput' + } + } + + my_job_response_body = json.dumps(my_job) + different_job_response_body = json.dumps(different_job) + job_already_exist_response = ({'status': '409'}, json.dumps({})) + different_job_response = ({'status': '200'}, + different_job_response_body) + + create_job_request = ('{}projects/{}/jobs?alt=json'.format( + self._SERVICE_URI_PREFIX, project), 'POST', my_job_response_body) + ask_if_done_request = ('{}projects/{}/jobs/{}?alt=json'.format( + self._SERVICE_URI_PREFIX, project, job_id), 'GET', None) + expected_requests = [ + create_job_request, + ask_if_done_request, + ] + + # Returns a different job (with different 'someInput' field) will + # cause 'create_job' request to fail. + responses = [job_already_exist_response, different_job_response] + + def check_input(existing_job): + return existing_job.get('someInput', None) == \ + my_job['someInput'] + with _TestMLEngineHook( + self, + responses=responses, + expected_requests=expected_requests) as cml_hook: + with self.assertRaises(errors.HttpError): + cml_hook.create_job( + project_id=project, job=my_job, + use_existing_job_fn=check_input) + + my_job_response = ({'status': '200'}, my_job_response_body) + expected_requests = [ + create_job_request, + ask_if_done_request, + ask_if_done_request, + ] + responses = [ + job_already_exist_response, + my_job_response, + my_job_response] + with _TestMLEngineHook( + self, + responses=responses, + expected_requests=expected_requests) as cml_hook: + create_job_response = cml_hook.create_job( + project_id=project, job=my_job, + use_existing_job_fn=check_input) + self.assertEquals(create_job_response, my_job) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/operators/test_cloudml_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_cloudml_operator.py b/tests/contrib/operators/test_cloudml_operator.py deleted file mode 100644 index dc2366e..0000000 --- a/tests/contrib/operators/test_cloudml_operator.py +++ /dev/null @@ -1,373 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import datetime -from apiclient import errors -import httplib2 -import unittest - -from airflow import configuration, DAG -from airflow.contrib.operators.cloudml_operator import CloudMLBatchPredictionOperator -from airflow.contrib.operators.cloudml_operator import CloudMLTrainingOperator - -from mock import ANY -from mock import patch - -DEFAULT_DATE = datetime.datetime(2017, 6, 6) - - -class CloudMLBatchPredictionOperatorTest(unittest.TestCase): - INPUT_MISSING_ORIGIN = { - 'dataFormat': 'TEXT', - 'inputPaths': ['gs://legal-bucket/fake-input-path/*'], - 'outputPath': 'gs://legal-bucket/fake-output-path', - 'region': 'us-east1', - } - SUCCESS_MESSAGE_MISSING_INPUT = { - 'jobId': 'test_prediction', - 'predictionOutput': { - 'outputPath': 'gs://fake-output-path', - 'predictionCount': 5000, - 'errorCount': 0, - 'nodeHours': 2.78 - }, - 'state': 'SUCCEEDED' - } - BATCH_PREDICTION_DEFAULT_ARGS = { - 'project_id': 'test-project', - 'job_id': 'test_prediction', - 'region': 'us-east1', - 'data_format': 'TEXT', - 'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'], - 'output_path': - 'gs://12_legal_bucket_underscore_number/legal-output-path', - 'task_id': 'test-prediction' - } - - def setUp(self): - super(CloudMLBatchPredictionOperatorTest, self).setUp() - configuration.load_test_config() - self.dag = DAG( - 'test_dag', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'end_date': DEFAULT_DATE, - }, - schedule_interval='@daily') - - def testSuccessWithModel(self): - with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ - as mock_hook: - - input_with_model = self.INPUT_MISSING_ORIGIN.copy() - input_with_model['modelName'] = \ - 'projects/test-project/models/test_model' - success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() - success_message['predictionInput'] = input_with_model - - hook_instance = mock_hook.return_value - hook_instance.get_job.side_effect = errors.HttpError( - resp=httplib2.Response({ - 'status': 404 - }), content=b'some bytes') - hook_instance.create_job.return_value = success_message - - prediction_task = CloudMLBatchPredictionOperator( - job_id='test_prediction', - project_id='test-project', - region=input_with_model['region'], - data_format=input_with_model['dataFormat'], - input_paths=input_with_model['inputPaths'], - output_path=input_with_model['outputPath'], - model_name=input_with_model['modelName'].split('/')[-1], - dag=self.dag, - task_id='test-prediction') - prediction_output = prediction_task.execute(None) - - mock_hook.assert_called_with('google_cloud_default', None) - hook_instance.create_job.assert_called_once_with( - 'test-project', - { - 'jobId': 'test_prediction', - 'predictionInput': input_with_model - }, ANY) - self.assertEquals( - success_message['predictionOutput'], - prediction_output) - - def testSuccessWithVersion(self): - with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ - as mock_hook: - - input_with_version = self.INPUT_MISSING_ORIGIN.copy() - input_with_version['versionName'] = \ - 'projects/test-project/models/test_model/versions/test_version' - success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() - success_message['predictionInput'] = input_with_version - - hook_instance = mock_hook.return_value - hook_instance.get_job.side_effect = errors.HttpError( - resp=httplib2.Response({ - 'status': 404 - }), content=b'some bytes') - hook_instance.create_job.return_value = success_message - - prediction_task = CloudMLBatchPredictionOperator( - job_id='test_prediction', project_id='test-project', - region=input_with_version['region'], - data_format=input_with_version['dataFormat'], - input_paths=input_with_version['inputPaths'], - output_path=input_with_version['outputPath'], - model_name=input_with_version['versionName'].split('/')[-3], - version_name=input_with_version['versionName'].split('/')[-1], - dag=self.dag, - task_id='test-prediction') - prediction_output = prediction_task.execute(None) - - mock_hook.assert_called_with('google_cloud_default', None) - hook_instance.create_job.assert_called_with( - 'test-project', - { - 'jobId': 'test_prediction', - 'predictionInput': input_with_version - }, ANY) - self.assertEquals( - success_message['predictionOutput'], - prediction_output) - - def testSuccessWithURI(self): - with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ - as mock_hook: - - input_with_uri = self.INPUT_MISSING_ORIGIN.copy() - input_with_uri['uri'] = 'gs://my_bucket/my_models/savedModel' - success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() - success_message['predictionInput'] = input_with_uri - - hook_instance = mock_hook.return_value - hook_instance.get_job.side_effect = errors.HttpError( - resp=httplib2.Response({ - 'status': 404 - }), content=b'some bytes') - hook_instance.create_job.return_value = success_message - - prediction_task = CloudMLBatchPredictionOperator( - job_id='test_prediction', - project_id='test-project', - region=input_with_uri['region'], - data_format=input_with_uri['dataFormat'], - input_paths=input_with_uri['inputPaths'], - output_path=input_with_uri['outputPath'], - uri=input_with_uri['uri'], - dag=self.dag, - task_id='test-prediction') - prediction_output = prediction_task.execute(None) - - mock_hook.assert_called_with('google_cloud_default', None) - hook_instance.create_job.assert_called_with( - 'test-project', - { - 'jobId': 'test_prediction', - 'predictionInput': input_with_uri - }, ANY) - self.assertEquals( - success_message['predictionOutput'], - prediction_output) - - def testInvalidModelOrigin(self): - # Test that both uri and model is given - task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() - task_args['uri'] = 'gs://fake-uri/saved_model' - task_args['model_name'] = 'fake_model' - with self.assertRaises(ValueError) as context: - CloudMLBatchPredictionOperator(**task_args).execute(None) - self.assertEquals('Ambiguous model origin.', str(context.exception)) - - # Test that both uri and model/version is given - task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() - task_args['uri'] = 'gs://fake-uri/saved_model' - task_args['model_name'] = 'fake_model' - task_args['version_name'] = 'fake_version' - with self.assertRaises(ValueError) as context: - CloudMLBatchPredictionOperator(**task_args).execute(None) - self.assertEquals('Ambiguous model origin.', str(context.exception)) - - # Test that a version is given without a model - task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() - task_args['version_name'] = 'bare_version' - with self.assertRaises(ValueError) as context: - CloudMLBatchPredictionOperator(**task_args).execute(None) - self.assertEquals( - 'Missing model origin.', - str(context.exception)) - - # Test that none of uri, model, model/version is given - task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() - with self.assertRaises(ValueError) as context: - CloudMLBatchPredictionOperator(**task_args).execute(None) - self.assertEquals( - 'Missing model origin.', - str(context.exception)) - - def testHttpError(self): - http_error_code = 403 - - with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ - as mock_hook: - input_with_model = self.INPUT_MISSING_ORIGIN.copy() - input_with_model['modelName'] = \ - 'projects/experimental/models/test_model' - - hook_instance = mock_hook.return_value - hook_instance.create_job.side_effect = errors.HttpError( - resp=httplib2.Response({ - 'status': http_error_code - }), content=b'Forbidden') - - with self.assertRaises(errors.HttpError) as context: - prediction_task = CloudMLBatchPredictionOperator( - job_id='test_prediction', - project_id='test-project', - region=input_with_model['region'], - data_format=input_with_model['dataFormat'], - input_paths=input_with_model['inputPaths'], - output_path=input_with_model['outputPath'], - model_name=input_with_model['modelName'].split('/')[-1], - dag=self.dag, - task_id='test-prediction') - prediction_task.execute(None) - - mock_hook.assert_called_with('google_cloud_default', None) - hook_instance.create_job.assert_called_with( - 'test-project', - { - 'jobId': 'test_prediction', - 'predictionInput': input_with_model - }, ANY) - - self.assertEquals(http_error_code, context.exception.resp.status) - - def testFailedJobError(self): - with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ - as mock_hook: - hook_instance = mock_hook.return_value - hook_instance.create_job.return_value = { - 'state': 'FAILED', - 'errorMessage': 'A failure message' - } - task_args = self.BATCH_PREDICTION_DEFAULT_ARGS.copy() - task_args['uri'] = 'a uri' - - with self.assertRaises(RuntimeError) as context: - CloudMLBatchPredictionOperator(**task_args).execute(None) - - self.assertEquals('A failure message', str(context.exception)) - - -class CloudMLTrainingOperatorTest(unittest.TestCase): - TRAINING_DEFAULT_ARGS = { - 'project_id': 'test-project', - 'job_id': 'test_training', - 'package_uris': ['gs://some-bucket/package1'], - 'training_python_module': 'trainer', - 'training_args': '--some_arg=\'aaa\'', - 'region': 'us-east1', - 'scale_tier': 'STANDARD_1', - 'task_id': 'test-training' - } - TRAINING_INPUT = { - 'jobId': 'test_training', - 'trainingInput': { - 'scaleTier': 'STANDARD_1', - 'packageUris': ['gs://some-bucket/package1'], - 'pythonModule': 'trainer', - 'args': '--some_arg=\'aaa\'', - 'region': 'us-east1' - } - } - - def testSuccessCreateTrainingJob(self): - with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ - as mock_hook: - success_response = self.TRAINING_INPUT.copy() - success_response['state'] = 'SUCCEEDED' - hook_instance = mock_hook.return_value - hook_instance.create_job.return_value = success_response - - training_op = CloudMLTrainingOperator(**self.TRAINING_DEFAULT_ARGS) - training_op.execute(None) - - mock_hook.assert_called_with(gcp_conn_id='google_cloud_default', - delegate_to=None) - # Make sure only 'create_job' is invoked on hook instance - self.assertEquals(len(hook_instance.mock_calls), 1) - hook_instance.create_job.assert_called_with( - 'test-project', self.TRAINING_INPUT, ANY) - - def testHttpError(self): - http_error_code = 403 - with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ - as mock_hook: - hook_instance = mock_hook.return_value - hook_instance.create_job.side_effect = errors.HttpError( - resp=httplib2.Response({ - 'status': http_error_code - }), content=b'Forbidden') - - with self.assertRaises(errors.HttpError) as context: - training_op = CloudMLTrainingOperator( - **self.TRAINING_DEFAULT_ARGS) - training_op.execute(None) - - mock_hook.assert_called_with( - gcp_conn_id='google_cloud_default', delegate_to=None) - # Make sure only 'create_job' is invoked on hook instance - self.assertEquals(len(hook_instance.mock_calls), 1) - hook_instance.create_job.assert_called_with( - 'test-project', self.TRAINING_INPUT, ANY) - self.assertEquals(http_error_code, context.exception.resp.status) - - def testFailedJobError(self): - with patch('airflow.contrib.operators.cloudml_operator.CloudMLHook') \ - as mock_hook: - failure_response = self.TRAINING_INPUT.copy() - failure_response['state'] = 'FAILED' - failure_response['errorMessage'] = 'A failure message' - hook_instance = mock_hook.return_value - hook_instance.create_job.return_value = failure_response - - with self.assertRaises(RuntimeError) as context: - training_op = CloudMLTrainingOperator( - **self.TRAINING_DEFAULT_ARGS) - training_op.execute(None) - - mock_hook.assert_called_with( - gcp_conn_id='google_cloud_default', delegate_to=None) - # Make sure only 'create_job' is invoked on hook instance - self.assertEquals(len(hook_instance.mock_calls), 1) - hook_instance.create_job.assert_called_with( - 'test-project', self.TRAINING_INPUT, ANY) - self.assertEquals('A failure message', str(context.exception)) - - -if __name__ == '__main__': - unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/af91e2ac/tests/contrib/operators/test_cloudml_operator_utils.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_cloudml_operator_utils.py b/tests/contrib/operators/test_cloudml_operator_utils.py deleted file mode 100644 index b2a5a30..0000000 --- a/tests/contrib/operators/test_cloudml_operator_utils.py +++ /dev/null @@ -1,183 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import datetime -import unittest - -from airflow import configuration, DAG -from airflow.contrib.operators import cloudml_operator_utils -from airflow.contrib.operators.cloudml_operator_utils import create_evaluate_ops -from airflow.exceptions import AirflowException - -from mock import ANY -from mock import patch - -DEFAULT_DATE = datetime.datetime(2017, 6, 6) - - -class CreateEvaluateOpsTest(unittest.TestCase): - - INPUT_MISSING_ORIGIN = { - 'dataFormat': 'TEXT', - 'inputPaths': ['gs://legal-bucket/fake-input-path/*'], - 'outputPath': 'gs://legal-bucket/fake-output-path', - 'region': 'us-east1', - 'versionName': 'projects/test-project/models/test_model/versions/test_version', - } - SUCCESS_MESSAGE_MISSING_INPUT = { - 'jobId': 'eval_test_prediction', - 'predictionOutput': { - 'outputPath': 'gs://fake-output-path', - 'predictionCount': 5000, - 'errorCount': 0, - 'nodeHours': 2.78 - }, - 'state': 'SUCCEEDED' - } - - def setUp(self): - super(CreateEvaluateOpsTest, self).setUp() - configuration.load_test_config() - self.dag = DAG( - 'test_dag', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'end_date': DEFAULT_DATE, - 'project_id': 'test-project', - 'region': 'us-east1', - 'model_name': 'test_model', - 'version_name': 'test_version', - }, - schedule_interval='@daily') - self.metric_fn = lambda x: (0.1,) - self.metric_fn_encoded = cloudml_operator_utils.base64.b64encode( - cloudml_operator_utils.dill.dumps(self.metric_fn, recurse=True)) - - def testSuccessfulRun(self): - input_with_model = self.INPUT_MISSING_ORIGIN.copy() - - pred, summary, validate = create_evaluate_ops( - task_prefix='eval-test', - batch_prediction_job_id='eval-test-prediction', - data_format=input_with_model['dataFormat'], - input_paths=input_with_model['inputPaths'], - prediction_path=input_with_model['outputPath'], - metric_fn_and_keys=(self.metric_fn, ['err']), - validate_fn=(lambda x: 'err=%.1f' % x['err']), - dag=self.dag) - - with patch('airflow.contrib.operators.cloudml_operator.' - 'CloudMLHook') as mock_cloudml_hook: - - success_message = self.SUCCESS_MESSAGE_MISSING_INPUT.copy() - success_message['predictionInput'] = input_with_model - hook_instance = mock_cloudml_hook.return_value - hook_instance.create_job.return_value = success_message - result = pred.execute(None) - mock_cloudml_hook.assert_called_with('google_cloud_default', None) - hook_instance.create_job.assert_called_once_with( - 'test-project', - { - 'jobId': 'eval_test_prediction', - 'predictionInput': input_with_model, - }, - ANY) - self.assertEqual(success_message['predictionOutput'], result) - - with patch('airflow.contrib.operators.dataflow_operator.' - 'DataFlowHook') as mock_dataflow_hook: - - hook_instance = mock_dataflow_hook.return_value - hook_instance.start_python_dataflow.return_value = None - summary.execute(None) - mock_dataflow_hook.assert_called_with( - gcp_conn_id='google_cloud_default', delegate_to=None) - hook_instance.start_python_dataflow.assert_called_once_with( - 'eval-test-summary', - { - 'prediction_path': 'gs://legal-bucket/fake-output-path', - 'metric_keys': 'err', - 'metric_fn_encoded': self.metric_fn_encoded, - }, - 'airflow.contrib.operators.cloudml_prediction_summary', - ['-m']) - - with patch('airflow.contrib.operators.cloudml_operator_utils.' - 'GoogleCloudStorageHook') as mock_gcs_hook: - - hook_instance = mock_gcs_hook.return_value - hook_instance.download.return_value = '{"err": 0.9, "count": 9}' - result = validate.execute({}) - hook_instance.download.assert_called_once_with( - 'legal-bucket', 'fake-output-path/prediction.summary.json') - self.assertEqual('err=0.9', result) - - def testFailures(self): - dag = DAG( - 'test_dag', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - 'end_date': DEFAULT_DATE, - 'project_id': 'test-project', - 'region': 'us-east1', - }, - schedule_interval='@daily') - - input_with_model = self.INPUT_MISSING_ORIGIN.copy() - other_params_but_models = { - 'task_prefix': 'eval-test', - 'batch_prediction_job_id': 'eval-test-prediction', - 'data_format': input_with_model['dataFormat'], - 'input_paths': input_with_model['inputPaths'], - 'prediction_path': input_with_model['outputPath'], - 'metric_fn_and_keys': (self.metric_fn, ['err']), - 'validate_fn': (lambda x: 'err=%.1f' % x['err']), - 'dag': dag, - } - - with self.assertRaisesRegexp(ValueError, 'Missing model origin'): - _ = create_evaluate_ops(**other_params_but_models) - - with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'): - _ = create_evaluate_ops(model_uri='abc', model_name='cde', - **other_params_but_models) - - with self.assertRaisesRegexp(ValueError, 'Ambiguous model origin'): - _ = create_evaluate_ops(model_uri='abc', version_name='vvv', - **other_params_but_models) - - with self.assertRaisesRegexp(AirflowException, - '`metric_fn` param must be callable'): - params = other_params_but_models.copy() - params['metric_fn_and_keys'] = (None, ['abc']) - _ = create_evaluate_ops(model_uri='gs://blah', **params) - - with self.assertRaisesRegexp(AirflowException, - '`validate_fn` param must be callable'): - params = other_params_but_models.copy() - params['validate_fn'] = None - _ = create_evaluate_ops(model_uri='gs://blah', **params) - - -if __name__ == '__main__': - unittest.main()