airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bo...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-2336] Use hmsclient in hive_hook
Date Wed, 25 Apr 2018 10:24:06 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/master fd6f1d1a0 -> 6c45b8c5f


[AIRFLOW-2336] Use hmsclient in hive_hook

The package hmsclient is Python2/3 compatible and
offer a handy context
manager to handle opening and closing connections.

Closes #3239 from gglanzani/AIRFLOW-2336


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/6c45b8c5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/6c45b8c5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/6c45b8c5

Branch: refs/heads/master
Commit: 6c45b8c5f2ad1af8faea13529dae01cee10b4937
Parents: fd6f1d1
Author: Giovanni Lanzani <giovanni@lanzani.nl>
Authored: Wed Apr 25 12:23:59 2018 +0200
Committer: Bolke de Bruin <bolke@xs4all.nl>
Committed: Wed Apr 25 12:23:59 2018 +0200

----------------------------------------------------------------------
 airflow/hooks/hive_hooks.py   | 127 +++++++++++++++-------------------
 setup.py                      |   6 +-
 tests/hooks/test_hive_hook.py | 136 ++++++++++++++++++++++++++++++++++++-
 3 files changed, 192 insertions(+), 77 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c45b8c5/airflow/hooks/hive_hooks.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
index d278483..65238df 100644
--- a/airflow/hooks/hive_hooks.py
+++ b/airflow/hooks/hive_hooks.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -28,7 +28,7 @@ import re
 import subprocess
 import time
 from tempfile import NamedTemporaryFile
-import hive_metastore
+import hmsclient
 
 from airflow import configuration as conf
 from airflow.exceptions import AirflowException
@@ -460,7 +460,6 @@ class HiveMetastoreHook(BaseHook):
         """
         from thrift.transport import TSocket, TTransport
         from thrift.protocol import TBinaryProtocol
-        from hive_service import ThriftHive
         ms = self.metastore_conn
         auth_mechanism = ms.extra_dejson.get('authMechanism', 'NOSASL')
         if configuration.conf.get('core', 'security') == 'kerberos':
@@ -489,7 +488,7 @@ class HiveMetastoreHook(BaseHook):
 
         protocol = TBinaryProtocol.TBinaryProtocol(transport)
 
-        return ThriftHive.Client(protocol)
+        return hmsclient.HMSClient(iprot=protocol)
 
     def get_conn(self):
         return self.metastore
@@ -512,10 +511,10 @@ class HiveMetastoreHook(BaseHook):
         >>> hh.check_for_partition('airflow', t, "ds='2015-01-01'")
         True
         """
-        self.metastore._oprot.trans.open()
-        partitions = self.metastore.get_partitions_by_filter(
-            schema, table, partition, 1)
-        self.metastore._oprot.trans.close()
+        with self.metastore as client:
+            partitions = client.get_partitions_by_filter(
+                schema, table, partition, 1)
+
         if partitions:
             return True
         else:
@@ -540,15 +539,8 @@ class HiveMetastoreHook(BaseHook):
         >>> hh.check_for_named_partition('airflow', t, "ds=xxx")
         False
         """
-        self.metastore._oprot.trans.open()
-        try:
-            self.metastore.get_partition_by_name(
-                schema, table, partition_name)
-            return True
-        except hive_metastore.ttypes.NoSuchObjectException:
-            return False
-        finally:
-            self.metastore._oprot.trans.close()
+        with self.metastore as client:
+            return client.check_for_named_partition(schema, table, partition_name)
 
     def get_table(self, table_name, db='default'):
         """Get a metastore table object
