airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bo...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-1695] Add RedshiftHook using boto3
Date Mon, 30 Oct 2017 19:36:23 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/master c800632bb -> 4fb7a90b3


[AIRFLOW-1695] Add RedshiftHook using boto3

Adds RedshiftHook class, allowing for management
of AWS Redshift
clusters and snapshots using boto3 library. Also
adds new test file and
unit tests for class methods.

Closes #2717 from andyxhadji/1695


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/4fb7a90b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/4fb7a90b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/4fb7a90b

Branch: refs/heads/master
Commit: 4fb7a90b36ec1daf169a65aa4adf28a31b30fbc5
Parents: c800632
Author: Andy Hadjigeorgiou <ahh2131@columbia.edu>
Authored: Mon Oct 30 20:36:18 2017 +0100
Committer: Bolke de Bruin <bolke@xs4all.nl>
Committed: Mon Oct 30 20:36:18 2017 +0100

----------------------------------------------------------------------
 airflow/contrib/hooks/__init__.py         |   1 +
 airflow/contrib/hooks/aws_hook.py         |  16 ++--
 airflow/contrib/hooks/redshift_hook.py    | 100 +++++++++++++++++++++++++
 airflow/hooks/__init__.py                 |   1 -
 tests/contrib/hooks/test_redshift_hook.py |  77 +++++++++++++++++++
 5 files changed, 186 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4fb7a90b/airflow/contrib/hooks/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py
index 2891980..6d45ace 100644
--- a/airflow/contrib/hooks/__init__.py
+++ b/airflow/contrib/hooks/__init__.py
@@ -41,6 +41,7 @@ _hooks = {
     'gcs_hook': ['GoogleCloudStorageHook'],
     'datastore_hook': ['DatastoreHook'],
     'gcp_cloudml_hook': ['CloudMLHook'],
+    'redshift_hook': ['RedshiftHook'],
     'gcp_dataproc_hook': ['DataProcHook'],
     'gcp_dataflow_hook': ['DataFlowHook'],
     'spark_submit_operator': ['SparkSubmitOperator'],

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4fb7a90b/airflow/contrib/hooks/aws_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py
index ca2ee05..8573db3 100644
--- a/airflow/contrib/hooks/aws_hook.py
+++ b/airflow/contrib/hooks/aws_hook.py
@@ -85,28 +85,28 @@ class AwsHook(BaseHook):
         aws_access_key_id = None
         aws_secret_access_key = None
         s3_endpoint_url = None
-        
+
         if self.aws_conn_id:
             try:
                 connection_object = self.get_connection(self.aws_conn_id)
                 if connection_object.login:
                     aws_access_key_id = connection_object.login
                     aws_secret_access_key = connection_object.password
-    
+
                 elif 'aws_secret_access_key' in connection_object.extra_dejson:
                     aws_access_key_id = connection_object.extra_dejson['aws_access_key_id']
                     aws_secret_access_key = connection_object.extra_dejson['aws_secret_access_key']
-    
+
                 elif 's3_config_file' in connection_object.extra_dejson:
                     aws_access_key_id, aws_secret_access_key = \
                         _parse_s3_config(connection_object.extra_dejson['s3_config_file'],
                                          connection_object.extra_dejson.get('s3_config_format'))
-    
+
                 if region_name is None:
                     region_name = connection_object.extra_dejson.get('region_name')
-    
-                s3_endpoint_url = connection_object.extra_dejson.get('host') 
-    
+
+                s3_endpoint_url = connection_object.extra_dejson.get('host')
+
             except AirflowException:
                 # No connection found: fallback on boto3 credential strategy
                 # http://boto3.readthedocs.io/en/latest/guide/configuration.html
@@ -129,7 +129,7 @@ class AwsHook(BaseHook):
     def get_resource_type(self, resource_type, region_name=None):
         aws_access_key_id, aws_secret_access_key, region_name, endpoint_url = \
             self._get_credentials(region_name)
-        
+
         return boto3.resource(
             resource_type,
             region_name=region_name,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4fb7a90b/airflow/contrib/hooks/redshift_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/redshift_hook.py b/airflow/contrib/hooks/redshift_hook.py
new file mode 100644
index 0000000..071caf2
--- /dev/null
+++ b/airflow/contrib/hooks/redshift_hook.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow.contrib.hooks.aws_hook import AwsHook
+
+class RedshiftHook(AwsHook):
+    """
+    Interact with AWS Redshift, using the boto3 library
+    """
+    def get_conn(self):
+        return self.get_client_type('redshift')
+
+    # TODO: Wrap create_cluster_snapshot
+    def cluster_status(self, cluster_identifier):
+        """
+        Return status of a cluster
+
+        :param cluster_identifier: unique identifier of a cluster whose properties you are
requesting
+        :type cluster_identifier: str
+        """
+        # Use describe clusters
+        response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)
+        # Possibly return error if cluster does not exist
+        return response['Clusters'][0]['ClusterStatus'] if response['Clusters'] else None
+
+    def delete_cluster(self, cluster_identifier, skip_final_cluster_snapshot=True, final_cluster_snapshot_identifier=''):
+        """
+        Delete a cluster and optionally create a snapshot
+
+        :param cluster_identifier: unique identifier of a cluster whose properties you are
requesting
+        :type cluster_identifier: str
+        :param skip_final_cluster_snapshot: determines if a final cluster snapshot is made
before shut-down
+        :type skip_final_cluster_snapshot: bool
+        :param final_cluster_snapshot_identifier: name of final cluster snapshot
+        :type final_cluster_snapshot_identifier: str
+        """
+        response = self.get_conn().delete_cluster(
+            ClusterIdentifier = cluster_identifier,
+            SkipFinalClusterSnapshot = skip_final_cluster_snapshot,
+            FinalClusterSnapshotIdentifier = final_cluster_snapshot_identifier
+        )
+        return response['Cluster'] if response['Cluster'] else None
+
+    def describe_cluster_snapshots(self, cluster_identifier):
+        """
+        Gets a list of snapshots for a cluster
+
+        :param cluster_identifier: unique identifier of a cluster whose properties you are
requesting
+        :type cluster_identifier: str
+        """
+        response = self.get_conn().describe_cluster_snapshots(
+            ClusterIdentifier = cluster_identifier
+        )
+        if 'Snapshots' not in response:
+            return None
+        snapshots = response['Snapshots']
+        snapshots = filter(lambda x: x['Status'], snapshots)
+        snapshots.sort(key=lambda x: x['SnapshotCreateTime'], reverse=True)
+        return snapshots
+
+    def restore_from_cluster_snapshot(self, cluster_identifier, snapshot_identifier):
+        """
+        Restores a cluster from it's snapshot
+
+        :param cluster_identifier: unique identifier of a cluster whose properties you are
requesting
+        :type cluster_identifier: str
+        :param snapshot_identifier: unique identifier for a snapshot of a cluster
+        :type snapshot_identifier: str
+        """
+        response = self.get_conn().restore_from_cluster_snapshot(
+            ClusterIdentifier = cluster_identifier,
+            SnapshotIdentifier = snapshot_identifier
+        )
+        return response['Cluster'] if response['Cluster'] else None
+
+    def create_cluster_snapshot(self, snapshot_identifier, cluster_identifier):
+        """
+        Creates a snapshot of a cluster
+
+        :param snapshot_identifier: unique identifier for a snapshot of a cluster
+        :type snapshot_identifier: str
+        :param cluster_identifier: unique identifier of a cluster whose properties you are
requesting
+        :type cluster_identifier: str
+        """
+        response = self.get_conn().create_cluster_snapshot(
+            SnapshotIdentifier=snapshot_identifier,
+            ClusterIdentifier=cluster_identifier,
+        )
+        return response['Snapshot'] if response['Snapshot'] else None

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4fb7a90b/airflow/hooks/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py
index bb02967..6e96e2a 100644
--- a/airflow/hooks/__init__.py
+++ b/airflow/hooks/__init__.py
@@ -85,4 +85,3 @@ def _integrate_plugins():
                     "import from 'airflow.hooks.[plugin_module]' "
                     "instead. Support for direct imports will be dropped "
                     "entirely in Airflow 2.0.".format(i=hook_name))
