beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From "ASF GitHub Bot (JIRA)" <j...@apache.org>
Subject [jira] [Commented] (BEAM-1630) Add Splittable DoFn to Python SDK
Date Tue, 19 Dec 2017 22:54:00 GMT

    [ https://issues.apache.org/jira/browse/BEAM-1630?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16297554#comment-16297554 ] 

ASF GitHub Bot commented on BEAM-1630:
--------------------------------------

chamikaramj closed pull request #4064: [BEAM-1630] Adds support for processing Splittable DoFns using DirectRunner.
URL: https://github.com/apache/beam/pull/4064
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/python/apache_beam/examples/snippets/snippets_test.py b/sdks/python/apache_beam/examples/snippets/snippets_test.py
index 505858a4327..f05dc392b1f 100644
--- a/sdks/python/apache_beam/examples/snippets/snippets_test.py
+++ b/sdks/python/apache_beam/examples/snippets/snippets_test.py
@@ -423,8 +423,8 @@ def __init__(self, file_to_write):
 
       def start_bundle(self):
         assert self.file_to_write
-        self.file_to_write += str(uuid.uuid4())
-        self.file_obj = open(self.file_to_write, 'w')
+        # Appending a UUID to create a unique file object per invocation.
+        self.file_obj = open(self.file_to_write + str(uuid.uuid4()), 'w')
 
       def process(self, element):
         assert self.file_obj
diff --git a/sdks/python/apache_beam/io/filebasedsource.py b/sdks/python/apache_beam/io/filebasedsource.py
index 052c2f32ff2..900825043b8 100644
--- a/sdks/python/apache_beam/io/filebasedsource.py
+++ b/sdks/python/apache_beam/io/filebasedsource.py
@@ -33,7 +33,7 @@
 from apache_beam.io import range_trackers
 from apache_beam.io.filesystem import CompressionTypes
 from apache_beam.io.filesystems import FileSystems
-from apache_beam.io.range_trackers import OffsetRange
+from apache_beam.io.restriction_trackers import OffsetRange
 from apache_beam.options.value_provider import StaticValueProvider
 from apache_beam.options.value_provider import ValueProvider
 from apache_beam.options.value_provider import check_accessible
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index 7cffa7f834e..fc7a2f3f7b1 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -47,7 +47,8 @@
 from apache_beam.utils import urns
 from apache_beam.utils.windowed_value import WindowedValue
 
-__all__ = ['BoundedSource', 'RangeTracker', 'Read', 'Sink', 'Write', 'Writer']
+__all__ = ['BoundedSource', 'RangeTracker', 'Read', 'RestrictionTracker',
+           'Sink', 'Write', 'Writer']
 
 
 # Encapsulates information about a bundle of a source generated when method
@@ -961,6 +962,7 @@ def display_data(self):
 
   def process(self, element, init_result):
     if self.writer is None:
+      # We ignore UUID collisions here since they are extremely rare.
       self.writer = self.sink.open_writer(init_result, str(uuid.uuid4()))
     self.writer.write(element)
 
@@ -1073,15 +1075,18 @@ def check_done(self):
     Called by the runner after iterator returned by ``DoFn.process()`` has been
     fully read.
 
-    Returns: ``True`` if current restriction has been fully processed.
-    Raises ValueError: if there is still any unclaimed work remaining in the
-      restriction invoking this method. Exception raised must have an
-      informative error message.
+    This method must raise a `ValueError` if there is still any unclaimed work
+    remaining in the restriction when this method is invoked. Exception raised
+    must have an informative error message.
 
     ** Thread safety **
 
     Methods of the class ``RestrictionTracker`` including this method may get
     invoked by different threads, hence must be made thread-safe, e.g. by using
     a single lock object.
+
+    Returns: ``True`` if current restriction has been fully processed.
+    Raises:
+      ~exceptions.ValueError: if there is still any unclaimed work remaining.
     """
     raise NotImplementedError
diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py
index 1339b911efc..7106aef057d 100644
--- a/sdks/python/apache_beam/io/range_trackers.py
+++ b/sdks/python/apache_beam/io/range_trackers.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 #
 
-"""iobase.RangeTracker implementations provided with Dataflow SDK.
+"""iobase.RangeTracker implementations provided with Apache Beam.
 """
 
 import logging
@@ -28,48 +28,6 @@
            'OrderedPositionRangeTracker', 'UnsplittableRangeTracker']
 
 
-class OffsetRange(object):
-
-  def __init__(self, start, stop):
-    if start >= stop:
-      raise ValueError(
-          'Start offset must be smaller than the stop offset. '
-          'Received %d and %d respectively.', start, stop)
-    self.start = start
-    self.stop = stop
-
-  def __eq__(self, other):
-    if not isinstance(other, OffsetRange):
-      return False
-
-    return self.start == other.start and self.stop == other.stop
-
-  def __ne__(self, other):
-    if not isinstance(other, OffsetRange):
-      return True
-
-    return not (self.start == other.start and self.stop == other.stop)
-
-  def split(self, desired_num_offsets_per_split, min_num_offsets_per_split=1):
-    current_split_start = self.start
-    max_split_size = max(desired_num_offsets_per_split,
-                         min_num_offsets_per_split)
-    while current_split_start < self.stop:
-      current_split_stop = min(current_split_start + max_split_size, self.stop)
-      remaining = self.stop - current_split_stop
-
-      # Avoiding a small split at the end.
-      if (remaining < desired_num_offsets_per_split / 4 or
-          remaining < min_num_offsets_per_split):
-        current_split_stop = self.stop
-
-      yield OffsetRange(current_split_start, current_split_stop)
-      current_split_start = current_split_stop
-
-  def new_tracker(self):
-    return OffsetRangeTracker(self.start, self.stop)
-
-
 class OffsetRangeTracker(iobase.RangeTracker):
   """A 'RangeTracker' for non-negative positions of type 'long'."""
 
diff --git a/sdks/python/apache_beam/io/range_trackers_test.py b/sdks/python/apache_beam/io/range_trackers_test.py
index 762d6547891..3e926634c85 100644
--- a/sdks/python/apache_beam/io/range_trackers_test.py
+++ b/sdks/python/apache_beam/io/range_trackers_test.py
@@ -23,43 +23,6 @@
 import unittest
 
 from apache_beam.io import range_trackers
-from apache_beam.io.range_trackers import OffsetRange
-
-
-class OffsetRangeTest(unittest.TestCase):
-
-  def test_create(self):
-    OffsetRange(0, 10)
-    OffsetRange(10, 100)
-
-    with self.assertRaises(ValueError):
-      OffsetRange(10, 9)
-
-  def test_split_respects_desired_num_splits(self):
-    range = OffsetRange(10, 100)
-    splits = list(range.split(desired_num_offsets_per_split=25))
-    self.assertEqual(4, len(splits))
-    self.assertIn(OffsetRange(10, 35), splits)
-    self.assertIn(OffsetRange(35, 60), splits)
-    self.assertIn(OffsetRange(60, 85), splits)
-    self.assertIn(OffsetRange(85, 100), splits)
-
-  def test_split_respects_min_num_splits(self):
-    range = OffsetRange(10, 100)
-    splits = list(range.split(desired_num_offsets_per_split=5,
-                              min_num_offsets_per_split=25))
-    self.assertEqual(3, len(splits))
-    self.assertIn(OffsetRange(10, 35), splits)
-    self.assertIn(OffsetRange(35, 60), splits)
-    self.assertIn(OffsetRange(60, 100), splits)
-
-  def test_split_no_small_split_at_end(self):
-    range = OffsetRange(10, 90)
-    splits = list(range.split(desired_num_offsets_per_split=25))
-    self.assertEqual(3, len(splits))
-    self.assertIn(OffsetRange(10, 35), splits)
-    self.assertIn(OffsetRange(35, 60), splits)
-    self.assertIn(OffsetRange(60, 90), splits)
 
 
 class OffsetRangeTrackerTest(unittest.TestCase):
diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py
new file mode 100644
index 00000000000..3e49e26022a
--- /dev/null
+++ b/sdks/python/apache_beam/io/restriction_trackers.py
@@ -0,0 +1,131 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""`iobase.RestrictionTracker` implementations provided with Apache Beam."""
+
+import threading
+
+from apache_beam.io.iobase import RestrictionTracker
+from apache_beam.io.range_trackers import OffsetRangeTracker
+
+
+class OffsetRange(object):
+
+  def __init__(self, start, stop):
+    if start > stop:
+      raise ValueError(
+          'Start offset must be not be larger than the stop offset. '
+          'Received %d and %d respectively.', start, stop)
+    self.start = start
+    self.stop = stop
+
+  def __eq__(self, other):
+    if not isinstance(other, OffsetRange):
+      return False
+
+    return self.start == other.start and self.stop == other.stop
+
+  def split(self, desired_num_offsets_per_split, min_num_offsets_per_split=1):
+    current_split_start = self.start
+    max_split_size = max(desired_num_offsets_per_split,
+                         min_num_offsets_per_split)
+    while current_split_start < self.stop:
+      current_split_stop = min(current_split_start + max_split_size, self.stop)
+      remaining = self.stop - current_split_stop
+
+      # Avoiding a small split at the end.
+      if (remaining < desired_num_offsets_per_split / 4 or
+          remaining < min_num_offsets_per_split):
+        current_split_stop = self.stop
+
+      yield OffsetRange(current_split_start, current_split_stop)
+      current_split_start = current_split_stop
+
+  def new_tracker(self):
+    return OffsetRangeTracker(self.start, self.stop)
+
+
+class OffsetRestrictionTracker(RestrictionTracker):
+  """An `iobase.RestrictionTracker` implementations for an offset range.
+
+  Offset range is represented as a pair of integers
+  [start_position, stop_position}.
+  """
+
+  def __init__(self, start_position, stop_position):
+    self._range = OffsetRange(start_position, stop_position)
+    self._current_position = None
+    self._last_claim_attempt = None
+    self._checkpointed = False
+    self._lock = threading.Lock()
+
+  def check_done(self):
+    with self._lock:
+      if self._last_claim_attempt < self._range.stop - 1:
+        raise ValueError(
+            'OffsetRestrictionTracker is not done since work in range [%s, %s) '
+            'has not been claimed.',
+            self._last_claim_attempt if self._last_claim_attempt is not None
+            else self._range.start,
+            self._range.stop)
+
+  def current_restriction(self):
+    with self._lock:
+      return (self._range.start, self._range.stop)
+
+  def start_position(self):
+    with self._lock:
+      return self._range.start
+
+  def stop_position(self):
+    with self._lock:
+      return self._range.stop
+
+  def try_claim(self, position):
+    with self._lock:
+      if self._last_claim_attempt and position <= self._last_claim_attempt:
+        raise ValueError(
+            'Positions claimed should strictly increase. Trying to claim '
+            'position %d while last claim attempt was %d.',
+            position, self._last_claim_attempt)
+
+      self._last_claim_attempt = position
+      if position < self._range.start:
+        raise ValueError(
+            'Position to be claimed cannot be smaller than the start position '
+            'of the range. Tried to claim position %r for the range [%r, %r)',
+            position, self._range.start, self._range.stop)
+
+      if position >= self._range.start and position < self._range.stop:
+        self._current_position = position
+        return True
+
+      return False
+
+  def checkpoint(self):
+    with self._lock:
+      # If self._current_position is 'None' no records have been claimed so
+      # residual should start from self._range.start.
+      if self._current_position is None:
+        end_position = self._range.start
+      else:
+        end_position = self._current_position + 1
+
+      residual_range = (end_position, self._range.stop)
+
+      self._range = OffsetRange(self._range.start, end_position)
+      return residual_range
diff --git a/sdks/python/apache_beam/io/restriction_trackers_test.py b/sdks/python/apache_beam/io/restriction_trackers_test.py
new file mode 100644
index 00000000000..e8a799f28f1
--- /dev/null
+++ b/sdks/python/apache_beam/io/restriction_trackers_test.py
@@ -0,0 +1,159 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for the range_trackers module."""
+
+import logging
+import unittest
+
+from apache_beam.io.restriction_trackers import OffsetRange
+from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
+
+
+class OffsetRangeTest(unittest.TestCase):
+
+  def test_create(self):
+    OffsetRange(0, 10)
+    OffsetRange(10, 10)
+    OffsetRange(10, 100)
+
+    with self.assertRaises(ValueError):
+      OffsetRange(10, 9)
+
+  def test_split_respects_desired_num_splits(self):
+    range = OffsetRange(10, 100)
+    splits = list(range.split(desired_num_offsets_per_split=25))
+    self.assertEqual(4, len(splits))
+    self.assertIn(OffsetRange(10, 35), splits)
+    self.assertIn(OffsetRange(35, 60), splits)
+    self.assertIn(OffsetRange(60, 85), splits)
+    self.assertIn(OffsetRange(85, 100), splits)
+
+  def test_split_respects_min_num_splits(self):
+    range = OffsetRange(10, 100)
+    splits = list(range.split(desired_num_offsets_per_split=5,
+                              min_num_offsets_per_split=25))
+    self.assertEqual(3, len(splits))
+    self.assertIn(OffsetRange(10, 35), splits)
+    self.assertIn(OffsetRange(35, 60), splits)
+    self.assertIn(OffsetRange(60, 100), splits)
+
+  def test_split_no_small_split_at_end(self):
+    range = OffsetRange(10, 90)
+    splits = list(range.split(desired_num_offsets_per_split=25))
+    self.assertEqual(3, len(splits))
+    self.assertIn(OffsetRange(10, 35), splits)
+    self.assertIn(OffsetRange(35, 60), splits)
+    self.assertIn(OffsetRange(60, 90), splits)
+
+
+class OffsetRestrictionTrackerTest(unittest.TestCase):
+
+  def test_try_claim(self):
+    tracker = OffsetRestrictionTracker(100, 200)
+    self.assertEqual((100, 200), tracker.current_restriction())
+    self.assertTrue(tracker.try_claim(100))
+    self.assertTrue(tracker.try_claim(150))
+    self.assertTrue(tracker.try_claim(199))
+    self.assertFalse(tracker.try_claim(200))
+
+  def test_checkpoint_unstarted(self):
+    tracker = OffsetRestrictionTracker(100, 200)
+    checkpoint = tracker.checkpoint()
+    self.assertEqual((100, 100), tracker.current_restriction())
+    self.assertEqual((100, 200), checkpoint)
+
+  def test_checkpoint_just_started(self):
+    tracker = OffsetRestrictionTracker(100, 200)
+    self.assertTrue(tracker.try_claim(100))
+    checkpoint = tracker.checkpoint()
+    self.assertEqual((100, 101), tracker.current_restriction())
+    self.assertEqual((101, 200), checkpoint)
+
+  def test_checkpoint_regular(self):
+    tracker = OffsetRestrictionTracker(100, 200)
+    self.assertTrue(tracker.try_claim(105))
+    self.assertTrue(tracker.try_claim(110))
+    checkpoint = tracker.checkpoint()
+    self.assertEqual((100, 111), tracker.current_restriction())
+    self.assertEqual((111, 200), checkpoint)
+
+  def test_checkpoint_claimed_last(self):
+    tracker = OffsetRestrictionTracker(100, 200)
+    self.assertTrue(tracker.try_claim(105))
+    self.assertTrue(tracker.try_claim(110))
+    self.assertTrue(tracker.try_claim(199))
+    checkpoint = tracker.checkpoint()
+    self.assertEqual((100, 200), tracker.current_restriction())
+    self.assertEqual((200, 200), checkpoint)
+
+  def test_checkpoint_after_failed_claim(self):
+    tracker = OffsetRestrictionTracker(100, 200)
+    self.assertTrue(tracker.try_claim(105))
+    self.assertTrue(tracker.try_claim(110))
+    self.assertTrue(tracker.try_claim(160))
+    self.assertFalse(tracker.try_claim(240))
+
+    checkpoint = tracker.checkpoint()
+    self.assertTrue((100, 161), tracker.current_restriction())
+    self.assertTrue((161, 200), checkpoint)
+
+  def test_non_monotonic_claim(self):
+    with self.assertRaises(ValueError):
+      tracker = OffsetRestrictionTracker(100, 200)
+      self.assertTrue(tracker.try_claim(105))
+      self.assertTrue(tracker.try_claim(110))
+      self.assertTrue(tracker.try_claim(103))
+
+  def test_claim_before_starting_range(self):
+    with self.assertRaises(ValueError):
+      tracker = OffsetRestrictionTracker(100, 200)
+      tracker.try_claim(90)
+
+  def test_check_done_after_try_claim_past_end_of_range(self):
+    tracker = OffsetRestrictionTracker(100, 200)
+    self.assertTrue(tracker.try_claim(150))
+    self.assertTrue(tracker.try_claim(175))
+    self.assertFalse(tracker.try_claim(220))
+    tracker.check_done()
+
+  def test_check_done_after_try_claim_past_end_of_range(self):
+    tracker = OffsetRestrictionTracker(100, 200)
+    self.assertTrue(tracker.try_claim(150))
+    self.assertTrue(tracker.try_claim(175))
+    self.assertFalse(tracker.try_claim(200))
+    tracker.check_done()
+
+  def test_check_done_after_try_claim_right_before_end_of_range(self):
+    tracker = OffsetRestrictionTracker(100, 200)
+    self.assertTrue(tracker.try_claim(150))
+    self.assertTrue(tracker.try_claim(175))
+    self.assertTrue(tracker.try_claim(199))
+    tracker.check_done()
+
+  def test_check_done_when_not_done(self):
+    tracker = OffsetRestrictionTracker(100, 200)
+    self.assertTrue(tracker.try_claim(150))
+    self.assertTrue(tracker.try_claim(175))
+
+    with self.assertRaises(ValueError):
+      tracker.check_done()
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()
diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py
index d4481e2b834..4c7d6012261 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -69,7 +69,7 @@
 from apache_beam.utils import urns
 from apache_beam.utils.annotations import deprecated
 
