airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] Fokko closed pull request #4084: [AIRFLOW-3205] Support multipart uploads to GCS
Date Mon, 05 Nov 2018 14:48:13 GMT
Fokko closed pull request #4084: [AIRFLOW-3205] Support multipart uploads to GCS
URL: https://github.com/apache/incubator-airflow/pull/4084
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/contrib/hooks/gcs_hook.py b/airflow/contrib/hooks/gcs_hook.py
index 1ece4dde1b..dc92b0cb3e 100644
--- a/airflow/contrib/hooks/gcs_hook.py
+++ b/airflow/contrib/hooks/gcs_hook.py
@@ -177,7 +177,8 @@ def download(self, bucket, object, filename=None):
 
     # pylint:disable=redefined-builtin
     def upload(self, bucket, object, filename,
-               mime_type='application/octet-stream', gzip=False):
+               mime_type='application/octet-stream', gzip=False,
+               multipart=False, num_retries=0):
         """
         Uploads a local file to Google Cloud Storage.
 
@@ -191,6 +192,14 @@ def upload(self, bucket, object, filename,
         :type mime_type: str
         :param gzip: Option to compress file for upload
         :type gzip: bool
+        :param multipart: If True, the upload will be split into multiple HTTP requests.
The
+                          default size is 256MiB per request. Pass a number instead of True
to
+                          specify the request size, which must be a multiple of 262144 (256KiB).
+        :type multipart: bool or int
+        :param num_retries: The number of times to attempt to re-upload the file (or individual
+                            chunks, in the case of multipart uploads). Retries are attempted
+                            with exponential backoff.
+        :type num_retries: int
         """
         service = self.get_conn()
 