-

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4fb7a90b/tests/contrib/hooks/test_redshift_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_redshift_hook.py b/tests/contrib/hooks/test_redshift_hook.py
new file mode 100644
index 0000000..185be5e
--- /dev/null
+++ b/tests/contrib/hooks/test_redshift_hook.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import boto3
+
+from airflow import configuration
+from airflow.contrib.hooks.redshift_hook import RedshiftHook
+from airflow.contrib.hooks.aws_hook import AwsHook
+
+try:
+    from moto import mock_redshift
+except ImportError:
+    mock_redshift = None
+
+@mock_redshift
+class TestRedshiftHook(unittest.TestCase):
+    def setUp(self):
+        configuration.load_test_config()
+        client = boto3.client('redshift', region_name='us-east-1')
+        client.create_cluster(
+            ClusterIdentifier='test_cluster',
+            NodeType='dc1.large',
+            MasterUsername='admin',
+            MasterUserPassword='mock_password'
+        )
+        client.create_cluster(
+            ClusterIdentifier='test_cluster_2',
+            NodeType='dc1.large',
+            MasterUsername='admin',
+            MasterUserPassword='mock_password'
+        )
+        if len(client.describe_clusters()['Clusters']) == 0:
+            raise ValueError('AWS not properly mocked')
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
+    def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
+        hook = AwsHook(aws_conn_id='aws_default')
+        client_from_hook = hook.get_client_type('redshift')
+
+        clusters = client_from_hook.describe_clusters()['Clusters']
+        self.assertEqual(len(clusters), 2)
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
+    def test_restore_from_cluster_snapshot_returns_dict_with_cluster_data(self):
+        hook = RedshiftHook(aws_conn_id='aws_default')
+        snapshot = hook.create_cluster_snapshot('test_snapshot', 'test_cluster')
+        self.assertEqual(hook.restore_from_cluster_snapshot('test_cluster_3', 'test_snapshot')['ClusterIdentifier'],
'test_cluster_3')
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
+    def test_delete_cluster_returns_a_dict_with_cluster_data(self):
+        hook = RedshiftHook(aws_conn_id='aws_default')
+
+        cluster = hook.delete_cluster('test_cluster_2')
+        self.assertNotEqual(cluster, None)
+
+    @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
+    def test_create_cluster_snapshot_returns_snapshot_data(self):
+        hook = RedshiftHook(aws_conn_id='aws_default')
+
+        snapshot = hook.create_cluster_snapshot('test_snapshot_2', 'test_cluster')
+        self.assertNotEqual(snapshot, None)
+
+if __name__ == '__main__':
+    unittest.main()


Mime
View raw message