beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From an...@apache.org
Subject [beam] branch master updated: [BEAM-7896] Implementing RateEstimation for KafkaTable with Unit and Integration Tests
Date Fri, 09 Aug 2019 15:45:36 GMT
This is an automated email from the ASF dual-hosted git repository.

anton pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9d131d4  [BEAM-7896] Implementing RateEstimation for KafkaTable with Unit and Integration
Tests
     new cd2ab9e  Merge pull request #9298 from riazela/KafkaRateEstimation2
9d131d4 is described below

commit 9d131d490dfa1b4838d0303a3f17f36202c0874b
Author: Alireza Samadian <alireza4263@gmail.com>
AuthorDate: Tue Aug 6 16:56:03 2019 -0700

    [BEAM-7896] Implementing RateEstimation for KafkaTable with Unit and Integration Tests
---
 sdks/java/extensions/sql/build.gradle              |   1 +
 .../sql/meta/provider/kafka/BeamKafkaTable.java    | 147 +++++++++--
 .../meta/provider/kafka/BeamKafkaCSVTableTest.java | 118 ++++++++-
 .../sql/meta/provider/kafka/KafkaCSVTableIT.java   | 292 +++++++++++++++++++++
 .../sql/meta/provider/kafka/KafkaCSVTestTable.java | 197 ++++++++++++++
 .../sql/meta/provider/kafka/KafkaTestRecord.java   |  39 +++
 6 files changed, 777 insertions(+), 17 deletions(-)

