beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From rober...@apache.org
Subject [1/2] incubator-beam git commit: [BEAM-852] Add validation to file based sources during create time
Date Mon, 14 Nov 2016 23:41:20 GMT
Repository: incubator-beam
Updated Branches:
  refs/heads/python-sdk 15e78b28a -> 560fe79f8


[BEAM-852] Add validation to file based sources during create time


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/76ad2929
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/76ad2929
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/76ad2929

Branch: refs/heads/python-sdk
Commit: 76ad29296fd57e1eec97bf40d9cf3a1d54a63a3f
Parents: 15e78b2
Author: Sourabh Bajaj <sourabhbajaj@google.com>
Authored: Mon Nov 14 15:40:10 2016 -0800
Committer: Robert Bradshaw <robertwb@google.com>
Committed: Mon Nov 14 15:40:10 2016 -0800

----------------------------------------------------------------------
 sdks/python/apache_beam/io/avroio.py            |  8 +++-
 sdks/python/apache_beam/io/bigquery.py          |  2 +-
 sdks/python/apache_beam/io/filebasedsource.py   | 16 +++++++-
 .../apache_beam/io/filebasedsource_test.py      | 41 ++++++++++++++------
 sdks/python/apache_beam/io/textio.py            | 13 +++++--
 5 files changed, 60 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76ad2929/sdks/python/apache_beam/io/avroio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py
index 53ed95a..e7e73dd 100644
--- a/sdks/python/apache_beam/io/avroio.py
+++ b/sdks/python/apache_beam/io/avroio.py
@@ -37,7 +37,7 @@ __all__ = ['ReadFromAvro', 'WriteToAvro']
 class ReadFromAvro(PTransform):
   """A ``PTransform`` for reading avro files."""
 
-  def __init__(self, file_pattern=None, min_bundle_size=0):
+  def __init__(self, file_pattern=None, min_bundle_size=0, validate=True):
     """Initializes ``ReadFromAvro``.
 
     Uses source '_AvroSource' to read a set of Avro files defined by a given
@@ -70,13 +70,17 @@ class ReadFromAvro(PTransform):
       file_pattern: the set of files to be read.
       min_bundle_size: the minimum size in bytes, to be considered when
                        splitting the input into bundles.
+      validate: flag to verify that the files exist during the pipeline
+                creation time.
       **kwargs: Additional keyword arguments to be passed to the base class.
     """
     super(ReadFromAvro, self).__init__()
     self._args = (file_pattern, min_bundle_size)
+    self._validate = validate
 
   def apply(self, pvalue):
-    return pvalue.pipeline | Read(_AvroSource(*self._args))
+    return pvalue.pipeline | Read(_AvroSource(*self._args,
+                                              validate=self._validate))
 
 
 class _AvroUtils(object):

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76ad2929/sdks/python/apache_beam/io/bigquery.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/bigquery.py b/sdks/python/apache_beam/io/bigquery.py
index f0e88a6..8d7892a 100644
--- a/sdks/python/apache_beam/io/bigquery.py
+++ b/sdks/python/apache_beam/io/bigquery.py
@@ -65,7 +65,7 @@ input entails querying the table for all its rows. The coder argument on
 BigQuerySource controls the reading of the lines in the export files (i.e.,
 transform a JSON object into a PCollection element). The coder is not involved
 when the same table is read as a side input since there is no intermediate
-format involved.  We get the table rows directly from the BigQuery service with
+format involved. We get the table rows directly from the BigQuery service with
 a query.
 
 Users may provide a query to read from rather than reading all of a BigQuery

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76ad2929/sdks/python/apache_beam/io/filebasedsource.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py
index 58ad118..c7bc27e 100644
--- a/sdks/python/apache_beam/io/filebasedsource.py
+++ b/sdks/python/apache_beam/io/filebasedsource.py
@@ -50,7 +50,8 @@ class FileBasedSource(iobase.BoundedSource):
                file_pattern,
                min_bundle_size=0,
                compression_type=fileio.CompressionTypes.AUTO,