-__all__ = ['Pipeline']
+__all__ = ['Pipeline', 'PTransformOverride']
 
 
 class Pipeline(object):
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index 567ab9293ce..3f608f313c1 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -22,6 +22,8 @@
 import unittest
 from collections import defaultdict
 
+from mock import mock
+
 import apache_beam as beam
 from apache_beam.io import Read
 from apache_beam.metrics import Metrics
@@ -30,7 +32,6 @@
 from apache_beam.pipeline import PipelineVisitor
 from apache_beam.pipeline import PTransformOverride
 from apache_beam.pvalue import AsSingleton
-from apache_beam.runners import DirectRunner
 from apache_beam.runners.dataflow.native_io.iobase import NativeSource
 from apache_beam.runners.direct.evaluation_context import _ExecutionContext
 from apache_beam.runners.direct.transform_evaluator import _GroupByKeyOnlyEvaluator
@@ -303,7 +304,9 @@ def raise_exception(exn):
   #   p = Pipeline('EagerRunner')
   #   self.assertEqual([1, 4, 9], p | Create([1, 2, 3]) | Map(lambda x: x*x))
 
-  def test_ptransform_overrides(self):
+  @mock.patch(
+      'apache_beam.runners.direct.direct_runner._get_transform_overrides')
+  def test_ptransform_overrides(self, file_system_override_mock):
 
     def my_par_do_matcher(applied_ptransform):
       return isinstance(applied_ptransform.transform, DoubleParDo)
@@ -318,8 +321,11 @@ def get_replacement_transform(self, ptransform):
           return TripleParDo()
         raise ValueError('Unsupported type of transform: %r', ptransform)
 
-    # Using following private variable for testing.
-    DirectRunner._PTRANSFORM_OVERRIDES.append(MyParDoOverride())
+    def get_overrides():
+      return [MyParDoOverride()]
+
+    file_system_override_mock.side_effect = get_overrides
+
     with Pipeline() as p:
       pcoll = p | beam.Create([1, 2, 3]) | 'Multiply' >> DoubleParDo()
       assert_that(pcoll, equal_to([3, 6, 9]))
diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd
index dcfac2e1b2b..dd7f3e45953 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -28,16 +28,20 @@ cdef class Receiver(object):
   cpdef receive(self, WindowedValue windowed_value)
 
 
-cdef class DoFnMethodWrapper(object):
+cdef class MethodWrapper(object):
   cdef public object args
   cdef public object defaults
   cdef public object method_value
 
 
 cdef class DoFnSignature(object):
-  cdef public DoFnMethodWrapper process_method
-  cdef public DoFnMethodWrapper start_bundle_method
-  cdef public DoFnMethodWrapper finish_bundle_method
+  cdef public MethodWrapper process_method
+  cdef public MethodWrapper start_bundle_method
+  cdef public MethodWrapper finish_bundle_method
+  cdef public MethodWrapper initial_restriction_method
+  cdef public MethodWrapper restriction_coder_method
+  cdef public MethodWrapper create_tracker_method
+  cdef public MethodWrapper split_method
   cdef public object do_fn
 
 
@@ -45,11 +49,14 @@ cdef class DoFnInvoker(object):
   cdef public DoFnSignature signature
   cdef _OutputProcessor output_processor
 
-  cpdef invoke_process(self, WindowedValue windowed_value)
+  cpdef invoke_process(self, WindowedValue windowed_value,
+                       restriction_tracker=*, output_processor=*)
   cpdef invoke_start_bundle(self)
   cpdef invoke_finish_bundle(self)
-
-  # TODO(chamikara) define static method create_invoker() here.
+  cpdef invoke_split(self, element, restriction)
+  cpdef invoke_initial_restriction(self, element)
+  cpdef invoke_restriction_coder(self)
+  cpdef invoke_create_tracker(self, restriction)
 
 
 cdef class SimpleInvoker(DoFnInvoker):
@@ -77,7 +84,10 @@ cdef class DoFnRunner(Receiver):
   cpdef process(self, WindowedValue windowed_value)
 
 
-cdef class _OutputProcessor(object):
+cdef class OutputProcessor(object):
+  pass
+
+cdef class _OutputProcessor(OutputProcessor):
   cdef object window_fn
   cdef Receiver main_receivers
   cdef object tagged_receivers
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 08ddf6593aa..57c9f389cd9 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -28,7 +28,9 @@
 from apache_beam.internal import util
 from apache_beam.metrics.execution import ScopedMetricsContainer
 from apache_beam.pvalue import TaggedOutput
+from apache_beam.transforms import DoFn
 from apache_beam.transforms import core
+from apache_beam.transforms.core import RestrictionProvider
 from apache_beam.transforms.window import GlobalWindow
 from apache_beam.transforms.window import TimestampedValue
 from apache_beam.transforms.window import WindowFn
@@ -58,23 +60,30 @@ def receive(self, windowed_value):
     raise NotImplementedError
 
 
-class DoFnMethodWrapper(object):
+class MethodWrapper(object):
   """For internal use only; no backwards-compatibility guarantees.
 
-  Represents a method of a DoFn object."""
+  Represents a method that can be invoked by `DoFnInvoker`."""
 
-  def __init__(self, do_fn, method_name):
+  def __init__(self, obj_to_invoke, method_name):
     """
-    Initiates a ``DoFnMethodWrapper``.
+    Initiates a ``MethodWrapper``.
 
     Args:
-      do_fn: A DoFn object that contains the method.
+      obj_to_invoke: the object that contains the method. Has to either be a
+                    `DoFn` object or a `RestrictionProvider` object.
       method_name: name of the method as a string.
     """
 
-    args, _, _, defaults = do_fn.get_function_arguments(method_name)
+    if not isinstance(obj_to_invoke, (DoFn, RestrictionProvider)):
+      raise ValueError('\'obj_to_invoke\' has to be either a \'DoFn\' or '
+                       'a \'RestrictionProvider\'. Received %r instead.',
+                       obj_to_invoke)
+
+    args, _, _, defaults = core.get_function_arguments(
+        obj_to_invoke, method_name)
     defaults = defaults if defaults else []
-    method_value = getattr(do_fn, method_name)
+    method_value = getattr(obj_to_invoke, method_name)
     self.method_value = method_value
     self.args = args
     self.defaults = defaults
@@ -98,11 +107,31 @@ def __init__(self, do_fn):
     assert isinstance(do_fn, core.DoFn)
     self.do_fn = do_fn
 
-    self.process_method = DoFnMethodWrapper(do_fn, 'process')
-    self.start_bundle_method = DoFnMethodWrapper(do_fn, 'start_bundle')
-    self.finish_bundle_method = DoFnMethodWrapper(do_fn, 'finish_bundle')
+    self.process_method = MethodWrapper(do_fn, 'process')
+    self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle')
+    self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle')
+
+    restriction_provider = self._get_restriction_provider(do_fn)
+    self.initial_restriction_method = (
+        MethodWrapper(restriction_provider, 'initial_restriction')
+        if restriction_provider else None)
+    self.restriction_coder_method = (
+        MethodWrapper(restriction_provider, 'restriction_coder')
+        if restriction_provider else None)
+    self.create_tracker_method = (
+        MethodWrapper(restriction_provider, 'create_tracker')
+        if restriction_provider else None)
+    self.split_method = (
+        MethodWrapper(restriction_provider, 'split')
+        if restriction_provider else None)
+
     self._validate()
 
+  def _get_restriction_provider(self, do_fn):
+    result = _find_param_with_default(self.process_method,
+                                      default_as_type=RestrictionProvider)
+    return result[1] if result else None
+
   def _validate(self):
     self._validate_process()
     self._validate_bundle_method(self.start_bundle_method)
@@ -120,6 +149,10 @@ def _validate_bundle_method(self, method_wrapper):
     for param in core.DoFn.DoFnParams:
       assert param not in method_wrapper.defaults
 
+  def is_splittable_dofn(self):
+    return any([isinstance(default, RestrictionProvider) for default in
+                self.process_method.defaults])
+
 
 class DoFnInvoker(object):
   """An abstraction that can be used to execute DoFn methods.