@@ -560,31 +552,25 @@ class HiveMetastoreHook(BaseHook):
         >>> [col.name for col in t.sd.cols]
         ['state', 'year', 'name', 'gender', 'num']
         """
-        self.metastore._oprot.trans.open()
         if db == 'default' and '.' in table_name:
             db, table_name = table_name.split('.')[:2]
-        table = self.metastore.get_table(dbname=db, tbl_name=table_name)
-        self.metastore._oprot.trans.close()
-        return table
+        with self.metastore as client:
+            return client.get_table(dbname=db, tbl_name=table_name)
 
     def get_tables(self, db, pattern='*'):
         """
         Get a metastore table object
         """
-        self.metastore._oprot.trans.open()
-        tables = self.metastore.get_tables(db_name=db, pattern=pattern)
-        objs = self.metastore.get_table_objects_by_name(db, tables)
-        self.metastore._oprot.trans.close()
-        return objs
+        with self.metastore as client:
+            tables = client.get_tables(db_name=db, pattern=pattern)
+            return client.get_table_objects_by_name(db, tables)
 
     def get_databases(self, pattern='*'):
         """
         Get a metastore table object
         """
-        self.metastore._oprot.trans.open()
-        dbs = self.metastore.get_databases(pattern)
-        self.metastore._oprot.trans.close()
-        return dbs
+        with self.metastore as client:
+            return client.get_databases(pattern)
 
     def get_partitions(
             self, schema, table_name, filter=None):
@@ -601,23 +587,22 @@ class HiveMetastoreHook(BaseHook):
         >>> parts
         [{'ds': '2015-01-01'}]
         """
-        self.metastore._oprot.trans.open()
-        table = self.metastore.get_table(dbname=schema, tbl_name=table_name)
-        if len(table.partitionKeys) == 0:
-            raise AirflowException("The table isn't partitioned")
-        else:
-            if filter:
-                parts = self.metastore.get_partitions_by_filter(
-                    db_name=schema, tbl_name=table_name,
-                    filter=filter, max_parts=HiveMetastoreHook.MAX_PART_COUNT)
+        with self.metastore as client:
+            table = client.get_table(dbname=schema, tbl_name=table_name)
+            if len(table.partitionKeys) == 0:
+                raise AirflowException("The table isn't partitioned")
             else:
-                parts = self.metastore.get_partitions(
-                    db_name=schema, tbl_name=table_name,
-                    max_parts=HiveMetastoreHook.MAX_PART_COUNT)
+                if filter:
+                    parts = client.get_partitions_by_filter(
+                        db_name=schema, tbl_name=table_name,
+                        filter=filter, max_parts=HiveMetastoreHook.MAX_PART_COUNT)
+                else:
+                    parts = client.get_partitions(
+                        db_name=schema, tbl_name=table_name,
+                        max_parts=HiveMetastoreHook.MAX_PART_COUNT)
 
-            self.metastore._oprot.trans.close()
-            pnames = [p.name for p in table.partitionKeys]
-            return [dict(zip(pnames, p.values)) for p in parts]
+                pnames = [p.name for p in table.partitionKeys]
+                return [dict(zip(pnames, p.values)) for p in parts]
 
     @staticmethod
     def _get_max_partition_from_part_specs(part_specs, partition_key, filter_map):
@@ -644,8 +629,9 @@ class HiveMetastoreHook(BaseHook):
         if partition_key not in part_specs[0].keys():
             raise AirflowException("Provided partition_key {} "
                                    "is not in part_specs.".format(partition_key))
