airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bo...@apache.org
Subject [12/45] incubator-airflow git commit: [AIRFLOW-793] Enable compressed loading in S3ToHiveTransfer
Date Mon, 13 Mar 2017 04:45:10 GMT
[AIRFLOW-793] Enable compressed loading in S3ToHiveTransfer

Testing Done:
- Added new unit tests for the S3ToHiveTransfer
module

Closes #2012 from krishnabhupatiraju/S3ToHiveTrans
fer_compress_loading

(cherry picked from commit ad15f5efd6c663bd5f0c8cd3f556d08182cc778c)
Signed-off-by: Bolke de Bruin <bolke@xs4all.nl>


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/1c231333
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/1c231333
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/1c231333

Branch: refs/heads/v1-8-stable
Commit: 1c2313338a586aae4a7752c3fb3b9de4e3564415
Parents: 3658bf3
Author: Krishna Bhupatiraju <krishna.bhupatiraju@airbnb.com>
Authored: Mon Feb 6 16:52:11 2017 -0800
Committer: Bolke de Bruin <bolke@xs4all.nl>
Committed: Sat Feb 18 15:56:37 2017 +0100

----------------------------------------------------------------------
 airflow/operators/s3_to_hive_operator.py | 151 ++++++++++++----
 airflow/utils/compression.py             |  38 ++++
 tests/operators/__init__.py              |   1 +
 tests/operators/s3_to_hive_operator.py   | 247 ++++++++++++++++++++++++++
 tests/utils/compression.py               |  97 ++++++++++
 5 files changed, 497 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/airflow/operators/s3_to_hive_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py
index 3e01c29..92340f8 100644
--- a/airflow/operators/s3_to_hive_operator.py
+++ b/airflow/operators/s3_to_hive_operator.py
@@ -16,13 +16,18 @@ from builtins import next
 from builtins import zip
 import logging
 from tempfile import NamedTemporaryFile
+from airflow.utils.file import TemporaryDirectory
+import gzip
+import bz2
+import tempfile
+import os
 
 from airflow.exceptions import AirflowException
 from airflow.hooks.S3_hook import S3Hook
 from airflow.hooks.hive_hooks import HiveCliHook
 from airflow.models import BaseOperator
 from airflow.utils.decorators import apply_defaults
-
+from airflow.utils.compression import uncompress_file
 
 class S3ToHiveTransfer(BaseOperator):
     """
@@ -68,8 +73,11 @@ class S3ToHiveTransfer(BaseOperator):
     :type delimiter: str
     :param s3_conn_id: source s3 connection
     :type s3_conn_id: str
-    :param hive_conn_id: destination hive connection
-    :type hive_conn_id: str
+    :param hive_cli_conn_id: destination hive connection
+    :type hive_cli_conn_id: str
+    :param input_compressed: Boolean to determine if file decompression is
+        required to process headers
+    :type input_compressed: bool
     """
 
     template_fields = ('s3_key', 'partition', 'hive_table')
@@ -91,6 +99,7 @@ class S3ToHiveTransfer(BaseOperator):
             wildcard_match=False,
             s3_conn_id='s3_default',
             hive_cli_conn_id='hive_cli_default',
+            input_compressed=False,
             *args, **kwargs):
         super(S3ToHiveTransfer, self).__init__(*args, **kwargs)
         self.s3_key = s3_key
@@ -105,28 +114,41 @@ class S3ToHiveTransfer(BaseOperator):
         self.wildcard_match = wildcard_match
         self.hive_cli_conn_id = hive_cli_conn_id
         self.s3_conn_id = s3_conn_id
+        self.input_compressed = input_compressed
+
+        if (self.check_headers and
+                not (self.field_dict is not None and self.headers)):
+            raise AirflowException("To check_headers provide " +
+                                   "field_dict and headers")
 
     def execute(self, context):
-        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
+        # Downloading file from S3
         self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
+        self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
         logging.info("Downloading S3 file")
+
         if self.wildcard_match:
             if not self.s3.check_for_wildcard_key(self.s3_key):