@@ -133,19 +166,27 @@ def __init__(self, output_processor, signature):
 
   @staticmethod
   def create_invoker(
-      output_processor,
-      signature, context, side_inputs, input_args, input_kwargs):
+      signature,
+      output_processor=None,
+      context=None, side_inputs=None, input_args=None, input_kwargs=None,
+      process_invocation=True):
     """ Creates a new DoFnInvoker based on given arguments.
 
     Args:
+        output_processor: an OutputProcessor for receiving elements produced by
+                          invoking functions of the DoFn.
         signature: a DoFnSignature for the DoFn being invoked.
         context: Context to be used when invoking the DoFn (deprecated).
         side_inputs: side inputs to be used when invoking th process method.
         input_args: arguments to be used when invoking the process method
         input_kwargs: kwargs to be used when invoking the process method.
+        process_invocation: If True, this function may return an invoker that
+                            performs extra optimizations for invoking process()
+                            method efficiently.
     """
+    side_inputs = side_inputs or []
     default_arg_values = signature.process_method.defaults
-    use_simple_invoker = (
+    use_simple_invoker = not process_invocation or (
         not side_inputs and not input_args and not input_kwargs and
         not default_arg_values)
     if use_simple_invoker:
@@ -155,13 +196,15 @@ def create_invoker(
           output_processor,
           signature, context, side_inputs, input_args, input_kwargs)
 
-  def invoke_process(self, windowed_value):
+  def invoke_process(self, windowed_value, restriction_tracker=None,
+                     output_processor=None):
     """Invokes the DoFn.process() function.
 
     Args:
       windowed_value: a WindowedValue object that gives the element for which
                       process() method should be invoked along with the window
                       the element belongs to.
+      output_procesor: if provided given OutputProcessor will be used.
     """
     raise NotImplementedError
 
@@ -177,6 +220,40 @@ def invoke_finish_bundle(self):
     self.output_processor.finish_bundle_outputs(
         self.signature.finish_bundle_method.method_value())
 
+  def invoke_split(self, element, restriction):
+    return self.signature.split_method.method_value(element, restriction)
+
+  def invoke_initial_restriction(self, element):
+    return self.signature.initial_restriction_method.method_value(element)
+
+  def invoke_restriction_coder(self):
+    return self.signature.restriction_coder_method.method_value()
+
+  def invoke_create_tracker(self, restriction):
+    return self.signature.create_tracker_method.method_value(restriction)
+
+
+def _find_param_with_default(
+    method, default_as_value=None, default_as_type=None):
+  if ((default_as_value and default_as_type) or
+      not (default_as_value or default_as_type)):
+    raise ValueError(
+        'Exactly one of \'default_as_value\' and \'default_as_type\' should be '
+        'provided. Received %r and %r.', default_as_value, default_as_type)
+
+  defaults = method.defaults
+  default_as_value = default_as_value
+  default_as_type = default_as_type
+  ret = None
+  for i, value in enumerate(defaults):
+    if default_as_value and value == default_as_value:
+      ret = (method.args[len(method.args) - len(defaults) + i], value)
+    elif default_as_type and isinstance(value, default_as_type):
+      index = len(method.args) - len(defaults) + i
+      ret = (method.args[index], value)
+
+  return ret
+
 
 class SimpleInvoker(DoFnInvoker):
   """An invoker that processes elements ignoring windowing information."""
@@ -185,8 +262,10 @@ def __init__(self, output_processor, signature):
     super(SimpleInvoker, self).__init__(output_processor, signature)
     self.process_method = signature.process_method.method_value
 
-  def invoke_process(self, windowed_value):
-    self.output_processor.process_outputs(
+  def invoke_process(self, windowed_value, restriction_tracker=None,
+                     output_processor=None):
+    output_processor = output_processor or self.output_processor
+    output_processor.process_outputs(
         windowed_value, self.process_method(windowed_value.value))
 
 
@@ -268,19 +347,35 @@ def __init__(self, placeholder):
     self.args_for_process = args_with_placeholders
     self.kwargs_for_process = input_kwargs
 
-  def invoke_process(self, windowed_value):
+  def invoke_process(self, windowed_value, restriction_tracker=None,
+                     output_processor=None):
+    output_processor = output_processor or self.output_processor
     self.context.set_element(windowed_value)
     # Call for the process function for each window if has windowed side inputs
     # or if the process accesses the window parameter. We can just call it once
     # otherwise as none of the arguments are changing
+
+    additional_kwargs = {}
+    if restriction_tracker:
+      restriction_tracker_param = _find_param_with_default(
+          self.signature.process_method,
+          default_as_type=core.RestrictionProvider)[0]
+      if not restriction_tracker_param:
+        raise ValueError(
+            'A RestrictionTracker %r was provided but DoFn does not have a '
+            'RestrictionTrackerParam defined', restriction_tracker)
+      additional_kwargs[restriction_tracker_param] = restriction_tracker
     if self.has_windowed_inputs and len(windowed_value.windows) != 1:
       for w in windowed_value.windows:
         self._invoke_per_window(
-            WindowedValue(windowed_value.value, windowed_value.timestamp, (w,)))
+            WindowedValue(windowed_value.value, windowed_value.timestamp, (w,)),
+            additional_kwargs, output_processor)
     else:
-      self._invoke_per_window(windowed_value)
+      self._invoke_per_window(
+          windowed_value, additional_kwargs, output_processor)
 
-  def _invoke_per_window(self, windowed_value):
+  def _invoke_per_window(
+      self, windowed_value, additional_kwargs, output_processor):
     if self.has_windowed_inputs:
       window, = windowed_value.windows
       args_for_process, kwargs_for_process = util.insert_values_in_args(
@@ -298,12 +393,19 @@ def _invoke_per_window(self, windowed_value):
       elif p == core.DoFn.TimestampParam:
         args_for_process[i] = windowed_value.timestamp
 
+    if additional_kwargs:
+      if kwargs_for_process is None:
+        kwargs_for_process = additional_kwargs
+      else:
+        for key in additional_kwargs:
+          kwargs_for_process[key] = additional_kwargs[key]
+
     if kwargs_for_process:
-      self.output_processor.process_outputs(
+      output_processor.process_outputs(
           windowed_value,
           self.process_method(*args_for_process, **kwargs_for_process))
     else:
-      self.output_processor.process_outputs(
+      output_processor.process_outputs(
           windowed_value, self.process_method(*args_for_process))
 
 
@@ -355,8 +457,8 @@ def __init__(self,
         windowing.windowfn, main_receivers, tagged_receivers)
 
     self.do_fn_invoker = DoFnInvoker.create_invoker(
-        output_processor, do_fn_signature, self.context,
-        side_inputs, args, kwargs)
+        do_fn_signature, output_processor, self.context, side_inputs, args,
+        kwargs)
 
   def receive(self, windowed_value):
     self.process(windowed_value)
@@ -411,7 +513,13 @@ def _reraise_augmented(self, exn):
     raise new_exn, None, original_traceback
 
 
-class _OutputProcessor(object):
+class OutputProcessor(object):
+
+  def process_outputs(self, windowed_input_element, results):
+    raise NotImplementedError
+
+
+class _OutputProcessor(OutputProcessor):
   """Processes output produced by DoFn method invocations."""
 
   def __init__(self, window_fn, main_receivers, tagged_receivers):
diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py
index 74406927bf5..2bd6b45bdcf 100644
--- a/sdks/python/apache_beam/runners/direct/direct_runner.py
+++ b/sdks/python/apache_beam/runners/direct/direct_runner.py
@@ -88,18 +88,27 @@ def from_runner_api_parameter(payload, context):
         context.windowing_strategies.get_by_id(payload.value))
 
 
-class DirectRunner(PipelineRunner):
-  """Executes a single pipeline on the local machine."""
-
+def _get_transform_overrides():
   # A list of PTransformOverride objects to be applied before running a pipeline
-  # using DirectRunner. Currently, this only works for overrides where the input
-  # and output types do not change.
+  # using DirectRunner.
+  # Currently this only works for overrides where the input and output types do
+  # not change.
   # For internal use only; no backwards-compatibility guarantees.
-  _PTRANSFORM_OVERRIDES = []
+
+  # Importing following locally to avoid a circular dependency.
+  from apache_beam.runners.sdf_common import SplittableParDoOverride
+  from apache_beam.runners.direct.sdf_direct_runner import ProcessKeyedElementsViaKeyedWorkItemsOverride
+  return [SplittableParDoOverride(),
+          ProcessKeyedElementsViaKeyedWorkItemsOverride()]
+
+
+class DirectRunner(PipelineRunner):
+  """Executes a single pipeline on the local machine."""
 
   def __init__(self):
     self._cache = None
     self._use_test_clock = False  # use RealClock() in production
+    self._ptransform_overrides = _get_transform_overrides()
 
   def apply_CombinePerKey(self, transform, pcoll):
     # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems
@@ -192,7 +201,7 @@ def run_pipeline(self, pipeline):
     """Execute the entire pipeline and returns an DirectPipelineResult."""
 
     # Performing configured PTransform overrides.
-    pipeline.replace_all(DirectRunner._PTRANSFORM_OVERRIDES)
+    pipeline.replace_all(self._ptransform_overrides)
 
     # TODO: Move imports to top. Pipeline <-> Runner dependency cause problems
     # with resolving imports when they are at top.
diff --git a/sdks/python/apache_beam/runners/direct/executor.py b/sdks/python/apache_beam/runners/direct/executor.py
index 93490536ed8..1cbabc4badf 100644
--- a/sdks/python/apache_beam/runners/direct/executor.py
+++ b/sdks/python/apache_beam/runners/direct/executor.py
@@ -326,6 +326,9 @@ def attempt_call(self, metrics_container,
         self._applied_ptransform, self._input_bundle,
         side_input_values, scoped_metrics_container)
 
+    with scoped_metrics_container:
+      evaluator.start_bundle()
+
     if self._fired_timers:
       for timer_firing in self._fired_timers:
         evaluator.process_timer_wrapper(timer_firing)
diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py
new file mode 100644
index 00000000000..ddbe9649b42
--- /dev/null
+++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner.py
@@ -0,0 +1,354 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module contains Splittable DoFn logic that is specific to DirectRunner.
+"""
+
+from threading import Lock
+from threading import Timer
+
+import apache_beam as beam
+from apache_beam import TimeDomain
+from apache_beam import pvalue
+from apache_beam.io.iobase import RestrictionTracker
+from apache_beam.pipeline import PTransformOverride
+from apache_beam.runners.common import DoFnContext
+from apache_beam.runners.common import DoFnInvoker
+from apache_beam.runners.common import DoFnSignature
+from apache_beam.runners.common import OutputProcessor
+from apache_beam.runners.direct.evaluation_context import DirectStepContext
+from apache_beam.runners.direct.util import KeyedWorkItem
+from apache_beam.runners.direct.watermark_manager import WatermarkManager
+from apache_beam.runners.sdf_common import ElementAndRestriction
+from apache_beam.runners.sdf_common import ProcessKeyedElements
+from apache_beam.transforms.core import ProcessContinuation
+from apache_beam.transforms.ptransform import PTransform
+from apache_beam.transforms.trigger import _ValueStateTag
+from apache_beam.utils.windowed_value import WindowedValue
+
+
+class ProcessKeyedElementsViaKeyedWorkItemsOverride(PTransformOverride):
+  """A transform override for ProcessElements transform."""
+
+  def get_matcher(self):
+    def _matcher(applied_ptransform):
+      return isinstance(
+          applied_ptransform.transform, ProcessKeyedElements)
+
+    return _matcher
+
+  def get_replacement_transform(self, ptransform):
+    return ProcessKeyedElementsViaKeyedWorkItems(ptransform)
+
+
+class ProcessKeyedElementsViaKeyedWorkItems(PTransform):
+  """A transform that processes Splittable DoFn input via KeyedWorkItems."""
+
+  def __init__(self, process_keyed_elements_transform):
+    self._process_keyed_elements_transform = process_keyed_elements_transform
+
+  def expand(self, pcoll):
+    return pcoll | beam.core.GroupByKey() | ProcessElements(
+        self._process_keyed_elements_transform)
+
+
+class ProcessElements(PTransform):
+  """A primitive transform for processing keyed elements or KeyedWorkItems.
+
+  Will be evaluated by
+  `runners.direct.transform_evaluator._ProcessElementsEvaluator`.
+  """
+
+  def __init__(self, process_keyed_elements_transform):
+    self._process_keyed_elements_transform = process_keyed_elements_transform
+    self.sdf = self._process_keyed_elements_transform.sdf
+
+  def expand(self, pcoll):
+    return pvalue.PCollection(pcoll.pipeline)
+
+  def new_process_fn(self, sdf):
+    return ProcessFn(
+        sdf,
+        self._process_keyed_elements_transform.ptransform_args,
+        self._process_keyed_elements_transform.ptransform_kwargs)
+
+
+class ProcessFn(beam.DoFn):
+  """A `DoFn` that executes machineary for invoking a Splittable `DoFn`.
+
+  Input to the `ParDo` step that includes a `ProcessFn` will be a `PCollection`
+  of `ElementAndRestriction` objects.
+
+  This class is mainly responsible for following.
+  (1) setup environment for properly invoking a Splittable `DoFn`.
+  (2) invoke `process()` method of a Splittable `DoFn`.
+  (3) after the `process()` invocation of the Splittable `DoFn`, determine if a
+  re-invocation of the element is needed. If this is the case, set state and
+  a timer for a re-invocation and hold output watermark till this
+  re-invocation.
+  (4) after the final invocation of a given element clear any previous state set
+  for re-invoking the element and release the output watermark.
+  """
+
+  def __init__(
+      self, sdf, args_for_invoker, kwargs_for_invoker):
+    self.sdf = sdf
+    self._element_tag = _ValueStateTag('element')
+    self._restriction_tag = _ValueStateTag('restriction')
+    self.watermark_hold_tag = _ValueStateTag('watermark_hold')
+    self._process_element_invoker = None
+
+    self.sdf_invoker = DoFnInvoker.create_invoker(
+        DoFnSignature(self.sdf), context=DoFnContext('unused_context'),
+        input_args=args_for_invoker, input_kwargs=kwargs_for_invoker)
+
+    self._step_context = None
+
+  @property
+  def step_context(self):
+    return self._step_context
+
+  @step_context.setter
+  def step_context(self, step_context):
+    assert isinstance(step_context, DirectStepContext)
+    self._step_context = step_context
+
+  def set_process_element_invoker(self, process_element_invoker):
+    assert isinstance(process_element_invoker, SDFProcessElementInvoker)
+    self._process_element_invoker = process_element_invoker
+
+  def process(self, element, timestamp=beam.DoFn.TimestampParam,
+              window=beam.DoFn.WindowParam, *args, **kwargs):
+    if isinstance(element, KeyedWorkItem):
+      # Must be a timer firing.
+      key = element.encoded_key
+    else:
+      key, values = element
+      values = list(values)
+      assert len(values) == 1
+      # Value here will either be a WindowedValue or an ElementAndRestriction
+      # object.
+      # TODO: handle key collisions here.
+      assert len(values) == 1, 'Internal error. Processing of splittable ' \
+                               'DoFn cannot continue since elements did not ' \
+                               'have unique keys.'
+      value = values[0]
+      if len(values) != 1:
+        raise ValueError('')
+
+    state = self._step_context.get_keyed_state(key)
+    element_state = state.get_state(window, self._element_tag)
+    # Initially element_state is an empty list.
+    is_seed_call = not element_state
+
+    if not is_seed_call:
+      element = state.get_state(window, self._element_tag)
+      restriction = state.get_state(window, self._restriction_tag)
+      windowed_element = WindowedValue(element, timestamp, [window])
+    else:
+      # After values iterator is expanded above we should have gotten a list
+      # with a single ElementAndRestriction object.
+      assert isinstance(value, ElementAndRestriction)
+      element_and_restriction = value
+      element = element_and_restriction.element
+      restriction = element_and_restriction.restriction
+
+      if isinstance(value, WindowedValue):
+        windowed_element = WindowedValue(
+            element, value.timestamp, value.windows)
+      else:
+        windowed_element = WindowedValue(element, timestamp, [window])
+
+    tracker = self.sdf_invoker.invoke_create_tracker(restriction)
+    assert self._process_element_invoker
+    assert isinstance(self._process_element_invoker,
+                      SDFProcessElementInvoker)
+
+    output_values = self._process_element_invoker.invoke_process_element(
+        self.sdf_invoker, windowed_element, tracker)
+
+    sdf_result = None
+    for output in output_values:
+      if isinstance(output, SDFProcessElementInvoker.Result):
+        # SDFProcessElementInvoker.Result should be the last item yielded.
+        sdf_result = output
+        break
+      yield output
+
+    assert sdf_result, ('SDFProcessElementInvoker must return a '
+                        'SDFProcessElementInvoker.Result object as the last '
+                        'value of a SDF invoke_process_element() invocation.')
+
+    if not sdf_result.residual_restriction:
+      # All work for current residual and restriction pair is complete.
+      state.clear_state(window, self._element_tag)
+      state.clear_state(window, self._restriction_tag)
+      # Releasing output watermark by setting it to positive infinity.
+      state.add_state(window, self.watermark_hold_tag,
+                      WatermarkManager.WATERMARK_POS_INF)
+    else:
+      state.add_state(window, self._element_tag, element)
+      state.add_state(window, self._restriction_tag,
+                      sdf_result.residual_restriction)
+      # Holding output watermark by setting it to negative infinity.
+      state.add_state(window, self.watermark_hold_tag,
+                      WatermarkManager.WATERMARK_NEG_INF)
+
+      # Setting a timer to be reinvoked to continue processing the element.
+      # Currently Python SDK only supports setting timers based on watermark. So
+      # forcing a reinvocation by setting a timer for watermark negative
+      # infinity.
+      # TODO(chamikara): update this by setting a timer for the proper
+      # processing time when Python SDK supports that.
+      state.set_timer(
+          window, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_NEG_INF)
+
+
+class SDFProcessElementInvoker(object):
+  """A utility that invokes SDF `process()` method and requests checkpoints.
+
+  This class is responsible for invoking the `process()` method of a Splittable
+  `DoFn` and making sure that invocation terminated properly. Based on the input
+  configuration, this class may decide to request a checkpoint for a `process()`
+  execution so that runner can process current output and resume the invocation
+  at a later time.
+
+  More specifically, when initializing a `SDFProcessElementInvoker`, caller may
+  specify the number of output elements or processing time after which a
+  checkpoint should be requested. This class is responsible for properly
+  requesting a checkpoint based on either of these criteria.
+  When the `process()` call of Splittable `DoFn` ends, this class performs
+  validations to make sure that processing ended gracefully and returns a
+  `SDFProcessElementInvoker.Result` that contains information which can be used
+  by the caller to perform another `process()` invocation for the residual.
+
+  A `process()` invocation may decide to give up processing voluntarily by
+  returning a `ProcessContinuation` object (see documentation of
+  `ProcessContinuation` for more details). So if a 'ProcessContinuation' is
+  produced this class ends the execution and performs steps to finalize the
+  current invocation.
+  """
+
+  class Result(object):
+    def __init__(
+        self, residual_restriction=None, process_continuation=None,
+        future_output_watermark=None):
+      """Returned as a result of a `invoke_process_element()` invocation.
+
+      Args:
+        residual_restriction: a restriction for the unprocessed part of the
+                             element.
+        process_continuation: a `ProcessContinuation` if one was returned as the
+                              last element of the SDF `process()` invocation.
+        future_output_watermark: output watermark of the results that will be
+                                 produced when invoking the Splittable `DoFn`
+                                 for the current element with
+                                 `residual_restriction`.
+      """
+
+      self.residual_restriction = residual_restriction
+      self.process_continuation = process_continuation
+      self.future_output_watermark = future_output_watermark
+
+  def __init__(
+      self, max_num_outputs, max_duration):
+    self._max_num_outputs = max_num_outputs
+    self._max_duration = max_duration
+    self._checkpoint_lock = Lock()
+
+  def test_method(self):
+    raise ValueError
+
+  def invoke_process_element(self, sdf_invoker, element, tracker):
+    """Invokes `process()` method of a Splittable `DoFn` for a given element.
+
+     Args:
+       sdf_invoker: a `DoFnInvoker` for the Splittable `DoFn`.
+       element: the element to process
+       tracker: a `RestrictionTracker` for the element that will be passed when
+                invoking the `process()` method of the Splittable `DoFn`.
+     Returns:
+       a `SDFProcessElementInvoker.Result` object.
+     """
+    assert isinstance(sdf_invoker, DoFnInvoker)
+    assert isinstance(tracker, RestrictionTracker)
+
+    class CheckpointState(object):
+
+      def __init__(self):
+        self.checkpointed = None
+        self.residual_restriction = None
+
+    checkpoint_state = CheckpointState()
+
+    def initiate_checkpoint():
+      with self._checkpoint_lock:
+        if checkpoint_state.checkpointed:
+          return
+      checkpoint_state.residual_restriction = tracker.checkpoint()
+      checkpoint_state.checkpointed = object()
+
+    output_processor = _OutputProcessor()
+    Timer(self._max_duration, initiate_checkpoint).start()
+    sdf_invoker.invoke_process(
+        element, restriction_tracker=tracker, output_processor=output_processor)
+
+    assert output_processor.output_iter is not None
+    output_count = 0
+
+    # We have to expand and re-yield here to support ending execution for a
+    # given number of output elements as well as to capture the
+    # ProcessContinuation of one was returned.
+    process_continuation = None
+    for output in output_processor.output_iter:
+      # A ProcessContinuation, if returned, should be the last element.
+      assert not process_continuation
+      if isinstance(output, ProcessContinuation):
+        # Taking a checkpoint so that we can determine primary and residual
+        # restrictions.
+        initiate_checkpoint()
+
+        # A ProcessContinuation should always be the last element produced by
+        # the output iterator.
+        # TODO: support continuing after the specified amount of delay.
+
+        # Continuing here instead of breaking to enforce that this is the last
+        # element.
+        process_continuation = output
+        continue
+
+      yield output
+      output_count += 1
+      if self._max_num_outputs and output_count >= self._max_num_outputs:
+        initiate_checkpoint()
+
+    tracker.check_done()
+    result = (
+        SDFProcessElementInvoker.Result(
+            residual_restriction=checkpoint_state.residual_restriction)
+        if checkpoint_state.residual_restriction
+        else SDFProcessElementInvoker.Result())
+    yield result
+
+
+class _OutputProcessor(OutputProcessor):
+
+  def __init__(self):
+    self.output_iter = None
+
+  def process_outputs(self, windowed_input_element, output_iter):
+    self.output_iter = output_iter
diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
new file mode 100644
index 00000000000..7ab6dde9397
--- /dev/null
+++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
@@ -0,0 +1,235 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for SDF implementation for DirectRunner."""
+
+import logging
+import os
+import unittest
+
+import apache_beam as beam
+from apache_beam import DoFn
+from apache_beam.io import filebasedsource_test
+from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.transforms.core import ProcessContinuation
+from apache_beam.transforms.core import RestrictionProvider
+from apache_beam.transforms.trigger import AccumulationMode
+from apache_beam.transforms.window import SlidingWindows
+from apache_beam.transforms.window import TimestampedValue
+
+
+class ReadFilesProvider(RestrictionProvider):
+
+  def initial_restriction(self, element):
+    size = os.path.getsize(element)
+    return (0, size)
+
+  def create_tracker(self, restriction):
+    return OffsetRestrictionTracker(*restriction)
+
+
+class ReadFiles(DoFn):
+
+  def __init__(self, resume_count=None):
+    self._resume_count = resume_count
+
+  def process(
+      self, element, restriction_tracker=ReadFilesProvider(), *args, **kwargs):
+    file_name = element
+    assert isinstance(restriction_tracker, OffsetRestrictionTracker)
+
+    with open(file_name, 'rb') as file:
+      pos = restriction_tracker.start_position()
+      if restriction_tracker.start_position() > 0:
+        file.seek(restriction_tracker.start_position() - 1)
+        line = file.readline()
+        pos = pos - 1 + len(line)
+
+      output_count = 0
+      while restriction_tracker.try_claim(pos):
+        line = file.readline()
+        len_line = len(line)
+        line = line.strip()
+        if not line:
+          break
+
+        if line is None:
+          break
+        yield line
+        output_count += 1
+
+        if self._resume_count and output_count == self._resume_count:
+          yield ProcessContinuation()
+          break
+
+        pos += len_line
+
+
+class ExpandStringsProvider(RestrictionProvider):
+
+  def initial_restriction(self, element):
+    return (0, len(element[0]))
+
+  def create_tracker(self, restriction):
+    return OffsetRestrictionTracker(restriction[0], restriction[1])
+
+  def split(self, element, restriction):
+    return [restriction,]
+
+
+class ExpandStrings(DoFn):
+
+  def __init__(self, record_window=False):
+    self._record_window = record_window
+
+  def process(
+      self, element, window=beam.DoFn.WindowParam,
+      restriction_tracker=ExpandStringsProvider(), side=None,
+      *args, **kwargs):
+    side = side or []
+    assert isinstance(restriction_tracker, OffsetRestrictionTracker)
+    side = list(side)
+    for i in range(restriction_tracker.start_position(),
+                   restriction_tracker.stop_position()):
+      if restriction_tracker.try_claim(i):
+        if not side:
+          yield (
+              element[0] + ':' + str(element[1]) + ':' + str(int(window.start))
+              if self._record_window else element)
+        else:
+          for val in side:
+            ret = (
+                element[0] + ':' + str(element[1]) + ':' +
+                str(int(window.start)) if self._record_window else element)
+            yield ret + ':' + val
+      else:
+        break
+
+
+class SDFDirectRunnerTest(unittest.TestCase):
+
+  def setUp(self):
+    super(SDFDirectRunnerTest, self).setUp()
+    # Importing following for DirectRunner SDF implemenation for testing.
+    from apache_beam.runners.direct import transform_evaluator
+    self._default_max_num_outputs = (
+        transform_evaluator._ProcessElementsEvaluator.DEFAULT_MAX_NUM_OUTPUTS)
+
+  def run_sdf_read_pipeline(
+      self, num_files, num_records_per_file, resume_count=None):
+    expected_data = []
+    file_names = []
+    for _ in range(num_files):
+      new_file_name, new_expected_data = filebasedsource_test.write_data(
+          num_records_per_file)
+      assert len(new_expected_data) == num_records_per_file
+      file_names.append(new_file_name)
+      expected_data.extend(new_expected_data)
+
+    assert len(expected_data) > 0
+
+    with TestPipeline() as p:
+      pc1 = (p
+             | 'Create1' >> beam.Create(file_names)
+             | 'SDF' >> beam.ParDo(ReadFiles(resume_count)))
+
+      assert_that(pc1, equal_to(expected_data))
+
+      # TODO(chamikara: verify the number of times process method was invoked
+      # using a side output once SDFs supports producing side outputs.
+
+  def test_sdf_no_checkpoint_single_element(self):
+    self.run_sdf_read_pipeline(
+        1,
+        self._default_max_num_outputs - 1)
+
+  def test_sdf_one_checkpoint_single_element(self):
+    self.run_sdf_read_pipeline(
+        1,
+        int(self._default_max_num_outputs + 1))
+
+  def test_sdf_multiple_checkpoints_single_element(self):
+    self.run_sdf_read_pipeline(
+        1,
+        int(self._default_max_num_outputs * 3))
+
+  def test_sdf_no_checkpoint_multiple_element(self):
+    self.run_sdf_read_pipeline(
+        5,
+        int(self._default_max_num_outputs - 1))
+
+  def test_sdf_one_checkpoint_multiple_element(self):
+    self.run_sdf_read_pipeline(
+        5,
+        int(self._default_max_num_outputs + 1))
+
+  def test_sdf_multiple_checkpoints_multiple_element(self):
+    self.run_sdf_read_pipeline(
+        5,
+        int(self._default_max_num_outputs * 3))
+
+  def test_sdf_with_resume_single_element(self):
+    resume_count = self._default_max_num_outputs / 10
+    # Makes sure that resume_count is not trivial.
+    assert resume_count > 0
+
+    self.run_sdf_read_pipeline(
+        1,
+        self._default_max_num_outputs - 1,
+        resume_count)
+
+  def test_sdf_with_resume_multiple_elements(self):
+    resume_count = self._default_max_num_outputs / 10
+    assert resume_count > 0
+
+    self.run_sdf_read_pipeline(
+        5,
+        int(self._default_max_num_outputs - 1),
+        resume_count)
+
+  def test_sdf_with_windowed_timestamped_input(self):
+    with TestPipeline() as p:
+      result = (p
+                | beam.Create([1, 3, 5, 10])
+                | beam.FlatMap(lambda t: [TimestampedValue(('A', t), t),
+                                          TimestampedValue(('B', t), t)])
+                | beam.WindowInto(SlidingWindows(10, 5),
+                                  accumulation_mode=AccumulationMode.DISCARDING)
+                | beam.ParDo(ExpandStrings(record_window=True)))
+
+      expected_result = [
+          'A:1:-5', 'A:1:0', 'A:3:-5', 'A:3:0', 'A:5:0', 'A:5:5', 'A:10:5',
+          'A:10:10', 'B:1:-5', 'B:1:0', 'B:3:-5', 'B:3:0', 'B:5:0', 'B:5:5',
+          'B:10:5', 'B:10:10',]
+      assert_that(result, equal_to(expected_result))
+
+  def test_sdf_with_side_inputs(self):
+    with TestPipeline() as p:
+      result = (p
+                | 'create_main' >> beam.Create(['1', '3', '5'])
+                | beam.ParDo(ExpandStrings(), side=['1', '3']))
+
+      expected_result = ['1:1', '3:1', '5:1', '1:3', '3:3', '5:3']
+      assert_that(result, equal_to(expected_result))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()
diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
index ce67f737b0a..a39d7bd59eb 100644
--- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
@@ -34,6 +34,9 @@
 from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite  # pylint: disable=protected-access
 from apache_beam.runners.direct.direct_runner import _StreamingGroupAlsoByWindow
 from apache_beam.runners.direct.direct_runner import _StreamingGroupByKeyOnly
+from apache_beam.runners.direct.sdf_direct_runner import ProcessElements
+from apache_beam.runners.direct.sdf_direct_runner import ProcessFn
+from apache_beam.runners.direct.sdf_direct_runner import SDFProcessElementInvoker
 from apache_beam.runners.direct.util import KeyedWorkItem
 from apache_beam.runners.direct.util import TransformResult
 from apache_beam.runners.direct.watermark_manager import WatermarkManager
@@ -75,6 +78,7 @@ def __init__(self, evaluation_context):
         _StreamingGroupAlsoByWindow: _StreamingGroupAlsoByWindowEvaluator,
         _NativeWrite: _NativeWriteEvaluator,
         TestStream: _TestStreamEvaluator,
+        ProcessElements: _ProcessElementsEvaluator
     }
     self._root_bundle_providers = {
         core.PTransform: DefaultRootBundleProvider,
@@ -192,8 +196,6 @@ def __init__(self, evaluation_context, applied_ptransform,
     self._execution_context = evaluation_context.get_execution_context(
         applied_ptransform)
     self.scoped_metrics_container = scoped_metrics_container
-    with scoped_metrics_container:
-      self.start_bundle()
 
   def _expand_outputs(self):
     outputs = set()
@@ -516,6 +518,17 @@ def __missing__(self, key):
 
 class _ParDoEvaluator(_TransformEvaluator):
   """TransformEvaluator for ParDo transform."""
+
+  def __init__(self, evaluation_context, applied_ptransform,
+               input_committed_bundle, side_inputs, scoped_metrics_container,
+               perform_dofn_pickle_test=True):
+    super(_ParDoEvaluator, self).__init__(
+        evaluation_context, applied_ptransform, input_committed_bundle,
+        side_inputs, scoped_metrics_container)
+    # This is a workaround for SDF implementation. SDF implementation adds state
+    # to the SDF that is not picklable.
+    self._perform_dofn_pickle_test = perform_dofn_pickle_test
+
   def start_bundle(self):
     transform = self._applied_ptransform.transform
 
@@ -530,7 +543,8 @@ def start_bundle(self):
     self._counter_factory = counters.CounterFactory()
 
     # TODO(aaltay): Consider storing the serialized form as an optimization.
-    dofn = pickler.loads(pickler.dumps(transform.dofn))
+    dofn = (pickler.loads(pickler.dumps(transform.dofn))
+            if self._perform_dofn_pickle_test else transform.dofn)
 
     pipeline_options = self._evaluation_context.pipeline_options
     if (pipeline_options is not None
@@ -538,8 +552,11 @@ def start_bundle(self):
       dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())
 
     dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
+    args = transform.args if hasattr(transform, 'args') else []
+    kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}
+
     self.runner = DoFnRunner(
-        dofn, transform.args, transform.kwargs,
+        dofn, args, kwargs,
         self._side_inputs,
         self._applied_ptransform.inputs[0].windowing,
         tagged_receivers=self._tagged_receivers,
@@ -827,3 +844,80 @@ def finish_bundle(self):
           None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF)
 
     return TransformResult(self, [], [], None, {None: hold})
+
+
+class _ProcessElementsEvaluator(_TransformEvaluator):
+  """An evaluator for sdf_direct_runner.ProcessElements transform."""
+
+  # Maximum number of elements that will be produced by a Splittable DoFn before
+  # a checkpoint is requested by the runner.
+  DEFAULT_MAX_NUM_OUTPUTS = 100
+  # Maximum duration a Splittable DoFn will process an element before a
+  # checkpoint is requested by the runner.
+  DEFAULT_MAX_DURATION = 1
+
+  def __init__(self, evaluation_context, applied_ptransform,
+               input_committed_bundle, side_inputs, scoped_metrics_container):
+    super(_ProcessElementsEvaluator, self).__init__(
+        evaluation_context, applied_ptransform, input_committed_bundle,
+        side_inputs, scoped_metrics_container)
+
+    process_elements_transform = applied_ptransform.transform
+    assert isinstance(process_elements_transform, ProcessElements)
+
+    # Replacing the do_fn of the transform with a wrapper do_fn that performs
+    # SDF magic.
+    transform = applied_ptransform.transform
+    sdf = transform.sdf
+    self._process_fn = transform.new_process_fn(sdf)
+    transform.dofn = self._process_fn
+
+    assert isinstance(self._process_fn, ProcessFn)
+
+    self.step_context = self._execution_context.get_step_context()
+    self._process_fn.step_context = self.step_context
+
+    process_element_invoker = (
+        SDFProcessElementInvoker(
+            max_num_outputs=self.DEFAULT_MAX_NUM_OUTPUTS,
+            max_duration=self.DEFAULT_MAX_DURATION))
+    self._process_fn.set_process_element_invoker(process_element_invoker)
+
+    self._par_do_evaluator = _ParDoEvaluator(
+        evaluation_context, applied_ptransform, input_committed_bundle,
+        side_inputs, scoped_metrics_container, perform_dofn_pickle_test=False)
+    self.keyed_holds = {}
+
+  def start_bundle(self):
+    self._par_do_evaluator.start_bundle()
+
+  def process_element(self, element):
+    assert isinstance(element, WindowedValue)
+    assert len(element.windows) == 1
+    window = element.windows[0]
+    if isinstance(element.value, KeyedWorkItem):
+      key = element.value.encoded_key
+    else:
+      # If not a `KeyedWorkItem`, this must be a tuple where key is a randomly
+      # generated key and the value is a `WindowedValue` that contains an
+      # `ElementAndRestriction` object.
+      assert isinstance(element.value, tuple)
+      key = element.value[0]
+
+    self._par_do_evaluator.process_element(element)
+
+    state = self.step_context.get_keyed_state(key)
+    self.keyed_holds[key] = state.get_state(
+        window, self._process_fn.watermark_hold_tag)
+
+  def finish_bundle(self):
+    par_do_result = self._par_do_evaluator.finish_bundle()
+
+    transform_result = TransformResult(
+        self, par_do_result.uncommitted_output_bundles,
+        par_do_result.unprocessed_bundles, par_do_result.counters,
+        par_do_result.keyed_watermark_holds,
+        par_do_result.undeclared_tag_values)
+    for key in self.keyed_holds:
+      transform_result.keyed_watermark_holds[key] = self.keyed_holds[key]
+    return transform_result
diff --git a/sdks/python/apache_beam/runners/sdf_common.py b/sdks/python/apache_beam/runners/sdf_common.py
new file mode 100644
index 00000000000..a7d80ac8b18
--- /dev/null
+++ b/sdks/python/apache_beam/runners/sdf_common.py
@@ -0,0 +1,168 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module contains Splittable DoFn logic that's common to all runners."""
+
+import uuid
+
+import apache_beam as beam
+from apache_beam import pvalue
+from apache_beam.coders import typecoders
+from apache_beam.pipeline import AppliedPTransform
+from apache_beam.pipeline import PTransformOverride
+from apache_beam.runners.common import DoFnInvoker
+from apache_beam.runners.common import DoFnSignature
+from apache_beam.transforms.core import ParDo
+from apache_beam.transforms.ptransform import PTransform
+
+
+class SplittableParDoOverride(PTransformOverride):
+  """A transform override for ParDo transformss of SplittableDoFns.
+
+  Replaces the ParDo transform with a SplittableParDo transform that performs
+  SDF specific logic.
+  """
+
+  def get_matcher(self):
+    def _matcher(applied_ptransform):
+      assert isinstance(applied_ptransform, AppliedPTransform)
+      transform = applied_ptransform.transform
+      if isinstance(transform, ParDo):
+        signature = DoFnSignature(transform.fn)
+        return signature.is_splittable_dofn()
+
+    return _matcher
+
+  def get_replacement_transform(self, ptransform):
+    assert isinstance(ptransform, ParDo)
+    do_fn = ptransform.fn
+    signature = DoFnSignature(do_fn)
+    if signature.is_splittable_dofn():
+      return SplittableParDo(ptransform)
+    else:
+      return ptransform
+
+
+class SplittableParDo(PTransform):
+  """A transform that processes a PCollection using a Splittable DoFn."""
+
+  def __init__(self, ptransform):
+    assert isinstance(ptransform, ParDo)
+    self._ptransform = ptransform
+
+  def expand(self, pcoll):
+    sdf = self._ptransform.fn
+    signature = DoFnSignature(sdf)
+    invoker = DoFnInvoker.create_invoker(signature, process_invocation=False)
+
+    element_coder = typecoders.registry.get_coder(pcoll.element_type)
+    restriction_coder = invoker.invoke_restriction_coder()
+
+    keyed_elements = (pcoll
+                      | 'pair' >> ParDo(PairWithRestrictionFn(sdf))
+                      | 'split' >> ParDo(SplitRestrictionFn(sdf))
+                      | 'explode' >> ParDo(ExplodeWindowsFn())
+                      | 'random' >> ParDo(RandomUniqueKeyFn()))
+
+    return keyed_elements | ProcessKeyedElements(
+        sdf, element_coder, restriction_coder,
+        pcoll.windowing, self._ptransform.args, self._ptransform.kwargs)
+
+
+class ElementAndRestriction(object):
+  """A holder for an element and a restriction."""
+
+  def __init__(self, element, restriction):
+    self.element = element
+    self.restriction = restriction
+
+
+class PairWithRestrictionFn(beam.DoFn):
+  """A transform that pairs each element with a restriction."""
+
+  def __init__(self, do_fn):
+    self._do_fn = do_fn
+
+  def start_bundle(self):
+    signature = DoFnSignature(self._do_fn)
+    self._invoker = DoFnInvoker.create_invoker(
+        signature, process_invocation=False)
+
+  def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
+    initial_restriction = self._invoker.invoke_initial_restriction(element)
+    yield ElementAndRestriction(element, initial_restriction)
+
+
+class SplitRestrictionFn(beam.DoFn):
+  """A transform that perform initial splitting of Splittable DoFn inputs."""
+
+  def __init__(self, do_fn):
+    self._do_fn = do_fn
+
+  def start_bundle(self):
+    signature = DoFnSignature(self._do_fn)
+    self._invoker = DoFnInvoker.create_invoker(
+        signature, process_invocation=False)
+
+  def process(self, element_and_restriction, *args, **kwargs):
+    element = element_and_restriction.element
+    restriction = element_and_restriction.restriction
+    restriction_parts = self._invoker.invoke_split(
+        element,
+        restriction)
+    for part in restriction_parts:
+      yield ElementAndRestriction(element, part)
+
+
+class ExplodeWindowsFn(beam.DoFn):
+  """A transform that forces the runner to explode windows.
+
+  This is done to make sure that Splittable DoFn proceses an element for each of
+  the windows that element belongs to.
+  """
+
+  def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
+    yield element
+
+
+class RandomUniqueKeyFn(beam.DoFn):
+  """A transform that assigns a unique key to each element."""
+
+  def process(self, element, window=beam.DoFn.WindowParam, *args, **kwargs):
+    # We ignore UUID collisions here since they are extremely rare.
+    yield (uuid.uuid4().bytes, element)
+
+
+class ProcessKeyedElements(PTransform):
+  """A primitive transform that performs SplittableDoFn magic.
+
+  Input to this transform should be a PCollection of keyed ElementAndRestriction
+  objects.
+  """
+
+  def __init__(
+      self, sdf, element_coder, restriction_coder, windowing_strategy,
+      ptransform_args, ptransform_kwargs):
+    self.sdf = sdf
+    self.element_coder = element_coder
+    self.restriction_coder = restriction_coder
+    self.windowing_strategy = windowing_strategy
+    self.ptransform_args = ptransform_args
+    self.ptransform_kwargs = ptransform_kwargs
+
+  def expand(self, pcoll):
+    return pvalue.PCollection(pcoll.pipeline)
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 533634dba58..39c337fab27 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -259,6 +259,19 @@ def restriction_coder(self):
     return coders.registry.get_coder(object)
 
 
+def get_function_arguments(obj, func):
+  """Return the function arguments based on the name provided. If they have
+  a _inspect_function attached to the class then use that otherwise default
+  to the python inspect library.
+  """
+  func_name = '_inspect_%s' % func
+  if hasattr(obj, func_name):
+    f = getattr(obj, func_name)
+    return f()
+  f = getattr(obj, func)
+  return inspect.getargspec(f)
+
+
 class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   """A function object used by a transform with custom processing.
 
@@ -326,16 +339,7 @@ def finish_bundle(self):
     pass
 
   def get_function_arguments(self, func):
-    """Return the function arguments based on the name provided. If they have
-    a _inspect_function attached to the class then use that otherwise default
-    to the python inspect library.
-    """
-    func_name = '_inspect_%s' % func
-    if hasattr(self, func_name):
-      f = getattr(self, func_name)
-      return f()
-    f = getattr(self, func)
-    return inspect.getargspec(f)
+    return get_function_arguments(self, func)
 
   # TODO(sourabhbajaj): Do we want to remove the responsibility of these from
   # the DoFn or maybe the runner


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


> Add Splittable DoFn to Python SDK
> ---------------------------------
>
>                 Key: BEAM-1630
>                 URL: https://issues.apache.org/jira/browse/BEAM-1630
>             Project: Beam
>          Issue Type: Improvement
>          Components: sdk-py-core
>            Reporter: Chamikara Jayalath
>            Assignee: Chamikara Jayalath
>
> Splittable DoFn [1] is currently being implemented for Java SDK [2]. We should add this to Python SDK as well.
> Following document proposes an API for this.
> https://docs.google.com/document/d/1h_zprJrOilivK2xfvl4L42vaX4DMYGfH1YDmi-s_ozM/edit?usp=sharing
> [1] https://s.apache.org/splittable-do-fn
> [2] https://lists.apache.org/thread.html/0ce61ac162460a149d5c93cdface37cc383f8030fe86ca09e5699b18@%3Cdev.beam.apache.org%3E



--
This message was sent by Atlassian JIRA
(v6.4.14#64029)

Mime
View raw message