-
-        if filter_map and not set(filter_map.keys()) < set(part_specs[0].keys()):
+        if filter_map:
+            is_subset = set(filter_map.keys()).issubset(set(part_specs[0].keys()))
+        if filter_map and not is_subset:
             raise AirflowException("Keys in provided filter_map {} "
                                    "are not subset of part_spec keys: {}"
                                    .format(', '.join(filter_map.keys()),
@@ -677,34 +663,33 @@ class HiveMetastoreHook(BaseHook):
         :type filter_map: map
 
         >>> hh = HiveMetastoreHook()
-        >>  filter_map = {'p_key': 'p_val'}
+        >>> filter_map = {'ds': '2015-01-01', 'ds': '2014-01-01'}
         >>> t = 'static_babynames_partitioned'
         >>> hh.max_partition(schema='airflow',\
         ... table_name=t, field='ds', filter_map=filter_map)
         '2015-01-01'
         """
-        self.metastore._oprot.trans.open()
-        table = self.metastore.get_table(dbname=schema, tbl_name=table_name)
-        key_name_set = set(key.name for key in table.partitionKeys)
-        if len(table.partitionKeys) == 1:
-            field = table.partitionKeys[0].name
-        elif not field:
-            raise AirflowException("Please specify the field you want the max "
-                                   "value for.")
-        elif field not in key_name_set:
-            raise AirflowException("Provided field is not a partition key.")
-
-        if filter_map and not set(filter_map.keys()).issubset(key_name_set):
-            raise AirflowException("Provided filter_map contains keys "
-                                   "that are not partition key.")
-
-        part_names = \
-            self.metastore.get_partition_names(schema,
-                                               table_name,
-                                               max_parts=HiveMetastoreHook.MAX_PART_COUNT)
-        part_specs = [self.metastore.partition_name_to_spec(part_name)
-                      for part_name in part_names]
-        self.metastore._oprot.trans.close()
+        with self.metastore as client:
+            table = client.get_table(dbname=schema, tbl_name=table_name)
+            key_name_set = set(key.name for key in table.partitionKeys)
+            if len(table.partitionKeys) == 1:
+                field = table.partitionKeys[0].name
+            elif not field:
+                raise AirflowException("Please specify the field you want the max "
+                                       "value for.")
+            elif field not in key_name_set:
+                raise AirflowException("Provided field is not a partition key.")
+
+            if filter_map and not set(filter_map.keys()).issubset(key_name_set):
+                raise AirflowException("Provided filter_map contains keys "
+                                       "that are not partition key.")
+
+            part_names = \
+                client.get_partition_names(schema,
+                                           table_name,
+                                           max_parts=HiveMetastoreHook.MAX_PART_COUNT)
+            part_specs = [client.partition_name_to_spec(part_name)
+                          for part_name in part_names]
 
         return HiveMetastoreHook._get_max_partition_from_part_specs(part_specs,
                                                                     field,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c45b8c5/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index 034fd35..742e01b 100644
--- a/setup.py
+++ b/setup.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -141,7 +141,7 @@ webhdfs = ['hdfs[dataframe,avro,kerberos]>=2.0.4']
 jenkins = ['python-jenkins>=0.4.15']
 jira = ['JIRA>1.0.7']
 hive = [
-    'hive-thrift-py>=0.0.1',
+    'hmsclient>=0.1.0',
     'pyhive>=0.1.3',
     'impyla>=0.13.3',
     'unicodecsv>=0.14.1'

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/6c45b8c5/tests/hooks/test_hive_hook.py
----------------------------------------------------------------------
diff --git a/tests/hooks/test_hive_hook.py b/tests/hooks/test_hive_hook.py
index 677132b..f48bed8 100644
--- a/tests/hooks/test_hive_hook.py
+++ b/tests/hooks/test_hive_hook.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -18,13 +18,71 @@
 # under the License.
 #
 
+import datetime
+import random
 import unittest
 
+from hmsclient import HMSClient
+
 from airflow.exceptions import AirflowException
 from airflow.hooks.hive_hooks import HiveMetastoreHook
+from airflow import DAG, configuration, operators
+from airflow.utils import timezone
+
+
+configuration.load_test_config()
+
+
+DEFAULT_DATE = timezone.datetime(2015, 1, 1)
+DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
+DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
+
+
+class HiveEnvironmentTest(unittest.TestCase):
+
+    def setUp(self):
+        configuration.load_test_config()
+        args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
+        self.dag = DAG('test_dag_id', default_args=args)
+        self.next_day = (DEFAULT_DATE +
+                         datetime.timedelta(days=1)).isoformat()[:10]
+        self.database = 'airflow'
+        self.partition_by = 'ds'
+        self.table = 'static_babynames_partitioned'
+        self.hql = """
+        CREATE DATABASE IF NOT EXISTS {{ params.database }};
+        USE {{ params.database }};
+        DROP TABLE IF EXISTS {{ params.table }};
+        CREATE TABLE IF NOT EXISTS {{ params.table }} (
+            state string,
+            year string,
+            name string,
+            gender string,
+            num int)
+        PARTITIONED BY ({{ params.partition_by }} string);
+        ALTER TABLE {{ params.table }}
+        ADD PARTITION({{ params.partition_by }}='{{ ds }}');
+        """
+        self.hook = HiveMetastoreHook()
+        t = operators.hive_operator.HiveOperator(
+            task_id='HiveHook_' + str(random.randint(1, 10000)),
+            params={
+                'database': self.database,
+                'table': self.table,
+                'partition_by': self.partition_by
+            },
+            hive_cli_conn_id='beeline_default',
+            hql=self.hql, dag=self.dag)
+        t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
+              ignore_ti_state=True)
+
+    def tearDown(self):
+        hook = HiveMetastoreHook()
+        with hook.get_conn() as metastore:
+            metastore.drop_table(self.database, self.table, deleteData=True)
 
 
-class TestHiveMetastoreHook(unittest.TestCase):
+class TestHiveMetastoreHook(HiveEnvironmentTest):
     VALID_FILTER_MAP = {'key2': 'value2'}
 
     def test_get_max_partition_from_empty_part_specs(self):
@@ -77,3 +135,75 @@ class TestHiveMetastoreHook(unittest.TestCase):
                 'key1',
                 self.VALID_FILTER_MAP)
         self.assertEqual(max_partition, b'value1')
+
+    def test_get_metastore_client(self):
+        self.assertIsInstance(self.hook.get_metastore_client(), HMSClient)
+
+    def test_get_conn(self):
+        self.assertIsInstance(self.hook.get_conn(), HMSClient)
+
+    def test_check_for_partition(self):
+        partition = "{p_by}='{date}'".format(date=DEFAULT_DATE_DS,
+                                             p_by=self.partition_by)
+        missing_partition = "{p_by}='{date}'".format(date=self.next_day,
+                                                     p_by=self.partition_by)
+        self.assertTrue(
+            self.hook.check_for_partition(self.database, self.table,
+                                          partition)
+        )
+        self.assertFalse(
+            self.hook.check_for_partition(self.database, self.table,
+                                          missing_partition)
+        )
+
+    def test_check_for_named_partition(self):
+        partition = "{p_by}={date}".format(date=DEFAULT_DATE_DS,
+                                           p_by=self.partition_by)
+        missing_partition = "{p_by}={date}".format(date=self.next_day,
+                                                   p_by=self.partition_by)
+        self.assertTrue(
+            self.hook.check_for_named_partition(self.database,
+                                                self.table,
+                                                partition)
+        )
+        self.assertFalse(
+            self.hook.check_for_named_partition(self.database,
+                                                self.table,
+                                                missing_partition)
+        )
+
+    def test_get_table(self):
+        table_info = self.hook.get_table(db=self.database,
+                                         table_name=self.table)
+        self.assertEqual(table_info.tableName, self.table)
+        columns = ['state', 'year', 'name', 'gender', 'num']
+        self.assertEqual([col.name for col in table_info.sd.cols], columns)
+
+    def test_get_tables(self):
+        tables = self.hook.get_tables(db=self.database,
+                                      pattern=self.table + "*")
+        self.assertIn(self.table, {table.tableName for table in tables})
+
+    def get_databases(self):
+        databases = self.hook.get_databases(pattern='*')
+        self.assertIn(self.database, databases)
+
+    def test_get_partitions(self):
+        partitions = self.hook.get_partitions(schema=self.database,
+                                              table_name=self.table)
+        self.assertEqual(len(partitions), 1)
+        self.assertEqual(partitions, [{self.partition_by: DEFAULT_DATE_DS}])
+
+    def test_max_partition(self):
+        filter_map = {self.partition_by: DEFAULT_DATE_DS}
+        partition = self.hook.max_partition(schema=self.database,
+                                            table_name=self.table,
+                                            field=self.partition_by,
+                                            filter_map=filter_map)
+        self.assertEqual(partition, DEFAULT_DATE_DS.encode('utf-8'))
+
+    def test_table_exists(self):
+        self.assertTrue(self.hook.table_exists(self.table, db=self.database))
+        self.assertFalse(
+            self.hook.table_exists(str(random.randint(1, 10000)))
+        )


Mime
View raw message