-                raise AirflowException("No key matches {0}".format(self.s3_key))
+                raise AirflowException("No key matches {0}"
+                                       .format(self.s3_key))
             s3_key_object = self.s3.get_wildcard_key(self.s3_key)
         else:
             if not self.s3.check_for_key(self.s3_key):
                 raise AirflowException(
                     "The key {0} does not exists".format(self.s3_key))
             s3_key_object = self.s3.get_key(self.s3_key)
-        with NamedTemporaryFile("w") as f:
+        root, file_ext = os.path.splitext(s3_key_object.key)
+        with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
+                NamedTemporaryFile(mode="w",
+                                   dir=tmp_dir,
+                                   suffix=file_ext) as f:
             logging.info("Dumping S3 key {0} contents to local"
                          " file {1}".format(s3_key_object.key, f.name))
             s3_key_object.get_contents_to_file(f)
             f.flush()
             self.s3.connection.close()
             if not self.headers:
-                logging.info("Loading file into Hive")
+                logging.info("Loading file {0} into Hive".format(f.name))
                 self.hive.load_file(
                     f.name,
                     self.hive_table,
@@ -136,33 +158,88 @@ class S3ToHiveTransfer(BaseOperator):
                     delimiter=self.delimiter,
                     recreate=self.recreate)
             else:
-                with open(f.name, 'r') as tmpf:
-                    if self.check_headers:
-                        header_l = tmpf.readline()
-                        header_line = header_l.rstrip()
-                        header_list = header_line.split(self.delimiter)
-                        field_names = list(self.field_dict.keys())
-                        test_field_match = [h1.lower() == h2.lower() for h1, h2
-                                            in zip(header_list, field_names)]
-                        if not all(test_field_match):
-                            logging.warning("Headers do not match field names"
-                                            "File headers:\n {header_list}\n"
-                                            "Field names: \n {field_names}\n"
-                                            "".format(**locals()))
-                            raise AirflowException("Headers do not match the "
-                                            "field_dict keys")
-                    with NamedTemporaryFile("w") as f_no_headers:
-                        tmpf.seek(0)
-                        next(tmpf)
-                        for line in tmpf:
-                            f_no_headers.write(line)
-                        f_no_headers.flush()
-                        logging.info("Loading file without headers into Hive")
-                        self.hive.load_file(
-                            f_no_headers.name,
-                            self.hive_table,
-                            field_dict=self.field_dict,
-                            create=self.create,
-                            partition=self.partition,
-                            delimiter=self.delimiter,
-                            recreate=self.recreate)
+                # Decompressing file
+                if self.input_compressed:
+                    logging.info("Uncompressing file {0}".format(f.name))
+                    fn_uncompressed = uncompress_file(f.name,
+                                                      file_ext,
+                                                      tmp_dir)
+                    logging.info("Uncompressed to {0}".format(fn_uncompressed))
+                    # uncompressed file available now so deleting
+                    # compressed file to save disk space
+                    f.close()
+                else:
+                    fn_uncompressed = f.name
+
+                # Testing if header matches field_dict
+                if self.check_headers:
+                    logging.info("Matching file header against field_dict")
+                    header_list = self._get_top_row_as_list(fn_uncompressed)
+                    if not self._match_headers(header_list):
+                        raise AirflowException("Header check failed")
+
+                # Deleting top header row
+                logging.info("Removing header from file {0}".
+                             format(fn_uncompressed))
+                headless_file = (
+                    self._delete_top_row_and_compress(fn_uncompressed,
+                                                      file_ext,
+                                                      tmp_dir))
+                logging.info("Headless file {0}".format(headless_file))
+                logging.info("Loading file {0} into Hive".format(headless_file))
+                self.hive.load_file(headless_file,
+                                    self.hive_table,
+                                    field_dict=self.field_dict,
+                                    create=self.create,
+                                    partition=self.partition,
+                                    delimiter=self.delimiter,
+                                    recreate=self.recreate)
+
+    def _get_top_row_as_list(self, file_name):
+        with open(file_name, 'rt') as f:
+            header_line = f.readline().strip()
+            header_list = header_line.split(self.delimiter)
+            return header_list
+
+    def _match_headers(self, header_list):
+        if not header_list:
+            raise AirflowException("Unable to retrieve header row from file")
+        field_names = self.field_dict.keys()
+        if len(field_names) != len(header_list):
+            logging.warning("Headers count mismatch"
+                            "File headers:\n {header_list}\n"
+                            "Field names: \n {field_names}\n"
+                            "".format(**locals()))
+            return False
+        test_field_match = [h1.lower() == h2.lower()
+                            for h1, h2 in zip(header_list, field_names)]
+        if not all(test_field_match):
+            logging.warning("Headers do not match field names"
+                            "File headers:\n {header_list}\n"
+                            "Field names: \n {field_names}\n"
+                            "".format(**locals()))
+            return False
+        else:
+            return True
+
+    def _delete_top_row_and_compress(
+            self,
+            input_file_name,
+            output_file_ext,
+            dest_dir):
+        # When output_file_ext is not defined, file is not compressed
+        open_fn = open
+        if output_file_ext.lower() == '.gz':
+            open_fn = gzip.GzipFile
+        elif output_file_ext.lower() == '.bz2':
+            open_fn = bz2.BZ2File
+
+        os_fh_output, fn_output = \
+            tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir)
+        with open(input_file_name, 'rb') as f_in,\
+                open_fn(fn_output, 'wb') as f_out:
+            f_in.seek(0)
+            next(f_in)
+            for line in f_in:
+                f_out.write(line)
+        return fn_output

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/airflow/utils/compression.py
----------------------------------------------------------------------
diff --git a/airflow/utils/compression.py b/airflow/utils/compression.py
new file mode 100644
index 0000000..9d0785f
--- /dev/null
+++ b/airflow/utils/compression.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tempfile import NamedTemporaryFile
+import shutil
+import gzip
+import bz2
+
+
+def uncompress_file(input_file_name, file_extension, dest_dir):
+    """
+    Uncompress gz and bz2 files
+    """
+    if file_extension.lower() not in ('.gz', '.bz2'):
+        raise NotImplementedError("Received {} format. Only gz and bz2 "
+                                  "files can currently be uncompressed."
+                                  .format(file_extension))
+    if file_extension.lower() == '.gz':
+        fmodule = gzip.GzipFile
+    elif file_extension.lower() == '.bz2':
+        fmodule = bz2.BZ2File
+    with fmodule(input_file_name, mode='rb') as f_compressed,\
+        NamedTemporaryFile(dir=dest_dir,
+                           mode='wb',
+                           delete=False) as f_uncompressed:
+        shutil.copyfileobj(f_compressed, f_uncompressed)
+    return f_uncompressed.name

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/tests/operators/__init__.py
----------------------------------------------------------------------
diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py
index 63ff2a0..1fb0e5e 100644
--- a/tests/operators/__init__.py
+++ b/tests/operators/__init__.py
@@ -17,3 +17,4 @@ from .subdag_operator import *
 from .operators import *
 from .sensors import *
 from .hive_operator import *
