beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From "ASF GitHub Bot (JIRA)" <j...@apache.org>
Subject [jira] [Commented] (BEAM-1866) FnAPI support for Metrics
Date Fri, 05 Jan 2018 17:31:00 GMT

    [ https://issues.apache.org/jira/browse/BEAM-1866?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16313491#comment-16313491
] 

ASF GitHub Bot commented on BEAM-1866:
--------------------------------------

robertwb closed pull request #4344: [BEAM-1866] Plumb user metrics through Fn API.
URL: https://github.com/apache/beam/pull/4344
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto
index ca23c619f33..36ed4242d6c 100644
--- a/model/fn-execution/src/main/proto/beam_fn_api.proto
+++ b/model/fn-execution/src/main/proto/beam_fn_api.proto
@@ -280,11 +280,44 @@ message Metrics {
 
   // User defined metrics
   message User {
-    // TODO: Define it.
+
+    // A key for identifying a metric at the most granular level.
+    message MetricKey {
+      // The step, if any, this metric is associated with.
+      string step = 1;
+
+      // (Required): The namespace of this metric.
+      string namespace = 2;
+
+      // (Required): The name of this metric.
+      string name = 3;
+    }
+
+    // Data associated with a counter metric.
+    message CounterData {
+      int64 value = 1;
+    }
+
+    // Data associated with a distribution metric.
+    message DistributionData {
+      int64 count = 1;
+      int64 sum = 2;
+      int64 min = 3;
+      int64 max = 4;
+    }
+
+    // (Required) The identifier for this metric.
+    MetricKey key = 1;
+
+    // (Required) The data for this metric.
+    oneof data {
+      CounterData counter_data = 1001;
+      DistributionData distribution_data = 1002;
+    }
   }
 
   map<string, PTransform> ptransforms = 1;
-  map<string, User> user = 2;
+  repeated User user = 2;
 }
 
 message ProcessBundleProgressResponse {
diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py
index ba840f752b1..2b5a4e4094d 100644
--- a/sdks/python/apache_beam/metrics/cells.py
+++ b/sdks/python/apache_beam/metrics/cells.py
@@ -28,6 +28,7 @@
 
 from apache_beam.metrics.metricbase import Counter
 from apache_beam.metrics.metricbase import Distribution
+from apache_beam.portability.api import beam_fn_api_pb2
 
 __all__ = ['DistributionResult']
 
@@ -286,10 +287,18 @@ def combine(self, other):
         new_min,
         new_max)
 
-  @classmethod
-  def singleton(cls, value):
+  @staticmethod
+  def singleton(value):
     return DistributionData(value, 1, value, value)
 
+  def to_runner_api(self):
+    return beam_fn_api_pb2.Metrics.User.DistributionData(
+        count=self.count, sum=self.sum, min=self.min, max=self.max)
+
+  @staticmethod
+  def from_runner_api(proto):
+    return DistributionData(proto.sum, proto.count, proto.min, proto.max)
+
 
 class MetricAggregator(object):
   """For internal use only; no backwards-compatibility guarantees.
diff --git a/sdks/python/apache_beam/metrics/execution.py b/sdks/python/apache_beam/metrics/execution.py
index 1704b98d46a..9e547a983de 100644
--- a/sdks/python/apache_beam/metrics/execution.py
+++ b/sdks/python/apache_beam/metrics/execution.py
@@ -34,6 +34,8 @@
 
 from apache_beam.metrics.cells import CounterCell
 from apache_beam.metrics.cells import DistributionCell
+from apache_beam.metrics.metricbase import MetricName
+from apache_beam.portability.api import beam_fn_api_pb2
 
 
 class MetricKey(object):
@@ -63,6 +65,14 @@ def __str__(self):
   def __hash__(self):
     return hash((self.step, self.metric))
 
+  def to_runner_api(self):
+    return beam_fn_api_pb2.Metrics.User.MetricKey(
+        step=self.step, namespace=self.metric.namespace, name=self.metric.name)
+
+  @staticmethod
+  def from_runner_api(proto):
+    return MetricKey(proto.step, MetricName(proto.namespace, proto.name))
+
 
 class MetricResult(object):
   """Keeps track of the status of a metric within a single bundle.