diff --git a/sdks/java/extensions/sql/build.gradle b/sdks/java/extensions/sql/build.gradle
index b4a7079..fe07bfe 100644
--- a/sdks/java/extensions/sql/build.gradle
+++ b/sdks/java/extensions/sql/build.gradle
@@ -203,6 +203,7 @@ task integrationTest(type: Test) {
   systemProperty "beamTestPipelineOptions", JsonOutput.toJson(pipelineOptions)
 
   include '**/*IT.class'
+  exclude '**/KafkaCSVTableIT.java'
   maxParallelForks 4
   classpath = project(":sdks:java:extensions:sql")
           .sourceSets
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java
index 0e1dab3..11c12f6 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java
@@ -19,9 +19,13 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.kafka;
 
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
 
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Properties;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 import org.apache.beam.sdk.coders.ByteArrayCoder;
 import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics;
 import org.apache.beam.sdk.extensions.sql.impl.schema.BaseBeamTable;
@@ -34,9 +38,15 @@ import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.Row;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.KafkaConsumer;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.ByteArrayDeserializer;
 import org.apache.kafka.common.serialization.ByteArraySerializer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * {@code BeamKafkaTable} represent a Kafka topic, as source or target. Need to extend to
convert
@@ -47,6 +57,10 @@ public abstract class BeamKafkaTable extends BaseBeamTable {
   private List<String> topics;
   private List<TopicPartition> topicPartitions;
   private Map<String, Object> configUpdates;
+  private BeamTableStatistics rowCountStatistics = null;
+  private static final Logger LOGGER = LoggerFactory.getLogger(BeamKafkaTable.class);
+  // This is the number of records looked from each partition when the rate is estimated
+  protected int numberOfRecordsForRate = 50;
 
   protected BeamKafkaTable(Schema beamSchema) {
     super(beamSchema);
@@ -84,7 +98,14 @@ public abstract class BeamKafkaTable extends BaseBeamTable {
 
   @Override
   public PCollection<Row> buildIOReader(PBegin begin) {
-    KafkaIO.Read<byte[], byte[]> kafkaRead = null;
+    return begin
+        .apply("read", createKafkaRead().withoutMetadata())
+        .apply("in_format", getPTransformForInput())
+        .setRowSchema(getSchema());
+  }
+
+  KafkaIO.Read<byte[], byte[]> createKafkaRead() {
+    KafkaIO.Read<byte[], byte[]> kafkaRead;
     if (topics != null) {
       kafkaRead =
           KafkaIO.<byte[], byte[]>read()
@@ -104,28 +125,25 @@ public abstract class BeamKafkaTable extends BaseBeamTable {
     } else {
       throw new IllegalArgumentException("One of topics and topicPartitions must be configurated.");
     }
-
-    return begin
-        .apply("read", kafkaRead.withoutMetadata())
-        .apply("in_format", getPTransformForInput())
-        .setRowSchema(getSchema());
+    return kafkaRead;
   }
 
   @Override
   public POutput buildIOWriter(PCollection<Row> input) {
     checkArgument(
         topics != null && topics.size() == 1, "Only one topic can be acceptable as
output.");
-    assert topics != null;
 
     return input
         .apply("out_reformat", getPTransformForOutput())
-        .apply(
-            "persistent",
-            KafkaIO.<byte[], byte[]>write()
-                .withBootstrapServers(bootstrapServers)
-                .withTopic(topics.get(0))
-                .withKeySerializer(ByteArraySerializer.class)
-                .withValueSerializer(ByteArraySerializer.class));
+        .apply("persistent", createKafkaWrite());
+  }
+
+  private KafkaIO.Write<byte[], byte[]> createKafkaWrite() {
+    return KafkaIO.<byte[], byte[]>write()
+        .withBootstrapServers(bootstrapServers)
+        .withTopic(topics.get(0))
+        .withKeySerializer(ByteArraySerializer.class)
+        .withValueSerializer(ByteArraySerializer.class);
   }
 
   public String getBootstrapServers() {
@@ -138,6 +156,105 @@ public abstract class BeamKafkaTable extends BaseBeamTable {
 
   @Override
   public BeamTableStatistics getTableStatistics(PipelineOptions options) {
-    return BeamTableStatistics.UNBOUNDED_UNKNOWN;
+    if (rowCountStatistics == null) {
+      try {
+        rowCountStatistics =
+            BeamTableStatistics.createUnboundedTableStatistics(
+                this.computeRate(numberOfRecordsForRate));
+      } catch (Exception e) {
+        LOGGER.warn("Could not get the row count for the topics " + getTopics(), e);
+        rowCountStatistics = BeamTableStatistics.UNBOUNDED_UNKNOWN;
+      }
+    }
+
+    return rowCountStatistics;
+  }
+
+  /**
+   * This method returns the estimate of the computeRate for this table using last numberOfRecords
+   * tuples in each partition.
+   */
+  double computeRate(int numberOfRecords) throws NoEstimationException {
+    Properties props = new Properties();
+
+    props.put("bootstrap.servers", bootstrapServers);
+    props.put("session.timeout.ms", "30000");
+    props.put("key.deserializer", "org.apache.kafka.common.serialization.ByteArrayDeserializer");
+    props.put("value.deserializer", "org.apache.kafka.common.serialization.ByteArrayDeserializer");
+
+    KafkaConsumer<String, String> consumer = new KafkaConsumer<String, String>(props);
+
+    return computeRate(consumer, numberOfRecords);
+  }
+
+  <T> double computeRate(Consumer<T, T> consumer, int numberOfRecordsToCheck)
+      throws NoEstimationException {
+
+    Stream<TopicPartition> c =
+        getTopics().stream()
+            .map(consumer::partitionsFor)
+            .flatMap(Collection::stream)
+            .map(parInf -> new TopicPartition(parInf.topic(), parInf.partition()));
+    List<TopicPartition> topicPartitions = c.collect(Collectors.toList());
+
+    consumer.assign(topicPartitions);
+    // This will return current offset of all the partitions that are assigned to the consumer.
(It
+    // will be the last record in those partitions). Note that each topic can have multiple
+    // partitions. Since the consumer is not assigned to any consumer group, changing the
offset or
+    // consuming messages does not have any effect on the other consumers (and the data that
our
+    // table is receiving)
+    Map<TopicPartition, Long> offsets = consumer.endOffsets(topicPartitions);
+    long nParsSeen = 0;
+    for (TopicPartition par : topicPartitions) {
+      long offset = offsets.get(par);
+      nParsSeen = (offset == 0) ? nParsSeen : nParsSeen + 1;
+      consumer.seek(par, Math.max(0L, offset - numberOfRecordsToCheck));
+    }
+
+    if (nParsSeen == 0) {
+      throw new NoEstimationException("There is no partition with messages in it.");
+    }
+
+    ConsumerRecords<T, T> records = consumer.poll(1000);
+
+    // Kafka guarantees the delivery of messages in order they arrive to each partition.
+    // Therefore the first message seen from each partition is the first message arrived
to that.
+    // We pick all the first messages of the partitions, and then consider the latest one
as the
+    // starting point
+    // and discard all the messages that have arrived sooner than that in the rate estimation.
+    Map<Integer, Long> minTimeStamps = new HashMap<>();
+    long maxMinTimeStamp = 0;
+    for (ConsumerRecord<T, T> record : records) {
+      if (!minTimeStamps.containsKey(record.partition())) {
+        minTimeStamps.put(record.partition(), record.timestamp());
+
+        nParsSeen--;
+        maxMinTimeStamp = Math.max(record.timestamp(), maxMinTimeStamp);
+        if (nParsSeen == 0) {
+          break;
+        }
+      }
+    }
+
+    int numberOfRecords = 0;
+    long maxTimeStamp = 0;
+    for (ConsumerRecord<T, T> record : records) {
+      maxTimeStamp = Math.max(maxTimeStamp, record.timestamp());
+      numberOfRecords =
+          record.timestamp() > maxMinTimeStamp ? numberOfRecords + 1 : numberOfRecords;
+    }
+
+    if (maxTimeStamp == maxMinTimeStamp) {
+      throw new NoEstimationException("Arrival time of all records are the same.");
+    }
+
+    return (numberOfRecords * 1000.) / ((double) maxTimeStamp - maxMinTimeStamp);
+  }
+
+  /** Will be thrown if we cannot estimate the rate for kafka table. */
+  static class NoEstimationException extends Exception {
+    NoEstimationException(String message) {
+      super(message);
+    }
   }
 }
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java
index 710a1a5..c407ff4 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java
@@ -20,7 +20,13 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.kafka;
 import static java.nio.charset.StandardCharsets.UTF_8;
 
 import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.beam.sdk.extensions.sql.BeamSqlTable;
+import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
+import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics;
 import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
+import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestTableUtils;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
@@ -30,11 +36,13 @@ import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.Row;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.calcite.adapter.java.JavaTypeFactory;
 import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
 import org.apache.calcite.rel.type.RelDataTypeSystem;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.commons.csv.CSVFormat;
+import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 
@@ -46,8 +54,101 @@ public class BeamKafkaCSVTableTest {
 
   private static final Row ROW2 = Row.withSchema(genSchema()).addValues(2L, 2, 2.0).build();
 
+  private static Map<String, BeamSqlTable> tables = new HashMap<>();
+  protected static BeamSqlEnv env = BeamSqlEnv.readOnly("test", tables);
+
+  @Test
+  public void testOrderedArrivalSinglePartitionRate() {
+    KafkaCSVTestTable table = getTable(1);
+    for (int i = 0; i < 100; i++) {
+      table.addRecord(KafkaTestRecord.create("key1", i + ",1,2", "topic1", 500 * i));
+    }
+
+    BeamTableStatistics stats = table.getTableStatistics(null);
+    Assert.assertEquals(2d, stats.getRate(), 0.001);
+  }
+
+  @Test
+  public void testOrderedArrivalMultiplePartitionsRate() {
+    KafkaCSVTestTable table = getTable(3);
+    for (int i = 0; i < 100; i++) {
+      table.addRecord(KafkaTestRecord.create("key" + i, i + ",1,2", "topic1", 500 * i));
+    }
+
+    BeamTableStatistics stats = table.getTableStatistics(null);
+    Assert.assertEquals(2d, stats.getRate(), 0.001);
+  }
+
+  @Test
+  public void testOnePartitionAheadRate() {
+    KafkaCSVTestTable table = getTable(3);
+    for (int i = 0; i < 100; i++) {
+      table.addRecord(KafkaTestRecord.create("1", i + ",1,2", "topic1", 1000 * i));
+      table.addRecord(KafkaTestRecord.create("2", i + ",1,2", "topic1", 500 * i));
+    }
+
+    table.setNumberOfRecordsForRate(20);
+    BeamTableStatistics stats = table.getTableStatistics(null);
+    Assert.assertEquals(1d, stats.getRate(), 0.001);
+  }
+
+  @Test
+  public void testLateRecords() {
+    KafkaCSVTestTable table = getTable(3);
+
+    table.addRecord(KafkaTestRecord.create("1", 132 + ",1,2", "topic1", 1000));
+    for (int i = 0; i < 98; i++) {
+      table.addRecord(KafkaTestRecord.create("1", i + ",1,2", "topic1", 500));
+    }
+    table.addRecord(KafkaTestRecord.create("1", 133 + ",1,2", "topic1", 2000));
+
+    table.setNumberOfRecordsForRate(200);
+    BeamTableStatistics stats = table.getTableStatistics(null);
+    Assert.assertEquals(1d, stats.getRate(), 0.001);
+  }
+
   @Test
-  public void testCsvRecorderDecoder() throws Exception {
+  public void testAllLate() {
+    KafkaCSVTestTable table = getTable(3);
+
+    table.addRecord(KafkaTestRecord.create("1", 132 + ",1,2", "topic1", 1000));
+    for (int i = 0; i < 98; i++) {
+      table.addRecord(KafkaTestRecord.create("1", i + ",1,2", "topic1", 500));
+    }
+
+    table.setNumberOfRecordsForRate(200);
+    BeamTableStatistics stats = table.getTableStatistics(null);
+    Assert.assertTrue(stats.isUnknown());
+  }
+
+  @Test
+  public void testEmptyPartitionsRate() {
+    KafkaCSVTestTable table = getTable(3);
+    BeamTableStatistics stats = table.getTableStatistics(null);
+    Assert.assertTrue(stats.isUnknown());
+  }
+
+  @Test
+  public void allTheRecordsSameTimeRate() {
+    KafkaCSVTestTable table = getTable(3);
+    for (int i = 0; i < 100; i++) {
+      table.addRecord(KafkaTestRecord.create("key" + i, i + ",1,2", "topic1", 1000));
+    }
+    BeamTableStatistics stats = table.getTableStatistics(null);
+    Assert.assertTrue(stats.isUnknown());
+  }
+
+  private static class PrintDoFn extends DoFn<Row, Row> {
+
+    @ProcessElement
+    public void process(ProcessContext c) {
+      System.out.println("we are here");
+      System.out.println(c.element().getValues());
+    }
+  }
+
+  @Test
+  public void testCsvRecorderDecoder() {
     PCollection<Row> result =
         pipeline
             .apply(Create.of("1,\"1\",1.0", "2,2,2.0"))
@@ -60,7 +161,7 @@ public class BeamKafkaCSVTableTest {
   }
 
   @Test
-  public void testCsvRecorderEncoder() throws Exception {
+  public void testCsvRecorderEncoder() {
     PCollection<Row> result =
         pipeline
             .apply(Create.of(ROW1, ROW2))
@@ -90,4 +191,17 @@ public class BeamKafkaCSVTableTest {
       ctx.output(KV.of(new byte[] {}, ctx.element().getBytes(UTF_8)));
     }
   }
+
+  private KafkaCSVTestTable getTable(int numberOfPartitions) {
+    return new KafkaCSVTestTable(
+        TestTableUtils.buildBeamSqlSchema(
+            Schema.FieldType.INT32,
+            "order_id",
+            Schema.FieldType.INT32,
+            "site_id",
+            Schema.FieldType.INT32,
+            "price"),
+        ImmutableList.of("topic1", "topic2"),
+        numberOfPartitions);
+  }
 }
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java
new file mode 100644
index 0000000..201a1df
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.kafka;
+
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT32;
+import static org.apache.beam.sdk.schemas.Schema.toSchema;
+
+import com.alibaba.fastjson.JSON;
+import java.util.Map;
+import java.util.Properties;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.direct.DirectOptions;
+import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
+import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
+import org.apache.beam.sdk.extensions.sql.meta.Table;
+import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider;
+import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.options.Validation;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.vendor.grpc.v1p21p0.com.google.common.base.MoreObjects;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
+import org.apache.kafka.clients.producer.KafkaProducer;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * This is an integration test for KafkaCSVTable. There should be a kafka server running
and the
+ * address should be passed to it. (https://issues.apache.org/jira/projects/BEAM/issues/BEAM-7523)
+ */
+public class KafkaCSVTableIT {
+
+  @Rule public transient TestPipeline pipeline = TestPipeline.create();
+
+  private static final Schema TEST_TABLE_SCHEMA =
+      Schema.builder()
+          .addNullableField("order_id", Schema.FieldType.INT32)
+          .addNullableField("member_id", Schema.FieldType.INT32)
+          .addNullableField("item_name", Schema.FieldType.INT32)
+          .build();
+
+  @BeforeClass
+  public static void prepare() {
+    PipelineOptionsFactory.register(KafkaOptions.class);
+  }
+
+  @Test
+  @SuppressWarnings("FutureReturnValueIgnored")
+  public void testFake2() throws BeamKafkaTable.NoEstimationException {
+    KafkaOptions kafkaOptions = pipeline.getOptions().as(KafkaOptions.class);
+    Table table =
+        Table.builder()
+            .name("kafka_table")
+            .comment("kafka" + " table")
+            .location("")
+            .schema(
+                Stream.of(
+                        Schema.Field.nullable("order_id", INT32),
+                        Schema.Field.nullable("member_id", INT32),
+                        Schema.Field.nullable("item_name", INT32))
+                    .collect(toSchema()))
+            .type("kafka")
+            .properties(JSON.parseObject(getKafkaPropertiesString(kafkaOptions)))
+            .build();
+    BeamKafkaTable kafkaTable = (BeamKafkaTable) new KafkaTableProvider().buildBeamSqlTable(table);
+    produceSomeRecordsWithDelay(100, 20);
+    double rate1 = kafkaTable.computeRate(20);
+    produceSomeRecordsWithDelay(100, 10);
+    double rate2 = kafkaTable.computeRate(20);
+    Assert.assertTrue(rate2 > rate1);
+  }
+
+  private String getKafkaPropertiesString(KafkaOptions kafkaOptions) {
+    return "{ \"bootstrap.servers\" : \""
+        + kafkaOptions.getKafkaBootstrapServerAddress()
+        + "\",\"topics\":[\""
+        + kafkaOptions.getKafkaTopic()
+        + "\"] }";
+  }
+
+  static final transient Map<Long, Boolean> FLAG = new ConcurrentHashMap<>();
+
+  @Test
+  public void testFake() throws InterruptedException {
+    KafkaOptions kafkaOptions = pipeline.getOptions().as(KafkaOptions.class);
+    pipeline.getOptions().as(DirectOptions.class).setBlockOnRun(false);
+    String createTableString =
+        "CREATE EXTERNAL TABLE kafka_table(\n"
+            + "order_id INTEGER, \n"
+            + "member_id INTEGER, \n"
+            + "item_name INTEGER \n"
+            + ") \n"
+            + "TYPE 'kafka' \n"
+            + "LOCATION '"
+            + "'\n"
+            + "TBLPROPERTIES '"
+            + getKafkaPropertiesString(kafkaOptions)
+            + "'";
+    TableProvider tb = new KafkaTableProvider();
+    BeamSqlEnv env = BeamSqlEnv.inMemory(tb);
+
+    env.executeDdl(createTableString);
+
+    PCollection<Row> queryOutput =
+        BeamSqlRelUtils.toPCollection(pipeline, env.parseQuery("SELECT * FROM kafka_table"));
+
+    queryOutput
+        .apply(ParDo.of(new FakeKvPair()))
+        .apply(
+            "waitForSuccess",
+            ParDo.of(
+                new StreamAssertEqual(
+                    ImmutableSet.of(
+                        row(TEST_TABLE_SCHEMA, 0, 1, 0),
+                        row(TEST_TABLE_SCHEMA, 1, 2, 1),
+                        row(TEST_TABLE_SCHEMA, 2, 3, 2)))));
+    queryOutput.apply(logRecords(""));
+    pipeline.run();
+    TimeUnit.MILLISECONDS.sleep(3000);
+    produceSomeRecords(3);
+
+    for (int i = 0; i < 200; i++) {
+      if (FLAG.getOrDefault(pipeline.getOptions().getOptionsId(), false)) {
+        return;
+      }
+      TimeUnit.MILLISECONDS.sleep(60);
+    }
+    Assert.fail();
+  }
+
+  private static MapElements<Row, Void> logRecords(String suffix) {
+    return MapElements.via(
+        new SimpleFunction<Row, Void>() {
+          @Override
+          public @Nullable Void apply(Row input) {
+            System.out.println(input.getValues() + suffix);
+            return null;
+          }
+        });
+  }
+
+  /** This is made because DoFn with states should get KV as input. */
+  public static class FakeKvPair extends DoFn<Row, KV<String, Row>> {
+    @ProcessElement
+    public void processElement(ProcessContext c) {
+      c.output(KV.of("fake_key", c.element()));
+    }
+  }
+
+  /** This DoFn will set a flag if all the elements are seen. */
+  public static class StreamAssertEqual extends DoFn<KV<String, Row>, Void> {
+    private final Set<Row> expected;
+
+    StreamAssertEqual(Set<Row> expected) {
+      super();
+      this.expected = expected;
+    }
+
+    @DoFn.StateId("seenValues")
+    private final StateSpec<BagState<Row>> seenRows = StateSpecs.bag();
+
+    @StateId("count")
+    private final StateSpec<ValueState<Integer>> countState = StateSpecs.value();
+
+    @ProcessElement
+    public void process(
+        ProcessContext context,
+        @StateId("seenValues") BagState<Row> seenValues,
+        @StateId("count") ValueState<Integer> countState) {
+      // I don't think doing this will be safe in parallel
+      int count = MoreObjects.firstNonNull(countState.read(), 0);
+      count = count + 1;
+      countState.write(count);
+      seenValues.add(context.element().getValue());
+
+      if (count >= expected.size()) {
+        if (StreamSupport.stream(seenValues.read().spliterator(), false)
+            .collect(Collectors.toSet())
+            .containsAll(expected)) {
+          System.out.println("in second if");
+          FLAG.put(context.getPipelineOptions().getOptionsId(), true);
+        }
+      }
+    }
+  }
+
+  private Row row(Schema schema, Object... values) {
+    return Row.withSchema(schema).addValues(values).build();
+  }
+
+  @SuppressWarnings("FutureReturnValueIgnored")
+  private void produceSomeRecords(int num) {
+    Producer<String, String> producer = new KafkaProducer<String, String>(producerProps());
+    String topicName = pipeline.getOptions().as(KafkaOptions.class).getKafkaTopic();
+    for (int i = 0; i < num; i++) {
+      producer.send(
+          new ProducerRecord<String, String>(
+              topicName, "k" + i, i + "," + ((i % 3) + 1) + "," + i));
+    }
+    producer.flush();
+    producer.close();
+  }
+
+  @SuppressWarnings("FutureReturnValueIgnored")
+  private void produceSomeRecordsWithDelay(int num, int delayMilis) {
+    Producer<String, String> producer = new KafkaProducer<String, String>(producerProps());
+    String topicName = pipeline.getOptions().as(KafkaOptions.class).getKafkaTopic();
+    for (int i = 0; i < num; i++) {
+      producer.send(
+          new ProducerRecord<String, String>(
+              topicName, "k" + i, i + "," + ((i % 3) + 1) + "," + i));
+      try {
+        TimeUnit.MILLISECONDS.sleep(delayMilis);
+      } catch (InterruptedException e) {
+        throw new RuntimeException("Could not wait for producing", e);
+      }
+    }
+    producer.flush();
+    producer.close();
+  }
+
+  private Properties producerProps() {
+    KafkaOptions options = pipeline.getOptions().as(KafkaOptions.class);
+    Properties props = new Properties();
+    props.put("bootstrap.servers", options.getKafkaBootstrapServerAddress());
+    props.put("key.serializer", "org.apache.kafka.common.serialization.StringSerializer");
+    props.put("value.serializer", "org.apache.kafka.common.serialization.StringSerializer");
+    props.put("buffer.memory", 33554432);
+    props.put("acks", "all");
+    props.put("request.required.acks", "1");
+    props.put("retries", 0);
+    props.put("linger.ms", 1);
+    return props;
+  }
+
+  /** Pipeline options specific for this test. */
+  public interface KafkaOptions extends PipelineOptions {
+
+    @Description("Kafka server address")
+    @Validation.Required
+    @Default.String("localhost:9092")
+    String getKafkaBootstrapServerAddress();
+
+    void setKafkaBootstrapServerAddress(String address);
+
+    @Description("Kafka topic")
+    @Validation.Required
+    @Default.String("test")
+    String getKafkaTopic();
+
+    void setKafkaTopic(String topic);
+  }
+}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTestTable.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTestTable.java
new file mode 100644
index 0000000..749adea
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTestTable.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.kafka;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.io.kafka.KafkaIO;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.MockConsumer;
+import org.apache.kafka.clients.consumer.OffsetAndTimestamp;
+import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.record.TimestampType;
+
+/** This is a MockKafkaCSVTestTable. It will use a Mock Consumer. */
+public class KafkaCSVTestTable extends BeamKafkaCSVTable {
+  private int partitionsPerTopic;
+  private List<KafkaTestRecord> records;
+  private static final String TIMESTAMP_TYPE_CONFIG = "test.timestamp.type";
+
+  public KafkaCSVTestTable(Schema beamSchema, List<String> topics, int partitionsPerTopic)
{
+    super(beamSchema, "server:123", topics);
+    this.partitionsPerTopic = partitionsPerTopic;
+    this.records = new ArrayList<>();
+  }
+
+  @Override
+  KafkaIO.Read<byte[], byte[]> createKafkaRead() {
+    return super.createKafkaRead().withConsumerFactoryFn(this::mkMockConsumer);
+  }
+
+  public void addRecord(KafkaTestRecord record) {
+    records.add(record);
+  }
+
+  @Override
+  double computeRate(int numberOfRecords) throws NoEstimationException {
+    return super.computeRate(mkMockConsumer(new HashMap<>()), numberOfRecords);
+  }
+
+  public void setNumberOfRecordsForRate(int numberOfRecordsForRate) {
+    this.numberOfRecordsForRate = numberOfRecordsForRate;
+  }
+
+  private MockConsumer<byte[], byte[]> mkMockConsumer(Map<String, Object> config)
{
+    OffsetResetStrategy offsetResetStrategy = OffsetResetStrategy.EARLIEST;
+    final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> kafkaRecords
= new HashMap<>();
+    Map<String, List<PartitionInfo>> partitionInfoMap = new HashMap<>();
+    Map<String, List<TopicPartition>> partitionMap = new HashMap<>();
+
+    // Create Topic Paritions
+    for (String topic : this.getTopics()) {
+      List<PartitionInfo> partIds = new ArrayList<>(partitionsPerTopic);
+      List<TopicPartition> topicParitions = new ArrayList<>(partitionsPerTopic);
+      for (int i = 0; i < partitionsPerTopic; i++) {
+        TopicPartition tp = new TopicPartition(topic, i);
+        topicParitions.add(tp);
+        partIds.add(new PartitionInfo(topic, i, null, null, null));
+        kafkaRecords.put(tp, new ArrayList<>());
+      }
+      partitionInfoMap.put(topic, partIds);
+      partitionMap.put(topic, topicParitions);
+    }
+
+    TimestampType timestampType =
+        TimestampType.forName(
+            (String)
+                config.getOrDefault(
+                    TIMESTAMP_TYPE_CONFIG, TimestampType.LOG_APPEND_TIME.toString()));
+
+    for (KafkaTestRecord record : this.records) {
+      int partitionIndex = record.getKey().hashCode() % partitionsPerTopic;
+      TopicPartition tp = partitionMap.get(record.getTopic()).get(partitionIndex);
+      byte[] key = record.getKey().getBytes(UTF_8);
+      byte[] value = record.getValue().getBytes(UTF_8);
+      kafkaRecords
+          .get(tp)
+          .add(
+              new ConsumerRecord<>(
+                  tp.topic(),
+                  tp.partition(),
+                  kafkaRecords.get(tp).size(),
+                  record.getTimeStamp(),
+                  timestampType,
+                  0,
+                  key.length,
+                  value.length,
+                  key,
+                  value));
+    }
+
+    // This is updated when reader assigns partitions.
+    final AtomicReference<List<TopicPartition>> assignedPartitions =
+        new AtomicReference<>(Collections.<TopicPartition>emptyList());
+    final MockConsumer<byte[], byte[]> consumer =
+        new MockConsumer<byte[], byte[]>(offsetResetStrategy) {
+          @Override
+          public synchronized void assign(final Collection<TopicPartition> assigned)
{
+            Collection<TopicPartition> realPartitions =
+                assigned.stream()
+                    .map(part -> partitionMap.get(part.topic()).get(part.partition()))
+                    .collect(Collectors.toList());
+            super.assign(realPartitions);
+            assignedPartitions.set(ImmutableList.copyOf(realPartitions));
+            for (TopicPartition tp : realPartitions) {
+              updateBeginningOffsets(ImmutableMap.of(tp, 0L));
+              updateEndOffsets(ImmutableMap.of(tp, (long) kafkaRecords.get(tp).size()));
+            }
+          }
+          // Override offsetsForTimes() in order to look up the offsets by timestamp.
+          @Override
+          public synchronized Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(
+              Map<TopicPartition, Long> timestampsToSearch) {
+            return timestampsToSearch.entrySet().stream()
+                .map(
+                    e -> {
+                      // In test scope, timestamp == offset. ????
+                      long maxOffset = kafkaRecords.get(e.getKey()).size();
+                      long offset = e.getValue();
+                      OffsetAndTimestamp value =
+                          (offset >= maxOffset) ? null : new OffsetAndTimestamp(offset,
offset);
+                      return new AbstractMap.SimpleEntry<>(e.getKey(), value);
+                    })
+                .collect(
+                    Collectors.toMap(
+                        AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue));
+          }
+        };
+
+    for (String topic : getTopics()) {
+      consumer.updatePartitions(topic, partitionInfoMap.get(topic));
+    }
+
+    Runnable recordEnqueueTask =
+        new Runnable() {
+          @Override
+          public void run() {
+            // add all the records with offset >= current partition position.
+            int recordsAdded = 0;
+            for (TopicPartition tp : assignedPartitions.get()) {
+              long curPos = consumer.position(tp);
+              for (ConsumerRecord<byte[], byte[]> r : kafkaRecords.get(tp)) {
+                if (r.offset() >= curPos) {
+                  consumer.addRecord(r);
+                  recordsAdded++;
+                }
+              }
+            }
+            if (recordsAdded == 0) {
+              if (config.get("inject.error.at.eof") != null) {
+                consumer.setException(new KafkaException("Injected error in consumer.poll()"));
+              }
+              // MockConsumer.poll(timeout) does not actually wait even when there aren't
any
+              // records.
+              // Add a small wait here in order to avoid busy looping in the reader.
+              Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS);
+            }
+            consumer.schedulePollTask(this);
+          }
+        };
+
+    consumer.schedulePollTask(recordEnqueueTask);
+
+    return consumer;
+  }
+}
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestRecord.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestRecord.java
new file mode 100644
index 0000000..015ac8b
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestRecord.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.kafka;
+
+import com.google.auto.value.AutoValue;
+import java.io.Serializable;
+
+/** This class is created because Kafka Consumer Records are not serializable. */
+@AutoValue
+public abstract class KafkaTestRecord implements Serializable {
+
+  public abstract String getKey();
+
+  public abstract String getValue();
+
+  public abstract String getTopic();
+
+  public abstract long getTimeStamp();
+
+  public static KafkaTestRecord create(
+      String newKey, String newValue, String newTopic, long newTimeStamp) {
+    return new AutoValue_KafkaTestRecord(newKey, newValue, newTopic, newTimeStamp);
+  }
+}


Mime
View raw message