+from .s3_to_hive_operator import *

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/tests/operators/s3_to_hive_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/s3_to_hive_operator.py b/tests/operators/s3_to_hive_operator.py
new file mode 100644
index 0000000..faab11e
--- /dev/null
+++ b/tests/operators/s3_to_hive_operator.py
@@ -0,0 +1,247 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+try:
+    from unittest import mock
+except ImportError:
+    try:
+        import mock
+    except ImportError:
+        mock = None
+import logging
+from itertools import product
+from airflow.operators.s3_to_hive_operator import S3ToHiveTransfer
+from collections import OrderedDict
+from airflow.exceptions import AirflowException
+from tempfile import NamedTemporaryFile, mkdtemp
+import gzip
+import bz2
+import shutil
+import filecmp
+import errno
+
+
+class S3ToHiveTransferTest(unittest.TestCase):
+
+    def setUp(self):
+        self.fn = {}
+        self.task_id = 'S3ToHiveTransferTest'
+        self.s3_key = 'S32hive_test_file'
+        self.field_dict = OrderedDict([('Sno', 'BIGINT'), ('Some,Text', 'STRING')])
+        self.hive_table = 'S32hive_test_table'
+        self.delimiter = '\t'
+        self.create = True
+        self.recreate = True
+        self.partition = {'ds': 'STRING'}
+        self.headers = True
+        self.check_headers = True
+        self.wildcard_match = False
+        self.input_compressed = False
+        self.kwargs = {'task_id': self.task_id,
+                       's3_key': self.s3_key,
+                       'field_dict': self.field_dict,
+                       'hive_table': self.hive_table,
+                       'delimiter': self.delimiter,
+                       'create': self.create,
+                       'recreate': self.recreate,
+                       'partition': self.partition,
+                       'headers': self.headers,
+                       'check_headers': self.check_headers,
+                       'wildcard_match': self.wildcard_match,
+                       'input_compressed': self.input_compressed
+                       }
+        try:
+            header = "Sno\tSome,Text \n".encode()
+            line1 = "1\tAirflow Test\n".encode()
+            line2 = "2\tS32HiveTransfer\n".encode()
+            self.tmp_dir = mkdtemp(prefix='test_tmps32hive_')
+            # create sample txt, gz and bz2 with and without headers
+            with NamedTemporaryFile(mode='wb+',
+                                    dir=self.tmp_dir,
+                                    delete=False) as f_txt_h:
+                self._set_fn(f_txt_h.name, '.txt', True)
+                f_txt_h.writelines([header, line1, line2])
+            fn_gz = self._get_fn('.txt', True) + ".gz"
+            with gzip.GzipFile(filename=fn_gz,
+                               mode="wb") as f_gz_h:
+                self._set_fn(fn_gz, '.gz', True)
+                f_gz_h.writelines([header, line1, line2])
+            fn_bz2 = self._get_fn('.txt', True) + '.bz2'
+            with bz2.BZ2File(filename=fn_bz2,
+                             mode="wb") as f_bz2_h:
+                self._set_fn(fn_bz2, '.bz2', True)
+                f_bz2_h.writelines([header, line1, line2])
+            # create sample txt, bz and bz2 without header
+            with NamedTemporaryFile(mode='wb+',
+                                    dir=self.tmp_dir,
+                                    delete=False) as f_txt_nh:
+                self._set_fn(f_txt_nh.name, '.txt', False)
+                f_txt_nh.writelines([line1, line2])
+            fn_gz = self._get_fn('.txt', False) + ".gz"
+            with gzip.GzipFile(filename=fn_gz,
+                               mode="wb") as f_gz_nh:
+                self._set_fn(fn_gz, '.gz', False)
+                f_gz_nh.writelines([line1, line2])
+            fn_bz2 = self._get_fn('.txt', False) + '.bz2'
+            with bz2.BZ2File(filename=fn_bz2,
+                             mode="wb") as f_bz2_nh:
+                self._set_fn(fn_bz2, '.bz2', False)
+                f_bz2_nh.writelines([line1, line2])
+        # Base Exception so it catches Keyboard Interrupt
+        except BaseException as e:
+            logging.error(e)
+            self.tearDown()
+
+    def tearDown(self):
+        try:
+            shutil.rmtree(self.tmp_dir)
+        except OSError as e:
+            # ENOENT - no such file or directory
+            if e.errno != errno.ENOENT:
+                raise e
+
+    # Helper method to create a dictionary of file names and
+    # file types (file extension and header)
+    def _set_fn(self, fn, ext, header):
+        key = self._get_key(ext, header)
+        self.fn[key] = fn
+
+    # Helper method to fetch a file of a
+    # certain format (file extension and header)
+    def _get_fn(self, ext, header):
+        key = self._get_key(ext, header)
+        return self.fn[key]
+
+    def _get_key(self, ext, header):
+        key = ext + "_" + ('h' if header else 'nh')
+        return key
+
+    def _cp_file_contents(self, fn_src, fn_dest):
+        with open(fn_src, 'rb') as f_src, open(fn_dest, 'wb') as f_dest:
+            shutil.copyfileobj(f_src, f_dest)
+
+    def _check_file_equality(self, fn_1, fn_2, ext):
+        # gz files contain mtime and filename in the header that
+        # causes filecmp to return False even if contents are identical
+        # Hence decompress to test for equality
+        if(ext == '.gz'):
+            with gzip.GzipFile(fn_1, 'rb') as f_1,\
+                 NamedTemporaryFile(mode='wb') as f_txt_1,\
+                 gzip.GzipFile(fn_2, 'rb') as f_2,\
+                 NamedTemporaryFile(mode='wb') as f_txt_2:
+                shutil.copyfileobj(f_1, f_txt_1)
+                shutil.copyfileobj(f_2, f_txt_2)
+                f_txt_1.flush()
+                f_txt_2.flush()
+                return filecmp.cmp(f_txt_1.name, f_txt_2.name, shallow=False)
+        else:
+            return filecmp.cmp(fn_1, fn_2, shallow=False)
+
+    def test_bad_parameters(self):
+        self.kwargs['check_headers'] = True
+        self.kwargs['headers'] = False
+        self.assertRaisesRegexp(AirflowException,
+                                "To check_headers.*",
+                                S3ToHiveTransfer,
+                                **self.kwargs)
+
+    def test__get_top_row_as_list(self):
+        self.kwargs['delimiter'] = '\t'
+        fn_txt = self._get_fn('.txt', True)
+        header_list = S3ToHiveTransfer(**self.kwargs).\
+            _get_top_row_as_list(fn_txt)
+        self.assertEqual(header_list, ['Sno', 'Some,Text'],
+                         msg="Top row from file doesnt matched expected value")
+
+        self.kwargs['delimiter'] = ','
+        header_list = S3ToHiveTransfer(**self.kwargs).\
+            _get_top_row_as_list(fn_txt)
+        self.assertEqual(header_list, ['Sno\tSome', 'Text'],
+                         msg="Top row from file doesnt matched expected value")
+
+    def test__match_headers(self):
+        self.kwargs['field_dict'] = OrderedDict([('Sno', 'BIGINT'),
+                                                ('Some,Text', 'STRING')])
+        self.assertTrue(S3ToHiveTransfer(**self.kwargs).
+                        _match_headers(['Sno', 'Some,Text']),
+                        msg="Header row doesnt match expected value")
+        # Testing with different column order
+        self.assertFalse(S3ToHiveTransfer(**self.kwargs).
+                         _match_headers(['Some,Text', 'Sno']),
+                         msg="Header row doesnt match expected value")
+        # Testing with extra column in header
+        self.assertFalse(S3ToHiveTransfer(**self.kwargs).
+                         _match_headers(['Sno', 'Some,Text', 'ExtraColumn']),
+                         msg="Header row doesnt match expected value")
+
+    def test__delete_top_row_and_compress(self):
+        s32hive = S3ToHiveTransfer(**self.kwargs)
+        # Testing gz file type
+        fn_txt = self._get_fn('.txt', True)
+        gz_txt_nh = s32hive._delete_top_row_and_compress(fn_txt,
+                                                         '.gz',
+                                                         self.tmp_dir)
+        fn_gz = self._get_fn('.gz', False)
+        self.assertTrue(self._check_file_equality(gz_txt_nh, fn_gz, '.gz'),
+                        msg="gz Compressed file not as expected")
+        # Testing bz2 file type
+        bz2_txt_nh = s32hive._delete_top_row_and_compress(fn_txt,
+                                                          '.bz2',
+                                                          self.tmp_dir)
+        fn_bz2 = self._get_fn('.bz2', False)
+        self.assertTrue(self._check_file_equality(bz2_txt_nh, fn_bz2, '.bz2'),
+                        msg="bz2 Compressed file not as expected")
+
+    @unittest.skipIf(mock is None, 'mock package not present')
+    @mock.patch('airflow.operators.s3_to_hive_operator.HiveCliHook')
+    @mock.patch('airflow.operators.s3_to_hive_operator.S3Hook')
+    def test_execute(self, mock_s3hook, mock_hiveclihook):
+        # Testing txt, zip, bz2 files with and without header row
+        for test in product(['.txt', '.gz', '.bz2'], [True, False]):
+            ext = test[0]
+            has_header = test[1]
+            self.kwargs['headers'] = has_header
+            self.kwargs['check_headers'] = has_header
+            logging.info("Testing {0} format {1} header".
+                         format(ext,
+                                ('with' if has_header else 'without'))
+                         )
+            self.kwargs['input_compressed'] = (False if ext == '.txt' else True)
+            self.kwargs['s3_key'] = self.s3_key + ext
+            ip_fn = self._get_fn(ext, self.kwargs['headers'])
+            op_fn = self._get_fn(ext, False)
+            # Mock s3 object returned by S3Hook
+            mock_s3_object = mock.Mock(key=self.kwargs['s3_key'])
+            mock_s3_object.get_contents_to_file.side_effect = \
+                lambda dest_file: \
+                self._cp_file_contents(ip_fn, dest_file.name)
+            mock_s3hook().get_key.return_value = mock_s3_object
+            # file paramter to HiveCliHook.load_file is compared
+            # against expected file oputput
+            mock_hiveclihook().load_file.side_effect = \
+                lambda *args, **kwargs: \
+                self.assertTrue(
+                    self._check_file_equality(args[0],
+                                              op_fn,
+                                              ext
+                                              ),
+                    msg='{0} output file not as expected'.format(ext))
+            # Execute S3ToHiveTransfer
+            s32hive = S3ToHiveTransfer(**self.kwargs)
+            s32hive.execute(None)
+
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/1c231333/tests/utils/compression.py
----------------------------------------------------------------------
diff --git a/tests/utils/compression.py b/tests/utils/compression.py
new file mode 100644
index 0000000..f8e0ebb
--- /dev/null
+++ b/tests/utils/compression.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow.utils import compression
+import unittest
+from tempfile import NamedTemporaryFile, mkdtemp
+import bz2
+import gzip
+import shutil
+import logging
+import errno
+import filecmp
+
+
+class Compression(unittest.TestCase):
+
+    def setUp(self):
+        self.fn = {}
+        try:
+            header = "Sno\tSome,Text \n".encode()
+            line1 = "1\tAirflow Test\n".encode()
+            line2 = "2\tCompressionUtil\n".encode()
+            self.tmp_dir = mkdtemp(prefix='test_utils_compression_')
+            # create sample txt, gz and bz2 files
+            with NamedTemporaryFile(mode='wb+',
+                                    dir=self.tmp_dir,
+                                    delete=False) as f_txt:
+                self._set_fn(f_txt.name, '.txt')
+                f_txt.writelines([header, line1, line2])
+            fn_gz = self._get_fn('.txt') + ".gz"
+            with gzip.GzipFile(filename=fn_gz,
+                               mode="wb") as f_gz:
+                self._set_fn(fn_gz, '.gz')
+                f_gz.writelines([header, line1, line2])
+            fn_bz2 = self._get_fn('.txt') + '.bz2'
+            with bz2.BZ2File(filename=fn_bz2,
+                             mode="wb") as f_bz2:
+                self._set_fn(fn_bz2, '.bz2')
+                f_bz2.writelines([header, line1, line2])
+        # Base Exception so it catches Keyboard Interrupt
+        except BaseException as e:
+            logging.error(e)
+            self.tearDown()
+
+    def tearDown(self):
+        try:
+            shutil.rmtree(self.tmp_dir)
+        except OSError as e:
+            # ENOENT - no such file or directory
+            if e.errno != errno.ENOENT:
+                raise e
+
+    # Helper method to create a dictionary of file names and
+    # file extension
+    def _set_fn(self, fn, ext):
+        self.fn[ext] = fn
+
+    # Helper method to fetch a file of a
+    # certain extension
+    def _get_fn(self, ext):
+        return self.fn[ext]
+
+    def test_uncompress_file(self):
+        # Testing txt file type
+        self.assertRaisesRegexp(NotImplementedError,
+                                "^Received .txt format. Only gz and bz2.*",
+                                compression.uncompress_file,
+                                **{'input_file_name': None,
+                                   'file_extension': '.txt',
+                                   'dest_dir': None
+                                   })
+        # Testing gz file type
+        fn_txt = self._get_fn('.txt')
+        fn_gz = self._get_fn('.gz')
+        txt_gz = compression.uncompress_file(fn_gz, '.gz', self.tmp_dir)
+        self.assertTrue(filecmp.cmp(txt_gz, fn_txt, shallow=False),
+                        msg="Uncompressed file doest match original")
+        # Testing bz2 file type
+        fn_bz2 = self._get_fn('.bz2')
+        txt_bz2 = compression.uncompress_file(fn_bz2, '.bz2', self.tmp_dir)
+        self.assertTrue(filecmp.cmp(txt_bz2, fn_txt, shallow=False),
+                        msg="Uncompressed file doest match original")
+
+
+if __name__ == '__main__':
+    unittest.main()


Mime
View raw message