airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From san...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-345] Add contrib ECSOperator
Date Wed, 23 Nov 2016 18:50:08 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/master 41490f9c4 -> 98197d956


[AIRFLOW-345] Add contrib ECSOperator

Closes #1894 from poulainv/ecs_operator


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/98197d95
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/98197d95
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/98197d95

Branch: refs/heads/master
Commit: 98197d95681abaae0ec8f928e0147a8b32132ecb
Parents: 41490f9
Author: Vincent Poulain <vincent.poulain@tinyclues.com>
Authored: Wed Nov 23 10:49:57 2016 -0800
Committer: Siddharth Anand <siddharthanand@yahoo.com>
Committed: Wed Nov 23 10:49:57 2016 -0800

----------------------------------------------------------------------
 airflow/contrib/hooks/aws_hook.py         |  27 +++-
 airflow/contrib/operators/ecs_operator.py | 127 +++++++++++++++
 docs/code.rst                             |   1 +
 tests/contrib/operators/ecs_operator.py   | 207 +++++++++++++++++++++++++
 4 files changed, 356 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/airflow/contrib/hooks/aws_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py
index 37a02ee..3eced28 100644
--- a/airflow/contrib/hooks/aws_hook.py
+++ b/airflow/contrib/hooks/aws_hook.py
@@ -12,24 +12,39 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+
 import boto3
+
+from airflow.exceptions import AirflowException
 from airflow.hooks.base_hook import BaseHook
 
 
 class AwsHook(BaseHook):
     """
     Interact with AWS.
-
     This class is a thin wrapper around the boto3 python library.
     """
     def __init__(self, aws_conn_id='aws_default'):
         self.aws_conn_id = aws_conn_id
 
