beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From al...@apache.org
Subject [20/22] beam git commit: Rename google_cloud_dataflow and google_cloud_platform
Date Thu, 23 Feb 2017 01:23:15 GMT
http://git-wip-us.apache.org/repos/asf/beam/blob/59ad58ac/sdks/python/apache_beam/io/gcp/bigquery_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py
new file mode 100644
index 0000000..fbf073c
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py
@@ -0,0 +1,828 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for BigQuery sources and sinks."""
+
+import datetime
+import json
+import logging
+import time
+import unittest
+
+import hamcrest as hc
+import mock
+
+import apache_beam as beam
+from apache_beam.io.gcp.bigquery import RowAsDictJsonCoder
+from apache_beam.io.gcp.bigquery import TableRowJsonCoder
+from apache_beam.io.gcp.bigquery import parse_table_schema_from_json
+from apache_beam.io.gcp.internal.clients import bigquery
+from apache_beam.internal.gcp.json_value import to_json_value
+from apache_beam.transforms.display import DisplayData
+from apache_beam.transforms.display_test import DisplayDataItemMatcher
+from apache_beam.utils.pipeline_options import PipelineOptions
+
+# Protect against environments where bigquery library is not available.
+# pylint: disable=wrong-import-order, wrong-import-position
+try:
+  from apitools.base.py.exceptions import HttpError
+except ImportError:
+  HttpError = None
+# pylint: enable=wrong-import-order, wrong-import-position
+
+
+@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
+class TestRowAsDictJsonCoder(unittest.TestCase):
+
+  def test_row_as_dict(self):
+    coder = RowAsDictJsonCoder()
+    test_value = {'s': 'abc', 'i': 123, 'f': 123.456, 'b': True}
+    self.assertEqual(test_value, coder.decode(coder.encode(test_value)))
+
+  def json_compliance_exception(self, value):
+    with self.assertRaises(ValueError) as exn:
+      coder = RowAsDictJsonCoder()
+      test_value = {'s': value}
+      self.assertEqual(test_value, coder.decode(coder.encode(test_value)))
+      self.assertTrue(bigquery.JSON_COMPLIANCE_ERROR in exn.exception.message)
+
+  def test_invalid_json_nan(self):
+    self.json_compliance_exception(float('nan'))
+
+  def test_invalid_json_inf(self):
+    self.json_compliance_exception(float('inf'))
+
+  def test_invalid_json_neg_inf(self):
+    self.json_compliance_exception(float('-inf'))
+
+
+@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
+class TestTableRowJsonCoder(unittest.TestCase):
+
+  def test_row_as_table_row(self):
+    schema_definition = [
+        ('s', 'STRING'),
+        ('i', 'INTEGER'),
+        ('f', 'FLOAT'),
+        ('b', 'BOOLEAN'),
+        ('r', 'RECORD')]
+    data_defination = [
+        'abc',
+        123,
+        123.456,
+        True,
+        {'a': 'b'}]
+    str_def = '{"s": "abc", "i": 123, "f": 123.456, "b": true, "r": {"a": "b"}}'
+    schema = bigquery.TableSchema(
+        fields=[bigquery.TableFieldSchema(name=k, type=v)
+                for k, v in schema_definition])
+    coder = TableRowJsonCoder(table_schema=schema)
+    test_row = bigquery.TableRow(
+        f=[bigquery.TableCell(v=to_json_value(e)) for e in data_defination])
+
+    self.assertEqual(str_def, coder.encode(test_row))
+    self.assertEqual(test_row, coder.decode(coder.encode(test_row)))
+    # A coder without schema can still decode.
+    self.assertEqual(
+        test_row, TableRowJsonCoder().decode(coder.encode(test_row)))
+
+  def test_row_and_no_schema(self):
+    coder = TableRowJsonCoder()
+    test_row = bigquery.TableRow(
+        f=[bigquery.TableCell(v=to_json_value(e))
+           for e in ['abc', 123, 123.456, True]])
+    with self.assertRaises(AttributeError) as ctx:
+      coder.encode(test_row)
+    self.assertTrue(
+        ctx.exception.message.startswith('The TableRowJsonCoder requires'))
+
+  def json_compliance_exception(self, value):
+    with self.assertRaises(ValueError) as exn:
+      schema_definition = [('f', 'FLOAT')]
+      schema = bigquery.TableSchema(
+          fields=[bigquery.TableFieldSchema(name=k, type=v)
+                  for k, v in schema_definition])
+      coder = TableRowJsonCoder(table_schema=schema)
+      test_row = bigquery.TableRow(
+          f=[bigquery.TableCell(v=to_json_value(value))])
+      coder.encode(test_row)
+      self.assertTrue(bigquery.JSON_COMPLIANCE_ERROR in exn.exception.message)
+
+  def test_invalid_json_nan(self):
+    self.json_compliance_exception(float('nan'))
+
+  def test_invalid_json_inf(self):
+    self.json_compliance_exception(float('inf'))
+
+  def test_invalid_json_neg_inf(self):
+    self.json_compliance_exception(float('-inf'))
+
+
+@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
+class TestTableSchemaParser(unittest.TestCase):
+  def test_parse_table_schema_from_json(self):
+    string_field = bigquery.TableFieldSchema(
+        name='s', type='STRING', mode='NULLABLE', description='s description')
+    number_field = bigquery.TableFieldSchema(
+        name='n', type='INTEGER', mode='REQUIRED', description='n description')
+    record_field = bigquery.TableFieldSchema(
+        name='r', type='RECORD', mode='REQUIRED', description='r description',
+        fields=[string_field, number_field])
+    expected_schema = bigquery.TableSchema(fields=[record_field])
+    json_str = json.dumps({'fields': [
+        {'name': 'r', 'type': 'RECORD', 'mode': 'REQUIRED',
+         'description': 'r description', 'fields': [
+             {'name': 's', 'type': 'STRING', 'mode': 'NULLABLE',
+              'description': 's description'},
+             {'name': 'n', 'type': 'INTEGER', 'mode': 'REQUIRED',
+              'description': 'n description'}]}]})
+    self.assertEqual(parse_table_schema_from_json(json_str),
+                     expected_schema)
+
+
+@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
+class TestBigQuerySource(unittest.TestCase):
+
+  def test_display_data_item_on_validate_true(self):
+    source = beam.io.BigQuerySource('dataset.table', validate=True)
+
+    dd = DisplayData.create_from(source)
+    expected_items = [
+        DisplayDataItemMatcher('validation', True),
+        DisplayDataItemMatcher('table', 'dataset.table')]
+    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+
+  def test_table_reference_display_data(self):
+    source = beam.io.BigQuerySource('dataset.table')
+    dd = DisplayData.create_from(source)
+    expected_items = [
+        DisplayDataItemMatcher('validation', False),
+        DisplayDataItemMatcher('table', 'dataset.table')]
+    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+
+    source = beam.io.BigQuerySource('project:dataset.table')
+    dd = DisplayData.create_from(source)
+    expected_items = [
+        DisplayDataItemMatcher('validation', False),
+        DisplayDataItemMatcher('table', 'project:dataset.table')]
+    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+
+    source = beam.io.BigQuerySource('xyz.com:project:dataset.table')
+    dd = DisplayData.create_from(source)
+    expected_items = [
+        DisplayDataItemMatcher('validation',
+                               False),
+        DisplayDataItemMatcher('table',
+                               'xyz.com:project:dataset.table')]
+    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+
+  def test_parse_table_reference(self):
+    source = beam.io.BigQuerySource('dataset.table')
+    self.assertEqual(source.table_reference.datasetId, 'dataset')
+    self.assertEqual(source.table_reference.tableId, 'table')
+
+    source = beam.io.BigQuerySource('project:dataset.table')
+    self.assertEqual(source.table_reference.projectId, 'project')
+    self.assertEqual(source.table_reference.datasetId, 'dataset')
+    self.assertEqual(source.table_reference.tableId, 'table')
+
+    source = beam.io.BigQuerySource('xyz.com:project:dataset.table')
+    self.assertEqual(source.table_reference.projectId, 'xyz.com:project')
+    self.assertEqual(source.table_reference.datasetId, 'dataset')
+    self.assertEqual(source.table_reference.tableId, 'table')
+
+    source = beam.io.BigQuerySource(query='my_query')
+    self.assertEqual(source.query, 'my_query')
+    self.assertIsNone(source.table_reference)
+    self.assertTrue(source.use_legacy_sql)
+
+  def test_query_only_display_data(self):
+    source = beam.io.BigQuerySource(query='my_query')
+    dd = DisplayData.create_from(source)
+    expected_items = [
+        DisplayDataItemMatcher('validation', False),
+        DisplayDataItemMatcher('query', 'my_query')]
+    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+
+  def test_specify_query_sql_format(self):
+    source = beam.io.BigQuerySource(query='my_query', use_standard_sql=True)
+    self.assertEqual(source.query, 'my_query')
+    self.assertFalse(source.use_legacy_sql)
+
+  def test_specify_query_flattened_records(self):
+    source = beam.io.BigQuerySource(query='my_query', flatten_results=False)
+    self.assertFalse(source.flatten_results)
+
+  def test_specify_query_unflattened_records(self):
+    source = beam.io.BigQuerySource(query='my_query', flatten_results=True)
+    self.assertTrue(source.flatten_results)
+
+  def test_specify_query_without_table(self):
+    source = beam.io.BigQuerySource(query='my_query')
+    self.assertEqual(source.query, 'my_query')
+    self.assertIsNone(source.table_reference)
+
+  def test_date_partitioned_table_name(self):
+    source = beam.io.BigQuerySource('dataset.table$20030102', validate=True)
+    dd = DisplayData.create_from(source)
+    expected_items = [
+        DisplayDataItemMatcher('validation', True),
+        DisplayDataItemMatcher('table', 'dataset.table$20030102')]
+    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+
+
+@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
+class TestBigQuerySink(unittest.TestCase):
+
+  def test_table_spec_display_data(self):
+    sink = beam.io.BigQuerySink('dataset.table')
+    dd = DisplayData.create_from(sink)
+    expected_items = [
+        DisplayDataItemMatcher('table', 'dataset.table'),
+        DisplayDataItemMatcher('validation', False)]
+    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+
+  def test_parse_schema_descriptor(self):
+    sink = beam.io.BigQuerySink(
+        'dataset.table', schema='s:STRING, n:INTEGER')
+    self.assertEqual(sink.table_reference.datasetId, 'dataset')
+    self.assertEqual(sink.table_reference.tableId, 'table')
+    result_schema = {
+        field.name: field.type for field in sink.table_schema.fields}
+    self.assertEqual({'n': 'INTEGER', 's': 'STRING'}, result_schema)
+
+  def test_project_table_display_data(self):
+    sinkq = beam.io.BigQuerySink('PROJECT:dataset.table')
+    dd = DisplayData.create_from(sinkq)
+    expected_items = [
+        DisplayDataItemMatcher('table', 'PROJECT:dataset.table'),
+        DisplayDataItemMatcher('validation', False)]
+    hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
+
+  def test_simple_schema_as_json(self):
+    sink = beam.io.BigQuerySink(
+        'PROJECT:dataset.table', schema='s:STRING, n:INTEGER')
+    self.assertEqual(
+        json.dumps({'fields': [
+            {'name': 's', 'type': 'STRING', 'mode': 'NULLABLE'},
+            {'name': 'n', 'type': 'INTEGER', 'mode': 'NULLABLE'}]}),
+        sink.schema_as_json())
+
+  def test_nested_schema_as_json(self):
+    string_field = bigquery.TableFieldSchema(
+        name='s', type='STRING', mode='NULLABLE', description='s description')
+    number_field = bigquery.TableFieldSchema(
+        name='n', type='INTEGER', mode='REQUIRED', description='n description')
+    record_field = bigquery.TableFieldSchema(
+        name='r', type='RECORD', mode='REQUIRED', description='r description',
+        fields=[string_field, number_field])
+    schema = bigquery.TableSchema(fields=[record_field])
+    sink = beam.io.BigQuerySink('dataset.table', schema=schema)
+    self.assertEqual(
+        {'fields': [
+            {'name': 'r', 'type': 'RECORD', 'mode': 'REQUIRED',
+             'description': 'r description', 'fields': [
+                 {'name': 's', 'type': 'STRING', 'mode': 'NULLABLE',
+                  'description': 's description'},
+                 {'name': 'n', 'type': 'INTEGER', 'mode': 'REQUIRED',
+                  'description': 'n description'}]}]},
+        json.loads(sink.schema_as_json()))
+
+
+@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
+class TestBigQueryReader(unittest.TestCase):
+
+  def get_test_rows(self):
+    now = time.time()
+    dt = datetime.datetime.utcfromtimestamp(float(now))
+    ts = dt.strftime('%Y-%m-%d %H:%M:%S.%f UTC')
+    expected_rows = [
+        {
+            'i': 1,
+            's': 'abc',
+            'f': 2.3,
+            'b': True,
+            't': ts,
+            'dt': '2016-10-31',
+            'ts': '22:39:12.627498',
+            'dt_ts': '2008-12-25T07:30:00',
+            'r': {'s2': 'b'},
+            'rpr': [{'s3': 'c', 'rpr2': [{'rs': ['d', 'e'], 's4': None}]}]
+        },
+        {
+            'i': 10,
+            's': 'xyz',
+            'f': -3.14,
+            'b': False,
+            'rpr': [],
+            't': None,
+            'dt': None,
+            'ts': None,
+            'dt_ts': None,
+            'r': None,
+        }]
+
+    nested_schema = [
+        bigquery.TableFieldSchema(
+            name='s2', type='STRING', mode='NULLABLE')]
+    nested_schema_2 = [
+        bigquery.TableFieldSchema(
+            name='s3', type='STRING', mode='NULLABLE'),
+        bigquery.TableFieldSchema(
+            name='rpr2', type='RECORD', mode='REPEATED', fields=[
+                bigquery.TableFieldSchema(
+                    name='rs', type='STRING', mode='REPEATED'),
+                bigquery.TableFieldSchema(
+                    name='s4', type='STRING', mode='NULLABLE')])]
+
+    schema = bigquery.TableSchema(
+        fields=[
+            bigquery.TableFieldSchema(
+                name='b', type='BOOLEAN', mode='REQUIRED'),
+            bigquery.TableFieldSchema(
+                name='f', type='FLOAT', mode='REQUIRED'),
+            bigquery.TableFieldSchema(
+                name='i', type='INTEGER', mode='REQUIRED'),
+            bigquery.TableFieldSchema(
+                name='s', type='STRING', mode='REQUIRED'),
+            bigquery.TableFieldSchema(
+                name='t', type='TIMESTAMP', mode='NULLABLE'),
+            bigquery.TableFieldSchema(
+                name='dt', type='DATE', mode='NULLABLE'),
+            bigquery.TableFieldSchema(
+                name='ts', type='TIME', mode='NULLABLE'),
+            bigquery.TableFieldSchema(
+                name='dt_ts', type='DATETIME', mode='NULLABLE'),
+            bigquery.TableFieldSchema(
+                name='r', type='RECORD', mode='NULLABLE',
+                fields=nested_schema),
+            bigquery.TableFieldSchema(
+                name='rpr', type='RECORD', mode='REPEATED',
+                fields=nested_schema_2)])
+
+    table_rows = [
+        bigquery.TableRow(f=[
+            bigquery.TableCell(v=to_json_value('true')),
+            bigquery.TableCell(v=to_json_value(str(2.3))),
+            bigquery.TableCell(v=to_json_value(str(1))),
+            bigquery.TableCell(v=to_json_value('abc')),
+            # For timestamps cannot use str() because it will truncate the
+            # number representing the timestamp.
+            bigquery.TableCell(v=to_json_value('%f' % now)),
+            bigquery.TableCell(v=to_json_value('2016-10-31')),
+            bigquery.TableCell(v=to_json_value('22:39:12.627498')),
+            bigquery.TableCell(v=to_json_value('2008-12-25T07:30:00')),
+            # For record we cannot use dict because it doesn't create nested
+            # schemas correctly so we have to use this f,v based format
+            bigquery.TableCell(v=to_json_value({'f': [{'v': 'b'}]})),
+            bigquery.TableCell(v=to_json_value([{'v':{'f':[{'v': 'c'}, {'v':[
+                {'v':{'f':[{'v':[{'v':'d'}, {'v':'e'}]}, {'v':None}]}}]}]}}]))
+            ]),
+        bigquery.TableRow(f=[
+            bigquery.TableCell(v=to_json_value('false')),
+            bigquery.TableCell(v=to_json_value(str(-3.14))),
+            bigquery.TableCell(v=to_json_value(str(10))),
+            bigquery.TableCell(v=to_json_value('xyz')),
+            bigquery.TableCell(v=None),
+            bigquery.TableCell(v=None),
+            bigquery.TableCell(v=None),
+            bigquery.TableCell(v=None),
+            bigquery.TableCell(v=None),
+            bigquery.TableCell(v=to_json_value([]))])]
+    return table_rows, schema, expected_rows
+
+  def test_read_from_table(self):
+    client = mock.Mock()
+    client.jobs.Insert.return_value = bigquery.Job(
+        jobReference=bigquery.JobReference(
+            jobId='somejob'))
+    table_rows, schema, expected_rows = self.get_test_rows()
+    client.jobs.GetQueryResults.return_value = bigquery.GetQueryResultsResponse(
+        jobComplete=True, rows=table_rows, schema=schema)
+    actual_rows = []
+    with beam.io.BigQuerySource('dataset.table').reader(client) as reader:
+      for row in reader:
+        actual_rows.append(row)
+    self.assertEqual(actual_rows, expected_rows)
+    self.assertEqual(schema, reader.schema)
+
+  def test_read_from_query(self):
+    client = mock.Mock()
+    client.jobs.Insert.return_value = bigquery.Job(
+        jobReference=bigquery.JobReference(
+            jobId='somejob'))
+    table_rows, schema, expected_rows = self.get_test_rows()
+    client.jobs.GetQueryResults.return_value = bigquery.GetQueryResultsResponse(
+        jobComplete=True, rows=table_rows, schema=schema)
+    actual_rows = []
+    with beam.io.BigQuerySource(query='query').reader(client) as reader:
+      for row in reader:
+        actual_rows.append(row)
+    self.assertEqual(actual_rows, expected_rows)
+    self.assertEqual(schema, reader.schema)
+    self.assertTrue(reader.use_legacy_sql)
+    self.assertTrue(reader.flatten_results)
+
+  def test_read_from_query_sql_format(self):
+    client = mock.Mock()
+    client.jobs.Insert.return_value = bigquery.Job(
+        jobReference=bigquery.JobReference(
+            jobId='somejob'))
+    table_rows, schema, expected_rows = self.get_test_rows()
+    client.jobs.GetQueryResults.return_value = bigquery.GetQueryResultsResponse(
+        jobComplete=True, rows=table_rows, schema=schema)
+    actual_rows = []
+    with beam.io.BigQuerySource(
+        query='query', use_standard_sql=True).reader(client) as reader:
+      for row in reader:
+        actual_rows.append(row)
+    self.assertEqual(actual_rows, expected_rows)
+    self.assertEqual(schema, reader.schema)
+    self.assertFalse(reader.use_legacy_sql)
+    self.assertTrue(reader.flatten_results)
+
+  def test_read_from_query_unflatten_records(self):
+    client = mock.Mock()
+    client.jobs.Insert.return_value = bigquery.Job(
+        jobReference=bigquery.JobReference(
+            jobId='somejob'))
+    table_rows, schema, expected_rows = self.get_test_rows()
+    client.jobs.GetQueryResults.return_value = bigquery.GetQueryResultsResponse(
+        jobComplete=True, rows=table_rows, schema=schema)
+    actual_rows = []
+    with beam.io.BigQuerySource(
+        query='query', flatten_results=False).reader(client) as reader:
+      for row in reader:
+        actual_rows.append(row)
+    self.assertEqual(actual_rows, expected_rows)
+    self.assertEqual(schema, reader.schema)
+    self.assertTrue(reader.use_legacy_sql)
+    self.assertFalse(reader.flatten_results)
+
+  def test_using_both_query_and_table_fails(self):
+    with self.assertRaises(ValueError) as exn:
+      beam.io.BigQuerySource(table='dataset.table', query='query')
+      self.assertEqual(exn.exception.message, 'Both a BigQuery table and a'
+                       ' query were specified. Please specify only one of '
+                       'these.')
+
+  def test_using_neither_query_nor_table_fails(self):
+    with self.assertRaises(ValueError) as exn:
+      beam.io.BigQuerySource()
+      self.assertEqual(exn.exception.message, 'A BigQuery table or a query'
+                       ' must be specified')
+
+  def test_read_from_table_as_tablerows(self):
+    client = mock.Mock()
+    client.jobs.Insert.return_value = bigquery.Job(
+        jobReference=bigquery.JobReference(
+            jobId='somejob'))
+    table_rows, schema, _ = self.get_test_rows()
+    client.jobs.GetQueryResults.return_value = bigquery.GetQueryResultsResponse(
+        jobComplete=True, rows=table_rows, schema=schema)
+    actual_rows = []
+    # We set the coder to TableRowJsonCoder, which is a signal that
+    # the caller wants to see the rows as TableRows.
+    with beam.io.BigQuerySource(
+        'dataset.table', coder=TableRowJsonCoder).reader(client) as reader:
+      for row in reader:
+        actual_rows.append(row)
+    self.assertEqual(actual_rows, table_rows)
+    self.assertEqual(schema, reader.schema)
+
+  @mock.patch('time.sleep', return_value=None)
+  def test_read_from_table_and_job_complete_retry(self, patched_time_sleep):
+    client = mock.Mock()
+    client.jobs.Insert.return_value = bigquery.Job(
+        jobReference=bigquery.JobReference(
+            jobId='somejob'))
+    table_rows, schema, expected_rows = self.get_test_rows()
+    # Return jobComplete=False on first call to trigger the code path where
+    # query needs to handle waiting a bit.
+    client.jobs.GetQueryResults.side_effect = [
+        bigquery.GetQueryResultsResponse(
+            jobComplete=False),
+        bigquery.GetQueryResultsResponse(
+            jobComplete=True, rows=table_rows, schema=schema)]
+    actual_rows = []
+    with beam.io.BigQuerySource('dataset.table').reader(client) as reader:
+      for row in reader:
+        actual_rows.append(row)
+    self.assertEqual(actual_rows, expected_rows)
+
+  def test_read_from_table_and_multiple_pages(self):
+    client = mock.Mock()
+    client.jobs.Insert.return_value = bigquery.Job(
+        jobReference=bigquery.JobReference(
+            jobId='somejob'))
+    table_rows, schema, expected_rows = self.get_test_rows()
+    # Return a pageToken on first call to trigger the code path where
+    # query needs to handle multiple pages of results.
+    client.jobs.GetQueryResults.side_effect = [
+        bigquery.GetQueryResultsResponse(
+            jobComplete=True, rows=table_rows, schema=schema,
+            pageToken='token'),
+        bigquery.GetQueryResultsResponse(
+            jobComplete=True, rows=table_rows, schema=schema)]
+    actual_rows = []
+    with beam.io.BigQuerySource('dataset.table').reader(client) as reader:
+      for row in reader:
+        actual_rows.append(row)
+    # We return expected rows for each of the two pages of results so we
+    # adjust our expectation below accordingly.
+    self.assertEqual(actual_rows, expected_rows * 2)
+
+  def test_table_schema_without_project(self):
+    # Reader should pick executing project by default.
+    source = beam.io.BigQuerySource(table='mydataset.mytable')
+    options = PipelineOptions(flags=['--project', 'myproject'])
+    source.pipeline_options = options
+    reader = source.reader()
+    self.assertEquals('SELECT * FROM [myproject:mydataset.mytable];',
+                      reader.query)
+
+
+@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
+class TestBigQueryWriter(unittest.TestCase):
+
+  @mock.patch('time.sleep', return_value=None)
+  def test_no_table_and_create_never(self, patched_time_sleep):
+    client = mock.Mock()
+    client.tables.Get.side_effect = HttpError(
+        response={'status': '404'}, url='', content='')
+    create_disposition = beam.io.BigQueryDisposition.CREATE_NEVER
+    with self.assertRaises(RuntimeError) as exn:
+      with beam.io.BigQuerySink(
+          'project:dataset.table',
+          create_disposition=create_disposition).writer(client):
+        pass
+    self.assertEqual(
+        exn.exception.message,
+        'Table project:dataset.table not found but create disposition is '
+        'CREATE_NEVER.')
+
+  def test_no_table_and_create_if_needed(self):
+    client = mock.Mock()
+    table = bigquery.Table(
+        tableReference=bigquery.TableReference(
+            projectId='project', datasetId='dataset', tableId='table'),
+        schema=bigquery.TableSchema())
+    client.tables.Get.side_effect = HttpError(
+        response={'status': '404'}, url='', content='')
+    client.tables.Insert.return_value = table
+    create_disposition = beam.io.BigQueryDisposition.CREATE_IF_NEEDED
+    with beam.io.BigQuerySink(
+        'project:dataset.table',
+        schema='somefield:INTEGER',
+        create_disposition=create_disposition).writer(client):
+      pass
+    self.assertTrue(client.tables.Get.called)
+    self.assertTrue(client.tables.Insert.called)
+
+  @mock.patch('time.sleep', return_value=None)
+  def test_no_table_and_create_if_needed_and_no_schema(
+      self, patched_time_sleep):
+    client = mock.Mock()
+    client.tables.Get.side_effect = HttpError(
+        response={'status': '404'}, url='', content='')
+    create_disposition = beam.io.BigQueryDisposition.CREATE_IF_NEEDED
+    with self.assertRaises(RuntimeError) as exn:
+      with beam.io.BigQuerySink(
+          'project:dataset.table',
+          create_disposition=create_disposition).writer(client):
+        pass
+    self.assertEqual(
+        exn.exception.message,
+        'Table project:dataset.table requires a schema. None can be inferred '
+        'because the table does not exist.')
+
+  @mock.patch('time.sleep', return_value=None)
+  def test_table_not_empty_and_write_disposition_empty(
+      self, patched_time_sleep):
+    client = mock.Mock()
+    client.tables.Get.return_value = bigquery.Table(
+        tableReference=bigquery.TableReference(
+            projectId='project', datasetId='dataset', tableId='table'),
+        schema=bigquery.TableSchema())
+    client.tabledata.List.return_value = bigquery.TableDataList(totalRows=1)
+    write_disposition = beam.io.BigQueryDisposition.WRITE_EMPTY
+    with self.assertRaises(RuntimeError) as exn:
+      with beam.io.BigQuerySink(
+          'project:dataset.table',
+          write_disposition=write_disposition).writer(client):
+        pass
+    self.assertEqual(
+        exn.exception.message,
+        'Table project:dataset.table is not empty but write disposition is '
+        'WRITE_EMPTY.')
+
+  def test_table_empty_and_write_disposition_empty(self):
+    client = mock.Mock()
+    table = bigquery.Table(
+        tableReference=bigquery.TableReference(
+            projectId='project', datasetId='dataset', tableId='table'),
+        schema=bigquery.TableSchema())
+    client.tables.Get.return_value = table
+    client.tabledata.List.return_value = bigquery.TableDataList(totalRows=0)
+    client.tables.Insert.return_value = table
+    write_disposition = beam.io.BigQueryDisposition.WRITE_EMPTY
+    with beam.io.BigQuerySink(
+        'project:dataset.table',
+        write_disposition=write_disposition).writer(client):
+      pass
+    self.assertTrue(client.tables.Get.called)
+    self.assertTrue(client.tabledata.List.called)
+    self.assertFalse(client.tables.Delete.called)
+    self.assertFalse(client.tables.Insert.called)
+
+  def test_table_with_write_disposition_truncate(self):
+    client = mock.Mock()
+    table = bigquery.Table(
+        tableReference=bigquery.TableReference(
+            projectId='project', datasetId='dataset', tableId='table'),
+        schema=bigquery.TableSchema())
+    client.tables.Get.return_value = table
+    client.tables.Insert.return_value = table
+    write_disposition = beam.io.BigQueryDisposition.WRITE_TRUNCATE
+    with beam.io.BigQuerySink(
+        'project:dataset.table',
+        write_disposition=write_disposition).writer(client):
+      pass
+    self.assertTrue(client.tables.Get.called)
+    self.assertTrue(client.tables.Delete.called)
+    self.assertTrue(client.tables.Insert.called)
+
+  def test_table_with_write_disposition_append(self):
+    client = mock.Mock()
+    table = bigquery.Table(
+        tableReference=bigquery.TableReference(
+            projectId='project', datasetId='dataset', tableId='table'),
+        schema=bigquery.TableSchema())
+    client.tables.Get.return_value = table
+    client.tables.Insert.return_value = table
+    write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND
+    with beam.io.BigQuerySink(
+        'project:dataset.table',
+        write_disposition=write_disposition).writer(client):
+      pass
+    self.assertTrue(client.tables.Get.called)
+    self.assertFalse(client.tables.Delete.called)
+    self.assertFalse(client.tables.Insert.called)
+
+  def test_rows_are_written(self):
+    client = mock.Mock()
+    table = bigquery.Table(
+        tableReference=bigquery.TableReference(
+            projectId='project', datasetId='dataset', tableId='table'),
+        schema=bigquery.TableSchema())
+    client.tables.Get.return_value = table
+    write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND
+
+    insert_response = mock.Mock()
+    insert_response.insertErrors = []
+    client.tabledata.InsertAll.return_value = insert_response
+
+    with beam.io.BigQuerySink(
+        'project:dataset.table',
+        write_disposition=write_disposition).writer(client) as writer:
+      writer.Write({'i': 1, 'b': True, 's': 'abc', 'f': 3.14})
+
+    sample_row = {'i': 1, 'b': True, 's': 'abc', 'f': 3.14}
+    expected_rows = []
+    json_object = bigquery.JsonObject()
+    for k, v in sample_row.iteritems():
+      json_object.additionalProperties.append(
+          bigquery.JsonObject.AdditionalProperty(
+              key=k, value=to_json_value(v)))
+    expected_rows.append(
+        bigquery.TableDataInsertAllRequest.RowsValueListEntry(
+            insertId='_1',  # First row ID generated with prefix ''
+            json=json_object))
+    client.tabledata.InsertAll.assert_called_with(
+        bigquery.BigqueryTabledataInsertAllRequest(
+            projectId='project', datasetId='dataset', tableId='table',
+            tableDataInsertAllRequest=bigquery.TableDataInsertAllRequest(
+                rows=expected_rows)))
+
+  def test_table_schema_without_project(self):
+    # Writer should pick executing project by default.
+    sink = beam.io.BigQuerySink(table='mydataset.mytable')
+    options = PipelineOptions(flags=['--project', 'myproject'])
+    sink.pipeline_options = options
+    writer = sink.writer()
+    self.assertEquals('myproject', writer.project_id)
+
+
+@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
+class TestBigQueryWrapper(unittest.TestCase):
+
+  def test_delete_non_existing_dataset(self):
+    client = mock.Mock()
+    client.datasets.Delete.side_effect = HttpError(
+        response={'status': '404'}, url='', content='')
+    wrapper = beam.io.gcp.bigquery.BigQueryWrapper(client)
+    wrapper._delete_dataset('', '')
+    self.assertTrue(client.datasets.Delete.called)
+
+  @mock.patch('time.sleep', return_value=None)
+  def test_delete_dataset_retries_fail(self, patched_time_sleep):
+    client = mock.Mock()
+    client.datasets.Delete.side_effect = ValueError("Cannot delete")
+    wrapper = beam.io.gcp.bigquery.BigQueryWrapper(client)
+    with self.assertRaises(ValueError) as _:
+      wrapper._delete_dataset('', '')
+    self.assertEqual(
+        beam.io.gcp.bigquery.MAX_RETRIES + 1,
+        client.datasets.Delete.call_count)
+    self.assertTrue(client.datasets.Delete.called)
+
+  def test_delete_non_existing_table(self):
+    client = mock.Mock()
+    client.tables.Delete.side_effect = HttpError(
+        response={'status': '404'}, url='', content='')
+    wrapper = beam.io.gcp.bigquery.BigQueryWrapper(client)
+    wrapper._delete_table('', '', '')
+    self.assertTrue(client.tables.Delete.called)
+
+  @mock.patch('time.sleep', return_value=None)
+  def test_delete_table_retries_fail(self, patched_time_sleep):
+    client = mock.Mock()
+    client.tables.Delete.side_effect = ValueError("Cannot delete")
+    wrapper = beam.io.gcp.bigquery.BigQueryWrapper(client)
+    with self.assertRaises(ValueError) as _:
+      wrapper._delete_table('', '', '')
+    self.assertTrue(client.tables.Delete.called)
+
+  @mock.patch('time.sleep', return_value=None)
+  def test_delete_dataset_retries_for_timeouts(self, patched_time_sleep):
+    client = mock.Mock()
+    client.datasets.Delete.side_effect = [
+        HttpError(
+            response={'status': '408'}, url='', content=''),
+        bigquery.BigqueryDatasetsDeleteResponse()
+    ]
+    wrapper = beam.io.gcp.bigquery.BigQueryWrapper(client)
+    wrapper._delete_dataset('', '')
+    self.assertTrue(client.datasets.Delete.called)
+
+  @mock.patch('time.sleep', return_value=None)
+  def test_delete_table_retries_for_timeouts(self, patched_time_sleep):
+    client = mock.Mock()
+    client.tables.Delete.side_effect = [
+        HttpError(
+            response={'status': '408'}, url='', content=''),
+        bigquery.BigqueryTablesDeleteResponse()
+    ]
+    wrapper = beam.io.gcp.bigquery.BigQueryWrapper(client)
+    wrapper._delete_table('', '', '')
+    self.assertTrue(client.tables.Delete.called)
+
+  @mock.patch('time.sleep', return_value=None)
+  def test_temporary_dataset_is_unique(self, patched_time_sleep):
+    client = mock.Mock()
+    client.datasets.Get.return_value = bigquery.Dataset(
+        datasetReference=bigquery.DatasetReference(
+            projectId='project_id', datasetId='dataset_id'))
+    wrapper = beam.io.gcp.bigquery.BigQueryWrapper(client)
+    with self.assertRaises(RuntimeError) as _:
+      wrapper.create_temporary_dataset('project_id')
+    self.assertTrue(client.datasets.Get.called)
+
+  def test_get_or_create_dataset_created(self):
+    client = mock.Mock()
+    client.datasets.Get.side_effect = HttpError(
+        response={'status': '404'}, url='', content='')
+    client.datasets.Insert.return_value = bigquery.Dataset(
+        datasetReference=bigquery.DatasetReference(
+            projectId='project_id', datasetId='dataset_id'))
+    wrapper = beam.io.gcp.bigquery.BigQueryWrapper(client)
+    new_dataset = wrapper.get_or_create_dataset('project_id', 'dataset_id')
+    self.assertEqual(new_dataset.datasetReference.datasetId, 'dataset_id')
+
+  def test_get_or_create_dataset_fetched(self):
+    client = mock.Mock()
+    client.datasets.Get.return_value = bigquery.Dataset(
+        datasetReference=bigquery.DatasetReference(
+            projectId='project_id', datasetId='dataset_id'))
+    wrapper = beam.io.gcp.bigquery.BigQueryWrapper(client)
+    new_dataset = wrapper.get_or_create_dataset('project_id', 'dataset_id')
+    self.assertEqual(new_dataset.datasetReference.datasetId, 'dataset_id')
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/59ad58ac/sdks/python/apache_beam/io/gcp/datastore/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/datastore/__init__.py b/sdks/python/apache_beam/io/gcp/datastore/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/datastore/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#

http://git-wip-us.apache.org/repos/asf/beam/blob/59ad58ac/sdks/python/apache_beam/io/gcp/datastore/v1/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/__init__.py b/sdks/python/apache_beam/io/gcp/datastore/v1/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/datastore/v1/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#

http://git-wip-us.apache.org/repos/asf/beam/blob/59ad58ac/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio.py b/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio.py
new file mode 100644
index 0000000..af0c72b
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio.py
@@ -0,0 +1,397 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A connector for reading from and writing to Google Cloud Datastore"""
+
+import logging
+
+# Protect against environments where datastore library is not available.
+# pylint: disable=wrong-import-order, wrong-import-position
+try:
+  from google.cloud.proto.datastore.v1 import datastore_pb2
+  from googledatastore import helper as datastore_helper
+except ImportError:
+  pass
+# pylint: enable=wrong-import-order, wrong-import-position
+
+from apache_beam.io.gcp.datastore.v1 import helper
+from apache_beam.io.gcp.datastore.v1 import query_splitter
+from apache_beam.transforms import Create
+from apache_beam.transforms import DoFn
+from apache_beam.transforms import FlatMap
+from apache_beam.transforms import GroupByKey
+from apache_beam.transforms import Map
+from apache_beam.transforms import PTransform
+from apache_beam.transforms import ParDo
+from apache_beam.transforms.util import Values
+
+__all__ = ['ReadFromDatastore', 'WriteToDatastore', 'DeleteFromDatastore']
+
+
+class ReadFromDatastore(PTransform):
+  """A ``PTransform`` for reading from Google Cloud Datastore.
+
+  To read a ``PCollection[Entity]`` from a Cloud Datastore ``Query``, use
+  ``ReadFromDatastore`` transform by providing a `project` id and a `query` to
+  read from. You can optionally provide a `namespace` and/or specify how many
+  splits you want for the query through `num_splits` option.
+
+  Note: Normally, a runner will read from Cloud Datastore in parallel across
+  many workers. However, when the `query` is configured with a `limit` or if the
+  query contains inequality filters like `GREATER_THAN, LESS_THAN` etc., then
+  all the returned results will be read by a single worker in order to ensure
+  correct data. Since data is read from a single worker, this could have
+  significant impact on the performance of the job.
+
+  The semantics for the query splitting is defined below:
+    1. If `num_splits` is equal to 0, then the number of splits will be chosen
+    dynamically at runtime based on the query data size.
+
+    2. Any value of `num_splits` greater than
+    `ReadFromDatastore._NUM_QUERY_SPLITS_MAX` will be capped at that value.
+
+    3. If the `query` has a user limit set, or contains inequality filters, then
+    `num_splits` will be ignored and no split will be performed.
+
+    4. Under certain cases Cloud Datastore is unable to split query to the
+    requested number of splits. In such cases we just use whatever the Cloud
+    Datastore returns.
+
+  See https://developers.google.com/datastore/ for more details on Google Cloud
+  Datastore.
+  """
+
+  # An upper bound on the number of splits for a query.
+  _NUM_QUERY_SPLITS_MAX = 50000
+  # A lower bound on the number of splits for a query. This is to ensure that
+  # we parellelize the query even when Datastore statistics are not available.
+  _NUM_QUERY_SPLITS_MIN = 12
+  # Default bundle size of 64MB.
+  _DEFAULT_BUNDLE_SIZE_BYTES = 64 * 1024 * 1024
+
+  def __init__(self, project, query, namespace=None, num_splits=0):
+    """Initialize the ReadFromDatastore transform.
+
+    Args:
+      project: The Project ID
+      query: Cloud Datastore query to be read from.
+      namespace: An optional namespace.
+      num_splits: Number of splits for the query.
+    """
+    logging.warning('datastoreio read transform is experimental.')
+    super(ReadFromDatastore, self).__init__()
+
+    if not project:
+      ValueError("Project cannot be empty")
+    if not query:
+      ValueError("Query cannot be empty")
+    if num_splits < 0:
+      ValueError("num_splits must be greater than or equal 0")
+
+    self._project = project
+    # using _namespace conflicts with DisplayData._namespace
+    self._datastore_namespace = namespace
+    self._query = query
+    self._num_splits = num_splits
+
+  def expand(self, pcoll):
+    # This is a composite transform involves the following:
+    #   1. Create a singleton of the user provided `query` and apply a ``ParDo``
+    #   that splits the query into `num_splits` and assign each split query a
+    #   unique `int` as the key. The resulting output is of the type
+    #   ``PCollection[(int, Query)]``.
+    #
+    #   If the value of `num_splits` is less than or equal to 0, then the
+    #   number of splits will be computed dynamically based on the size of the
+    #   data for the `query`.
+    #
+    #   2. The resulting ``PCollection`` is sharded using a ``GroupByKey``
+    #   operation. The queries are extracted from the (int, Iterable[Query]) and
+    #   flattened to output a ``PCollection[Query]``.
+    #
+    #   3. In the third step, a ``ParDo`` reads entities for each query and
+    #   outputs a ``PCollection[Entity]``.
+
+    queries = (pcoll.pipeline
+               | 'User Query' >> Create([self._query])
+               | 'Split Query' >> ParDo(ReadFromDatastore.SplitQueryFn(
+                   self._project, self._query, self._datastore_namespace,
+                   self._num_splits)))
+
+    sharded_queries = (queries
+                       | GroupByKey()
+                       | Values()
+                       | 'flatten' >> FlatMap(lambda x: x))
+
+    entities = sharded_queries | 'Read' >> ParDo(
+        ReadFromDatastore.ReadFn(self._project, self._datastore_namespace))
+    return entities
+
+  def display_data(self):
+    disp_data = {'project': self._project,
+                 'query': str(self._query),
+                 'num_splits': self._num_splits}
+
+    if self._datastore_namespace is not None:
+      disp_data['namespace'] = self._datastore_namespace
+
+    return disp_data
+
+  class SplitQueryFn(DoFn):
+    """A `DoFn` that splits a given query into multiple sub-queries."""
+    def __init__(self, project, query, namespace, num_splits):
+      super(ReadFromDatastore.SplitQueryFn, self).__init__()
+      self._datastore = None
+      self._project = project
+      self._datastore_namespace = namespace
+      self._query = query
+      self._num_splits = num_splits
+
+    def start_bundle(self):
+      self._datastore = helper.get_datastore(self._project)
+
+    def process(self, query, *args, **kwargs):
+      # distinct key to be used to group query splits.
+      key = 1
+
+      # If query has a user set limit, then the query cannot be split.
+      if query.HasField('limit'):
+        return [(key, query)]
+
+      # Compute the estimated numSplits if not specified by the user.
+      if self._num_splits == 0:
+        estimated_num_splits = ReadFromDatastore.get_estimated_num_splits(
+            self._project, self._datastore_namespace, self._query,
+            self._datastore)
+      else:
+        estimated_num_splits = self._num_splits
+
+      logging.info("Splitting the query into %d splits", estimated_num_splits)
+      try:
+        query_splits = query_splitter.get_splits(
+            self._datastore, query, estimated_num_splits,
+            helper.make_partition(self._project, self._datastore_namespace))
+      except Exception:
+        logging.warning("Unable to parallelize the given query: %s", query,
+                        exc_info=True)
+        query_splits = [query]
+
+      sharded_query_splits = []
+      for split_query in query_splits:
+        sharded_query_splits.append((key, split_query))
+        key += 1
+
+      return sharded_query_splits
+
+    def display_data(self):
+      disp_data = {'project': self._project,
+                   'query': str(self._query),
+                   'num_splits': self._num_splits}
+
+      if self._datastore_namespace is not None:
+        disp_data['namespace'] = self._datastore_namespace
+
+      return disp_data
+
+  class ReadFn(DoFn):
+    """A DoFn that reads entities from Cloud Datastore, for a given query."""
+    def __init__(self, project, namespace=None):
+      super(ReadFromDatastore.ReadFn, self).__init__()
+      self._project = project
+      self._datastore_namespace = namespace
+      self._datastore = None
+
+    def start_bundle(self):
+      self._datastore = helper.get_datastore(self._project)
+
+    def process(self, query, *args, **kwargs):
+      # Returns an iterator of entities that reads in batches.
+      entities = helper.fetch_entities(self._project, self._datastore_namespace,
+                                       query, self._datastore)
+      return entities
+
+    def display_data(self):
+      disp_data = {'project': self._project}
+
+      if self._datastore_namespace is not None:
+        disp_data['namespace'] = self._datastore_namespace
+
+      return disp_data
+
+  @staticmethod
+  def query_latest_statistics_timestamp(project, namespace, datastore):
+    """Fetches the latest timestamp of statistics from Cloud Datastore.
+
+    Cloud Datastore system tables with statistics are periodically updated.
+    This method fethes the latest timestamp (in microseconds) of statistics
+    update using the `__Stat_Total__` table.
+    """
+    query = helper.make_latest_timestamp_query(namespace)
+    req = helper.make_request(project, namespace, query)
+    resp = datastore.run_query(req)
+    if len(resp.batch.entity_results) == 0:
+      raise RuntimeError("Datastore total statistics unavailable.")
+
+    entity = resp.batch.entity_results[0].entity
+    return datastore_helper.micros_from_timestamp(
+        entity.properties['timestamp'].timestamp_value)
+
+  @staticmethod
+  def get_estimated_size_bytes(project, namespace, query, datastore):
+    """Get the estimated size of the data returned by the given query.
+
+    Cloud Datastore provides no way to get a good estimate of how large the
+    result of a query is going to be. Hence we use the __Stat_Kind__ system
+    table to get size of the entire kind as an approximate estimate, assuming
+    exactly 1 kind is specified in the query.
+    See https://cloud.google.com/datastore/docs/concepts/stats.
+    """
+    kind = query.kind[0].name
+    latest_timestamp = ReadFromDatastore.query_latest_statistics_timestamp(
+        project, namespace, datastore)
+    logging.info('Latest stats timestamp for kind %s is %s',
+                 kind, latest_timestamp)
+
+    kind_stats_query = (
+        helper.make_kind_stats_query(namespace, kind, latest_timestamp))
+
+    req = helper.make_request(project, namespace, kind_stats_query)
+    resp = datastore.run_query(req)
+    if len(resp.batch.entity_results) == 0:
+      raise RuntimeError("Datastore statistics for kind %s unavailable" % kind)
+
+    entity = resp.batch.entity_results[0].entity
+    return datastore_helper.get_value(entity.properties['entity_bytes'])
+
+  @staticmethod
+  def get_estimated_num_splits(project, namespace, query, datastore):
+    """Computes the number of splits to be performed on the given query."""
+    try:
+      estimated_size_bytes = ReadFromDatastore.get_estimated_size_bytes(
+          project, namespace, query, datastore)
+      logging.info('Estimated size bytes for query: %s', estimated_size_bytes)
+      num_splits = int(min(ReadFromDatastore._NUM_QUERY_SPLITS_MAX, round(
+          (float(estimated_size_bytes) /
+           ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES))))
+
+    except Exception as e:
+      logging.warning('Failed to fetch estimated size bytes: %s', e)
+      # Fallback in case estimated size is unavailable.
+      num_splits = ReadFromDatastore._NUM_QUERY_SPLITS_MIN
+
+    return max(num_splits, ReadFromDatastore._NUM_QUERY_SPLITS_MIN)
+
+
+class _Mutate(PTransform):
+  """A ``PTransform`` that writes mutations to Cloud Datastore.
+
+  Only idempotent Datastore mutation operations (upsert and delete) are
+  supported, as the commits are retried when failures occur.
+  """
+
+  # Max allowed Datastore write batch size.
+  _WRITE_BATCH_SIZE = 500
+
+  def __init__(self, project, mutation_fn):
+    """Initializes a Mutate transform.
+
+     Args:
+       project: The Project ID
+       mutation_fn: A function that converts `entities` or `keys` to
+         `mutations`.
+     """
+    self._project = project
+    self._mutation_fn = mutation_fn
+    logging.warning('datastoreio write transform is experimental.')
+
+  def expand(self, pcoll):
+    return (pcoll
+            | 'Convert to Mutation' >> Map(self._mutation_fn)
+            | 'Write Mutation to Datastore' >> ParDo(_Mutate.DatastoreWriteFn(
+                self._project)))
+
+  def display_data(self):
+    return {'project': self._project,
+            'mutation_fn': self._mutation_fn.__class__.__name__}
+
+  class DatastoreWriteFn(DoFn):
+    """A ``DoFn`` that write mutations to Datastore.
+
+    Mutations are written in batches, where the maximum batch size is
+    `Mutate._WRITE_BATCH_SIZE`.
+
+    Commits are non-transactional. If a commit fails because of a conflict over
+    an entity group, the commit will be retried. This means that the mutation
+    should be idempotent (`upsert` and `delete` mutations) to prevent duplicate
+    data or errors.
+    """
+    def __init__(self, project):
+      self._project = project
+      self._datastore = None
+      self._mutations = []
+
+    def start_bundle(self):
+      self._mutations = []
+      self._datastore = helper.get_datastore(self._project)
+
+    def process(self, element):
+      self._mutations.append(element)
+      if len(self._mutations) >= _Mutate._WRITE_BATCH_SIZE:
+        self._flush_batch()
+
+    def finish_bundle(self):
+      if self._mutations:
+        self._flush_batch()
+      self._mutations = []
+
+    def _flush_batch(self):
+      # Flush the current batch of mutations to Cloud Datastore.
+      helper.write_mutations(self._datastore, self._project, self._mutations)
+      logging.debug("Successfully wrote %d mutations.", len(self._mutations))
+      self._mutations = []
+
+
+class WriteToDatastore(_Mutate):
+  """A ``PTransform`` to write a ``PCollection[Entity]`` to Cloud Datastore."""
+  def __init__(self, project):
+    super(WriteToDatastore, self).__init__(
+        project, WriteToDatastore.to_upsert_mutation)
+
+  @staticmethod
+  def to_upsert_mutation(entity):
+    if not helper.is_key_valid(entity.key):
+      raise ValueError('Entities to be written to the Cloud Datastore must '
+                       'have complete keys:\n%s' % entity)
+    mutation = datastore_pb2.Mutation()
+    mutation.upsert.CopyFrom(entity)
+    return mutation
+
+
+class DeleteFromDatastore(_Mutate):
+  """A ``PTransform`` to delete a ``PCollection[Key]`` from Cloud Datastore."""
+  def __init__(self, project):
+    super(DeleteFromDatastore, self).__init__(
+        project, DeleteFromDatastore.to_delete_mutation)
+
+  @staticmethod
+  def to_delete_mutation(key):
+    if not helper.is_key_valid(key):
+      raise ValueError('Keys to be deleted from the Cloud Datastore must be '
+                       'complete:\n%s", key')
+    mutation = datastore_pb2.Mutation()
+    mutation.delete.CopyFrom(key)
+    return mutation