-               splittable=True):
+               splittable=True,
+               validate=True):
     """Initializes ``FileBasedSource``.
 
     Args:
@@ -68,10 +69,13 @@ class FileBasedSource(iobase.BoundedSource):
                   the file, for example, for compressed files where currently
                   it is not possible to efficiently read a data range without
                   decompressing the whole file.
+      validate: Boolean flag to verify that the files exist during the pipeline
+                creation time.
     Raises:
       TypeError: when compression_type is not valid or if file_pattern is not a
                  string.
       ValueError: when compression and splittable files are specified.
+      IOError: when the file pattern specified yields an empty result.
     """
     if not isinstance(file_pattern, basestring):
       raise TypeError(
@@ -91,6 +95,8 @@ class FileBasedSource(iobase.BoundedSource):
     else:
       # We can't split compressed files efficiently so turn off splitting.
       self._splittable = False
+    if validate:
+      self._validate()
 
   def display_data(self):
     return {'filePattern': DisplayDataItem(self._pattern, label="File Pattern"),
@@ -133,7 +139,6 @@ class FileBasedSource(iobase.BoundedSource):
 
   @staticmethod
   def _estimate_sizes_in_parallel(file_names):
-
     if not file_names:
       return []
     elif len(file_names) == 1:
@@ -150,6 +155,13 @@ class FileBasedSource(iobase.BoundedSource):
       finally:
         pool.terminate()
 
+  def _validate(self):
+    """Validate if there are actual files in the specified glob pattern
+    """
+    if len(fileio.ChannelFactory.glob(self._pattern)) <= 0:
+      raise IOError(
+          'No files found based on the file pattern %s' % self._pattern)
+
   def split(
       self, desired_bundle_size=None, start_position=None, stop_position=None):
     return self._get_concat_source().split(

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76ad2929/sdks/python/apache_beam/io/filebasedsource_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/filebasedsource_test.py b/sdks/python/apache_beam/io/filebasedsource_test.py
index 7bc31fd..7f4d8d3 100644
--- a/sdks/python/apache_beam/io/filebasedsource_test.py
+++ b/sdks/python/apache_beam/io/filebasedsource_test.py
@@ -220,6 +220,26 @@ class TestFileBasedSource(unittest.TestCase):
     # environments with limited amount of resources.
     filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
 
+  def test_validation_file_exists(self):
+    file_name, _ = write_data(10)
+    LineSource(file_name)
+
+  def test_validation_directory_non_empty(self):
+    temp_dir = tempfile.mkdtemp()
+    file_name, _ = write_data(10, directory=temp_dir)
+    LineSource(file_name)
+
+  def test_validation_failing(self):
+    no_files_found_error = 'No files found based on the file pattern*'
+    with self.assertRaisesRegexp(IOError, no_files_found_error):
+      LineSource('dummy_pattern')
+    with self.assertRaisesRegexp(IOError, no_files_found_error):
+      temp_dir = tempfile.mkdtemp()
+      LineSource(os.path.join(temp_dir, '*'))
+
+  def test_validation_file_missing_verification_disabled(self):
+    LineSource('dummy_pattern', validate=False)
+
   def test_fully_read_single_file(self):
     file_name, expected_data = write_data(10)
     assert len(expected_data) == 10
@@ -525,7 +545,7 @@ class TestSingleFileSource(unittest.TestCase):
     start_not_a_number_error = 'start_offset must be a number*'
     stop_not_a_number_error = 'stop_offset must be a number*'
     file_name = 'dummy_pattern'
-    fbs = LineSource(file_name)
+    fbs = LineSource(file_name, validate=False)
 
     with self.assertRaisesRegexp(TypeError, start_not_a_number_error):
       SingleFileSource(
@@ -545,7 +565,7 @@ class TestSingleFileSource(unittest.TestCase):
 
   def test_source_creation_display_data(self):
     file_name = 'dummy_pattern'
-    fbs = LineSource(file_name)
+    fbs = LineSource(file_name, validate=False)
     dd = DisplayData.create_from(fbs)
     expected_items = [
         DisplayDataItemMatcher('compression', 'auto'),
@@ -556,8 +576,7 @@ class TestSingleFileSource(unittest.TestCase):
   def test_source_creation_fails_if_start_lg_stop(self):
     start_larger_than_stop_error = (
         'start_offset must be smaller than stop_offset*')
-
-    fbs = LineSource('dummy_pattern')
+    fbs = LineSource('dummy_pattern', validate=False)
     SingleFileSource(
         fbs, file_name='dummy_file', start_offset=99, stop_offset=100)
     with self.assertRaisesRegexp(ValueError, start_larger_than_stop_error):
@@ -568,7 +587,7 @@ class TestSingleFileSource(unittest.TestCase):
           fbs, file_name='dummy_file', start_offset=100, stop_offset=100)
 
   def test_estimates_size(self):
-    fbs = LineSource('dummy_pattern')
+    fbs = LineSource('dummy_pattern', validate=False)
 
     # Should simply return stop_offset - start_offset
     source = SingleFileSource(
@@ -580,7 +599,7 @@ class TestSingleFileSource(unittest.TestCase):
     self.assertEquals(90, source.estimate_size())
 
   def test_read_range_at_beginning(self):
-    fbs = LineSource('dummy_pattern')
+    fbs = LineSource('dummy_pattern', validate=False)
 
     file_name, expected_data = write_data(10)
     assert len(expected_data) == 10
@@ -591,7 +610,7 @@ class TestSingleFileSource(unittest.TestCase):
     self.assertItemsEqual(expected_data[:4], read_data)
 
   def test_read_range_at_end(self):
-    fbs = LineSource('dummy_pattern')
+    fbs = LineSource('dummy_pattern', validate=False)
 
     file_name, expected_data = write_data(10)
     assert len(expected_data) == 10
@@ -602,7 +621,7 @@ class TestSingleFileSource(unittest.TestCase):
     self.assertItemsEqual(expected_data[-3:], read_data)
 
   def test_read_range_at_middle(self):
-    fbs = LineSource('dummy_pattern')
+    fbs = LineSource('dummy_pattern', validate=False)
 
     file_name, expected_data = write_data(10)
     assert len(expected_data) == 10
@@ -613,7 +632,7 @@ class TestSingleFileSource(unittest.TestCase):
     self.assertItemsEqual(expected_data[4:7], read_data)
 
   def test_produces_splits_desiredsize_large_than_size(self):
-    fbs = LineSource('dummy_pattern')
+    fbs = LineSource('dummy_pattern', validate=False)
 
     file_name, expected_data = write_data(10)
     assert len(expected_data) == 10
@@ -629,7 +648,7 @@ class TestSingleFileSource(unittest.TestCase):
     self.assertItemsEqual(expected_data, read_data)
 
   def test_produces_splits_desiredsize_smaller_than_size(self):
-    fbs = LineSource('dummy_pattern')
+    fbs = LineSource('dummy_pattern', validate=False)
 
     file_name, expected_data = write_data(10)
     assert len(expected_data) == 10
@@ -647,7 +666,7 @@ class TestSingleFileSource(unittest.TestCase):
     self.assertItemsEqual(expected_data, read_data)
 
   def test_produce_split_with_start_and_end_positions(self):
-    fbs = LineSource('dummy_pattern')
+    fbs = LineSource('dummy_pattern', validate=False)
 
     file_name, expected_data = write_data(10)
     assert len(expected_data) == 10

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/76ad2929/sdks/python/apache_beam/io/textio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py
index 01f6ef6..e031572 100644
--- a/sdks/python/apache_beam/io/textio.py
+++ b/sdks/python/apache_beam/io/textio.py
@@ -72,9 +72,10 @@ class _TextSource(filebasedsource.FileBasedSource):
 
   def __init__(self, file_pattern, min_bundle_size,
                compression_type, strip_trailing_newlines, coder,
-               buffer_size=DEFAULT_READ_BUFFER_SIZE):
+               buffer_size=DEFAULT_READ_BUFFER_SIZE, validate=True):
     super(_TextSource, self).__init__(file_pattern, min_bundle_size,
-                                      compression_type=compression_type)
+                                      compression_type=compression_type,
+                                      validate=validate)
 
     self._strip_trailing_newlines = strip_trailing_newlines
     self._compression_type = compression_type
@@ -206,7 +207,6 @@ class ReadFromText(PTransform):
 
   This implementation only supports reading text encoded using UTF-8 or ASCII.
   This does not support other encodings such as UTF-16 or UTF-32."""
-
   def __init__(
       self,
       file_pattern=None,
@@ -214,6 +214,7 @@ class ReadFromText(PTransform):
       compression_type=fileio.CompressionTypes.AUTO,
       strip_trailing_newlines=True,
       coder=coders.StrUtf8Coder(),
+      validate=True,
       **kwargs):
     """Initialize the ReadFromText transform.
 
@@ -230,15 +231,19 @@ class ReadFromText(PTransform):
       strip_trailing_newlines: Indicates whether this source should remove
                                the newline char in each line it reads before
                                decoding that line.
+      validate: flag to verify that the files exist during the pipeline
+                creation time.
       coder: Coder used to decode each line.
     """
 
     super(ReadFromText, self).__init__(**kwargs)
     self._args = (file_pattern, min_bundle_size, compression_type,
                   strip_trailing_newlines, coder)
+    self._validate = validate
 
   def apply(self, pvalue):
-    return pvalue.pipeline | Read(_TextSource(*self._args))
+    return pvalue.pipeline | Read(_TextSource(*self._args,
+                                              validate=self._validate))
 
 
 class WriteToText(PTransform):


Mime
View raw message