@@ -192,6 +202,20 @@ def get_cumulative(self):
     """
     return self._get_updates()
 
+  def to_runner_api(self):
+    return (
+        [beam_fn_api_pb2.Metrics.User(
+            key=beam_fn_api_pb2.Metrics.User.MetricKey(
+                step=self.step_name, namespace=k.namespace, name=k.name),
+            counter_data=beam_fn_api_pb2.Metrics.User.CounterData(
+                value=v.get_cumulative()))
+         for k, v in self.counters.items()] +
+        [beam_fn_api_pb2.Metrics.User(
+            key=beam_fn_api_pb2.Metrics.User.MetricKey(
+                step=self.step_name, namespace=k.namespace, name=k.name),
+            distribution_data=v.get_cumulative().to_runner_api())
+         for k, v in self.distributions.items()])
+
 
 class ScopedMetricsContainer(object):
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index c9218572b36..a5c77f7de0b 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -28,6 +28,7 @@
 import grpc
 
 import apache_beam as beam  # pylint: disable=ungrouped-imports
+from apache_beam import metrics
 from apache_beam.coders import WindowedValueCoder
 from apache_beam.coders import registry
 from apache_beam.coders.coder_impl import create_InputStream
@@ -50,53 +51,51 @@
 # This module is experimental. No backwards-compatibility guarantees.
 
 
-def streaming_rpc_handler(cls, method_name):
-  """Un-inverts the flow of control between the runner and the sdk harness."""
+class BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
 
-  class StreamingRpcHandler(cls):
+  _DONE = object()
 
-    _DONE = object()
+  def __init__(self):
+    self._push_queue = queue.Queue()
+    self._futures_by_id = dict()
+    self._read_thread = threading.Thread(
+        name='beam_control_read', target=self._read)
+    self._started = False
+    self._uid_counter = 0
 
-    def __init__(self):
-      self._push_queue = queue.Queue()
-      self._pull_queue = queue.Queue()
-      setattr(self, method_name, self.run)
-      self._read_thread = threading.Thread(
-          name='streaming_rpc_handler_read', target=self._read)
-      self._started = False
-
-    def run(self, iterator, context):
-      self._inputs = iterator
-      # Note: We only support one client for now.
-      self._read_thread.start()
-      self._started = True
-      while True:
-        to_push = self._push_queue.get()
-        if to_push is self._DONE:
-          return
-        yield to_push
-
-    def _read(self):
-      for data in self._inputs:
-        self._pull_queue.put(data)
-
-    def push(self, item):
-      self._push_queue.put(item)
-
-    def pull(self, timeout=None):
-      return self._pull_queue.get(timeout=timeout)
-
-    def empty(self):
-      return self._pull_queue.empty()
-
-    def done(self):
-      self.push(self._DONE)
-      # Can't join a thread before it's started.
-      while not self._started:
-        time.sleep(.01)
-      self._read_thread.join()
-
-  return StreamingRpcHandler()
+  def Control(self, iterator, context):
+    self._inputs = iterator
+    # Note: We only support one client for now.
+    self._read_thread.start()
+    self._started = True
+    while True:
+      to_push = self._push_queue.get()
+      if to_push is self._DONE:
+        return
+      yield to_push
+
+  def _read(self):
+    for data in self._inputs:
+      self._futures_by_id.pop(data.instruction_id).set(data)
+
+  def push(self, item):
+    if item is self._DONE:
+      future = None
+    else:
+      if not item.instruction_id:
+        self._uid_counter += 1
+        item.instruction_id = 'control_%s' % self._uid_counter
+      future = ControlFuture(item.instruction_id)
+      self._futures_by_id[item.instruction_id] = future
+    self._push_queue.put(item)
+    return future
+
+  def done(self):
+    self.push(self._DONE)
+    # Can't join a thread before it's started.
+    while not self._started:
+      time.sleep(.01)
+    self._read_thread.join()
 
 
 class _GroupingBuffer(object):
@@ -185,6 +184,7 @@ def __init__(self, use_grpc=False, sdk_harness_factory=None):
     if sdk_harness_factory and not use_grpc:
       raise ValueError('GRPC must be used if a harness factory is provided.')
     self._sdk_harness_factory = sdk_harness_factory
+    self._progress_frequency = None
 
   def _next_uid(self):
     self._last_uid += 1
@@ -836,7 +836,7 @@ def extract_endpoints(stage):
           pcoll_id = transform.spec.payload
           if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
             target = transform.unique_name, only_element(transform.outputs)
-            data_input[target] = pcoll_id
+            data_input[target] = pcoll_buffers[pcoll_id]
           elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
             target = transform.unique_name, only_element(transform.inputs)
             data_output[target] = pcoll_id
@@ -869,26 +869,6 @@ def extract_endpoints(stage):
             pipeline_components.windowing_strategies.items()),
         environments=dict(pipeline_components.environments.items()))
 
-    process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
-        instruction_id=self._next_uid(),
-        register=beam_fn_api_pb2.RegisterRequest(
-            process_bundle_descriptor=[process_bundle_descriptor]))
-
-    process_bundle = beam_fn_api_pb2.InstructionRequest(
-        instruction_id=self._next_uid(),
-        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
-            process_bundle_descriptor_reference=
-            process_bundle_descriptor.id))
-
-    # Write all the input data to the channel.
-    for (transform_id, name), pcoll_id in data_input.items():
-      data_out = controller.data_plane_handler.output_stream(
-          process_bundle.instruction_id, beam_fn_api_pb2.Target(
-              primitive_transform_reference=transform_id, name=name))
-      for element_data in pcoll_buffers[pcoll_id]:
-        data_out.write(element_data)
-      data_out.close()
-
     # Store the required side inputs into state.
     for (transform_id, tag), (pcoll_id, si) in data_side_input.items():
       elements_by_window = _WindowGroupingBuffer(si)
@@ -900,63 +880,39 @@ def extract_endpoints(stage):
                 ptransform_id=transform_id,
                 side_input_id=tag,
                 window=window))
-        controller.state_handler.blocking_append(
-            state_key, elements_data, process_bundle.instruction_id)
+        controller.state_handler.blocking_append(state_key, elements_data, None)
 
-    # Register and start running the bundle.
-    logging.debug('Register and start running the bundle')
-    controller.control_handler.push(process_bundle_registration)
-    controller.control_handler.push(process_bundle)
-
-    # Wait for the bundle to finish.
-    logging.debug('Wait for the bundle to finish.')
-    while True:
-      result = controller.control_handler.pull()
-      if result and result.instruction_id == process_bundle.instruction_id:
-        if result.error:
-          raise RuntimeError(result.error)
-        break
-
-    expected_targets = [
-        beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
-                               name=output_name)
-        for (transform_id, output_name), _ in data_output.items()]
-
-    # Gather all output data.
-    logging.debug('Gather all output data from %s.', expected_targets)
-
-    for output in controller.data_plane_handler.input_elements(
-        process_bundle.instruction_id, expected_targets):
-      target_tuple = (
-          output.target.primitive_transform_reference, output.target.name)
-      if target_tuple in data_output:
-        pcoll_id = data_output[target_tuple]
-        if pcoll_id.startswith('materialize:'):
+    def get_buffer(pcoll_id):
+      if pcoll_id.startswith('materialize:'):
+        if pcoll_id not in pcoll_buffers:
           # Just store the data chunks for replay.
-          pcoll_buffers[pcoll_id].append(output.data)
-        elif pcoll_id.startswith('group:'):
-          # This is a grouping write, create a grouping buffer if needed.
-          if pcoll_id not in pcoll_buffers:
-            original_gbk_transform = pcoll_id.split(':', 1)[1]
-            transform_proto = pipeline_components.transforms[
-                original_gbk_transform]
-            input_pcoll = only_element(transform_proto.inputs.values())
-            output_pcoll = only_element(transform_proto.outputs.values())
-            pre_gbk_coder = context.coders[safe_coders[
-                pipeline_components.pcollections[input_pcoll].coder_id]]
-            post_gbk_coder = context.coders[safe_coders[
-                pipeline_components.pcollections[output_pcoll].coder_id]]
-            windowing_strategy = context.windowing_strategies[
-                pipeline_components
-                .pcollections[output_pcoll].windowing_strategy_id]
-            pcoll_buffers[pcoll_id] = _GroupingBuffer(
-                pre_gbk_coder, post_gbk_coder, windowing_strategy)
-          pcoll_buffers[pcoll_id].append(output.data)
-        else:
-          # These should be the only two identifiers we produce for now,
-          # but special side input writes may go here.
-          raise NotImplementedError(pcoll_id)
-    return result
+          pcoll_buffers[pcoll_id] = list()
+      elif pcoll_id.startswith('group:'):
+        # This is a grouping write, create a grouping buffer if needed.
+        if pcoll_id not in pcoll_buffers:
+          original_gbk_transform = pcoll_id.split(':', 1)[1]
+          transform_proto = pipeline_components.transforms[
+              original_gbk_transform]
+          input_pcoll = only_element(transform_proto.inputs.values())
+          output_pcoll = only_element(transform_proto.outputs.values())
+          pre_gbk_coder = context.coders[safe_coders[
+              pipeline_components.pcollections[input_pcoll].coder_id]]
+          post_gbk_coder = context.coders[safe_coders[
+              pipeline_components.pcollections[output_pcoll].coder_id]]
+          windowing_strategy = context.windowing_strategies[
+              pipeline_components
+              .pcollections[output_pcoll].windowing_strategy_id]
+          pcoll_buffers[pcoll_id] = _GroupingBuffer(
+              pre_gbk_coder, post_gbk_coder, windowing_strategy)
+      else:
+        # These should be the only two identifiers we produce for now,
+        # but special side input writes may go here.
+        raise NotImplementedError(pcoll_id)
+      return pcoll_buffers[pcoll_id]
+
+    return BundleManager(
+        controller, get_buffer, process_bundle_descriptor,
+        self._progress_frequency).process_bundle(data_input, data_output)
 
   # These classes are used to interact with the worker.
 
@@ -1008,22 +964,22 @@ class DirectController(object):
     """An in-memory controller for fn API control, state and data planes."""
 
     def __init__(self):
-      self._responses = []
       self.state_handler = FnApiRunner.StateServicer()
       self.control_handler = self
       self.data_plane_handler = data_plane.InMemoryDataChannel()
       self.worker = sdk_worker.SdkWorker(
           self.state_handler, data_plane.InMemoryDataChannelFactory(
               self.data_plane_handler.inverse()), {})
+      self._uid_counter = 0
 
     def push(self, request):
+      if not request.instruction_id:
+        self._uid_counter += 1
+        request.instruction_id = 'control_%s' % self._uid_counter
       logging.debug('CONTROL REQUEST %s', request)
       response = self.worker.do_instruction(request)
       logging.debug('CONTROL RESPONSE %s', response)
-      self._responses.append(response)
-
-    def pull(self):
-      return self._responses.pop(0)
+      return ControlFuture(request.instruction_id, response)
 
     def done(self):
       pass
@@ -1046,8 +1002,7 @@ def __init__(self, sdk_harness_factory=None):
       self.data_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
       self.data_port = self.data_server.add_insecure_port('[::]:0')
 
-      self.control_handler = streaming_rpc_handler(
-          beam_fn_api_pb2_grpc.BeamFnControlServicer, 'Control')
+      self.control_handler = BeamFnControlServicer()
       beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
           self.control_handler, self.control_server)
 
@@ -1090,14 +1045,170 @@ def close(self):
       self.data_server.stop(5).wait()
 
 
+class BundleManager(object):
+
+  _uid_counter = 0
+
+  def __init__(
+      self, controller, get_buffer, bundle_descriptor, progress_frequency=None):
+    self._controller = controller
+    self._get_buffer = get_buffer
+    self._bundle_descriptor = bundle_descriptor
+    self._registered = False
+    self._progress_frequency = progress_frequency
+
+  def process_bundle(self, inputs, expected_outputs):
+    # Unique id for the instruction processing this bundle.
+    BundleManager._uid_counter += 1
+    process_bundle_id = 'bundle_%s' % BundleManager._uid_counter
+
+    # Register the bundle descriptor, if needed.
+    if not self._registered:
+      process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
+          register=beam_fn_api_pb2.RegisterRequest(
+              process_bundle_descriptor=[self._bundle_descriptor]))
+      self._controller.control_handler.push(process_bundle_registration)
+      self._registered = True
+
+    # Write all the input data to the channel.
+    for (transform_id, name), elements in inputs.items():
+      data_out = self._controller.data_plane_handler.output_stream(
+          process_bundle_id, beam_fn_api_pb2.Target(
+              primitive_transform_reference=transform_id, name=name))
+      for element_data in elements:
+        data_out.write(element_data)
+      data_out.close()
+
+    # Actually start the bundle.
+    process_bundle = beam_fn_api_pb2.InstructionRequest(
+        instruction_id=process_bundle_id,
+        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
+            process_bundle_descriptor_reference=self._bundle_descriptor.id))
+    result_future = self._controller.control_handler.push(process_bundle)
+
+    with ProgressRequester(
+        self._controller, process_bundle_id, self._progress_frequency):
+      # Gather all output data.
+      expected_targets = [
+          beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
+                                 name=output_name)
+          for (transform_id, output_name), _ in expected_outputs.items()]
+      logging.debug('Gather all output data from %s.', expected_targets)
+      for output in self._controller.data_plane_handler.input_elements(
+          process_bundle_id, expected_targets):
+        target_tuple = (
+            output.target.primitive_transform_reference, output.target.name)
+        if target_tuple in expected_outputs:
+          self._get_buffer(expected_outputs[target_tuple]).append(output.data)
+
+      logging.debug('Wait for the bundle to finish.')
+      result = result_future.get()
+
+    if result.error:
+      raise RuntimeError(result.error)
+    return result
+
+
+class ProgressRequester(threading.Thread):
+  def __init__(self, controller, instruction_id, frequency, callback=None):
+    super(ProgressRequester, self).__init__()
+    self._controller = controller
+    self._instruction_id = instruction_id
+    self._frequency = frequency
+    self._done = False
+    self._latest_progress = None
+    self._callback = callback
+    self.daemon = True
+
+  def __enter__(self):
+    if self._frequency:
+      self.start()
+
+  def __exit__(self, *unused_exc_info):
+    if self._frequency:
+      self.stop()
+
+  def run(self):
+    while not self._done:
+      try:
+        progress_result = self._controller.control_handler.push(
+            beam_fn_api_pb2.InstructionRequest(
+                process_bundle_progress=
+                beam_fn_api_pb2.ProcessBundleProgressRequest(
+                    instruction_reference=self._instruction_id))).get()
+        self._latest_progress = progress_result.process_bundle_progress
+        if self._callback:
+          self._callback(self._latest_progress)
+      except Exception, exn:
+        logging.error("Bad progress: %s", exn)
+      time.sleep(self._frequency)
+
+  def stop(self):
+    self._done = True
+
+
+class ControlFuture(object):
+  def __init__(self, instruction_id, response=None):
+    self.instruction_id = instruction_id
+    if response:
+      self._response = response
+    else:
+      self._response = None
+      self._condition = threading.Condition()
+
+  def set(self, response):
+    with self._condition:
+      self._response = response
+      self._condition.notify_all()
+
+  def get(self, timeout=None):
+    if not self._response:
+      with self._condition:
+        if not self._response:
+          self._condition.wait(timeout)
+    return self._response
+
+
+class FnApiMetrics(metrics.metric.MetricResults):
+  def __init__(self, step_metrics):
+    self._counters = {}
+    self._distributions = {}
+    for step_metric in step_metrics.values():
+      for proto in step_metric.user:
+        key = metrics.execution.MetricKey.from_runner_api(proto.key)
+        if proto.HasField('counter_data'):
+          self._counters[key] = proto.counter_data.value
+        elif proto.HasField('distribution_data'):
+          self._distributions[
+              key] = metrics.cells.DistributionData.from_runner_api(
+                  proto.distribution_data)
+
+  def query(self, filter=None):
+    counters = [metrics.execution.MetricResult(k, v, v)
+                for k, v in self._counters.items()
+                if self.matches(filter, k)]
+    distributions = [metrics.execution.MetricResult(k, v, v)
+                     for k, v in self._distributions.items()
+                     if self.matches(filter, k)]
+
+    return {'counters': counters,
+            'distributions': distributions}
+
+
 class RunnerResult(runner.PipelineResult):
   def __init__(self, state, metrics_by_stage):
     super(RunnerResult, self).__init__(state)
     self._metrics_by_stage = metrics_by_stage
+    self._user_metrics = None
 
   def wait_until_finish(self, duration=None):
     pass
 
+  def metrics(self):
+    if self._user_metrics is None:
+      self._user_metrics = FnApiMetrics(self._metrics_by_stage)
+    return self._user_metrics
+
 
 def only_element(iterable):
   element, = iterable
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 1cffa2652ee..6304f71df5e 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -114,12 +114,41 @@ def expand(self, pcolls):
       pcoll_b = p | 'b' >> beam.Create(['b'])
       assert_that((pcoll_a, pcoll_b) | First(), equal_to(['a']))
 
+  def test_metrics(self):
+
+    p = self.create_pipeline()
+    if not isinstance(p.runner, fn_api_runner.FnApiRunner):
+      # This test is inherited by others that may not support the same
+      # internal way of accessing progress metrics.
+      self.skipTest('Metrics not supported.')
+
+    counter = beam.metrics.Metrics.counter('ns', 'counter')
+    distribution = beam.metrics.Metrics.distribution('ns', 'distribution')
+    pcoll = p | beam.Create(['a', 'zzz'])
+    # pylint: disable=expression-not-assigned
+    pcoll | 'count1' >> beam.FlatMap(lambda x: counter.inc())
+    pcoll | 'count2' >> beam.FlatMap(lambda x: counter.inc(len(x)))
+    pcoll | 'dist' >> beam.FlatMap(lambda x: distribution.update(len(x)))
+
+    res = p.run()
+    res.wait_until_finish()
+    c1, = res.metrics().query(beam.metrics.MetricsFilter().with_step('count1'))[
+        'counters']
+    self.assertEqual(c1.committed, 2)
+    c2, = res.metrics().query(beam.metrics.MetricsFilter().with_step('count2'))[
+        'counters']
+    self.assertEqual(c2.committed, 4)
+    dist, = res.metrics().query(beam.metrics.MetricsFilter().with_step('dist'))[
+        'distributions']
+    self.assertEqual(
+        dist.committed, beam.metrics.cells.DistributionData(4, 2, 1, 3))
+
   def test_progress_metrics(self):
     p = self.create_pipeline()
     if not isinstance(p.runner, fn_api_runner.FnApiRunner):
       # This test is inherited by others that may not support the same
       # internal way of accessing progress metrics.
-      return
+      self.skipTest('Progress metrics not supported.')
 
     _ = (p
          | beam.Create([0, 0, 0, 2.1e-3 * DEFAULT_SAMPLING_PERIOD_MS])
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 136f22d0903..1b270b90372 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -273,10 +273,13 @@ def process_bundle(self, instruction_id):
   def metrics(self):
     return beam_fn_api_pb2.Metrics(
         # TODO(robertwb): Rename to progress?
-        ptransforms=
-        {transform_id:
-         self._fix_output_tags(transform_id, op.progress_metrics())
-         for transform_id, op in self.ops.items()})
+        ptransforms={
+            transform_id:
+            self._fix_output_tags(transform_id, op.progress_metrics())
+            for transform_id, op in self.ops.items()},
+        user=sum(
+            [op.metrics_container.to_runner_api() for op in self.ops.values()],
+            []))
 
   def _fix_output_tags(self, transform_id, metrics):
     # Outputs are still referred to by index, not by name, in many Operations.
diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py
index 8098a63f3c7..d6838345a8d 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -118,6 +118,12 @@ def __init__(self, operation_name, spec, counter_factory, state_sampler):
     self.counter_factory = counter_factory
     self.consumers = collections.defaultdict(list)
 
+    # These are overwritten in the legacy harness.
+    self.step_name = operation_name
+    self.metrics_container = MetricsContainer(self.step_name)
+    self.scoped_metrics_container = ScopedMetricsContainer(
+        self.metrics_container)
+
     self.state_sampler = state_sampler
     self.scoped_start_state = self.state_sampler.scoped_state(
         self.operation_name, 'start')
@@ -127,7 +133,6 @@ def __init__(self, operation_name, spec, counter_factory, state_sampler):
         self.operation_name, 'finish')
     # TODO(ccy): the '-abort' state can be added when the abort is supported in
     # Operations.
-    self.scoped_metrics_container = None
     self.receivers = []
 
   def start(self):
@@ -260,6 +265,7 @@ def __init__(
       self, name, spec, counter_factory, sampler, side_input_maps=None):
     super(DoOperation, self).__init__(name, spec, counter_factory, sampler)
     self.side_input_maps = side_input_maps
+    self.tagged_receivers = None
 
   def _read_side_inputs(self, tags_and_types):
     """Generator reading side inputs in the order prescribed by tags_and_types.
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 980357ee221..f8ac88d52cb 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -208,7 +208,11 @@ def process_bundle(self, request, instruction_id):
 
   def process_bundle_progress(self, request, instruction_id):
     # It is an error to get progress for a not-in-flight bundle.
-    return self.bundle_processors.get(instruction_id).metrics()
+    processor = self.bundle_processors.get(request.instruction_reference)
+    return beam_fn_api_pb2.InstructionResponse(
+        instruction_id=instruction_id,
+        process_bundle_progress=beam_fn_api_pb2.ProcessBundleProgressResponse(
+            metrics=processor.metrics() if processor else None))
 
 
 class GrpcStateHandler(object):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


> FnAPI support for Metrics
> -------------------------
>
>                 Key: BEAM-1866
>                 URL: https://issues.apache.org/jira/browse/BEAM-1866
>             Project: Beam
>          Issue Type: New Feature
>          Components: beam-model
>            Reporter: Daniel Halperin
>              Labels: portability
>
> As part of the Fn API work, we need to define a Metrics interface between the Runner
and the SDK. Right now, Metrics are simply lost.



--
This message was sent by Atlassian JIRA
(v6.4.14#64029)

Mime
View raw message