-    def get_client_type(self, client_type):
-        connection_object = self.get_connection(self.aws_conn_id)
+    def get_client_type(self, client_type, region_name=None):
+        try:
+            connection_object = self.get_connection(self.aws_conn_id)
+            aws_access_key_id = connection_object.login
+            aws_secret_access_key = connection_object.password
+
+            if region_name is None:
+                region_name = connection_object.extra_dejson.get('region_name')
+
+        except AirflowException:
+            # No connection found: fallback on boto3 credential strategy
+            # http://boto3.readthedocs.io/en/latest/guide/configuration.html
+            aws_access_key_id = None
+            aws_secret_access_key = None
+
         return boto3.client(
             client_type,
-            region_name=connection_object.extra_dejson.get('region_name'),
-            aws_access_key_id=connection_object.login,
-            aws_secret_access_key=connection_object.password,
+            region_name=region_name,
+            aws_access_key_id=aws_access_key_id,
+            aws_secret_access_key=aws_secret_access_key
         )

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/airflow/contrib/operators/ecs_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py
new file mode 100644
index 0000000..7415d32
--- /dev/null
+++ b/airflow/contrib/operators/ecs_operator.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import logging
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.utils import apply_defaults
+
+from airflow.contrib.hooks.aws_hook import AwsHook
+
+
+class ECSOperator(BaseOperator):
+
+    """
+    Execute a task on AWS EC2 Container Service
+
+    :param task_definition: the task definition name on EC2 Container Service
+    :type task_definition: str
+    :param cluster: the cluster name on EC2 Container Service
+    :type cluster: str
+    :param: overrides: the same parameter that boto3 will receive: http://boto3.readthedocs.org/en/latest/reference/services/ecs.html#ECS.Client.run_task
+    :type: overrides: dict
+    :param aws_conn_id: connection id of AWS credentials / region name. If None, credential
boto3 strategy will be used (http://boto3.readthedocs.io/en/latest/guide/configuration.html).
+    :type aws_conn_id: str
+    :param region_name: region name to use in AWS Hook. Override the region_name in connection
(if provided)
+    """
+
+    ui_color = '#f0ede4'
+    client = None
+    arn = None
+    template_fields = ('overrides',)
+
+    @apply_defaults
+    def __init__(self, task_definition, cluster, overrides,
+                 aws_conn_id=None, region_name=None, **kwargs):
+        super(ECSOperator, self).__init__(**kwargs)
+
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.task_definition = task_definition
+        self.cluster = cluster
+        self.overrides = overrides
+
+        self.hook = self.get_hook()
+
+    def execute(self, context):
+
+        logging.info('Running ECS Task - Task definition: {} - on cluster {}'.format(
+            self.task_definition,
+            self.cluster
+        ))
+        logging.info('ECSOperator overrides: {}'.format(self.overrides))
+
+        self.client = self.hook.get_client_type(
+            'ecs',
+            region_name=self.region_name
+        )
+
+        response = self.client.run_task(
+            cluster=self.cluster,
+            taskDefinition=self.task_definition,
+            overrides=self.overrides,
+            startedBy=self.owner
+        )
+
+        failures = response['failures']
+        if (len(failures) > 0):
+            raise AirflowException(response)
+        logging.info('ECS Task started: {}'.format(response))
+
+        self.arn = response['tasks'][0]['taskArn']
+        self._wait_for_task_ended()
+
+        self._check_success_task()
+        logging.info('ECS Task has been successfully executed: {}'.format(response))
+
+    def _wait_for_task_ended(self):
+        waiter = self.client.get_waiter('tasks_stopped')
+        waiter.config.max_attempts = sys.maxint  # timeout is managed by airflow
+        waiter.wait(
+            cluster=self.cluster,
+            tasks=[self.arn]
+        )
+
+    def _check_success_task(self):
+        response = self.client.describe_tasks(
+            cluster=self.cluster,
+            tasks=[self.arn]
+        )
+        logging.info('ECS Task stopped, check status: {}'.format(response))
+
+        if (len(response.get('failures', [])) > 0):
+            raise AirflowException(response)
+
+        for task in response['tasks']:
+            containers = task['containers']
+            for container in containers:
+                if container.get('lastStatus') == 'STOPPED' and container['exitCode'] !=
0:
+                    raise AirflowException('This task is not in success state {}'.format(task))
+                elif container.get('lastStatus') == 'PENDING':
+                    raise AirflowException('This task is still pending {}'.format(task))
+                elif 'error' in container.get('reason', '').lower():
+                    raise AirflowException('This containers encounter an error during launching
: {}'.format(container.get('reason', '').lower()))
+
+    def get_hook(self):
+        return AwsHook(
+            aws_conn_id=self.aws_conn_id
+        )
+
+    def on_kill(self):
+        response = self.client.stop_task(
+            cluster=self.cluster,
+            task=self.arn,
+            reason='Task killed by the user')
+        logging.info(response)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/docs/code.rst
----------------------------------------------------------------------
diff --git a/docs/code.rst b/docs/code.rst
index 8548120..0e1993e 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -97,6 +97,7 @@ Community-contributed Operators
 
 .. autoclass:: airflow.contrib.operators.bigquery_operator.BigQueryOperator
 .. autoclass:: airflow.contrib.operators.bigquery_to_gcs.BigQueryToCloudStorageOperator
+.. autoclass:: airflow.contrib.operators.ecs_operator.ECSOperator
 .. autoclass:: airflow.contrib.operators.gcs_download_operator.GoogleCloudStorageDownloadOperator
 .. autoclass:: airflow.contrib.operators.QuboleOperator
 .. autoclass:: airflow.contrib.operators.hipchat_operator.HipChatAPIOperator

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98197d95/tests/contrib/operators/ecs_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/ecs_operator.py b/tests/contrib/operators/ecs_operator.py
new file mode 100644
index 0000000..5a593a6
--- /dev/null
+++ b/tests/contrib/operators/ecs_operator.py
@@ -0,0 +1,207 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import unittest
+from copy import deepcopy
+
+from airflow import configuration
+from airflow.exceptions import AirflowException
+from airflow.contrib.operators.ecs_operator import ECSOperator
+
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+
+
+RESPONSE_WITHOUT_FAILURES = {
+    "failures": [],
+    "tasks": [
+        {
+            "containers": [
+                {
+                    "containerArn": "arn:aws:ecs:us-east-1:012345678910:container/e1ed7aac-d9b2-4315-8726-d2432bf11868",
+                    "lastStatus": "PENDING",
+                    "name": "wordpress",
+                    "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55"
+                }
+            ],
+            "desiredStatus": "RUNNING",
+            "lastStatus": "PENDING",
+            "taskArn": "arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55",
+            "taskDefinitionArn": "arn:aws:ecs:us-east-1:012345678910:task-definition/hello_world:11"
+        }
+    ]
+}
+
+
+class TestECSOperator(unittest.TestCase):
+
+    @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook')
+    def setUp(self, aws_hook_mock):
+        configuration.load_test_config()
+
+        self.aws_hook_mock = aws_hook_mock
+        self.ecs = ECSOperator(
+            task_id='task',
+            task_definition='t',
+            cluster='c',
+            overrides={},
+            aws_conn_id=None,
+            region_name='eu-west-1')
+
+    def test_init(self):
+
+        self.assertEqual(self.ecs.region_name, 'eu-west-1')
+        self.assertEqual(self.ecs.task_definition, 't')
+        self.assertEqual(self.ecs.aws_conn_id, None)
+        self.assertEqual(self.ecs.cluster, 'c')
+        self.assertEqual(self.ecs.overrides, {})
+        self.assertEqual(self.ecs.hook, self.aws_hook_mock.return_value)
+
+        self.aws_hook_mock.assert_called_once_with(aws_conn_id=None)
+
+    def test_template_fields_overrides(self):
+        self.assertEqual(self.ecs.template_fields, ('overrides',))
+
+    @mock.patch.object(ECSOperator, '_wait_for_task_ended')
+    @mock.patch.object(ECSOperator, '_check_success_task')
+    def test_execute_without_failures(self, check_mock, wait_mock):
+
+        client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
+        client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
+
+        self.ecs.execute(None)
+
+        self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1')
+        client_mock.run_task.assert_called_once_with(
+            cluster='c',
+            overrides={},
+            startedBy='Airflow',
+            taskDefinition='t'
+        )
+
+        wait_mock.assert_called_once_with()
+        check_mock.assert_called_once_with()
+        self.assertEqual(self.ecs.arn, 'arn:aws:ecs:us-east-1:012345678910:task/d8c67b3c-ac87-4ffe-a847-4785bc3a8b55')
+
+    def test_execute_with_failures(self):
+
+        client_mock = self.aws_hook_mock.return_value.get_client_type.return_value
+        resp_failures = deepcopy(RESPONSE_WITHOUT_FAILURES)
+        resp_failures['failures'].append('dummy error')
+        client_mock.run_task.return_value = resp_failures
+
+        with self.assertRaises(AirflowException):
+            self.ecs.execute(None)
+
+        self.aws_hook_mock.return_value.get_client_type.assert_called_once_with('ecs', region_name='eu-west-1')
+        client_mock.run_task.assert_called_once_with(
+            cluster='c',
+            overrides={},
+            startedBy='Airflow',
+            taskDefinition='t'
+        )
+
+    def test_wait_end_tasks(self):
+
+        client_mock = mock.Mock()
+        self.ecs.arn = 'arn'
+        self.ecs.client = client_mock
+
+        self.ecs._wait_for_task_ended()
+        client_mock.get_waiter.assert_called_once_with('tasks_stopped')
+        client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn'])
+        self.assertEquals(sys.maxint, client_mock.get_waiter.return_value.config.max_attempts)
+
+    def test_check_success_tasks_raises(self):
+        client_mock = mock.Mock()
+        self.ecs.arn = 'arn'
+        self.ecs.client = client_mock
+
+        client_mock.describe_tasks.return_value = {
+            'tasks': [{
+                'containers': [{
+                    'name': 'foo',
+                    'lastStatus': 'STOPPED',
+                    'exitCode': 1
+                }]
+            }]
+        }
+        with self.assertRaises(Exception) as e:
+            self.ecs._check_success_task()
+
+        self.assertEquals(str(e.exception), "This task is not in success state {'containers':
[{'lastStatus': 'STOPPED', 'name': 'foo', 'exitCode': 1}]}")
+        client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+    def test_check_success_tasks_raises_pending(self):
+        client_mock = mock.Mock()
+        self.ecs.client = client_mock
+        self.ecs.arn = 'arn'
+        client_mock.describe_tasks.return_value = {
+            'tasks': [{
+                'containers': [{
+                    'name': 'container-name',
+                    'lastStatus': 'PENDING'
+                }]
+            }]
+        }
+        with self.assertRaises(Exception) as e:
+            self.ecs._check_success_task()
+        self.assertEquals(str(e.exception), "This task is still pending {'containers': [{'lastStatus':
'PENDING', 'name': 'container-name'}]}")
+        client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+    def test_check_success_tasks_raises_mutliple(self):
+        client_mock = mock.Mock()
+        self.ecs.client = client_mock
+        self.ecs.arn = 'arn'
+        client_mock.describe_tasks.return_value = {
+            'tasks': [{
+                'containers': [{
+                    'name': 'foo',
+                    'exitCode': 1
+                }, {
+                    'name': 'bar',
+                    'lastStatus': 'STOPPED',
+                    'exitCode': 0
+                }]
+            }]
+        }
+        self.ecs._check_success_task()
+        client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+    def test_check_success_task_not_raises(self):
+        client_mock = mock.Mock()
+        self.ecs.client = client_mock
+        self.ecs.arn = 'arn'
+        client_mock.describe_tasks.return_value = {
+            'tasks': [{
+                'containers': [{
+                    'name': 'container-name',
+                    'lastStatus': 'STOPPED',
+                    'exitCode': 0
+                }]
+            }]
+        }
+        self.ecs._check_success_task()
+        client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn'])
+
+
+if __name__ == '__main__':
+    unittest.main()


Mime
View raw message