http://git-wip-us.apache.org/repos/asf/beam/blob/59ad58ac/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio_test.py b/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio_test.py
new file mode 100644
index 0000000..3121d3a
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/datastore/v1/datastoreio_test.py
@@ -0,0 +1,245 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from mock import MagicMock, call, patch
+
+from apache_beam.io.gcp.datastore.v1 import fake_datastore
+from apache_beam.io.gcp.datastore.v1 import helper
+from apache_beam.io.gcp.datastore.v1 import query_splitter
+from apache_beam.io.gcp.datastore.v1.datastoreio import _Mutate
+from apache_beam.io.gcp.datastore.v1.datastoreio import ReadFromDatastore
+from apache_beam.io.gcp.datastore.v1.datastoreio import WriteToDatastore
+
+# Protect against environments where datastore library is not available.
+# pylint: disable=wrong-import-order, wrong-import-position
+try:
+  from google.cloud.proto.datastore.v1 import datastore_pb2
+  from google.cloud.proto.datastore.v1 import query_pb2
+  from google.protobuf import timestamp_pb2
+  from googledatastore import helper as datastore_helper
+except ImportError:
+  datastore_pb2 = None
+# pylint: enable=wrong-import-order, wrong-import-position
+
+
+@unittest.skipIf(datastore_pb2 is None, 'GCP dependencies are not installed')
+class DatastoreioTest(unittest.TestCase):
+  _PROJECT = 'project'
+  _KIND = 'kind'
+  _NAMESPACE = 'namespace'
+
+  def setUp(self):
+    self._mock_datastore = MagicMock()
+    self._query = query_pb2.Query()
+    self._query.kind.add().name = self._KIND
+
+  def test_get_estimated_size_bytes_without_namespace(self):
+    entity_bytes = 100
+    timestamp = timestamp_pb2.Timestamp(seconds=1234)
+    self.check_estimated_size_bytes(entity_bytes, timestamp)
+
+  def test_get_estimated_size_bytes_with_namespace(self):
+    entity_bytes = 100
+    timestamp = timestamp_pb2.Timestamp(seconds=1234)
+    self.check_estimated_size_bytes(entity_bytes, timestamp, self._NAMESPACE)
+
+  def test_SplitQueryFn_with_num_splits(self):
+    with patch.object(helper, 'get_datastore',
+                      return_value=self._mock_datastore):
+      num_splits = 23
+
+      def fake_get_splits(datastore, query, num_splits, partition=None):
+        return self.split_query(query, num_splits)
+
+      with patch.object(query_splitter, 'get_splits',
+                        side_effect=fake_get_splits):
+
+        split_query_fn = ReadFromDatastore.SplitQueryFn(
+            self._PROJECT, self._query, None, num_splits)
+        split_query_fn.start_bundle()
+        returned_split_queries = []
+        for split_query in split_query_fn.process(self._query):
+          returned_split_queries.append(split_query)
+
+        self.assertEqual(len(returned_split_queries), num_splits)
+        self.assertEqual(0, len(self._mock_datastore.run_query.call_args_list))
+        self.verify_unique_keys(returned_split_queries)
+
+  def test_SplitQueryFn_without_num_splits(self):
+    with patch.object(helper, 'get_datastore',
+                      return_value=self._mock_datastore):
+      # Force SplitQueryFn to compute the number of query splits
+      num_splits = 0
+      expected_num_splits = 23
+      entity_bytes = (expected_num_splits *
+                      ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES)
+      with patch.object(ReadFromDatastore, 'get_estimated_size_bytes',
+                        return_value=entity_bytes):
+
+        def fake_get_splits(datastore, query, num_splits, partition=None):
+          return self.split_query(query, num_splits)
+
+        with patch.object(query_splitter, 'get_splits',
+                          side_effect=fake_get_splits):
+          split_query_fn = ReadFromDatastore.SplitQueryFn(
+              self._PROJECT, self._query, None, num_splits)
+          split_query_fn.start_bundle()
+          returned_split_queries = []
+          for split_query in split_query_fn.process(self._query):
+            returned_split_queries.append(split_query)
+
+          self.assertEqual(len(returned_split_queries), expected_num_splits)
+          self.assertEqual(0,
+                           len(self._mock_datastore.run_query.call_args_list))
+          self.verify_unique_keys(returned_split_queries)
+
+  def test_SplitQueryFn_with_query_limit(self):
+    """A test that verifies no split is performed when the query has a limit."""
+    with patch.object(helper, 'get_datastore',
+                      return_value=self._mock_datastore):
+      self._query.limit.value = 3
+      split_query_fn = ReadFromDatastore.SplitQueryFn(
+          self._PROJECT, self._query, None, 4)
+      split_query_fn.start_bundle()
+      returned_split_queries = []
+      for split_query in split_query_fn.process(self._query):
+        returned_split_queries.append(split_query)
+
+      self.assertEqual(1, len(returned_split_queries))
+      self.assertEqual(0, len(self._mock_datastore.method_calls))
+
+  def test_SplitQueryFn_with_exception(self):
+    """A test that verifies that no split is performed when failures occur."""
+    with patch.object(helper, 'get_datastore',
+                      return_value=self._mock_datastore):
+      # Force SplitQueryFn to compute the number of query splits
+      num_splits = 0
+      expected_num_splits = 1
+      entity_bytes = (expected_num_splits *
+                      ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES)
+      with patch.object(ReadFromDatastore, 'get_estimated_size_bytes',
+                        return_value=entity_bytes):
+
+        with patch.object(query_splitter, 'get_splits',
+                          side_effect=ValueError("Testing query split error")):
+          split_query_fn = ReadFromDatastore.SplitQueryFn(
+              self._PROJECT, self._query, None, num_splits)
+          split_query_fn.start_bundle()
+          returned_split_queries = []
+          for split_query in split_query_fn.process(self._query):
+            returned_split_queries.append(split_query)
+
+          self.assertEqual(len(returned_split_queries), expected_num_splits)
+          self.assertEqual(returned_split_queries[0][1], self._query)
+          self.assertEqual(0,
+                           len(self._mock_datastore.run_query.call_args_list))
+          self.verify_unique_keys(returned_split_queries)
+
+  def test_DatastoreWriteFn_with_emtpy_batch(self):
+    self.check_DatastoreWriteFn(0)
+
+  def test_DatastoreWriteFn_with_one_batch(self):
+    num_entities_to_write = _Mutate._WRITE_BATCH_SIZE * 1 - 50
+    self.check_DatastoreWriteFn(num_entities_to_write)
+
+  def test_DatastoreWriteFn_with_multiple_batches(self):
+    num_entities_to_write = _Mutate._WRITE_BATCH_SIZE * 3 + 50
+    self.check_DatastoreWriteFn(num_entities_to_write)
+
+  def test_DatastoreWriteFn_with_batch_size_exact_multiple(self):
+    num_entities_to_write = _Mutate._WRITE_BATCH_SIZE * 2
+    self.check_DatastoreWriteFn(num_entities_to_write)
+
+  def check_DatastoreWriteFn(self, num_entities):
+    """A helper function to test DatastoreWriteFn."""
+
+    with patch.object(helper, 'get_datastore',
+                      return_value=self._mock_datastore):
+      entities = [e.entity for e in
+                  fake_datastore.create_entities(num_entities)]
+
+      expected_mutations = map(WriteToDatastore.to_upsert_mutation, entities)
+      actual_mutations = []
+
+      self._mock_datastore.commit.side_effect = (
+          fake_datastore.create_commit(actual_mutations))
+
+      datastore_write_fn = _Mutate.DatastoreWriteFn(self._PROJECT)
+
+      datastore_write_fn.start_bundle()
+      for mutation in expected_mutations:
+        datastore_write_fn.process(mutation)
+      datastore_write_fn.finish_bundle()
+
+      self.assertEqual(actual_mutations, expected_mutations)
+      self.assertEqual((num_entities - 1) / _Mutate._WRITE_BATCH_SIZE + 1,
+                       self._mock_datastore.commit.call_count)
+
+  def verify_unique_keys(self, queries):
+    """A helper function that verifies if all the queries have unique keys."""
+    keys, _ = zip(*queries)
+    keys = set(keys)
+    self.assertEqual(len(keys), len(queries))
+
+  def check_estimated_size_bytes(self, entity_bytes, timestamp, namespace=None):
+    """A helper method to test get_estimated_size_bytes"""
+
+    timestamp_req = helper.make_request(
+        self._PROJECT, namespace, helper.make_latest_timestamp_query(namespace))
+    timestamp_resp = self.make_stats_response(
+        {'timestamp': datastore_helper.from_timestamp(timestamp)})
+    kind_stat_req = helper.make_request(
+        self._PROJECT, namespace, helper.make_kind_stats_query(
+            namespace, self._query.kind[0].name,
+            datastore_helper.micros_from_timestamp(timestamp)))
+    kind_stat_resp = self.make_stats_response(
+        {'entity_bytes': entity_bytes})
+
+    def fake_run_query(req):
+      if req == timestamp_req:
+        return timestamp_resp
+      elif req == kind_stat_req:
+        return kind_stat_resp
+      else:
+        print kind_stat_req
+        raise ValueError("Unknown req: %s" % req)
+
+    self._mock_datastore.run_query.side_effect = fake_run_query
+    self.assertEqual(entity_bytes, ReadFromDatastore.get_estimated_size_bytes(
+        self._PROJECT, namespace, self._query, self._mock_datastore))
+    self.assertEqual(self._mock_datastore.run_query.call_args_list,
+                     [call(timestamp_req), call(kind_stat_req)])
+
+  def make_stats_response(self, property_map):
+    resp = datastore_pb2.RunQueryResponse()
+    entity_result = resp.batch.entity_results.add()
+    datastore_helper.add_properties(entity_result.entity, property_map)
+    return resp
+
+  def split_query(self, query, num_splits):
+    """Generate dummy query splits."""
+    split_queries = []
+    for _ in range(0, num_splits):
+      q = query_pb2.Query()
+      q.CopyFrom(query)
+      split_queries.append(q)
+    return split_queries
+
+if __name__ == '__main__':
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/59ad58ac/sdks/python/apache_beam/io/gcp/datastore/v1/fake_datastore.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/fake_datastore.py b/sdks/python/apache_beam/io/gcp/datastore/v1/fake_datastore.py
new file mode 100644
index 0000000..bc4d07f
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/datastore/v1/fake_datastore.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Fake datastore used for unit testing."""
+import uuid
+
+# Protect against environments where datastore library is not available.
+# pylint: disable=wrong-import-order, wrong-import-position
+try:
+  from google.cloud.proto.datastore.v1 import datastore_pb2
+  from google.cloud.proto.datastore.v1 import query_pb2
+except ImportError:
+  pass
+# pylint: enable=wrong-import-order, wrong-import-position
+
+
+def create_run_query(entities, batch_size):
+  """A fake datastore run_query method that returns entities in batches.
+
+  Note: the outer method is needed to make the `entities` and `batch_size`
+  available in the scope of fake_run_query method.
+
+  Args:
+    entities: list of entities supposed to be contained in the datastore.
+    batch_size: the number of entities that run_query method returns in one
+                request.
+  """
+  def run_query(req):
+    start = int(req.query.start_cursor) if req.query.start_cursor else 0
+    # if query limit is less than batch_size, then only return that much.
+    count = min(batch_size, req.query.limit.value)
+    # cannot go more than the number of entities contained in datastore.
+    end = min(len(entities), start + count)
+    finish = False
+    # Finish reading when there are no more entities to return,
+    # or request query limit has been satisfied.
+    if end == len(entities) or count == req.query.limit.value:
+      finish = True
+    return create_response(entities[start:end], str(end), finish)
+  return run_query
+
+
+def create_commit(mutations):
+  """A fake Datastore commit method that writes the mutations to a list.
+
+  Args:
+    mutations: A list to write mutations to.
+
+  Returns:
+    A fake Datastore commit method
+  """
+
+  def commit(req):
+    for mutation in req.mutations:
+      mutations.append(mutation)
+
+  return commit
+
+
+def create_response(entities, end_cursor, finish):
+  """Creates a query response for a given batch of scatter entities."""
+  resp = datastore_pb2.RunQueryResponse()
+  if finish:
+    resp.batch.more_results = query_pb2.QueryResultBatch.NO_MORE_RESULTS
+  else:
+    resp.batch.more_results = query_pb2.QueryResultBatch.NOT_FINISHED
+
+  resp.batch.end_cursor = end_cursor
+  for entity_result in entities:
+    resp.batch.entity_results.add().CopyFrom(entity_result)
+
+  return resp
+
+
+def create_entities(count):
+  """Creates a list of entities with random keys."""
+  entities = []
+
+  for _ in range(count):
+    entity_result = query_pb2.EntityResult()
+    entity_result.entity.key.path.add().name = str(uuid.uuid4())
+    entities.append(entity_result)
+
+  return entities

http://git-wip-us.apache.org/repos/asf/beam/blob/59ad58ac/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py b/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py
new file mode 100644
index 0000000..e15e43b
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/datastore/v1/helper.py
@@ -0,0 +1,274 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Cloud Datastore helper functions."""
+import sys
+
+# Protect against environments where datastore library is not available.
+# pylint: disable=wrong-import-order, wrong-import-position
+try:
+  from google.cloud.proto.datastore.v1 import datastore_pb2
+  from google.cloud.proto.datastore.v1 import entity_pb2
+  from google.cloud.proto.datastore.v1 import query_pb2
+  from googledatastore import PropertyFilter, CompositeFilter
+  from googledatastore import helper as datastore_helper
+  from googledatastore.connection import Datastore
+  from googledatastore.connection import RPCError
+  QUERY_NOT_FINISHED = query_pb2.QueryResultBatch.NOT_FINISHED
+except ImportError:
+  QUERY_NOT_FINISHED = None
+# pylint: enable=wrong-import-order, wrong-import-position
+
+from apache_beam.internal import auth
+from apache_beam.utils import retry
+
+
+def key_comparator(k1, k2):
+  """A comparator for Datastore keys.
+
+  Comparison is only valid for keys in the same partition. The comparison here
+  is between the list of paths for each key.
+  """
+
+  if k1.partition_id != k2.partition_id:
+    raise ValueError('Cannot compare keys with different partition ids.')
+
+  k2_iter = iter(k2.path)
+
+  for k1_path in k1.path:
+    k2_path = next(k2_iter, None)
+    if not k2_path:
+      return 1
+
+    result = compare_path(k1_path, k2_path)
+
+    if result != 0:
+      return result
+
+  k2_path = next(k2_iter, None)
+  if k2_path:
+    return -1
+  else:
+    return 0
+
+
+def compare_path(p1, p2):
+  """A comparator for key path.
+
+  A path has either an `id` or a `name` field defined. The
+  comparison works with the following rules:
+
+  1. If one path has `id` defined while the other doesn't, then the
+  one with `id` defined is considered smaller.
+  2. If both paths have `id` defined, then their ids are compared.
+  3. If no `id` is defined for both paths, then their `names` are compared.
+  """
+
+  result = str_compare(p1.kind, p2.kind)
+  if result != 0:
+    return result
+
+  if p1.HasField('id'):
+    if not p2.HasField('id'):
+      return -1
+
+    return p1.id - p2.id
+
+  if p2.HasField('id'):
+    return 1
+
+  return str_compare(p1.name, p2.name)
+
+
+def str_compare(s1, s2):
+  if s1 == s2:
+    return 0
+  elif s1 < s2:
+    return -1
+  else:
+    return 1
+
+
+def get_datastore(project):
+  """Returns a Cloud Datastore client."""
+  credentials = auth.get_service_credentials()
+  return Datastore(project, credentials)
+
+
+def make_request(project, namespace, query):
+  """Make a Cloud Datastore request for the given query."""
+  req = datastore_pb2.RunQueryRequest()
+  req.partition_id.CopyFrom(make_partition(project, namespace))
+
+  req.query.CopyFrom(query)
+  return req
+
+
+def make_partition(project, namespace):
+  """Make a PartitionId for the given project and namespace."""
+  partition = entity_pb2.PartitionId()
+  partition.project_id = project
+  if namespace is not None:
+    partition.namespace_id = namespace
+
+  return partition
+
+
+def retry_on_rpc_error(exception):
+  """A retry filter for Cloud Datastore RPCErrors."""
+  if isinstance(exception, RPCError):
+    if exception.code >= 500:
+      return True
+    else:
+      return False
+  else:
+    # TODO(vikasrk): Figure out what other errors should be retried.
+    return False
+
+
+def fetch_entities(project, namespace, query, datastore):
+  """A helper method to fetch entities from Cloud Datastore.
+
+  Args:
+    project: Project ID
+    namespace: Cloud Datastore namespace
+    query: Query to be read from
+    datastore: Cloud Datastore Client
+
+  Returns:
+    An iterator of entities.
+  """
+  return QueryIterator(project, namespace, query, datastore)
+
+
+def is_key_valid(key):
+  """Returns True if a Cloud Datastore key is complete.
+
+  A key is complete if its last element has either an id or a name.
+  """
+  if not key.path:
+    return False
+  return key.path[-1].HasField('id') or key.path[-1].HasField('name')
+
+
+def write_mutations(datastore, project, mutations):
+  """A helper function to write a batch of mutations to Cloud Datastore.
+
+  If a commit fails, it will be retried upto 5 times. All mutations in the
+  batch will be committed again, even if the commit was partially successful.
+  If the retry limit is exceeded, the last exception from Cloud Datastore will
+  be raised.
+  """
+  commit_request = datastore_pb2.CommitRequest()
+  commit_request.mode = datastore_pb2.CommitRequest.NON_TRANSACTIONAL
+  commit_request.project_id = project
+  for mutation in mutations:
+    commit_request.mutations.add().CopyFrom(mutation)
+
+  @retry.with_exponential_backoff(num_retries=5,
+                                  retry_filter=retry_on_rpc_error)
+  def commit(req):
+    datastore.commit(req)
+
+  commit(commit_request)
+
+
+def make_latest_timestamp_query(namespace):
+  """Make a Query to fetch the latest timestamp statistics."""
+  query = query_pb2.Query()
+  if namespace is None:
+    query.kind.add().name = '__Stat_Total__'
+  else:
+    query.kind.add().name = '__Stat_Ns_Total__'
+
+  # Descending order of `timestamp`
+  datastore_helper.add_property_orders(query, "-timestamp")
+  # Only get the latest entity
+  query.limit.value = 1
+  return query
+
+
+def make_kind_stats_query(namespace, kind, latest_timestamp):
+  """Make a Query to fetch the latest kind statistics."""
+  kind_stat_query = query_pb2.Query()
+  if namespace is None:
+    kind_stat_query.kind.add().name = '__Stat_Kind__'
+  else:
+    kind_stat_query.kind.add().name = '__Stat_Ns_Kind__'
+
+  kind_filter = datastore_helper.set_property_filter(
+      query_pb2.Filter(), 'kind_name', PropertyFilter.EQUAL, unicode(kind))
+  timestamp_filter = datastore_helper.set_property_filter(
+      query_pb2.Filter(), 'timestamp', PropertyFilter.EQUAL,
+      latest_timestamp)
+
+  datastore_helper.set_composite_filter(kind_stat_query.filter,
+                                        CompositeFilter.AND, kind_filter,
+                                        timestamp_filter)
+  return kind_stat_query
+
+
+class QueryIterator(object):
+  """A iterator class for entities of a given query.
+
+  Entities are read in batches. Retries on failures.
+  """
+  _NOT_FINISHED = QUERY_NOT_FINISHED
+  # Maximum number of results to request per query.
+  _BATCH_SIZE = 500
+
+  def __init__(self, project, namespace, query, datastore):
+    self._query = query
+    self._datastore = datastore
+    self._project = project
+    self._namespace = namespace
+    self._start_cursor = None
+    self._limit = self._query.limit.value or sys.maxint
+    self._req = make_request(project, namespace, query)
+
+  @retry.with_exponential_backoff(num_retries=5,
+                                  retry_filter=retry_on_rpc_error)
+  def _next_batch(self):
+    """Fetches the next batch of entities."""
+    if self._start_cursor is not None:
+      self._req.query.start_cursor = self._start_cursor
+
+    # set batch size
+    self._req.query.limit.value = min(self._BATCH_SIZE, self._limit)
+    resp = self._datastore.run_query(self._req)
+    return resp
+
+  def __iter__(self):
+    more_results = True
+    while more_results:
+      resp = self._next_batch()
+      for entity_result in resp.batch.entity_results:
+        yield entity_result.entity
+
+      self._start_cursor = resp.batch.end_cursor
+      num_results = len(resp.batch.entity_results)
+      self._limit -= num_results
+
+      # Check if we need to read more entities.
+      # True when query limit hasn't been satisfied and there are more entities
+      # to be read. The latter is true if the response has a status
+      # `NOT_FINISHED` or if the number of results read in the previous batch
+      # is equal to `_BATCH_SIZE` (all indications that there is more data be
+      # read).
+      more_results = ((self._limit > 0) and
+                      ((num_results == self._BATCH_SIZE) or
+                       (resp.batch.more_results == self._NOT_FINISHED)))