@@ -202,23 +211,45 @@ def upload(self, bucket, object, filename,
                     shutil.copyfileobj(f_in, f_out)
                     filename = filename_gz
 
-        media = MediaFileUpload(filename, mime_type)
-
         try:
-            service \
-                .objects() \
-                .insert(bucket=bucket, name=object, media_body=media) \
-                .execute()
+            if multipart:
+                if multipart is True:
+                    chunksize = 256 * 1024 * 1024
+                else:
+                    chunksize = multipart
+
+                if chunksize % (256 * 1024) > 0 or chunksize < 0:
+                    raise ValueError("Multipart size is not a multiple of 262144 (256KiB)")
+
+                media = MediaFileUpload(filename, mimetype=mime_type,
+                                        chunksize=chunksize, resumable=True)
+
+                request = service.objects().insert(bucket=bucket, name=object, media_body=media)
+                response = None
+                while response is None:
+                    status, response = request.next_chunk(num_retries=num_retries)
+                    if status:
+                        self.log.info("Upload progress %.1f%%", status.progress() * 100)
+
+            else:
+                media = MediaFileUpload(filename, mime_type)
+
+                service \
+                    .objects() \
+                    .insert(bucket=bucket, name=object, media_body=media) \
+                    .execute(num_retries=num_retries)
 
-            # Clean up gzip file
-            if gzip:
-                os.remove(filename)
-            return True
         except errors.HttpError as ex:
             if ex.resp['status'] == '404':
                 return False
             raise
 
+        finally:
+            if gzip:
+                os.remove(filename)
+
+        return True
+
     # pylint:disable=redefined-builtin
     def exists(self, bucket, object):
         """
diff --git a/tests/contrib/hooks/test_gcs_hook.py b/tests/contrib/hooks/test_gcs_hook.py
index ed3dce9a5e..eea79a376c 100644
--- a/tests/contrib/hooks/test_gcs_hook.py
+++ b/tests/contrib/hooks/test_gcs_hook.py
@@ -18,6 +18,8 @@
 # under the License.
 
 import unittest
+import tempfile
+import os
 
 from airflow.contrib.hooks import gcs_hook
 from airflow.exceptions import AirflowException
@@ -339,3 +341,143 @@ def test_delete_nonexisting_object(self, mock_service):
         response = self.gcs_hook.delete(bucket=test_bucket, object=test_object)
 
         self.assertFalse(response)
+
+
+class TestGoogleCloudStorageHookUpload(unittest.TestCase):
+    def setUp(self):
+        with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__')):
+            self.gcs_hook = gcs_hook.GoogleCloudStorageHook(
+                google_cloud_storage_conn_id='test'
+            )
+
+        # generate a 384KiB test file (larger than the minimum 256KiB multipart chunk size)
+        self.testfile = tempfile.NamedTemporaryFile(delete=False)
+        self.testfile.write(b"x" * 393216)
+        self.testfile.flush()
+
+    def tearDown(self):
+        os.unlink(self.testfile.name)
+
+    @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
+    def test_upload(self, mock_service):
+        test_bucket = 'test_bucket'
+        test_object = 'test_object'
+
+        (mock_service.return_value.objects.return_value
+         .insert.return_value.execute.return_value) = {
+            "kind": "storage#object",
+            "id": "{}/{}/0123456789012345".format(test_bucket, test_object),
+            "name": test_object,
+            "bucket": test_bucket,
+            "generation": "0123456789012345",
+            "contentType": "application/octet-stream",
+            "timeCreated": "2018-03-15T16:51:02.502Z",
+            "updated": "2018-03-15T16:51:02.502Z",
+            "storageClass": "MULTI_REGIONAL",
+            "timeStorageClassUpdated": "2018-03-15T16:51:02.502Z",
+            "size": "393216",
+            "md5Hash": "leYUJBUWrRtks1UeUFONJQ==",
+            "crc32c": "xgdNfQ==",
+            "etag": "CLf4hODk7tkCEAE="
+        }
+
+        response = self.gcs_hook.upload(test_bucket,
+                                        test_object,
+                                        self.testfile.name)
+
+        self.assertTrue(response)
+
+    @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
+    def test_upload_gzip(self, mock_service):
+        test_bucket = 'test_bucket'
+        test_object = 'test_object'
+
+        (mock_service.return_value.objects.return_value
+         .insert.return_value.execute.return_value) = {
+            "kind": "storage#object",
+            "id": "{}/{}/0123456789012345".format(test_bucket, test_object),
+            "name": test_object,
+            "bucket": test_bucket,
+            "generation": "0123456789012345",
+            "contentType": "application/octet-stream",
+            "timeCreated": "2018-03-15T16:51:02.502Z",
+            "updated": "2018-03-15T16:51:02.502Z",
+            "storageClass": "MULTI_REGIONAL",
+            "timeStorageClassUpdated": "2018-03-15T16:51:02.502Z",
+            "size": "393216",
+            "md5Hash": "leYUJBUWrRtks1UeUFONJQ==",
+            "crc32c": "xgdNfQ==",
+            "etag": "CLf4hODk7tkCEAE="
+        }
+
+        response = self.gcs_hook.upload(test_bucket,
+                                        test_object,
+                                        self.testfile.name,
+                                        gzip=True)
+        self.assertFalse(os.path.exists(self.testfile.name + '.gz'))
+        self.assertTrue(response)
+
+    @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
+    def test_upload_gzip_error(self, mock_service):
+        test_bucket = 'test_bucket'
+        test_object = 'test_object'
+
+        (mock_service.return_value.objects.return_value
+         .insert.return_value.execute.side_effect) = HttpError(
+            resp={'status': '404'}, content=EMPTY_CONTENT)
+
+        response = self.gcs_hook.upload(test_bucket,
+                                        test_object,
+                                        self.testfile.name,
+                                        gzip=True)
+        self.assertFalse(os.path.exists(self.testfile.name + '.gz'))
+        self.assertFalse(response)
+
+    @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
+    def test_upload_multipart(self, mock_service):
+        test_bucket = 'test_bucket'
+        test_object = 'test_object'
+
+        class MockProgress:
+            def __init__(self, value):
+                self.value = value
+
+            def progress(self):
+                return self.value
+
+        (mock_service.return_value.objects.return_value
+         .insert.return_value.next_chunk.side_effect) = [
+            (MockProgress(0.66), None),
+            (MockProgress(1.0), {
+                "kind": "storage#object",
+                "id": "{}/{}/0123456789012345".format(test_bucket, test_object),
+                "name": test_object,
+                "bucket": test_bucket,
+                "generation": "0123456789012345",
+                "contentType": "application/octet-stream",
+                "timeCreated": "2018-03-15T16:51:02.502Z",
+                "updated": "2018-03-15T16:51:02.502Z",
+                "storageClass": "MULTI_REGIONAL",
+                "timeStorageClassUpdated": "2018-03-15T16:51:02.502Z",
+                "size": "393216",
+                "md5Hash": "leYUJBUWrRtks1UeUFONJQ==",
+                "crc32c": "xgdNfQ==",
+                "etag": "CLf4hODk7tkCEAE="
+            })
+        ]
+
+        response = self.gcs_hook.upload(test_bucket,
+                                        test_object,
+                                        self.testfile.name,
+                                        multipart=True)
+
+        self.assertTrue(response)
+
+    @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn'))
+    def test_upload_multipart_wrong_chunksize(self, mock_service):
+        test_bucket = 'test_bucket'
+        test_object = 'test_object'
+
+        with self.assertRaises(ValueError):
+            self.gcs_hook.upload(test_bucket, test_object,
+                                 self.testfile.name, multipart=123)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message