Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 4586F200CB4 for ; Tue, 27 Jun 2017 18:40:12 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 43F56160BDC; Tue, 27 Jun 2017 16:40:12 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id DF6C9160BD8 for ; Tue, 27 Jun 2017 18:40:10 +0200 (CEST) Received: (qmail 81486 invoked by uid 500); 27 Jun 2017 16:40:10 -0000 Mailing-List: contact commits-help@airflow.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@airflow.incubator.apache.org Delivered-To: mailing list commits@airflow.incubator.apache.org Received: (qmail 81477 invoked by uid 99); 27 Jun 2017 16:40:10 -0000 Received: from pnap-us-west-generic-nat.apache.org (HELO spamd1-us-west.apache.org) (209.188.14.142) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 27 Jun 2017 16:40:10 +0000 Received: from localhost (localhost [127.0.0.1]) by spamd1-us-west.apache.org (ASF Mail Server at spamd1-us-west.apache.org) with ESMTP id AF40AC0DC6 for ; Tue, 27 Jun 2017 16:40:09 +0000 (UTC) X-Virus-Scanned: Debian amavisd-new at spamd1-us-west.apache.org X-Spam-Flag: NO X-Spam-Score: -4.23 X-Spam-Level: X-Spam-Status: No, score=-4.23 tagged_above=-999 required=6.31 tests=[KAM_ASCII_DIVIDERS=0.8, RCVD_IN_DNSWL_HI=-5, RCVD_IN_MSPIKE_H3=-0.01, RCVD_IN_MSPIKE_WL=-0.01, SPF_PASS=-0.001, T_RP_MATCHES_RCVD=-0.01, URIBL_BLOCKED=0.001] autolearn=disabled Received: from mx1-lw-eu.apache.org ([10.40.0.8]) by localhost (spamd1-us-west.apache.org [10.40.0.7]) (amavisd-new, port 10024) with ESMTP id PV76sDRISMZq for ; Tue, 27 Jun 2017 16:40:07 +0000 (UTC) Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx1-lw-eu.apache.org (ASF Mail Server at mx1-lw-eu.apache.org) with SMTP id 685545F613 for ; Tue, 27 Jun 2017 16:40:05 +0000 (UTC) Received: (qmail 79938 invoked by uid 99); 27 Jun 2017 16:40:04 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 27 Jun 2017 16:40:04 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 3DC7BDFB94; Tue, 27 Jun 2017 16:40:04 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: criccomini@apache.org To: commits@airflow.incubator.apache.org Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: incubator-airflow git commit: [AIRFLOW-1273]AIRFLOW-1273] Add Google Cloud ML version and model operators Date: Tue, 27 Jun 2017 16:40:04 +0000 (UTC) archived-at: Tue, 27 Jun 2017 16:40:12 -0000 Repository: incubator-airflow Updated Branches: refs/heads/master e870a8e2c -> 534a0e078 [AIRFLOW-1273]AIRFLOW-1273] Add Google Cloud ML version and model operators Includes Google Cloud ML hooks for version and model operations, and their unit tests. https://issues.apache.org/jira/browse/AIRFLOW-1273 Closes #2379 from N3da/master Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/534a0e07 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/534a0e07 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/534a0e07 Branch: refs/heads/master Commit: 534a0e078af4890761bd48e26075d7c61f7e202e Parents: e870a8e Author: Neda Mirian Authored: Tue Jun 27 09:39:00 2017 -0700 Committer: Chris Riccomini Committed: Tue Jun 27 09:39:00 2017 -0700 ---------------------------------------------------------------------- airflow/contrib/hooks/__init__.py | 1 + airflow/contrib/hooks/gcp_cloudml_hook.py | 167 ++++++++++++++ airflow/contrib/operators/cloudml_operator.py | 178 ++++++++++++++ airflow/utils/db.py | 4 + tests/contrib/hooks/test_gcp_cloudml_hook.py | 255 +++++++++++++++++++++ 5 files changed, 605 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/534a0e07/airflow/contrib/hooks/__init__.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py index 182a49f..4941314 100644 --- a/airflow/contrib/hooks/__init__.py +++ b/airflow/contrib/hooks/__init__.py @@ -40,6 +40,7 @@ _hooks = { 'qubole_hook': ['QuboleHook'], 'gcs_hook': ['GoogleCloudStorageHook'], 'datastore_hook': ['DatastoreHook'], + 'gcp_cloudml_hook': ['CloudMLHook'], 'gcp_dataproc_hook': ['DataProcHook'], 'gcp_dataflow_hook': ['DataFlowHook'], 'spark_submit_operator': ['SparkSubmitOperator'], http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/534a0e07/airflow/contrib/hooks/gcp_cloudml_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/gcp_cloudml_hook.py b/airflow/contrib/hooks/gcp_cloudml_hook.py new file mode 100644 index 0000000..e722b2a --- /dev/null +++ b/airflow/contrib/hooks/gcp_cloudml_hook.py @@ -0,0 +1,167 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import random +import time +from airflow import settings +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook +from apiclient.discovery import build +from apiclient import errors +from oauth2client.client import GoogleCredentials + +logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL) + + +def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func): + + for i in range(0, max_n): + try: + response = request.execute() + if is_error_func(response): + raise ValueError('The response contained an error: {}'.format(response)) + elif is_done_func(response): + logging.info('Operation is done: {}'.format(response)) + return response + else: + time.sleep((2**i) + (random.randint(0, 1000) / 1000)) + except errors.HttpError as e: + if e.resp.status != 429: + logging.info('Something went wrong. Not retrying: {}'.format(e)) + raise e + else: + time.sleep((2**i) + (random.randint(0, 1000) / 1000)) + + +class CloudMLHook(GoogleCloudBaseHook): + + def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None): + super(CloudMLHook, self).__init__(gcp_conn_id, delegate_to) + self._cloudml = self.get_conn() + + def get_conn(self): + """ + Returns a Google CloudML service object. + """ + credentials = GoogleCredentials.get_application_default() + return build('ml', 'v1', credentials=credentials) + + def create_version(self, project_name, model_name, version_spec): + """ + Creates the Version on Cloud ML. + + Returns the operation if the version was created successfully and raises + an error otherwise. + """ + parent_name = 'projects/{}/models/{}'.format(project_name, model_name) + create_request = self._cloudml.projects().models().versions().create( + parent=parent_name, body=version_spec) + response = create_request.execute() + get_request = self._cloudml.projects().operations().get( + name=response['name']) + + return _poll_with_exponential_delay( + request=get_request, + max_n=9, + is_done_func=lambda resp: resp.get('done', False), + is_error_func=lambda resp: resp.get('error', None) is not None) + + def set_default_version(self, project_name, model_name, version_name): + """ + Sets a version to be the default. Blocks until finished. + """ + full_version_name = 'projects/{}/models/{}/versions/{}'.format( + project_name, model_name, version_name) + request = self._cloudml.projects().models().versions().setDefault( + name=full_version_name, body={}) + + try: + response = request.execute() + logging.info('Successfully set version: {} to default'.format(response)) + return response + except errors.HttpError as e: + logging.error('Something went wrong: {}'.format(e)) + raise e + + def list_versions(self, project_name, model_name): + """ + Lists all available versions of a model. Blocks until finished. + """ + result = [] + full_parent_name = 'projects/{}/models/{}'.format( + project_name, model_name) + request = self._cloudml.projects().models().versions().list( + parent=full_parent_name, pageSize=100) + + response = request.execute() + next_page_token = response.get('nextPageToken', None) + result.extend(response.get('versions', [])) + while next_page_token is not None: + next_request = self._cloudml.projects().models().versions().list( + parent=full_parent_name, + pageToken=next_page_token, + pageSize=100) + response = next_request.execute() + next_page_token = response.get('nextPageToken', None) + result.extend(response.get('versions', [])) + time.sleep(5) + return result + + def delete_version(self, project_name, model_name, version_name): + """ + Deletes the given version of a model. Blocks until finished. + """ + full_name = 'projects/{}/models/{}/versions/{}'.format( + project_name, model_name, version_name) + delete_request = self._cloudml.projects().models().versions().delete( + name=full_name) + response = delete_request.execute() + get_request = self._cloudml.projects().operations().get( + name=response['name']) + + return _poll_with_exponential_delay( + request=get_request, + max_n=9, + is_done_func=lambda resp: resp.get('done', False), + is_error_func=lambda resp: resp.get('error', None) is not None) + + def create_model(self, project_name, model): + """ + Create a Model. Blocks until finished. + """ + assert model['name'] is not None and model['name'] is not '' + project = 'projects/{}'.format(project_name) + + request = self._cloudml.projects().models().create( + parent=project, body=model) + return request.execute() + + def get_model(self, project_name, model_name): + """ + Gets a Model. Blocks until finished. + """ + assert model_name is not None and model_name is not '' + full_model_name = 'projects/{}/models/{}'.format( + project_name, model_name) + request = self._cloudml.projects().models().get(name=full_model_name) + try: + return request.execute() + except errors.HttpError as e: + if e.resp.status == 404: + logging.error('Model was not found: {}'.format(e)) + return None + raise e http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/534a0e07/airflow/contrib/operators/cloudml_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/cloudml_operator.py b/airflow/contrib/operators/cloudml_operator.py new file mode 100644 index 0000000..b0b6e91 --- /dev/null +++ b/airflow/contrib/operators/cloudml_operator.py @@ -0,0 +1,178 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the 'License'); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from airflow import settings +from airflow.contrib.hooks.gcp_cloudml_hook import CloudMLHook +from airflow.operators import BaseOperator +from airflow.utils.decorators import apply_defaults + +logging.getLogger('GoogleCloudML').setLevel(settings.LOGGING_LEVEL) + + +class CloudMLVersionOperator(BaseOperator): + """ + Operator for managing a Google Cloud ML version. + + :param model_name: The name of the Google Cloud ML model that the version + belongs to. + :type model_name: string + + :param project_name: The Google Cloud project name to which CloudML + model belongs. + :type project_name: string + + :param version: A dictionary containing the information about the version. + If the `operation` is `create`, `version` should contain all the + information about this version such as name, and deploymentUrl. + If the `operation` is `get` or `delete`, the `version` parameter + should contain the `name` of the version. + If it is None, the only `operation` possible would be `list`. + :type version: dict + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: string + + :param operation: The operation to perform. Available operations are: + 'create': Creates a new version in the model specified by `model_name`, + in which case the `version` parameter should contain all the + information to create that version + (e.g. `name`, `deploymentUrl`). + 'get': Gets full information of a particular version in the model + specified by `model_name`. + The name of the version should be specified in the `version` + parameter. + + 'list': Lists all available versions of the model specified + by `model_name`. + + 'delete': Deletes the version specified in `version` parameter from the + model specified by `model_name`). + The name of the version should be specified in the `version` + parameter. + :type operation: string + + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: string + """ + + + template_fields = [ + '_model_name', + '_version', + ] + + @apply_defaults + def __init__(self, + model_name, + project_name, + version=None, + gcp_conn_id='google_cloud_default', + operation='create', + delegate_to=None, + *args, + **kwargs): + + super(CloudMLVersionOperator, self).__init__(*args, **kwargs) + self._model_name = model_name + self._version = version + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._project_name = project_name + self._operation = operation + + def execute(self, context): + hook = CloudMLHook( + gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) + + if self._operation == 'create': + assert self._version is not None + return hook.create_version(self._project_name, self._model_name, + self._version) + elif self._operation == 'set_default': + return hook.set_default_version( + self._project_name, self._model_name, + self._version['name']) + elif self._operation == 'list': + return hook.list_versions(self._project_name, self._model_name) + elif self._operation == 'delete': + return hook.delete_version(self._project_name, self._model_name, + self._version['name']) + else: + raise ValueError('Unknown operation: {}'.format(self._operation)) + + +class CloudMLModelOperator(BaseOperator): + """ + Operator for managing a Google Cloud ML model. + + :param model: A dictionary containing the information about the model. + If the `operation` is `create`, then the `model` parameter should + contain all the information about this model such as `name`. + + If the `operation` is `get`, the `model` parameter + should contain the `name` of the model. + :type model: dict + + :param project_name: The Google Cloud project name to which CloudML + model belongs. + :type project_name: string + + :param gcp_conn_id: The connection ID to use when fetching connection info. + :type gcp_conn_id: string + + :param operation: The operation to perform. Available operations are: + 'create': Creates a new model as provided by the `model` parameter. + 'get': Gets a particular model where the name is specified in `model`. + + :param delegate_to: The account to impersonate, if any. + For this to work, the service account making the request must have + domain-wide delegation enabled. + :type delegate_to: string + """ + + template_fields = [ + '_model', + ] + + @apply_defaults + def __init__(self, + model, + project_name, + gcp_conn_id='google_cloud_default', + operation='create', + delegate_to=None, + *args, + **kwargs): + super(CloudMLModelOperator, self).__init__(*args, **kwargs) + self._model = model + self._operation = operation + self._gcp_conn_id = gcp_conn_id + self._delegate_to = delegate_to + self._project_name = project_name + + def execute(self, context): + hook = CloudMLHook( + gcp_conn_id=self._gcp_conn_id, delegate_to=self._delegate_to) + if self._operation == 'create': + hook.create_model(self._project_name, self._model) + elif self._operation == 'get': + hook.get_model(self._project_name, self._model['name']) + else: + raise ValueError('Unknown operation: {}'.format(self._operation)) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/534a0e07/airflow/utils/db.py ---------------------------------------------------------------------- diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 54254f6..04b1512 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -130,6 +130,10 @@ def initdb(): schema='hive', port=3400)) merge_conn( models.Connection( + conn_id='google_cloud_default', conn_type='google_cloud_platform', + schema='default',)) + merge_conn( + models.Connection( conn_id='hive_cli_default', conn_type='hive_cli', schema='default',)) merge_conn( http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/534a0e07/tests/contrib/hooks/test_gcp_cloudml_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_gcp_cloudml_hook.py b/tests/contrib/hooks/test_gcp_cloudml_hook.py new file mode 100644 index 0000000..aa50e69 --- /dev/null +++ b/tests/contrib/hooks/test_gcp_cloudml_hook.py @@ -0,0 +1,255 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import mock +import unittest +try: # python 2 + from urlparse import urlparse, parse_qsl +except ImportError: #python 3 + from urllib.parse import urlparse, parse_qsl + +from airflow.contrib.hooks import gcp_cloudml_hook as hook +from apiclient.discovery import build +from apiclient.http import HttpMockSequence +from oauth2client.contrib.gce import HttpAccessTokenRefreshError + +cml_available = True +try: + hook.CloudMLHook().get_conn() +except HttpAccessTokenRefreshError: + cml_available = False + + +class _TestCloudMLHook(object): + + def __init__(self, test_cls, responses, expected_requests): + """ + Init method. + + Usage example: + with _TestCloudMLHook(self, responses, expected_requests) as hook: + self.run_my_test(hook) + + Args: + test_cls: The caller's instance used for test communication. + responses: A list of (dict_response, response_content) tuples. + expected_requests: A list of (uri, http_method, body) tuples. + """ + + self._test_cls = test_cls + self._responses = responses + self._expected_requests = [ + self._normalize_requests_for_comparison(x[0], x[1], x[2]) for x in expected_requests] + self._actual_requests = [] + + def _normalize_requests_for_comparison(self, uri, http_method, body): + parts = urlparse(uri) + return (parts._replace(query=set(parse_qsl(parts.query))), http_method, body) + + def __enter__(self): + http = HttpMockSequence(self._responses) + native_request_method = http.request + + # Collecting requests to validate at __exit__. + def _request_wrapper(*args, **kwargs): + self._actual_requests.append(args + (kwargs['body'],)) + return native_request_method(*args, **kwargs) + + http.request = _request_wrapper + service_mock = build('ml', 'v1', http=http) + with mock.patch.object( + hook.CloudMLHook, 'get_conn', return_value=service_mock): + return hook.CloudMLHook() + + def __exit__(self, *args): + # Propogating exceptions here since assert will silence them. + if any(args): + return None + self._test_cls.assertEquals( + [self._normalize_requests_for_comparison(x[0], x[1], x[2]) for x in self._actual_requests], self._expected_requests) + + +class TestCloudMLHook(unittest.TestCase): + + def setUp(self): + pass + + _SKIP_IF = unittest.skipIf(not cml_available, + 'CloudML is not available to run tests') + _SERVICE_URI_PREFIX = 'https://ml.googleapis.com/v1/' + + @_SKIP_IF + def test_create_version(self): + project = 'test-project' + model_name = 'test-model' + version = 'test-version' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + response_body = {'name': operation_name, 'done': True} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models/{}/versions?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name), 'POST', + '"{}"'.format(version)), + ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name), + 'GET', None), + ] + + with _TestCloudMLHook( + self, + responses=[succeeded_response] * 2, + expected_requests=expected_requests) as cml_hook: + create_version_response = cml_hook.create_version( + project_name=project, model_name=model_name, version_spec=version) + self.assertEquals(create_version_response, response_body) + + @_SKIP_IF + def test_set_default_version(self): + project = 'test-project' + model_name = 'test-model' + version = 'test-version' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + response_body = {'name': operation_name, 'done': True} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models/{}/versions/{}:setDefault?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name, version), 'POST', + '{}'), + ] + + with _TestCloudMLHook( + self, + responses=[succeeded_response], + expected_requests=expected_requests) as cml_hook: + set_default_version_response = cml_hook.set_default_version( + project_name=project, model_name=model_name, version_name=version) + self.assertEquals(set_default_version_response, response_body) + + @_SKIP_IF + def test_list_versions(self): + project = 'test-project' + model_name = 'test-model' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + # This test returns the versions one at a time. + versions = ['ver_{}'.format(ix) for ix in range(3)] + + response_bodies = [{'name': operation_name, 'nextPageToken': ix, 'versions': [ + ver]} for ix, ver in enumerate(versions)] + response_bodies[-1].pop('nextPageToken') + responses = [({'status': '200'}, json.dumps(body)) + for body in response_bodies] + + expected_requests = [ + ('{}projects/{}/models/{}/versions?alt=json&pageSize=100'.format( + self._SERVICE_URI_PREFIX, project, model_name), 'GET', + None), + ] + [ + ('{}projects/{}/models/{}/versions?alt=json&pageToken={}&pageSize=100'.format( + self._SERVICE_URI_PREFIX, project, model_name, ix), 'GET', + None) for ix in range(len(versions) - 1) + ] + + with _TestCloudMLHook( + self, + responses=responses, + expected_requests=expected_requests) as cml_hook: + list_versions_response = cml_hook.list_versions( + project_name=project, model_name=model_name) + self.assertEquals(list_versions_response, versions) + + @_SKIP_IF + def test_delete_version(self): + project = 'test-project' + model_name = 'test-model' + version = 'test-version' + operation_name = 'projects/{}/operations/test-operation'.format( + project) + + not_done_response_body = {'name': operation_name, 'done': False} + done_response_body = {'name': operation_name, 'done': True} + not_done_response = ( + {'status': '200'}, json.dumps(not_done_response_body)) + succeeded_response = ( + {'status': '200'}, json.dumps(done_response_body)) + + expected_requests = [ + ('{}projects/{}/models/{}/versions/{}?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name, version), 'DELETE', + None), + ('{}{}?alt=json'.format(self._SERVICE_URI_PREFIX, operation_name), + 'GET', None), + ] + + with _TestCloudMLHook( + self, + responses=[not_done_response, succeeded_response], + expected_requests=expected_requests) as cml_hook: + delete_version_response = cml_hook.delete_version( + project_name=project, model_name=model_name, version_name=version) + self.assertEquals(delete_version_response, done_response_body) + + @_SKIP_IF + def test_create_model(self): + project = 'test-project' + model_name = 'test-model' + model = { + 'name': model_name, + } + response_body = {} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models?alt=json'.format( + self._SERVICE_URI_PREFIX, project), 'POST', + json.dumps(model)), + ] + + with _TestCloudMLHook( + self, + responses=[succeeded_response], + expected_requests=expected_requests) as cml_hook: + create_model_response = cml_hook.create_model( + project_name=project, model=model) + self.assertEquals(create_model_response, response_body) + + @_SKIP_IF + def test_get_model(self): + project = 'test-project' + model_name = 'test-model' + response_body = {'model': model_name} + succeeded_response = ({'status': '200'}, json.dumps(response_body)) + + expected_requests = [ + ('{}projects/{}/models/{}?alt=json'.format( + self._SERVICE_URI_PREFIX, project, model_name), 'GET', + None), + ] + + with _TestCloudMLHook( + self, + responses=[succeeded_response], + expected_requests=expected_requests) as cml_hook: + get_model_response = cml_hook.get_model( + project_name=project, model_name=model_name) + self.assertEquals(get_model_response, response_body) + + +if __name__ == '__main__': + unittest.main()