http://git-wip-us.apache.org/repos/asf/beam/blob/59ad58ac/sdks/python/apache_beam/io/gcp/datastore/v1/helper_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1/helper_test.py b/sdks/python/apache_beam/io/gcp/datastore/v1/helper_test.py
new file mode 100644
index 0000000..582a5b3
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/datastore/v1/helper_test.py
@@ -0,0 +1,265 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for datastore helper."""
+import sys
+import unittest
+
+from mock import MagicMock
+
+from apache_beam.io.gcp.datastore.v1 import fake_datastore
+from apache_beam.io.gcp.datastore.v1 import helper
+from apache_beam.tests.test_utils import patch_retry
+
+
+# Protect against environments where apitools library is not available.
+# pylint: disable=wrong-import-order, wrong-import-position
+try:
+  from google.cloud.proto.datastore.v1 import datastore_pb2
+  from google.cloud.proto.datastore.v1 import entity_pb2
+  from google.cloud.proto.datastore.v1 import query_pb2
+  from google.cloud.proto.datastore.v1.entity_pb2 import Key
+  from googledatastore.connection import RPCError
+  from googledatastore import helper as datastore_helper
+except ImportError:
+  datastore_helper = None
+# pylint: enable=wrong-import-order, wrong-import-position
+
+
+@unittest.skipIf(datastore_helper is None, 'GCP dependencies are not installed')
+class HelperTest(unittest.TestCase):
+
+  def setUp(self):
+    self._mock_datastore = MagicMock()
+    self._query = query_pb2.Query()
+    self._query.kind.add().name = 'dummy_kind'
+    patch_retry(self, helper)
+
+  def permanent_datastore_failure(self, req):
+    raise RPCError("dummy", 500, "failed")
+
+  def transient_datastore_failure(self, req):
+    if self._transient_fail_count:
+      self._transient_fail_count -= 1
+      raise RPCError("dummy", 500, "failed")
+    else:
+      return datastore_pb2.RunQueryResponse()
+
+  def test_query_iterator(self):
+    self._mock_datastore.run_query.side_effect = (
+        self.permanent_datastore_failure)
+    query_iterator = helper.QueryIterator("project", None, self._query,
+                                          self._mock_datastore)
+    self.assertRaises(RPCError, iter(query_iterator).next)
+    self.assertEqual(6, len(self._mock_datastore.run_query.call_args_list))
+
+  def test_query_iterator_with_transient_failures(self):
+    self._mock_datastore.run_query.side_effect = (
+        self.transient_datastore_failure)
+    query_iterator = helper.QueryIterator("project", None, self._query,
+                                          self._mock_datastore)
+    fail_count = 2
+    self._transient_fail_count = fail_count
+    for _ in query_iterator:
+      pass
+
+    self.assertEqual(fail_count + 1,
+                     len(self._mock_datastore.run_query.call_args_list))
+
+  def test_query_iterator_with_single_batch(self):
+    num_entities = 100
+    batch_size = 500
+    self.check_query_iterator(num_entities, batch_size, self._query)
+
+  def test_query_iterator_with_multiple_batches(self):
+    num_entities = 1098
+    batch_size = 500
+    self.check_query_iterator(num_entities, batch_size, self._query)
+
+  def test_query_iterator_with_exact_batch_multiple(self):
+    num_entities = 1000
+    batch_size = 500
+    self.check_query_iterator(num_entities, batch_size, self._query)
+
+  def test_query_iterator_with_query_limit(self):
+    num_entities = 1098
+    batch_size = 500
+    self._query.limit.value = 1004
+    self.check_query_iterator(num_entities, batch_size, self._query)
+
+  def test_query_iterator_with_large_query_limit(self):
+    num_entities = 1098
+    batch_size = 500
+    self._query.limit.value = 10000
+    self.check_query_iterator(num_entities, batch_size, self._query)
+
+  def check_query_iterator(self, num_entities, batch_size, query):
+    """A helper method to test the QueryIterator.
+
+    Args:
+      num_entities: number of entities contained in the fake datastore.
+      batch_size: the number of entities returned by fake datastore in one req.
+      query: the query to be executed
+
+    """
+    entities = fake_datastore.create_entities(num_entities)
+    self._mock_datastore.run_query.side_effect = \
+        fake_datastore.create_run_query(entities, batch_size)
+    query_iterator = helper.QueryIterator("project", None, self._query,
+                                          self._mock_datastore)
+
+    i = 0
+    for entity in query_iterator:
+      self.assertEqual(entity, entities[i].entity)
+      i += 1
+
+    limit = query.limit.value if query.HasField('limit') else sys.maxint
+    self.assertEqual(i, min(num_entities, limit))
+
+  def test_is_key_valid(self):
+    key = entity_pb2.Key()
+    # Complete with name, no ancestor
+    datastore_helper.add_key_path(key, 'kind', 'name')
+    self.assertTrue(helper.is_key_valid(key))
+
+    key = entity_pb2.Key()
+    # Complete with id, no ancestor
+    datastore_helper.add_key_path(key, 'kind', 12)
+    self.assertTrue(helper.is_key_valid(key))
+
+    key = entity_pb2.Key()
+    # Incomplete, no ancestor
+    datastore_helper.add_key_path(key, 'kind')
+    self.assertFalse(helper.is_key_valid(key))
+
+    key = entity_pb2.Key()
+    # Complete with name and ancestor
+    datastore_helper.add_key_path(key, 'kind', 'name', 'kind2', 'name2')
+    self.assertTrue(helper.is_key_valid(key))
+
+    key = entity_pb2.Key()
+    # Complete with id and ancestor
+    datastore_helper.add_key_path(key, 'kind', 'name', 'kind2', 123)
+    self.assertTrue(helper.is_key_valid(key))
+
+    key = entity_pb2.Key()
+    # Incomplete with ancestor
+    datastore_helper.add_key_path(key, 'kind', 'name', 'kind2')
+    self.assertFalse(helper.is_key_valid(key))
+
+    key = entity_pb2.Key()
+    self.assertFalse(helper.is_key_valid(key))
+
+  def test_compare_path_with_different_kind(self):
+    p1 = Key.PathElement()
+    p1.kind = 'dummy1'
+
+    p2 = Key.PathElement()
+    p2.kind = 'dummy2'
+
+    self.assertLess(helper.compare_path(p1, p2), 0)
+
+  def test_compare_path_with_different_id(self):
+    p1 = Key.PathElement()
+    p1.kind = 'dummy'
+    p1.id = 10
+
+    p2 = Key.PathElement()
+    p2.kind = 'dummy'
+    p2.id = 15
+
+    self.assertLess(helper.compare_path(p1, p2), 0)
+
+  def test_compare_path_with_different_name(self):
+    p1 = Key.PathElement()
+    p1.kind = 'dummy'
+    p1.name = "dummy1"
+
+    p2 = Key.PathElement()
+    p2.kind = 'dummy'
+    p2.name = 'dummy2'
+
+    self.assertLess(helper.compare_path(p1, p2), 0)
+
+  def test_compare_path_of_different_type(self):
+    p1 = Key.PathElement()
+    p1.kind = 'dummy'
+    p1.id = 10
+
+    p2 = Key.PathElement()
+    p2.kind = 'dummy'
+    p2.name = 'dummy'
+
+    self.assertLess(helper.compare_path(p1, p2), 0)
+
+  def test_key_comparator_with_different_partition(self):
+    k1 = Key()
+    k1.partition_id.namespace_id = 'dummy1'
+    k2 = Key()
+    k2.partition_id.namespace_id = 'dummy2'
+    self.assertRaises(ValueError, helper.key_comparator, k1, k2)
+
+  def test_key_comparator_with_single_path(self):
+    k1 = Key()
+    k2 = Key()
+    p1 = k1.path.add()
+    p2 = k2.path.add()
+    p1.kind = p2.kind = 'dummy'
+    self.assertEqual(helper.key_comparator(k1, k2), 0)
+
+  def test_key_comparator_with_multiple_paths_1(self):
+    k1 = Key()
+    k2 = Key()
+    p11 = k1.path.add()
+    p12 = k1.path.add()
+    p21 = k2.path.add()
+    p11.kind = p12.kind = p21.kind = 'dummy'
+    self.assertGreater(helper.key_comparator(k1, k2), 0)
+
+  def test_key_comparator_with_multiple_paths_2(self):
+    k1 = Key()
+    k2 = Key()
+    p11 = k1.path.add()
+    p21 = k2.path.add()
+    p22 = k2.path.add()
+    p11.kind = p21.kind = p22.kind = 'dummy'
+    self.assertLess(helper.key_comparator(k1, k2), 0)
+
+  def test_key_comparator_with_multiple_paths_3(self):
+    k1 = Key()
+    k2 = Key()
+    p11 = k1.path.add()
+    p12 = k1.path.add()
+    p21 = k2.path.add()
+    p22 = k2.path.add()
+    p11.kind = p12.kind = p21.kind = p22.kind = 'dummy'
+    self.assertEqual(helper.key_comparator(k1, k2), 0)
+
+  def test_key_comparator_with_multiple_paths_4(self):
+    k1 = Key()
+    k2 = Key()
+    p11 = k1.path.add()
+    p12 = k2.path.add()
+    p21 = k2.path.add()
+    p11.kind = p12.kind = 'dummy'
+    # make path2 greater than path1
+    p21.kind = 'dummy1'
+    self.assertLess(helper.key_comparator(k1, k2), 0)
+
+
+if __name__ == '__main__':
+  unittest.main()


Mime
View raw message