Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 96A4E200B4A for ; Wed, 15 Jun 2016 01:13:06 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 952F8160A06; Tue, 14 Jun 2016 23:13:06 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id BA377160A5F for ; Wed, 15 Jun 2016 01:13:04 +0200 (CEST) Received: (qmail 72644 invoked by uid 500); 14 Jun 2016 23:13:03 -0000 Mailing-List: contact commits-help@beam.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@beam.incubator.apache.org Delivered-To: mailing list commits@beam.incubator.apache.org Received: (qmail 72635 invoked by uid 99); 14 Jun 2016 23:13:03 -0000 Received: from pnap-us-west-generic-nat.apache.org (HELO spamd1-us-west.apache.org) (209.188.14.142) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 14 Jun 2016 23:13:03 +0000 Received: from localhost (localhost [127.0.0.1]) by spamd1-us-west.apache.org (ASF Mail Server at spamd1-us-west.apache.org) with ESMTP id 2E2A3C0548 for ; Tue, 14 Jun 2016 23:13:03 +0000 (UTC) X-Virus-Scanned: Debian amavisd-new at spamd1-us-west.apache.org X-Spam-Flag: NO X-Spam-Score: -4.646 X-Spam-Level: X-Spam-Status: No, score=-4.646 tagged_above=-999 required=6.31 tests=[KAM_ASCII_DIVIDERS=0.8, KAM_LAZY_DOMAIN_SECURITY=1, RCVD_IN_DNSWL_HI=-5, RCVD_IN_MSPIKE_H3=-0.01, RCVD_IN_MSPIKE_WL=-0.01, RP_MATCHES_RCVD=-1.426] autolearn=disabled Received: from mx1-lw-eu.apache.org ([10.40.0.8]) by localhost (spamd1-us-west.apache.org [10.40.0.7]) (amavisd-new, port 10024) with ESMTP id BHwf7F1U1Yj6 for ; Tue, 14 Jun 2016 23:12:55 +0000 (UTC) Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx1-lw-eu.apache.org (ASF Mail Server at mx1-lw-eu.apache.org) with SMTP id F1DAC60E4F for ; Tue, 14 Jun 2016 23:12:39 +0000 (UTC) Received: (qmail 70099 invoked by uid 99); 14 Jun 2016 23:12:37 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 14 Jun 2016 23:12:37 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id D53C1E0B66; Tue, 14 Jun 2016 23:12:37 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 8bit From: davor@apache.org To: commits@beam.incubator.apache.org Date: Tue, 14 Jun 2016 23:13:11 -0000 Message-Id: <5d8c8a4e944246b8a65eccacc85380e8@git.apache.org> In-Reply-To: <95df9c9428334e3980c0c77c4ddc9382@git.apache.org> References: <95df9c9428334e3980c0c77c4ddc9382@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: [36/50] [abbrv] incubator-beam git commit: Move all files to apache_beam folder archived-at: Tue, 14 Jun 2016 23:13:06 -0000 http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/fileio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/fileio.py b/sdks/python/apache_beam/io/fileio.py new file mode 100644 index 0000000..9a003f0 --- /dev/null +++ b/sdks/python/apache_beam/io/fileio.py @@ -0,0 +1,747 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""File-based sources and sinks.""" + +from __future__ import absolute_import + +import glob +import gzip +import logging +from multiprocessing.pool import ThreadPool +import os +import re +import shutil +import tempfile +import time + +from google.cloud.dataflow import coders +from google.cloud.dataflow.io import iobase +from google.cloud.dataflow.io import range_trackers +from google.cloud.dataflow.utils import processes +from google.cloud.dataflow.utils import retry + + +__all__ = ['TextFileSource', 'TextFileSink'] + +DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN' + + +# Retrying is needed because there are transient errors that can happen. +@retry.with_exponential_backoff(num_retries=4, retry_filter=lambda _: True) +def _gcs_file_copy(from_path, to_path, encoding=''): + """Copy a local file to a GCS location with retries for transient errors.""" + if not encoding: + command_args = ['gsutil', '-m', '-q', 'cp', from_path, to_path] + else: + encoding = 'Content-Type:' + encoding + command_args = ['gsutil', '-m', '-q', '-h', encoding, 'cp', from_path, + to_path] + logging.info('Executing command: %s', command_args) + popen = processes.Popen(command_args, stdout=processes.PIPE, + stderr=processes.PIPE) + stdoutdata, stderrdata = popen.communicate() + if popen.returncode != 0: + raise ValueError( + 'Failed to copy GCS file from %s to %s (stdout=%s, stderr=%s).' % ( + from_path, to_path, stdoutdata, stderrdata)) + + +# ----------------------------------------------------------------------------- +# TextFileSource, TextFileSink. + + +class TextFileSource(iobase.NativeSource): + """A source for a GCS or local text file. + + Parses a text file as newline-delimited elements, by default assuming + UTF-8 encoding. + """ + + def __init__(self, file_path, start_offset=None, end_offset=None, + compression_type='AUTO', strip_trailing_newlines=True, + coder=coders.StrUtf8Coder()): + """Initialize a TextSource. + + Args: + file_path: The file path to read from as a local file path or a GCS + gs:// path. The path can contain glob characters (*, ?, and [...] + sets). + start_offset: The byte offset in the source text file that the reader + should start reading. By default is 0 (beginning of file). + end_offset: The byte offset in the file that the reader should stop + reading. By default it is the end of the file. + compression_type: Used to handle compressed input files. Typical value + is 'AUTO'. + strip_trailing_newlines: Indicates whether this source should remove + the newline char in each line it reads before decoding that line. + coder: Coder used to decode each line. + + Raises: + TypeError: if file_path is not a string. + + If the file_path contains glob characters then the start_offset and + end_offset must not be specified. + + The 'start_offset' and 'end_offset' pair provide a mechanism to divide the + text file into multiple pieces for individual sources. Because the offset + is measured by bytes, some complication arises when the offset splits in + the middle of a text line. To avoid the scenario where two adjacent sources + each get a fraction of a line we adopt the following rules: + + If start_offset falls inside a line (any character except the firt one) + then the source will skip the line and start with the next one. + + If end_offset falls inside a line (any character except the first one) then + the source will contain that entire line. + """ + if not isinstance(file_path, basestring): + raise TypeError( + '%s: file_path must be a string; got %r instead' % + (self.__class__.__name__, file_path)) + + self.file_path = file_path + self.start_offset = start_offset + self.end_offset = end_offset + self.compression_type = compression_type + self.strip_trailing_newlines = strip_trailing_newlines + self.coder = coder + + self.is_gcs_source = file_path.startswith('gs://') + + @property + def format(self): + """Source format name required for remote execution.""" + return 'text' + + def __eq__(self, other): + return (self.file_path == other.file_path and + self.start_offset == other.start_offset and + self.end_offset == other.end_offset and + self.strip_trailing_newlines == other.strip_trailing_newlines and + self.coder == other.coder) + + @property + def path(self): + return self.file_path + + def reader(self): + # If a multi-file pattern was specified as a source then make sure the + # start/end offsets use the default values for reading the entire file. + if re.search(r'[*?\[\]]', self.file_path) is not None: + if self.start_offset is not None: + raise ValueError( + 'start offset cannot be specified for a multi-file source: ' + '%s' % self.file_path) + if self.end_offset is not None: + raise ValueError( + 'End offset cannot be specified for a multi-file source: ' + '%s' % self.file_path) + return TextMultiFileReader(self) + else: + return TextFileReader(self) + + +class ChannelFactory(object): + # TODO(robertwb): Generalize into extensible framework. + + @staticmethod + def mkdir(path): + if path.startswith('gs://'): + return + else: + try: + os.makedirs(path) + except OSError as err: + raise IOError(err) + + @staticmethod + def open(path, mode, mime_type): + if path.startswith('gs://'): + # pylint: disable=g-import-not-at-top + from google.cloud.dataflow.io import gcsio + return gcsio.GcsIO().open(path, mode, mime_type=mime_type) + else: + return open(path, mode) + + @staticmethod + def rename(src, dst): + if src.startswith('gs://'): + assert dst.startswith('gs://'), dst + # pylint: disable=g-import-not-at-top + from google.cloud.dataflow.io import gcsio + gcsio.GcsIO().rename(src, dst) + else: + try: + os.rename(src, dst) + except OSError as err: + raise IOError(err) + + @staticmethod + def copytree(src, dst): + if src.startswith('gs://'): + assert dst.startswith('gs://'), dst + assert src.endswith('/'), src + assert dst.endswith('/'), dst + # pylint: disable=g-import-not-at-top + from google.cloud.dataflow.io import gcsio + gcsio.GcsIO().copytree(src, dst) + else: + try: + if os.path.exists(dst): + shutil.rmtree(dst) + shutil.copytree(src, dst) + except OSError as err: + raise IOError(err) + + @staticmethod + def exists(path): + if path.startswith('gs://'): + # pylint: disable=g-import-not-at-top + from google.cloud.dataflow.io import gcsio + return gcsio.GcsIO().exists() + else: + return os.path.exists(path) + + @staticmethod + def rmdir(path): + if path.startswith('gs://'): + # pylint: disable=g-import-not-at-top + from google.cloud.dataflow.io import gcsio + gcs = gcsio.GcsIO() + if not path.endswith('/'): + path += '/' + # TODO(robertwb): Threadpool? + for entry in gcs.glob(path + '*'): + gcs.delete(entry) + else: + try: + shutil.rmtree(path) + except OSError as err: + raise IOError(err) + + @staticmethod + def rm(path): + if path.startswith('gs://'): + # pylint: disable=g-import-not-at-top + from google.cloud.dataflow.io import gcsio + gcsio.GcsIO().delete(path) + else: + try: + os.remove(path) + except OSError as err: + raise IOError(err) + + @staticmethod + def glob(path): + if path.startswith('gs://'): + # pylint: disable=g-import-not-at-top + from google.cloud.dataflow.io import gcsio + return gcsio.GcsIO().glob(path) + else: + return glob.glob(path) + + +class _CompressionType(object): + """Object representing single compression type.""" + + def __init__(self, identifier): + self.identifier = identifier + + def __eq__(self, other): + return self.identifier == other.identifier + + +class CompressionTypes(object): + """Enum-like class representing known compression types.""" + NO_COMPRESSION = _CompressionType(1) # No compression. + DEFLATE = _CompressionType(2) # 'Deflate' ie gzip compression. + + @staticmethod + def valid_compression_type(compression_type): + """Returns true for valid compression types, false otherwise.""" + return isinstance(compression_type, _CompressionType) + + +class FileSink(iobase.Sink): + """A sink to a GCS or local files. + + To implement a file-based sink, extend this class and override + either ``write_record()`` or ``write_encoded_record()``. + + If needed, also overwrite ``open()`` and/or ``close()`` to customize the + file handling or write headers and footers. + + The output of this write is a PCollection of all written shards. + """ + + # Approximate number of write results be assigned for each rename thread. + _WRITE_RESULTS_PER_RENAME_THREAD = 100 + + # Max number of threads to be used for renaming even if it means each thread + # will process more write results. + _MAX_RENAME_THREADS = 64 + + def __init__(self, + file_path_prefix, + coder, + file_name_suffix='', + num_shards=0, + shard_name_template=None, + mime_type='application/octet-stream'): + if shard_name_template is None: + shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE + elif shard_name_template is '': + num_shards = 1 + self.file_path_prefix = file_path_prefix + self.file_name_suffix = file_name_suffix + self.num_shards = num_shards + self.coder = coder + self.mime_type = mime_type + self.shard_name_format = self._template_to_format(shard_name_template) + + def open(self, temp_path): + """Opens ``temp_path``, returning an opaque file handle object. + + The returned file handle is passed to ``write_[encoded_]record`` and + ``close``. + """ + return ChannelFactory.open(temp_path, 'wb', self.mime_type) + + def write_record(self, file_handle, value): + """Writes a single record go the file handle returned by ``open()``. + + By default, calls ``write_encoded_record`` after encoding the record with + this sink's Coder. + """ + self.write_encoded_record(file_handle, self.coder.encode(value)) + + def write_encoded_record(self, file_handle, encoded_value): + """Writes a single encoded record to the file handle returned by ``open()``. + """ + raise NotImplementedError + + def close(self, file_handle): + """Finalize and close the file handle returned from ``open()``. + + Called after all records are written. + + By default, calls ``file_handle.close()`` iff it is not None. + """ + if file_handle is not None: + file_handle.close() + + def initialize_write(self): + tmp_dir = self.file_path_prefix + self.file_name_suffix + time.strftime( + '-temp-%Y-%m-%d_%H-%M-%S') + ChannelFactory().mkdir(tmp_dir) + return tmp_dir + + def open_writer(self, init_result, uid): + return FileSinkWriter(self, os.path.join(init_result, uid)) + + def finalize_write(self, init_result, writer_results): + writer_results = sorted(writer_results) + num_shards = len(writer_results) + channel_factory = ChannelFactory() + num_threads = max(1, min( + num_shards / FileSink._WRITE_RESULTS_PER_RENAME_THREAD, + FileSink._MAX_RENAME_THREADS)) + + rename_ops = [] + for shard_num, shard in enumerate(writer_results): + final_name = ''.join([ + self.file_path_prefix, + self.shard_name_format % dict(shard_num=shard_num, + num_shards=num_shards), + self.file_name_suffix]) + rename_ops.append((shard, final_name)) + + logging.info( + 'Starting finalize_write threads with num_shards: %d, num_threads: %d', + num_shards, num_threads) + start_time = time.time() + + # Use a thread pool for renaming operations. + def _rename_file(rename_op): + """_rename_file executes single (old_name, new_name) rename operation.""" + old_name, final_name = rename_op + try: + channel_factory.rename(old_name, final_name) + except IOError as e: + # May have already been copied. + exists = channel_factory.exists(final_name) + if not exists: + logging.warning(('IOError in _rename_file. old_name: %s, ' + 'final_name: %s, err: %s'), old_name, final_name, e) + return(None, e) + except Exception as e: # pylint: disable=broad-except + logging.warning(('Exception in _rename_file. old_name: %s, ' + 'final_name: %s, err: %s'), old_name, final_name, e) + return(None, e) + return (final_name, None) + + rename_results = ThreadPool(num_threads).map(_rename_file, rename_ops) + + for final_name, err in rename_results: + if err: + logging.warning('Error when processing rename_results: %s', err) + raise err + else: + yield final_name + + logging.info('Renamed %d shards in %.2f seconds.', + num_shards, time.time() - start_time) + + try: + channel_factory.rmdir(init_result) + except IOError: + # May have already been removed. + pass + + @staticmethod + def _template_to_format(shard_name_template): + if not shard_name_template: + return '' + m = re.search('S+', shard_name_template) + if m is None: + raise ValueError("Shard number pattern S+ not found in template '%s'" + % shard_name_template) + shard_name_format = shard_name_template.replace( + m.group(0), '%%(shard_num)0%dd' % len(m.group(0))) + m = re.search('N+', shard_name_format) + if m: + shard_name_format = shard_name_format.replace( + m.group(0), '%%(num_shards)0%dd' % len(m.group(0))) + return shard_name_format + + def __eq__(self, other): + # TODO(robertwb): Clean up workitem_test which uses this. + # pylint: disable=unidiomatic-typecheck + return type(self) == type(other) and self.__dict__ == other.__dict__ + + +class FileSinkWriter(iobase.Writer): + """The writer for FileSink. + """ + + def __init__(self, sink, temp_shard_path): + self.sink = sink + self.temp_shard_path = temp_shard_path + self.temp_handle = self.sink.open(temp_shard_path) + + def write(self, value): + self.sink.write_record(self.temp_handle, value) + + def close(self): + self.sink.close(self.temp_handle) + return self.temp_shard_path + + +class TextFileSink(FileSink): + """A sink to a GCS or local text file or files.""" + + def __init__(self, + file_path_prefix, + file_name_suffix='', + append_trailing_newlines=True, + num_shards=0, + shard_name_template=None, + coder=coders.ToStringCoder(), + compression_type=CompressionTypes.NO_COMPRESSION, + ): + """Initialize a TextFileSink. + + Args: + file_path_prefix: The file path to write to. The files written will begin + with this prefix, followed by a shard identifier (see num_shards), and + end in a common extension, if given by file_name_suffix. In most cases, + only this argument is specified and num_shards, shard_name_template, and + file_name_suffix use default values. + file_name_suffix: Suffix for the files written. + append_trailing_newlines: indicate whether this sink should write an + additional newline char after writing each element. + num_shards: The number of files (shards) used for output. If not set, the + service will decide on the optimal number of shards. + Constraining the number of shards is likely to reduce + the performance of a pipeline. Setting this value is not recommended + unless you require a specific number of output files. + shard_name_template: A template string containing placeholders for + the shard number and shard count. Currently only '' and + '-SSSSS-of-NNNNN' are patterns accepted by the service. + When constructing a filename for a particular shard number, the + upper-case letters 'S' and 'N' are replaced with the 0-padded shard + number and shard count respectively. This argument can be '' in which + case it behaves as if num_shards was set to 1 and only one file will be + generated. The default pattern used is '-SSSSS-of-NNNNN'. + coder: Coder used to encode each line. + compression_type: Type of compression to use for this sink. + + Raises: + TypeError: if file path parameters are not a string or if compression_type + is not member of CompressionTypes. + ValueError: if shard_name_template is not of expected format. + + Returns: + A TextFileSink object usable for writing. + """ + if not isinstance(file_path_prefix, basestring): + raise TypeError( + 'TextFileSink: file_path_prefix must be a string; got %r instead' % + file_path_prefix) + if not isinstance(file_name_suffix, basestring): + raise TypeError( + 'TextFileSink: file_name_suffix must be a string; got %r instead' % + file_name_suffix) + + if not CompressionTypes.valid_compression_type(compression_type): + raise TypeError('compression_type must be CompressionType object but ' + 'was %s' % type(compression_type)) + if compression_type == CompressionTypes.DEFLATE: + mime_type = 'application/x-gzip' + else: + mime_type = 'text/plain' + + super(TextFileSink, self).__init__(file_path_prefix, + file_name_suffix=file_name_suffix, + num_shards=num_shards, + shard_name_template=shard_name_template, + coder=coder, + mime_type=mime_type) + + self.compression_type = compression_type + self.append_trailing_newlines = append_trailing_newlines + + def open(self, temp_path): + """Opens ''temp_path'', returning a writeable file object.""" + fobj = ChannelFactory.open(temp_path, 'wb', self.mime_type) + if self.compression_type == CompressionTypes.DEFLATE: + return gzip.GzipFile(fileobj=fobj) + return fobj + + def write_encoded_record(self, file_handle, encoded_value): + file_handle.write(encoded_value) + if self.append_trailing_newlines: + file_handle.write('\n') + + +class NativeTextFileSink(iobase.NativeSink): + """A sink to a GCS or local text file or files.""" + + def __init__(self, file_path_prefix, + append_trailing_newlines=True, + file_name_suffix='', + num_shards=0, + shard_name_template=None, + validate=True, + coder=coders.ToStringCoder()): + # We initialize a file_path attribute containing just the prefix part for + # local runner environment. For now, sharding is not supported in the local + # runner and sharding options (template, num, suffix) are ignored. + # The attribute is also used in the worker environment when we just write + # to a specific file. + # TODO(silviuc): Add support for file sharding in the local runner. + self.file_path = file_path_prefix + self.append_trailing_newlines = append_trailing_newlines + self.coder = coder + + self.is_gcs_sink = self.file_path.startswith('gs://') + + self.file_name_prefix = file_path_prefix + self.file_name_suffix = file_name_suffix + self.num_shards = num_shards + # TODO(silviuc): Update this when the service supports more patterns. + self.shard_name_template = ('-SSSSS-of-NNNNN' if shard_name_template is None + else shard_name_template) + # TODO(silviuc): Implement sink validation. + self.validate = validate + + @property + def format(self): + """Sink format name required for remote execution.""" + return 'text' + + @property + def path(self): + return self.file_path + + def writer(self): + return TextFileWriter(self) + + def __eq__(self, other): + return (self.file_path == other.file_path and + self.append_trailing_newlines == other.append_trailing_newlines and + self.coder == other.coder and + self.file_name_prefix == other.file_name_prefix and + self.file_name_suffix == other.file_name_suffix and + self.num_shards == other.num_shards and + self.shard_name_template == other.shard_name_template and + self.validate == other.validate) + + +# ----------------------------------------------------------------------------- +# TextFileReader, TextMultiFileReader. + + +class TextFileReader(iobase.NativeSourceReader): + """A reader for a text file source.""" + + def __init__(self, source): + self.source = source + self.start_offset = self.source.start_offset or 0 + self.end_offset = self.source.end_offset + self.current_offset = self.start_offset + + def __enter__(self): + if self.source.is_gcs_source: + # pylint: disable=g-import-not-at-top + from google.cloud.dataflow.io import gcsio + self._file = gcsio.GcsIO().open(self.source.file_path, 'rb') + else: + self._file = open(self.source.file_path, 'rb') + # Determine the real end_offset. + # If not specified it will be the length of the file. + if self.end_offset is None: + self._file.seek(0, os.SEEK_END) + self.end_offset = self._file.tell() + + if self.start_offset is None: + self.start_offset = 0 + self.current_offset = self.start_offset + if self.start_offset > 0: + # Read one byte before. This operation will either consume a previous + # newline if start_offset was at the beginning of a line or consume the + # line if we were in the middle of it. Either way we get the read position + # exactly where we wanted: at the begining of the first full line. + self._file.seek(self.start_offset - 1) + self.current_offset -= 1 + line = self._file.readline() + self.current_offset += len(line) + else: + self._file.seek(self.start_offset) + + # Initializing range tracker after start and end offsets are finalized. + self.range_tracker = range_trackers.OffsetRangeTracker(self.start_offset, + self.end_offset) + + return self + + def __exit__(self, exception_type, exception_value, traceback): + self._file.close() + + def __iter__(self): + while True: + if not self.range_tracker.try_claim( + record_start=self.current_offset): + # Reader has completed reading the set of records in its range. Note + # that the end offset of the range may be smaller than the original + # end offset defined when creating the reader due to reader accepting + # a dynamic split request from the service. + return + line = self._file.readline() + self.current_offset += len(line) + if self.source.strip_trailing_newlines: + line = line.rstrip('\n') + yield self.source.coder.decode(line) + + def get_progress(self): + return iobase.ReaderProgress(position=iobase.ReaderPosition( + byte_offset=self.range_tracker.last_record_start)) + + def request_dynamic_split(self, dynamic_split_request): + assert dynamic_split_request is not None + progress = dynamic_split_request.progress + split_position = progress.position + if split_position is None: + percent_complete = progress.percent_complete + if percent_complete is not None: + if percent_complete <= 0 or percent_complete >= 1: + logging.warning( + 'FileBasedReader cannot be split since the provided percentage ' + 'of work to be completed is out of the valid range (0, ' + '1). Requested: %r', + dynamic_split_request) + return + split_position = iobase.ReaderPosition() + split_position.byte_offset = ( + self.range_tracker.position_at_fraction(percent_complete)) + else: + logging.warning( + 'TextReader requires either a position or a percentage of work to ' + 'be complete to perform a dynamic split request. Requested: %r', + dynamic_split_request) + return + + if self.range_tracker.try_split(split_position.byte_offset): + return iobase.DynamicSplitResultWithPosition(split_position) + else: + return + + +class TextMultiFileReader(iobase.NativeSourceReader): + """A reader for a multi-file text source.""" + + def __init__(self, source): + self.source = source + self.file_paths = ChannelFactory.glob(self.source.file_path) + if not self.file_paths: + raise RuntimeError( + 'No files found for path: %s' % self.source.file_path) + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + pass + + def __iter__(self): + index = 0 + for path in self.file_paths: + index += 1 + logging.info('Reading from %s (%d/%d)', path, index, len(self.file_paths)) + with TextFileSource( + path, strip_trailing_newlines=self.source.strip_trailing_newlines, + coder=self.source.coder).reader() as reader: + for line in reader: + yield line + + +# ----------------------------------------------------------------------------- +# TextFileWriter. + + +class TextFileWriter(iobase.NativeSinkWriter): + """The sink writer for a TextFileSink.""" + + def __init__(self, sink): + self.sink = sink + + def __enter__(self): + if self.sink.is_gcs_sink: + # TODO(silviuc): Use the storage library instead of gsutil for writes. + self.temp_path = os.path.join(tempfile.mkdtemp(), 'gcsfile') + self._file = open(self.temp_path, 'wb') + else: + self._file = open(self.sink.file_path, 'wb') + return self + + def __exit__(self, exception_type, exception_value, traceback): + self._file.close() + if hasattr(self, 'temp_path'): + _gcs_file_copy(self.temp_path, self.sink.file_path, 'text/plain') + + def Write(self, line): + self._file.write(self.sink.coder.encode(line)) + if self.sink.append_trailing_newlines: + self._file.write('\n') http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/fileio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/fileio_test.py b/sdks/python/apache_beam/io/fileio_test.py new file mode 100644 index 0000000..70192d1 --- /dev/null +++ b/sdks/python/apache_beam/io/fileio_test.py @@ -0,0 +1,522 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for local and GCS sources and sinks.""" + +import glob +import gzip +import logging +import os +import tempfile +import unittest + +import google.cloud.dataflow as df +from google.cloud.dataflow import coders +from google.cloud.dataflow.io import fileio +from google.cloud.dataflow.io import iobase + + +class TestTextFileSource(unittest.TestCase): + + def create_temp_file(self, text): + temp = tempfile.NamedTemporaryFile(delete=False) + with temp.file as tmp: + tmp.write(text) + return temp.name + + def read_with_offsets(self, input_lines, output_lines, + start_offset=None, end_offset=None): + source = fileio.TextFileSource( + file_path=self.create_temp_file('\n'.join(input_lines)), + start_offset=start_offset, end_offset=end_offset) + read_lines = [] + with source.reader() as reader: + for line in reader: + read_lines.append(line) + self.assertEqual(read_lines, output_lines) + + def progress_with_offsets(self, input_lines, + start_offset=None, end_offset=None): + source = fileio.TextFileSource( + file_path=self.create_temp_file('\n'.join(input_lines)), + start_offset=start_offset, end_offset=end_offset) + progress_record = [] + with source.reader() as reader: + self.assertEqual(reader.get_progress().position.byte_offset, -1) + for line in reader: + self.assertIsNotNone(line) + progress_record.append(reader.get_progress().position.byte_offset) + + previous = 0 + for current in progress_record: + self.assertGreater(current, previous) + previous = current + + def test_read_entire_file(self): + lines = ['First', 'Second', 'Third'] + source = fileio.TextFileSource( + file_path=self.create_temp_file('\n'.join(lines))) + read_lines = [] + with source.reader() as reader: + for line in reader: + read_lines.append(line) + self.assertEqual(read_lines, lines) + + def test_progress_entire_file(self): + lines = ['First', 'Second', 'Third'] + source = fileio.TextFileSource( + file_path=self.create_temp_file('\n'.join(lines))) + progress_record = [] + with source.reader() as reader: + self.assertEqual(-1, reader.get_progress().position.byte_offset) + for line in reader: + self.assertIsNotNone(line) + progress_record.append(reader.get_progress().position.byte_offset) + self.assertEqual(13, reader.get_progress().position.byte_offset) + + self.assertEqual(len(progress_record), 3) + self.assertEqual(progress_record, [0, 6, 13]) + + def try_splitting_reader_at(self, reader, split_request, expected_response): + actual_response = reader.request_dynamic_split(split_request) + + if expected_response is None: + self.assertIsNone(actual_response) + else: + self.assertIsNotNone(actual_response.stop_position) + self.assertIsInstance(actual_response.stop_position, + iobase.ReaderPosition) + self.assertIsNotNone(actual_response.stop_position.byte_offset) + self.assertEqual(expected_response.stop_position.byte_offset, + actual_response.stop_position.byte_offset) + + return actual_response + + def test_update_stop_position_for_percent_complete(self): + lines = ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee'] + source = fileio.TextFileSource( + file_path=self.create_temp_file('\n'.join(lines))) + with source.reader() as reader: + # Reading two lines + reader_iter = iter(reader) + next(reader_iter) + next(reader_iter) + next(reader_iter) + + # Splitting at end of the range should be unsuccessful + self.try_splitting_reader_at( + reader, + iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=0)), + None) + self.try_splitting_reader_at( + reader, + iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete=1)), + None) + + # Splitting at positions on or before start offset of the last record + self.try_splitting_reader_at( + reader, + iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete= + 0.2)), + None) + self.try_splitting_reader_at( + reader, + iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete= + 0.4)), + None) + + # Splitting at a position after the start offset of the last record should + # be successful + self.try_splitting_reader_at( + reader, + iobase.DynamicSplitRequest(iobase.ReaderProgress(percent_complete= + 0.6)), + iobase.DynamicSplitResultWithPosition(iobase.ReaderPosition( + byte_offset=15))) + + def test_update_stop_position_percent_complete_for_position(self): + lines = ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee'] + source = fileio.TextFileSource( + file_path=self.create_temp_file('\n'.join(lines))) + with source.reader() as reader: + # Reading two lines + reader_iter = iter(reader) + next(reader_iter) + next(reader_iter) + next(reader_iter) + + # Splitting at end of the range should be unsuccessful + self.try_splitting_reader_at( + reader, + iobase.DynamicSplitRequest(iobase.ReaderProgress( + position=iobase.ReaderPosition(byte_offset=0))), + None) + self.try_splitting_reader_at( + reader, + iobase.DynamicSplitRequest(iobase.ReaderProgress( + position=iobase.ReaderPosition(byte_offset=25))), + None) + + # Splitting at positions on or before start offset of the last record + self.try_splitting_reader_at( + reader, + iobase.DynamicSplitRequest(iobase.ReaderProgress( + position=iobase.ReaderPosition(byte_offset=5))), + None) + self.try_splitting_reader_at( + reader, + iobase.DynamicSplitRequest(iobase.ReaderProgress( + position=iobase.ReaderPosition(byte_offset=10))), + None) + + # Splitting at a position after the start offset of the last record should + # be successful + self.try_splitting_reader_at( + reader, + iobase.DynamicSplitRequest(iobase.ReaderProgress( + position=iobase.ReaderPosition(byte_offset=15))), + iobase.DynamicSplitResultWithPosition(iobase.ReaderPosition( + byte_offset=15))) + + def run_update_stop_position_exhaustive(self, lines, newline): + """An exhaustive test for dynamic splitting. + + For the given set of data items, try to perform a split at all possible + combinations of following. + + * start position + * original stop position + * updated stop position + * number of items read + + Args: + lines: set of data items to be used to create the file + newline: separater to be used when writing give set of lines to a text + file. + """ + + file_path = self.create_temp_file(newline.join(lines)) + + total_records = len(lines) + total_bytes = 0 + + for line in lines: + total_bytes += len(line) + total_bytes += len(newline) * (total_records - 1) + + for start in xrange(0, total_bytes - 1): + for end in xrange(start + 1, total_bytes): + for stop in xrange(start, end): + for records_to_read in range(0, total_records): + self.run_update_stop_position(start, end, stop, records_to_read, + file_path) + + def test_update_stop_position_exhaustive(self): + self.run_update_stop_position_exhaustive( + ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee'], '\n') + + def test_update_stop_position_exhaustive_with_empty_lines(self): + self.run_update_stop_position_exhaustive( + ['', 'aaaa', '', 'bbbb', 'cccc', '', 'dddd', 'eeee', ''], '\n') + + def test_update_stop_position_exhaustive_windows_newline(self): + self.run_update_stop_position_exhaustive( + ['aaaa', 'bbbb', 'cccc', 'dddd', 'eeee'], '\r\n') + + def test_update_stop_position_exhaustive_multi_byte(self): + self.run_update_stop_position_exhaustive( + [u'අඅඅඅ'.encode('utf-8'), u'බබබබ'.encode('utf-8'), + u'කකකක'.encode('utf-8')], '\n') + + def run_update_stop_position(self, start_offset, end_offset, stop_offset, + records_to_read, + file_path): + source = fileio.TextFileSource(file_path, start_offset, end_offset) + + records_of_first_split = '' + + with source.reader() as reader: + reader_iter = iter(reader) + i = 0 + + try: + while i < records_to_read: + records_of_first_split += next(reader_iter) + i += 1 + except StopIteration: + # Invalid case, given source does not contain this many records. + return + + last_record_start_after_reading = reader.range_tracker.last_record_start + + if stop_offset <= last_record_start_after_reading: + expected_split_response = None + elif stop_offset == start_offset or stop_offset == end_offset: + expected_split_response = None + elif records_to_read == 0: + expected_split_response = None # unstarted + else: + expected_split_response = iobase.DynamicSplitResultWithPosition( + stop_position=iobase.ReaderPosition(byte_offset=stop_offset)) + + split_response = self.try_splitting_reader_at( + reader, + iobase.DynamicSplitRequest(progress=iobase.ReaderProgress( + iobase.ReaderPosition(byte_offset=stop_offset))), + expected_split_response) + + # Reading remaining records from the updated reader. + for line in reader: + records_of_first_split += line + + if split_response is not None: + # Total contents received by reading the two splits should be equal to the + # result obtained by reading the original source. + records_of_original = '' + records_of_second_split = '' + + with source.reader() as original_reader: + for line in original_reader: + records_of_original += line + + new_source = fileio.TextFileSource( + file_path, + split_response.stop_position.byte_offset, + end_offset) + with new_source.reader() as reader: + for line in reader: + records_of_second_split += line + + self.assertEqual(records_of_original, + records_of_first_split + records_of_second_split) + + def test_various_offset_combination_with_local_file_for_read(self): + lines = ['01234', '6789012', '456789012'] + self.read_with_offsets(lines, lines[1:], start_offset=5) + self.read_with_offsets(lines, lines[1:], start_offset=6) + self.read_with_offsets(lines, lines[2:], start_offset=7) + self.read_with_offsets(lines, lines[1:2], start_offset=5, end_offset=13) + self.read_with_offsets(lines, lines[1:2], start_offset=5, end_offset=14) + self.read_with_offsets(lines, lines[1:], start_offset=5, end_offset=16) + self.read_with_offsets(lines, lines[2:], start_offset=14, end_offset=20) + self.read_with_offsets(lines, lines[2:], start_offset=14) + self.read_with_offsets(lines, [], start_offset=20, end_offset=20) + + def test_various_offset_combination_with_local_file_for_progress(self): + lines = ['01234', '6789012', '456789012'] + self.progress_with_offsets(lines, start_offset=5) + self.progress_with_offsets(lines, start_offset=6) + self.progress_with_offsets(lines, start_offset=7) + self.progress_with_offsets(lines, start_offset=5, end_offset=13) + self.progress_with_offsets(lines, start_offset=5, end_offset=14) + self.progress_with_offsets(lines, start_offset=5, end_offset=16) + self.progress_with_offsets(lines, start_offset=14, end_offset=20) + self.progress_with_offsets(lines, start_offset=14) + self.progress_with_offsets(lines, start_offset=20, end_offset=20) + + +class NativeTestTextFileSink(unittest.TestCase): + + def create_temp_file(self): + temp = tempfile.NamedTemporaryFile(delete=False) + return temp.name + + def test_write_entire_file(self): + lines = ['First', 'Second', 'Third'] + file_path = self.create_temp_file() + sink = fileio.NativeTextFileSink(file_path) + with sink.writer() as writer: + for line in lines: + writer.Write(line) + with open(file_path, 'r') as f: + self.assertEqual(f.read().splitlines(), lines) + + +class TestPureTextFileSink(unittest.TestCase): + + def setUp(self): + self.lines = ['Line %d' % d for d in range(100)] + self.path = tempfile.NamedTemporaryFile().name + + def _write_lines(self, sink, lines): + f = sink.open(self.path) + for line in lines: + sink.write_record(f, line) + sink.close(f) + + def test_write_text_file(self): + sink = fileio.TextFileSink(self.path) + self._write_lines(sink, self.lines) + + with open(self.path, 'r') as f: + self.assertEqual(f.read().splitlines(), self.lines) + + def test_write_gzip_file(self): + sink = fileio.TextFileSink( + self.path, compression_type=fileio.CompressionTypes.DEFLATE) + self._write_lines(sink, self.lines) + + with gzip.GzipFile(self.path, 'r') as f: + self.assertEqual(f.read().splitlines(), self.lines) + + +class MyFileSink(fileio.FileSink): + + def open(self, temp_path): + # TODO(robertwb): Fix main session pickling. + # file_handle = super(MyFileSink, self).open(temp_path) + file_handle = fileio.FileSink.open(self, temp_path) + file_handle.write('[start]') + return file_handle + + def write_encoded_record(self, file_handle, encoded_value): + file_handle.write('[') + file_handle.write(encoded_value) + file_handle.write(']') + + def close(self, file_handle): + file_handle.write('[end]') + # TODO(robertwb): Fix main session pickling. + # file_handle = super(MyFileSink, self).close(file_handle) + file_handle = fileio.FileSink.close(self, file_handle) + + +class TestFileSink(unittest.TestCase): + + def test_file_sink_writing(self): + temp_path = tempfile.NamedTemporaryFile().name + sink = MyFileSink(temp_path, + file_name_suffix='.foo', + coder=coders.ToStringCoder()) + + # Manually invoke the generic Sink API. + init_token = sink.initialize_write() + + writer1 = sink.open_writer(init_token, '1') + writer1.write('a') + writer1.write('b') + res1 = writer1.close() + + writer2 = sink.open_writer(init_token, '2') + writer2.write('x') + writer2.write('y') + writer2.write('z') + res2 = writer2.close() + + res = list(sink.finalize_write(init_token, [res1, res2])) + # Retry the finalize operation (as if the first attempt was lost). + res = list(sink.finalize_write(init_token, [res1, res2])) + + # Check the results. + shard1 = temp_path + '-00000-of-00002.foo' + shard2 = temp_path + '-00001-of-00002.foo' + self.assertEqual(res, [shard1, shard2]) + self.assertEqual(open(shard1).read(), '[start][a][b][end]') + self.assertEqual(open(shard2).read(), '[start][x][y][z][end]') + + # Check that any temp files are deleted. + self.assertItemsEqual([shard1, shard2], glob.glob(temp_path + '*')) + + def test_empty_write(self): + temp_path = tempfile.NamedTemporaryFile().name + sink = MyFileSink(temp_path, + file_name_suffix='.foo', + coder=coders.ToStringCoder()) + p = df.Pipeline('DirectPipelineRunner') + p | df.Create([]) | df.io.Write(sink) # pylint: disable=expression-not-assigned + p.run() + self.assertEqual(open(temp_path + '-00000-of-00001.foo').read(), + '[start][end]') + + def test_fixed_shard_write(self): + temp_path = tempfile.NamedTemporaryFile().name + sink = MyFileSink(temp_path, + file_name_suffix='.foo', + num_shards=3, + shard_name_template='_NN_SSS_', + coder=coders.ToStringCoder()) + p = df.Pipeline('DirectPipelineRunner') + p | df.Create(['a', 'b']) | df.io.Write(sink) # pylint: disable=expression-not-assigned + + p.run() + + concat = ''.join(open(temp_path + '_03_%03d_.foo' % shard_num).read() + for shard_num in range(3)) + self.assertTrue('][a][' in concat, concat) + self.assertTrue('][b][' in concat, concat) + + def test_file_sink_multi_shards(self): + temp_path = tempfile.NamedTemporaryFile().name + sink = MyFileSink(temp_path, + file_name_suffix='.foo', + coder=coders.ToStringCoder()) + + # Manually invoke the generic Sink API. + init_token = sink.initialize_write() + + num_shards = 1000 + writer_results = [] + for i in range(num_shards): + uuid = 'uuid-%05d' % i + writer = sink.open_writer(init_token, uuid) + writer.write('a') + writer.write('b') + writer.write(uuid) + writer_results.append(writer.close()) + + res_first = list(sink.finalize_write(init_token, writer_results)) + # Retry the finalize operation (as if the first attempt was lost). + res_second = list(sink.finalize_write(init_token, writer_results)) + + self.assertItemsEqual(res_first, res_second) + + res = sorted(res_second) + for i in range(num_shards): + shard_name = '%s-%05d-of-%05d.foo' % (temp_path, i, num_shards) + uuid = 'uuid-%05d' % i + self.assertEqual(res[i], shard_name) + self.assertEqual( + open(shard_name).read(), ('[start][a][b][%s][end]' % uuid)) + + # Check that any temp files are deleted. + self.assertItemsEqual(res, glob.glob(temp_path + '*')) + + def test_file_sink_io_error(self): + temp_path = tempfile.NamedTemporaryFile().name + sink = MyFileSink(temp_path, + file_name_suffix='.foo', + coder=coders.ToStringCoder()) + + # Manually invoke the generic Sink API. + init_token = sink.initialize_write() + + writer1 = sink.open_writer(init_token, '1') + writer1.write('a') + writer1.write('b') + res1 = writer1.close() + + writer2 = sink.open_writer(init_token, '2') + writer2.write('x') + writer2.write('y') + writer2.write('z') + res2 = writer2.close() + + os.remove(res2) + with self.assertRaises(IOError): + list(sink.finalize_write(init_token, [res1, res2])) + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/gcsio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/gcsio.py b/sdks/python/apache_beam/io/gcsio.py new file mode 100644 index 0000000..8157b76 --- /dev/null +++ b/sdks/python/apache_beam/io/gcsio.py @@ -0,0 +1,602 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google Cloud Storage client. + +This library evolved from the Google App Engine GCS client available at +https://github.com/GoogleCloudPlatform/appengine-gcs-client. +""" + +import errno +import fnmatch +import logging +import multiprocessing +import os +import re +import StringIO +import threading + +from google.cloud.dataflow.internal import auth +from google.cloud.dataflow.utils import retry + +from apitools.base.py.exceptions import HttpError +import apitools.base.py.transfer as transfer + +# Issue a friendlier error message if the storage library is not available. +# TODO(silviuc): Remove this guard when storage is available everywhere. +try: + # pylint: disable=g-import-not-at-top + from google.cloud.dataflow.internal.clients import storage +except ImportError: + raise RuntimeError( + 'Google Cloud Storage I/O not supported for this execution environment ' + '(could not import storage API client).') + + +DEFAULT_READ_BUFFER_SIZE = 1024 * 1024 + + +def parse_gcs_path(gcs_path): + """Return the bucket and object names of the given gs:// path.""" + match = re.match('^gs://([^/]+)/(.+)$', gcs_path) + if match is None: + raise ValueError('GCS path must be in the form gs:///.') + return match.group(1), match.group(2) + + +class GcsIOError(IOError, retry.PermanentException): + """GCS IO error that should not be retried.""" + pass + + +class GcsIO(object): + """Google Cloud Storage I/O client.""" + + def __new__(cls, storage_client=None): + if storage_client: + return super(GcsIO, cls).__new__(cls, storage_client) + else: + # Create a single storage client for each thread. We would like to avoid + # creating more than one storage client for each thread, since each + # initialization requires the relatively expensive step of initializing + # credentaials. + local_state = threading.local() + if getattr(local_state, 'gcsio_instance', None) is None: + credentials = auth.get_service_credentials() + storage_client = storage.StorageV1(credentials=credentials) + local_state.gcsio_instance = ( + super(GcsIO, cls).__new__(cls, storage_client)) + local_state.gcsio_instance.client = storage_client + return local_state.gcsio_instance + + def __init__(self, storage_client=None): + # We must do this check on storage_client because the client attribute may + # have already been set in __new__ for the singleton case when + # storage_client is None. + if storage_client is not None: + self.client = storage_client + + def open(self, filename, mode='r', + read_buffer_size=DEFAULT_READ_BUFFER_SIZE, + mime_type='application/octet-stream'): + """Open a GCS file path for reading or writing. + + Args: + filename: GCS file path in the form gs:///. + mode: 'r' for reading or 'w' for writing. + read_buffer_size: Buffer size to use during read operations. + mime_type: Mime type to set for write operations. + + Returns: + file object. + + Raises: + ValueError: Invalid open file mode. + """ + if mode == 'r' or mode == 'rb': + return GcsBufferedReader(self.client, filename, + buffer_size=read_buffer_size) + elif mode == 'w' or mode == 'wb': + return GcsBufferedWriter(self.client, filename, mime_type=mime_type) + else: + raise ValueError('Invalid file open mode: %s.' % mode) + + @retry.with_exponential_backoff( + retry_filter=retry.retry_on_server_errors_and_timeout_filter) + def glob(self, pattern): + """Return the GCS path names matching a given path name pattern. + + Path name patterns are those recognized by fnmatch.fnmatch(). The path + can contain glob characters (*, ?, and [...] sets). + + Args: + pattern: GCS file path pattern in the form gs:///. + + Returns: + list of GCS file paths matching the given pattern. + """ + bucket, name_pattern = parse_gcs_path(pattern) + # Get the prefix with which we can list objects in the given bucket. + prefix = re.match('^[^[*?]*', name_pattern).group(0) + request = storage.StorageObjectsListRequest(bucket=bucket, prefix=prefix) + object_paths = [] + while True: + response = self.client.objects.List(request) + for item in response.items: + if fnmatch.fnmatch(item.name, name_pattern): + object_paths.append('gs://%s/%s' % (item.bucket, item.name)) + if response.nextPageToken: + request.pageToken = response.nextPageToken + else: + break + return object_paths + + @retry.with_exponential_backoff( + retry_filter=retry.retry_on_server_errors_and_timeout_filter) + def delete(self, path): + """Deletes the object at the given GCS path. + + Args: + path: GCS file path pattern in the form gs:///. + """ + bucket, object_path = parse_gcs_path(path) + request = storage.StorageObjectsDeleteRequest(bucket=bucket, + object=object_path) + try: + self.client.objects.Delete(request) + except HttpError as http_error: + if http_error.status_code == 404: + # Return success when the file doesn't exist anymore for idempotency. + return + raise + + @retry.with_exponential_backoff( + retry_filter=retry.retry_on_server_errors_and_timeout_filter) + def copy(self, src, dest): + """Copies the given GCS object from src to dest. + + Args: + src: GCS file path pattern in the form gs:///. + dest: GCS file path pattern in the form gs:///. + """ + src_bucket, src_path = parse_gcs_path(src) + dest_bucket, dest_path = parse_gcs_path(dest) + request = storage.StorageObjectsCopyRequest(sourceBucket=src_bucket, + sourceObject=src_path, + destinationBucket=dest_bucket, + destinationObject=dest_path) + try: + self.client.objects.Copy(request) + except HttpError as http_error: + if http_error.status_code == 404: + # This is a permanent error that should not be retried. Note that + # FileSink.finalize_write expects an IOError when the source file does + # not exist. + raise GcsIOError(errno.ENOENT, 'Source file not found: %s' % src) + raise + + # We intentionally do not decorate this method with a retry, since the + # underlying copy and delete operations are already idempotent operations + # protected by retry decorators. + def copytree(self, src, dest): + """Renames the given GCS "directory" recursively from src to dest. + + Args: + src: GCS file path pattern in the form gs:////. + dest: GCS file path pattern in the form gs:////. + """ + assert src.endswith('/') + assert dest.endswith('/') + for entry in self.glob(src + '*'): + rel_path = entry[len(src):] + self.copy(entry, dest + rel_path) + + # We intentionally do not decorate this method with a retry, since the + # underlying copy and delete operations are already idempotent operations + # protected by retry decorators. + def rename(self, src, dest): + """Renames the given GCS object from src to dest. + + Args: + src: GCS file path pattern in the form gs:///. + dest: GCS file path pattern in the form gs:///. + """ + self.copy(src, dest) + self.delete(src) + + @retry.with_exponential_backoff( + retry_filter=retry.retry_on_server_errors_and_timeout_filter) + def exists(self, path): + """Returns whether the given GCS object exists. + + Args: + path: GCS file path pattern in the form gs:///. + """ + bucket, object_path = parse_gcs_path(path) + try: + request = storage.StorageObjectsGetRequest(bucket=bucket, + object=object_path) + self.client.objects.Get(request) # metadata + return True + except IOError: + return False + + +class GcsBufferedReader(object): + """A class for reading Google Cloud Storage files.""" + + def __init__(self, client, path, buffer_size=DEFAULT_READ_BUFFER_SIZE): + self.client = client + self.path = path + self.bucket, self.name = parse_gcs_path(path) + self.buffer_size = buffer_size + + # Get object state. + get_request = ( + storage.StorageObjectsGetRequest( + bucket=self.bucket, + object=self.name)) + try: + metadata = self._get_object_metadata(get_request) + except HttpError as http_error: + if http_error.status_code == 404: + raise IOError(errno.ENOENT, 'Not found: %s' % self.path) + else: + logging.error( + 'HTTP error while requesting file %s: %s', self.path, http_error) + raise + self.size = metadata.size + + # Ensure read is from file of the correct generation. + get_request.generation = metadata.generation + + # Initialize read buffer state. + self.download_stream = StringIO.StringIO() + self.downloader = transfer.Download( + self.download_stream, auto_transfer=False) + self.client.objects.Get(get_request, download=self.downloader) + self.position = 0 + self.buffer = '' + self.buffer_start_position = 0 + self.closed = False + + @retry.with_exponential_backoff( + retry_filter=retry.retry_on_server_errors_and_timeout_filter) + def _get_object_metadata(self, get_request): + return self.client.objects.Get(get_request) + + def read(self, size=-1): + """Read data from a GCS file. + + Args: + size: Number of bytes to read. Actual number of bytes read is always + equal to size unless EOF is reached. If size is negative or + unspecified, read the entire file. + + Returns: + data read as str. + + Raises: + IOError: When this buffer is closed. + """ + return self._read_inner(size=size, readline=False) + + def readline(self, size=-1): + """Read one line delimited by '\\n' from the file. + + Mimics behavior of the readline() method on standard file objects. + + A trailing newline character is kept in the string. It may be absent when a + file ends with an incomplete line. If the size argument is non-negative, + it specifies the maximum string size (counting the newline) to return. + A negative size is the same as unspecified. Empty string is returned + only when EOF is encountered immediately. + + Args: + size: Maximum number of bytes to read. If not specified, readline stops + only on '\\n' or EOF. + + Returns: + The data read as a string. + + Raises: + IOError: When this buffer is closed. + """ + return self._read_inner(size=size, readline=True) + + def _read_inner(self, size=-1, readline=False): + """Shared implementation of read() and readline().""" + self._check_open() + if not self._remaining(): + return '' + + # Prepare to read. + data_list = [] + if size is None: + size = -1 + to_read = min(size, self._remaining()) + if to_read < 0: + to_read = self._remaining() + break_after = False + + while to_read > 0: + # If we have exhausted the buffer, get the next segment. + # TODO(ccy): We should consider prefetching the next block in another + # thread. + self._fetch_next_if_buffer_exhausted() + + # Determine number of bytes to read from buffer. + buffer_bytes_read = self.position - self.buffer_start_position + bytes_to_read_from_buffer = min( + len(self.buffer) - buffer_bytes_read, to_read) + + # If readline is set, we only want to read up to and including the next + # newline character. + if readline: + next_newline_position = self.buffer.find( + '\n', buffer_bytes_read, len(self.buffer)) + if next_newline_position != -1: + bytes_to_read_from_buffer = (1 + next_newline_position - + buffer_bytes_read) + break_after = True + + # Read bytes. + data_list.append( + self.buffer[buffer_bytes_read:buffer_bytes_read + + bytes_to_read_from_buffer]) + self.position += bytes_to_read_from_buffer + to_read -= bytes_to_read_from_buffer + + if break_after: + break + + return ''.join(data_list) + + def _fetch_next_if_buffer_exhausted(self): + if not self.buffer or (self.buffer_start_position + len(self.buffer) + <= self.position): + bytes_to_request = min(self._remaining(), self.buffer_size) + self.buffer_start_position = self.position + self.buffer = self._get_segment(self.position, bytes_to_request) + + def _remaining(self): + return self.size - self.position + + def close(self): + """Close the current GCS file.""" + self.closed = True + self.download_stream = None + self.downloader = None + self.buffer = None + + def _get_segment(self, start, size): + """Get the given segment of the current GCS file.""" + if size == 0: + return '' + end = start + size - 1 + self.downloader.GetRange(start, end) + value = self.download_stream.getvalue() + # Clear the StringIO object after we've read its contents. + self.download_stream.truncate(0) + assert len(value) == size + return value + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + self.close() + + def seek(self, offset, whence=os.SEEK_SET): + """Set the file's current offset. + + Note if the new offset is out of bound, it is adjusted to either 0 or EOF. + + Args: + offset: seek offset as number. + whence: seek mode. Supported modes are os.SEEK_SET (absolute seek), + os.SEEK_CUR (seek relative to the current position), and os.SEEK_END + (seek relative to the end, offset should be negative). + + Raises: + IOError: When this buffer is closed. + ValueError: When whence is invalid. + """ + self._check_open() + + self.buffer = '' + self.buffer_start_position = -1 + + if whence == os.SEEK_SET: + self.position = offset + elif whence == os.SEEK_CUR: + self.position += offset + elif whence == os.SEEK_END: + self.position = self.size + offset + else: + raise ValueError('Whence mode %r is invalid.' % whence) + + self.position = min(self.position, self.size) + self.position = max(self.position, 0) + + def tell(self): + """Tell the file's current offset. + + Returns: + current offset in reading this file. + + Raises: + IOError: When this buffer is closed. + """ + self._check_open() + return self.position + + def _check_open(self): + if self.closed: + raise IOError('Buffer is closed.') + + def seekable(self): + return True + + def readable(self): + return True + + def writable(self): + return False + + +class GcsBufferedWriter(object): + """A class for writing Google Cloud Storage files.""" + + class PipeStream(object): + """A class that presents a pipe connection as a readable stream.""" + + def __init__(self, recv_pipe): + self.conn = recv_pipe + self.closed = False + self.position = 0 + self.remaining = '' + + def read(self, size): + """Read data from the wrapped pipe connection. + + Args: + size: Number of bytes to read. Actual number of bytes read is always + equal to size unless EOF is reached. + + Returns: + data read as str. + """ + data_list = [] + bytes_read = 0 + while bytes_read < size: + bytes_from_remaining = min(size - bytes_read, len(self.remaining)) + data_list.append(self.remaining[0:bytes_from_remaining]) + self.remaining = self.remaining[bytes_from_remaining:] + self.position += bytes_from_remaining + bytes_read += bytes_from_remaining + if not self.remaining: + try: + self.remaining = self.conn.recv_bytes() + except EOFError: + break + return ''.join(data_list) + + def tell(self): + """Tell the file's current offset. + + Returns: + current offset in reading this file. + + Raises: + IOError: When this stream is closed. + """ + self._check_open() + return self.position + + def seek(self, offset, whence=os.SEEK_SET): + # The apitools.base.py.transfer.Upload class insists on seeking to the end + # of a stream to do a check before completing an upload, so we must have + # this no-op method here in that case. + if whence == os.SEEK_END and offset == 0: + return + elif whence == os.SEEK_SET and offset == self.position: + return + raise NotImplementedError + + def _check_open(self): + if self.closed: + raise IOError('Stream is closed.') + + def __init__(self, client, path, mime_type='application/octet-stream'): + self.client = client + self.path = path + self.bucket, self.name = parse_gcs_path(path) + + self.closed = False + self.position = 0 + + # Set up communication with uploading thread. + parent_conn, child_conn = multiprocessing.Pipe() + self.conn = parent_conn + + # Set up uploader. + self.insert_request = ( + storage.StorageObjectsInsertRequest( + bucket=self.bucket, + name=self.name)) + self.upload = transfer.Upload(GcsBufferedWriter.PipeStream(child_conn), + mime_type) + self.upload.strategy = transfer.RESUMABLE_UPLOAD + + # Start uploading thread. + self.upload_thread = threading.Thread(target=self._start_upload) + self.upload_thread.daemon = True + self.upload_thread.start() + + # TODO(silviuc): Refactor so that retry logic can be applied. + # There is retry logic in the underlying transfer library but we should make + # it more explicit so we can control the retry parameters. + @retry.no_retries # Using no_retries marks this as an integration point. + def _start_upload(self): + # This starts the uploader thread. We are forced to run the uploader in + # another thread because the apitools uploader insists on taking a stream + # as input. Happily, this also means we get asynchronous I/O to GCS. + # + # The uploader by default transfers data in chunks of 1024 * 1024 bytes at + # a time, buffering writes until that size is reached. + self.client.objects.Insert(self.insert_request, upload=self.upload) + + def write(self, data): + """Write data to a GCS file. + + Args: + data: data to write as str. + + Raises: + IOError: When this buffer is closed. + """ + self._check_open() + if not data: + return + self.conn.send_bytes(data) + self.position += len(data) + + def tell(self): + """Return the total number of bytes passed to write() so far.""" + return self.position + + def close(self): + """Close the current GCS file.""" + self.conn.close() + self.upload_thread.join() + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + self.close() + + def _check_open(self): + if self.closed: + raise IOError('Buffer is closed.') + + def seekable(self): + return False + + def readable(self): + return False + + def writable(self): + return True http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/gcsio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/gcsio_test.py b/sdks/python/apache_beam/io/gcsio_test.py new file mode 100644 index 0000000..702c834 --- /dev/null +++ b/sdks/python/apache_beam/io/gcsio_test.py @@ -0,0 +1,503 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Google Cloud Storage client.""" + +import logging +import multiprocessing +import os +import random +import threading +import unittest + + +import httplib2 + +from google.cloud.dataflow.io import gcsio +from apitools.base.py.exceptions import HttpError +from google.cloud.dataflow.internal.clients import storage + + +class FakeGcsClient(object): + # Fake storage client. Usage in gcsio.py is client.objects.Get(...) and + # client.objects.Insert(...). + + def __init__(self): + self.objects = FakeGcsObjects() + + +class FakeFile(object): + + def __init__(self, bucket, obj, contents, generation): + self.bucket = bucket + self.object = obj + self.contents = contents + self.generation = generation + + def get_metadata(self): + return storage.Object(bucket=self.bucket, + name=self.object, + generation=self.generation, + size=len(self.contents)) + + +class FakeGcsObjects(object): + + def __init__(self): + self.files = {} + # Store the last generation used for a given object name. Note that this + # has to persist even past the deletion of the object. + self.last_generation = {} + self.list_page_tokens = {} + + def add_file(self, f): + self.files[(f.bucket, f.object)] = f + self.last_generation[(f.bucket, f.object)] = f.generation + + def get_file(self, bucket, obj): + return self.files.get((bucket, obj), None) + + def delete_file(self, bucket, obj): + del self.files[(bucket, obj)] + + def get_last_generation(self, bucket, obj): + return self.last_generation.get((bucket, obj), 0) + + def Get(self, get_request, download=None): # pylint: disable=invalid-name + f = self.get_file(get_request.bucket, get_request.object) + if f is None: + raise ValueError('Specified object does not exist.') + if download is None: + return f.get_metadata() + else: + stream = download.stream + + def get_range_callback(start, end): + assert start >= 0 and end >= start and end < len(f.contents) + stream.write(f.contents[start:end + 1]) + download.GetRange = get_range_callback + + def Insert(self, insert_request, upload=None): # pylint: disable=invalid-name + assert upload is not None + generation = self.get_last_generation(insert_request.bucket, + insert_request.name) + 1 + f = FakeFile(insert_request.bucket, insert_request.name, '', generation) + + # Stream data into file. + stream = upload.stream + data_list = [] + while True: + data = stream.read(1024 * 1024) + if not data: + break + data_list.append(data) + f.contents = ''.join(data_list) + + self.add_file(f) + + def Copy(self, copy_request): # pylint: disable=invalid-name + src_file = self.get_file(copy_request.sourceBucket, + copy_request.sourceObject) + if not src_file: + raise HttpError(httplib2.Response({'status': '404'}), '404 Not Found', + 'https://fake/url') + generation = self.get_last_generation(copy_request.destinationBucket, + copy_request.destinationObject) + 1 + dest_file = FakeFile(copy_request.destinationBucket, + copy_request.destinationObject, + src_file.contents, generation) + self.add_file(dest_file) + + def Delete(self, delete_request): # pylint: disable=invalid-name + # Here, we emulate the behavior of the GCS service in raising a 404 error + # if this object already exists. + if self.get_file(delete_request.bucket, delete_request.object): + self.delete_file(delete_request.bucket, delete_request.object) + else: + raise HttpError(httplib2.Response({'status': '404'}), '404 Not Found', + 'https://fake/url') + + def List(self, list_request): # pylint: disable=invalid-name + bucket = list_request.bucket + prefix = list_request.prefix or '' + matching_files = [] + for file_bucket, file_name in sorted(iter(self.files)): + if bucket == file_bucket and file_name.startswith(prefix): + file_object = self.files[(file_bucket, file_name)].get_metadata() + matching_files.append(file_object) + + # Handle pagination. + items_per_page = 5 + if not list_request.pageToken: + range_start = 0 + else: + if list_request.pageToken not in self.list_page_tokens: + raise ValueError('Invalid page token.') + range_start = self.list_page_tokens[list_request.pageToken] + del self.list_page_tokens[list_request.pageToken] + + result = storage.Objects( + items=matching_files[range_start:range_start + items_per_page]) + if range_start + items_per_page < len(matching_files): + next_range_start = range_start + items_per_page + next_page_token = '_page_token_%s_%s_%d' % (bucket, prefix, + next_range_start) + self.list_page_tokens[next_page_token] = next_range_start + result.nextPageToken = next_page_token + return result + + +class TestGCSPathParser(unittest.TestCase): + + def test_gcs_path(self): + self.assertEqual( + gcsio.parse_gcs_path('gs://bucket/name'), ('bucket', 'name')) + self.assertEqual( + gcsio.parse_gcs_path('gs://bucket/name/sub'), ('bucket', 'name/sub')) + + def test_bad_gcs_path(self): + self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs://') + self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs://bucket') + self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs://bucket/') + self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs:///name') + self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs:///') + self.assertRaises(ValueError, gcsio.parse_gcs_path, 'gs:/blah/bucket/name') + + +class TestGCSIO(unittest.TestCase): + + def _insert_random_file(self, client, path, size, generation=1): + bucket, name = gcsio.parse_gcs_path(path) + f = FakeFile(bucket, name, os.urandom(size), generation) + client.objects.add_file(f) + return f + + def setUp(self): + self.client = FakeGcsClient() + self.gcs = gcsio.GcsIO(self.client) + + def test_delete(self): + file_name = 'gs://gcsio-test/delete_me' + file_size = 1024 + + # Test deletion of non-existent file. + self.gcs.delete(file_name) + + self._insert_random_file(self.client, file_name, file_size) + self.assertTrue(gcsio.parse_gcs_path(file_name) in + self.client.objects.files) + + self.gcs.delete(file_name) + + self.assertFalse(gcsio.parse_gcs_path(file_name) in + self.client.objects.files) + + def test_copy(self): + src_file_name = 'gs://gcsio-test/source' + dest_file_name = 'gs://gcsio-test/dest' + file_size = 1024 + self._insert_random_file(self.client, src_file_name, + file_size) + self.assertTrue(gcsio.parse_gcs_path(src_file_name) in + self.client.objects.files) + self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in + self.client.objects.files) + + self.gcs.copy(src_file_name, dest_file_name) + + self.assertTrue(gcsio.parse_gcs_path(src_file_name) in + self.client.objects.files) + self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in + self.client.objects.files) + + self.assertRaises(IOError, self.gcs.copy, + 'gs://gcsio-test/non-existent', + 'gs://gcsio-test/non-existent-destination') + + def test_copytree(self): + src_dir_name = 'gs://gcsio-test/source/' + dest_dir_name = 'gs://gcsio-test/dest/' + file_size = 1024 + paths = ['a', 'b/c', 'b/d'] + for path in paths: + src_file_name = src_dir_name + path + dest_file_name = dest_dir_name + path + self._insert_random_file(self.client, src_file_name, + file_size) + self.assertTrue(gcsio.parse_gcs_path(src_file_name) in + self.client.objects.files) + self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in + self.client.objects.files) + + self.gcs.copytree(src_dir_name, dest_dir_name) + + for path in paths: + src_file_name = src_dir_name + path + dest_file_name = dest_dir_name + path + self.assertTrue(gcsio.parse_gcs_path(src_file_name) in + self.client.objects.files) + self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in + self.client.objects.files) + + def test_rename(self): + src_file_name = 'gs://gcsio-test/source' + dest_file_name = 'gs://gcsio-test/dest' + file_size = 1024 + self._insert_random_file(self.client, src_file_name, + file_size) + self.assertTrue(gcsio.parse_gcs_path(src_file_name) in + self.client.objects.files) + self.assertFalse(gcsio.parse_gcs_path(dest_file_name) in + self.client.objects.files) + + self.gcs.rename(src_file_name, dest_file_name) + + self.assertFalse(gcsio.parse_gcs_path(src_file_name) in + self.client.objects.files) + self.assertTrue(gcsio.parse_gcs_path(dest_file_name) in + self.client.objects.files) + + def test_full_file_read(self): + file_name = 'gs://gcsio-test/full_file' + file_size = 5 * 1024 * 1024 + 100 + random_file = self._insert_random_file(self.client, file_name, file_size) + f = self.gcs.open(file_name) + f.seek(0, os.SEEK_END) + self.assertEqual(f.tell(), file_size) + self.assertEqual(f.read(), '') + f.seek(0) + self.assertEqual(f.read(), random_file.contents) + + def test_file_random_seek(self): + file_name = 'gs://gcsio-test/seek_file' + file_size = 5 * 1024 * 1024 - 100 + random_file = self._insert_random_file(self.client, file_name, file_size) + + f = self.gcs.open(file_name) + random.seed(0) + for _ in range(0, 10): + a = random.randint(0, file_size - 1) + b = random.randint(0, file_size - 1) + start, end = min(a, b), max(a, b) + f.seek(start) + self.assertEqual(f.tell(), start) + self.assertEqual(f.read(end - start + 1), + random_file.contents[start:end + 1]) + self.assertEqual(f.tell(), end + 1) + + def test_file_read_line(self): + file_name = 'gs://gcsio-test/read_line_file' + lines = [] + + # Set a small buffer size to exercise refilling the buffer. + # First line is carefully crafted so the newline falls as the last character + # of the buffer to exercise this code path. + read_buffer_size = 1024 + lines.append('x' * 1023 + '\n') + + for _ in range(1, 1000): + line_length = random.randint(100, 500) + line = os.urandom(line_length).replace('\n', ' ') + '\n' + lines.append(line) + contents = ''.join(lines) + + file_size = len(contents) + bucket, name = gcsio.parse_gcs_path(file_name) + self.client.objects.add_file(FakeFile(bucket, name, contents, 1)) + + f = self.gcs.open(file_name, read_buffer_size=read_buffer_size) + + # Test read of first two lines. + f.seek(0) + self.assertEqual(f.readline(), lines[0]) + self.assertEqual(f.tell(), len(lines[0])) + self.assertEqual(f.readline(), lines[1]) + + # Test read at line boundary. + f.seek(file_size - len(lines[-1]) - 1) + self.assertEqual(f.readline(), '\n') + + # Test read at end of file. + f.seek(file_size) + self.assertEqual(f.readline(), '') + + # Test reads at random positions. + random.seed(0) + for _ in range(0, 10): + start = random.randint(0, file_size - 1) + line_index = 0 + # Find line corresponding to start index. + chars_left = start + while True: + next_line_length = len(lines[line_index]) + if chars_left - next_line_length < 0: + break + chars_left -= next_line_length + line_index += 1 + f.seek(start) + self.assertEqual(f.readline(), lines[line_index][chars_left:]) + + def test_file_write(self): + file_name = 'gs://gcsio-test/write_file' + file_size = 5 * 1024 * 1024 + 2000 + contents = os.urandom(file_size) + f = self.gcs.open(file_name, 'w') + f.write(contents[0:1000]) + f.write(contents[1000:1024 * 1024]) + f.write(contents[1024 * 1024:]) + f.close() + bucket, name = gcsio.parse_gcs_path(file_name) + self.assertEqual( + self.client.objects.get_file(bucket, name).contents, contents) + + def test_context_manager(self): + # Test writing with a context manager. + file_name = 'gs://gcsio-test/context_manager_file' + file_size = 1024 + contents = os.urandom(file_size) + with self.gcs.open(file_name, 'w') as f: + f.write(contents) + bucket, name = gcsio.parse_gcs_path(file_name) + self.assertEqual( + self.client.objects.get_file(bucket, name).contents, contents) + + # Test reading with a context manager. + with self.gcs.open(file_name) as f: + self.assertEqual(f.read(), contents) + + # Test that exceptions are not swallowed by the context manager. + with self.assertRaises(ZeroDivisionError): + with self.gcs.open(file_name) as f: + f.read(0 / 0) + + def test_glob(self): + bucket_name = 'gcsio-test' + object_names = [ + 'cow/cat/fish', + 'cow/cat/blubber', + 'cow/dog/blubber', + 'apple/dog/blubber', + 'apple/fish/blubber', + 'apple/fish/blowfish', + 'apple/fish/bambi', + 'apple/fish/balloon', + 'apple/fish/cat', + 'apple/fish/cart', + 'apple/fish/carl', + 'apple/dish/bat', + 'apple/dish/cat', + 'apple/dish/carl', + ] + for object_name in object_names: + file_name = 'gs://%s/%s' % (bucket_name, object_name) + self._insert_random_file(self.client, file_name, 0) + test_cases = [ + ('gs://gcsio-test/*', [ + 'cow/cat/fish', + 'cow/cat/blubber', + 'cow/dog/blubber', + 'apple/dog/blubber', + 'apple/fish/blubber', + 'apple/fish/blowfish', + 'apple/fish/bambi', + 'apple/fish/balloon', + 'apple/fish/cat', + 'apple/fish/cart', + 'apple/fish/carl', + 'apple/dish/bat', + 'apple/dish/cat', + 'apple/dish/carl', + ]), + ('gs://gcsio-test/cow/*', [ + 'cow/cat/fish', + 'cow/cat/blubber', + 'cow/dog/blubber', + ]), + ('gs://gcsio-test/cow/ca*', [ + 'cow/cat/fish', + 'cow/cat/blubber', + ]), + ('gs://gcsio-test/apple/[df]ish/ca*', [ + 'apple/fish/cat', + 'apple/fish/cart', + 'apple/fish/carl', + 'apple/dish/cat', + 'apple/dish/carl', + ]), + ('gs://gcsio-test/apple/fish/car?', [ + 'apple/fish/cart', + 'apple/fish/carl', + ]), + ('gs://gcsio-test/apple/fish/b*', [ + 'apple/fish/blubber', + 'apple/fish/blowfish', + 'apple/fish/bambi', + 'apple/fish/balloon', + ]), + ('gs://gcsio-test/apple/dish/[cb]at', [ + 'apple/dish/bat', + 'apple/dish/cat', + ]), + ] + for file_pattern, expected_object_names in test_cases: + expected_file_names = ['gs://%s/%s' % (bucket_name, o) for o in + expected_object_names] + self.assertEqual(set(self.gcs.glob(file_pattern)), + set(expected_file_names)) + + +class TestPipeStream(unittest.TestCase): + + def _read_and_verify(self, stream, expected, buffer_size): + data_list = [] + bytes_read = 0 + seen_last_block = False + while True: + data = stream.read(buffer_size) + self.assertLessEqual(len(data), buffer_size) + if len(data) < buffer_size: + # Test the constraint that the pipe stream returns less than the buffer + # size only when at the end of the stream. + if data: + self.assertFalse(seen_last_block) + seen_last_block = True + if not data: + break + data_list.append(data) + bytes_read += len(data) + self.assertEqual(stream.tell(), bytes_read) + self.assertEqual(''.join(data_list), expected) + + def test_pipe_stream(self): + block_sizes = list(4 ** i for i in range(0, 12)) + data_blocks = list(os.urandom(size) for size in block_sizes) + expected = ''.join(data_blocks) + + buffer_sizes = [100001, 512 * 1024, 1024 * 1024] + + for buffer_size in buffer_sizes: + parent_conn, child_conn = multiprocessing.Pipe() + stream = gcsio.GcsBufferedWriter.PipeStream(child_conn) + child_thread = threading.Thread(target=self._read_and_verify, + args=(stream, expected, buffer_size)) + child_thread.start() + for data in data_blocks: + parent_conn.send_bytes(data) + parent_conn.close() + child_thread.join() + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()