spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From t...@apache.org
Subject [1/2] spark git commit: [SPARK-22908][SS] Roll forward continuous processing Kafka support with fix to continuous Kafka data reader
Date Wed, 17 Jan 2018 02:11:31 GMT
Repository: spark
Updated Branches:
  refs/heads/master a9b845ebb -> 166705785


http://git-wip-us.apache.org/repos/asf/spark/blob/16670578/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
index a0f5695..1acff61 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
@@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.SparkContext
-import org.apache.spark.sql.ForeachWriter
+import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec}
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
+import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter
 import org.apache.spark.sql.functions.{count, window}
 import org.apache.spark.sql.kafka010.KafkaSourceProvider._
-import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest}
+import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger}
 import org.apache.spark.sql.streaming.util.StreamManualClock
 import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
 import org.apache.spark.util.Utils
@@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
 
   override val streamingTimeout = 30.seconds
 
+  protected val brokerProps = Map[String, Object]()
+
   override def beforeAll(): Unit = {
     super.beforeAll()
-    testUtils = new KafkaTestUtils
+    testUtils = new KafkaTestUtils(brokerProps)
     testUtils.setup()
   }
 
@@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
     if (testUtils != null) {
       testUtils.teardown()
       testUtils = null
-      super.afterAll()
     }
+    super.afterAll()
   }
 
   protected def makeSureGetOffsetCalled = AssertOnQuery { q =>
     // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure
-    // its "getOffset" is called before pushing any data. Otherwise, because of the race contion,
+    // its "getOffset" is called before pushing any data. Otherwise, because of the race condition,
     // we don't know which data should be fetched when `startingOffsets` is latest.
-    q.processAllAvailable()
+    q match {
+      case c: ContinuousExecution => c.awaitEpoch(0)
+      case m: MicroBatchExecution => m.processAllAvailable()
+    }
     true
   }
 
+  protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = {
+    testUtils.addPartitions(topic, newCount)
+  }
+
   /**
    * Add data to Kafka.
    *
@@ -82,10 +94,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
       message: String = "",
       topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData {
 
-    override def addData(query: Option[StreamExecution]): (Source, Offset) = {
-      if (query.get.isActive) {
+    override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
+      query match {
         // Make sure no Spark job is running when deleting a topic
-        query.get.processAllAvailable()
+        case Some(m: MicroBatchExecution) => m.processAllAvailable()
+        case _ =>
       }
 
       val existingTopics = testUtils.getAllTopicsAndPartitionSize().toMap
@@ -97,16 +110,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
         topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2))
       }
 
-      // Read all topics again in case some topics are delete.
-      val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys
       require(
         query.nonEmpty,
         "Cannot add data when there is no query for finding the active kafka source")
 
       val sources = query.get.logicalPlan.collect {
-        case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] =>
-          source.asInstanceOf[KafkaSource]
-      }
+        case StreamingExecutionRelation(source: KafkaSource, _) => source
+      } ++ (query.get.lastExecution match {
+        case null => Seq()
+        case e => e.logical.collect {
+          case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
+        }
+      })
       if (sources.isEmpty) {
         throw new Exception(
           "Could not find Kafka source in the StreamExecution logical plan to add data to")
@@ -137,14 +152,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
     override def toString: String =
       s"AddKafkaData(topics = $topics, data = $data, message = $message)"
   }
-}
 
+  private val topicId = new AtomicInteger(0)
+  protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
+}
 
-class KafkaSourceSuite extends KafkaSourceTest {
+class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase {
 
   import testImplicits._
 
-  private val topicId = new AtomicInteger(0)
+  test("(de)serialization of initial offsets") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 5)
+
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+
+    testStream(reader.load)(
+      makeSureGetOffsetCalled,
+      StopStream,
+      StartStream(),
+      StopStream)
+  }
+
+  test("maxOffsetsPerTrigger") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 3)
+    testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0))
+    testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1))
+    testUtils.sendMessages(topic, Array("1"), Some(2))
+
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("kafka.metadata.max.age.ms", "1")
+      .option("maxOffsetsPerTrigger", 10)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+    val kafka = reader.load()
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .as[(String, String)]
+    val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt)
+
+    val clock = new StreamManualClock
+
+    val waitUntilBatchProcessed = AssertOnQuery { q =>
+      eventually(Timeout(streamingTimeout)) {
+        if (!q.exception.isDefined) {
+          assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
+        }
+      }
+      if (q.exception.isDefined) {
+        throw q.exception.get
+      }
+      true
+    }
+
+    testStream(mapped)(
+      StartStream(ProcessingTime(100), clock),
+      waitUntilBatchProcessed,
+      // 1 from smallest, 1 from middle, 8 from biggest
+      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107),
+      AdvanceManualClock(100),
+      waitUntilBatchProcessed,
+      // smallest now empty, 1 more from middle, 9 more from biggest
+      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
+        11, 108, 109, 110, 111, 112, 113, 114, 115, 116
+      ),
+      StopStream,
+      StartStream(ProcessingTime(100), clock),
+      waitUntilBatchProcessed,
+      // smallest now empty, 1 more from middle, 9 more from biggest
+      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
+        11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
+        12, 117, 118, 119, 120, 121, 122, 123, 124, 125
+      ),
+      AdvanceManualClock(100),
+      waitUntilBatchProcessed,
+      // smallest now empty, 1 more from middle, 9 more from biggest
+      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
+        11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
+        12, 117, 118, 119, 120, 121, 122, 123, 124, 125,
+        13, 126, 127, 128, 129, 130, 131, 132, 133, 134
+      )
+    )
+  }
+
+  test("input row metrics") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 5)
+    testUtils.sendMessages(topic, Array("-1"))
+    require(testUtils.getLatestOffsets(Set(topic)).size === 5)
+
+    val kafka = spark
+      .readStream
+      .format("kafka")
+      .option("subscribe", topic)
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .load()
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .as[(String, String)]
+
+    val mapped = kafka.map(kv => kv._2.toInt + 1)
+    testStream(mapped)(
+      StartStream(trigger = ProcessingTime(1)),
+      makeSureGetOffsetCalled,
+      AddKafkaData(Set(topic), 1, 2, 3),
+      CheckAnswer(2, 3, 4),
+      AssertOnQuery { query =>
+        val recordsRead = query.recentProgress.map(_.numInputRows).sum
+        recordsRead == 3
+      }
+    )
+  }
+
+  test("subscribing topic by pattern with topic deletions") {
+    val topicPrefix = newTopic()
+    val topic = topicPrefix + "-seems"
+    val topic2 = topicPrefix + "-bad"
+    testUtils.createTopic(topic, partitions = 5)
+    testUtils.sendMessages(topic, Array("-1"))
+    require(testUtils.getLatestOffsets(Set(topic)).size === 5)
+
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("kafka.metadata.max.age.ms", "1")
+      .option("subscribePattern", s"$topicPrefix-.*")
+      .option("failOnDataLoss", "false")
+
+    val kafka = reader.load()
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .as[(String, String)]
+    val mapped = kafka.map(kv => kv._2.toInt + 1)
+
+    testStream(mapped)(
+      makeSureGetOffsetCalled,
+      AddKafkaData(Set(topic), 1, 2, 3),
+      CheckAnswer(2, 3, 4),
+      Assert {
+        testUtils.deleteTopic(topic)
+        testUtils.createTopic(topic2, partitions = 5)
+        true
+      },
+      AddKafkaData(Set(topic2), 4, 5, 6),
+      CheckAnswer(2, 3, 4, 5, 6, 7)
+    )
+  }
 
   testWithUninterruptibleThread(
     "deserialization of initial offset with Spark 2.1.0") {
@@ -237,86 +396,94 @@ class KafkaSourceSuite extends KafkaSourceTest {
     }
   }
 
-  test("(de)serialization of initial offsets") {
+  test("KafkaSource with watermark") {
+    val now = System.currentTimeMillis()
     val topic = newTopic()
-    testUtils.createTopic(topic, partitions = 64)
+    testUtils.createTopic(newTopic(), partitions = 1)
+    testUtils.sendMessages(topic, Array(1).map(_.toString))
 
-    val reader = spark
+    val kafka = spark
       .readStream
       .format("kafka")
       .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("kafka.metadata.max.age.ms", "1")
+      .option("startingOffsets", s"earliest")
       .option("subscribe", topic)
+      .load()
 
-    testStream(reader.load)(
-      makeSureGetOffsetCalled,
-      StopStream,
-      StartStream(),
-      StopStream)
+    val windowedAggregation = kafka
+      .withWatermark("timestamp", "10 seconds")
+      .groupBy(window($"timestamp", "5 seconds") as 'window)
+      .agg(count("*") as 'count)
+      .select($"window".getField("start") as 'window, $"count")
+
+    val query = windowedAggregation
+      .writeStream
+      .format("memory")
+      .outputMode("complete")
+      .queryName("kafkaWatermark")
+      .start()
+    query.processAllAvailable()
+    val rows = spark.table("kafkaWatermark").collect()
+    assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
+    val row = rows(0)
+    // We cannot check the exact window start time as it depands on the time that messages were
+    // inserted by the producer. So here we just use a low bound to make sure the internal
+    // conversion works.
+    assert(
+      row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000,
+      s"Unexpected results: $row")
+    assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row")
+    query.stop()
   }
 
-  test("maxOffsetsPerTrigger") {
+  test("delete a topic when a Spark job is running") {
+    KafkaSourceSuite.collectedData.clear()
+
     val topic = newTopic()
-    testUtils.createTopic(topic, partitions = 3)
-    testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0))
-    testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1))
-    testUtils.sendMessages(topic, Array("1"), Some(2))
+    testUtils.createTopic(topic, partitions = 1)
+    testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray)
 
     val reader = spark
       .readStream
       .format("kafka")
       .option("kafka.bootstrap.servers", testUtils.brokerAddress)
       .option("kafka.metadata.max.age.ms", "1")
-      .option("maxOffsetsPerTrigger", 10)
       .option("subscribe", topic)
+      // If a topic is deleted and we try to poll data starting from offset 0,
+      // the Kafka consumer will just block until timeout and return an empty result.
+      // So set the timeout to 1 second to make this test fast.
+      .option("kafkaConsumer.pollTimeoutMs", "1000")
       .option("startingOffsets", "earliest")
+      .option("failOnDataLoss", "false")
     val kafka = reader.load()
       .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
       .as[(String, String)]
-    val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt)
-
-    val clock = new StreamManualClock
-
-    val waitUntilBatchProcessed = AssertOnQuery { q =>
-      eventually(Timeout(streamingTimeout)) {
-        if (!q.exception.isDefined) {
-          assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
-        }
+    KafkaSourceSuite.globalTestUtils = testUtils
+    // The following ForeachWriter will delete the topic before fetching data from Kafka
+    // in executors.
+    val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
+      override def open(partitionId: Long, version: Long): Boolean = {
+        KafkaSourceSuite.globalTestUtils.deleteTopic(topic)
+        true
       }
-      if (q.exception.isDefined) {
-        throw q.exception.get
+
+      override def process(value: Int): Unit = {
+        KafkaSourceSuite.collectedData.add(value)
       }
-      true
-    }
 
-    testStream(mapped)(
-      StartStream(ProcessingTime(100), clock),
-      waitUntilBatchProcessed,
-      // 1 from smallest, 1 from middle, 8 from biggest
-      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107),
-      AdvanceManualClock(100),
-      waitUntilBatchProcessed,
-      // smallest now empty, 1 more from middle, 9 more from biggest
-      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
-        11, 108, 109, 110, 111, 112, 113, 114, 115, 116
-      ),
-      StopStream,
-      StartStream(ProcessingTime(100), clock),
-      waitUntilBatchProcessed,
-      // smallest now empty, 1 more from middle, 9 more from biggest
-      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
-        11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
-        12, 117, 118, 119, 120, 121, 122, 123, 124, 125
-      ),
-      AdvanceManualClock(100),
-      waitUntilBatchProcessed,
-      // smallest now empty, 1 more from middle, 9 more from biggest
-      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
-        11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
-        12, 117, 118, 119, 120, 121, 122, 123, 124, 125,
-        13, 126, 127, 128, 129, 130, 131, 132, 133, 134
-      )
-    )
+      override def close(errorOrNull: Throwable): Unit = {}
+    }).start()
+    query.processAllAvailable()
+    query.stop()
+    // `failOnDataLoss` is `false`, we should not fail the query
+    assert(query.exception.isEmpty)
   }
+}
+
+class KafkaSourceSuiteBase extends KafkaSourceTest {
+
+  import testImplicits._
 
   test("SPARK-22956: currentPartitionOffsets should be set when no new data comes in") {
     def getSpecificDF(range: Range.Inclusive): org.apache.spark.sql.Dataset[Int] = {
@@ -393,7 +560,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
       .format("kafka")
       .option("kafka.bootstrap.servers", testUtils.brokerAddress)
       .option("kafka.metadata.max.age.ms", "1")
-      .option("subscribePattern", s"topic-.*")
+      .option("subscribePattern", s"$topic.*")
 
     val kafka = reader.load()
       .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
@@ -487,65 +654,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
     }
   }
 
-  test("subscribing topic by pattern with topic deletions") {
-    val topicPrefix = newTopic()
-    val topic = topicPrefix + "-seems"
-    val topic2 = topicPrefix + "-bad"
-    testUtils.createTopic(topic, partitions = 5)
-    testUtils.sendMessages(topic, Array("-1"))
-    require(testUtils.getLatestOffsets(Set(topic)).size === 5)
-
-    val reader = spark
-      .readStream
-      .format("kafka")
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .option("kafka.metadata.max.age.ms", "1")
-      .option("subscribePattern", s"$topicPrefix-.*")
-      .option("failOnDataLoss", "false")
-
-    val kafka = reader.load()
-      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
-      .as[(String, String)]
-    val mapped = kafka.map(kv => kv._2.toInt + 1)
-
-    testStream(mapped)(
-      makeSureGetOffsetCalled,
-      AddKafkaData(Set(topic), 1, 2, 3),
-      CheckAnswer(2, 3, 4),
-      Assert {
-        testUtils.deleteTopic(topic)
-        testUtils.createTopic(topic2, partitions = 5)
-        true
-      },
-      AddKafkaData(Set(topic2), 4, 5, 6),
-      CheckAnswer(2, 3, 4, 5, 6, 7)
-    )
-  }
-
-  test("starting offset is latest by default") {
-    val topic = newTopic()
-    testUtils.createTopic(topic, partitions = 5)
-    testUtils.sendMessages(topic, Array("0"))
-    require(testUtils.getLatestOffsets(Set(topic)).size === 5)
-
-    val reader = spark
-      .readStream
-      .format("kafka")
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .option("subscribe", topic)
-
-    val kafka = reader.load()
-      .selectExpr("CAST(value AS STRING)")
-      .as[String]
-    val mapped = kafka.map(_.toInt)
-
-    testStream(mapped)(
-      makeSureGetOffsetCalled,
-      AddKafkaData(Set(topic), 1, 2, 3),
-      CheckAnswer(1, 2, 3)  // should not have 0
-    )
-  }
-
   test("bad source options") {
     def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = {
       val ex = intercept[IllegalArgumentException] {
@@ -605,77 +713,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
     testUnsupportedConfig("kafka.auto.offset.reset", "latest")
   }
 
-  test("input row metrics") {
-    val topic = newTopic()
-    testUtils.createTopic(topic, partitions = 5)
-    testUtils.sendMessages(topic, Array("-1"))
-    require(testUtils.getLatestOffsets(Set(topic)).size === 5)
-
-    val kafka = spark
-      .readStream
-      .format("kafka")
-      .option("subscribe", topic)
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .load()
-      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
-      .as[(String, String)]
-
-    val mapped = kafka.map(kv => kv._2.toInt + 1)
-    testStream(mapped)(
-      StartStream(trigger = ProcessingTime(1)),
-      makeSureGetOffsetCalled,
-      AddKafkaData(Set(topic), 1, 2, 3),
-      CheckAnswer(2, 3, 4),
-      AssertOnQuery { query =>
-        val recordsRead = query.recentProgress.map(_.numInputRows).sum
-        recordsRead == 3
-      }
-    )
-  }
-
-  test("delete a topic when a Spark job is running") {
-    KafkaSourceSuite.collectedData.clear()
-
-    val topic = newTopic()
-    testUtils.createTopic(topic, partitions = 1)
-    testUtils.sendMessages(topic, (1 to 10).map(_.toString).toArray)
-
-    val reader = spark
-      .readStream
-      .format("kafka")
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .option("kafka.metadata.max.age.ms", "1")
-      .option("subscribe", topic)
-      // If a topic is deleted and we try to poll data starting from offset 0,
-      // the Kafka consumer will just block until timeout and return an empty result.
-      // So set the timeout to 1 second to make this test fast.
-      .option("kafkaConsumer.pollTimeoutMs", "1000")
-      .option("startingOffsets", "earliest")
-      .option("failOnDataLoss", "false")
-    val kafka = reader.load()
-      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
-      .as[(String, String)]
-    KafkaSourceSuite.globalTestUtils = testUtils
-    // The following ForeachWriter will delete the topic before fetching data from Kafka
-    // in executors.
-    val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
-      override def open(partitionId: Long, version: Long): Boolean = {
-        KafkaSourceSuite.globalTestUtils.deleteTopic(topic)
-        true
-      }
-
-      override def process(value: Int): Unit = {
-        KafkaSourceSuite.collectedData.add(value)
-      }
-
-      override def close(errorOrNull: Throwable): Unit = {}
-    }).start()
-    query.processAllAvailable()
-    query.stop()
-    // `failOnDataLoss` is `false`, we should not fail the query
-    assert(query.exception.isEmpty)
-  }
-
   test("get offsets from case insensitive parameters") {
     for ((optionKey, optionValue, answer) <- Seq(
       (STARTING_OFFSETS_OPTION_KEY, "earLiEst", EarliestOffsetRangeLimit),
@@ -694,8 +731,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
     }
   }
 
-  private def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
-
   private def assignString(topic: String, partitions: Iterable[Int]): String = {
     JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p)))
   }
@@ -741,6 +776,10 @@ class KafkaSourceSuite extends KafkaSourceTest {
 
     testStream(mapped)(
       makeSureGetOffsetCalled,
+      Execute { q =>
+        // wait to reach the last offset in every partition
+        q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L)))
+      },
       CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22),
       StopStream,
       StartStream(),
@@ -771,10 +810,13 @@ class KafkaSourceSuite extends KafkaSourceTest {
       .format("memory")
       .outputMode("append")
       .queryName("kafkaColumnTypes")
+      .trigger(defaultTrigger)
       .start()
-    query.processAllAvailable()
-    val rows = spark.table("kafkaColumnTypes").collect()
-    assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
+    var rows: Array[Row] = Array()
+    eventually(timeout(streamingTimeout)) {
+      rows = spark.table("kafkaColumnTypes").collect()
+      assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
+    }
     val row = rows(0)
     assert(row.getAs[Array[Byte]]("key") === null, s"Unexpected results: $row")
     assert(row.getAs[Array[Byte]]("value") === "1".getBytes(UTF_8), s"Unexpected results: $row")
@@ -788,47 +830,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
     query.stop()
   }
 
-  test("KafkaSource with watermark") {
-    val now = System.currentTimeMillis()
-    val topic = newTopic()
-    testUtils.createTopic(newTopic(), partitions = 1)
-    testUtils.sendMessages(topic, Array(1).map(_.toString))
-
-    val kafka = spark
-      .readStream
-      .format("kafka")
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .option("kafka.metadata.max.age.ms", "1")
-      .option("startingOffsets", s"earliest")
-      .option("subscribe", topic)
-      .load()
-
-    val windowedAggregation = kafka
-      .withWatermark("timestamp", "10 seconds")
-      .groupBy(window($"timestamp", "5 seconds") as 'window)
-      .agg(count("*") as 'count)
-      .select($"window".getField("start") as 'window, $"count")
-
-    val query = windowedAggregation
-      .writeStream
-      .format("memory")
-      .outputMode("complete")
-      .queryName("kafkaWatermark")
-      .start()
-    query.processAllAvailable()
-    val rows = spark.table("kafkaWatermark").collect()
-    assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
-    val row = rows(0)
-    // We cannot check the exact window start time as it depands on the time that messages were
-    // inserted by the producer. So here we just use a low bound to make sure the internal
-    // conversion works.
-    assert(
-      row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000,
-      s"Unexpected results: $row")
-    assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row")
-    query.stop()
-  }
-
   private def testFromLatestOffsets(
       topic: String,
       addPartitions: Boolean,
@@ -865,9 +866,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
       AddKafkaData(Set(topic), 7, 8),
       CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9),
       AssertOnQuery("Add partitions") { query: StreamExecution =>
-        if (addPartitions) {
-          testUtils.addPartitions(topic, 10)
-        }
+        if (addPartitions) setTopicPartitions(topic, 10, query)
         true
       },
       AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16),
@@ -908,9 +907,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
       StartStream(),
       CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9),
       AssertOnQuery("Add partitions") { query: StreamExecution =>
-        if (addPartitions) {
-          testUtils.addPartitions(topic, 10)
-        }
+        if (addPartitions) setTopicPartitions(topic, 10, query)
         true
       },
       AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16),
@@ -1042,20 +1039,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
     }
   }
 
-  test("stress test for failOnDataLoss=false") {
-    val reader = spark
-      .readStream
-      .format("kafka")
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .option("kafka.metadata.max.age.ms", "1")
-      .option("subscribePattern", "failOnDataLoss.*")
-      .option("startingOffsets", "earliest")
-      .option("failOnDataLoss", "false")
-      .option("fetchOffset.retryIntervalMs", "3000")
-    val kafka = reader.load()
-      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
-      .as[(String, String)]
-    val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
+  protected def startStream(ds: Dataset[Int]) = {
+    ds.writeStream.foreach(new ForeachWriter[Int] {
 
       override def open(partitionId: Long, version: Long): Boolean = {
         true
@@ -1069,6 +1054,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
       override def close(errorOrNull: Throwable): Unit = {
       }
     }).start()
+  }
+
+  test("stress test for failOnDataLoss=false") {
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("kafka.metadata.max.age.ms", "1")
+      .option("subscribePattern", "failOnDataLoss.*")
+      .option("startingOffsets", "earliest")
+      .option("failOnDataLoss", "false")
+      .option("fetchOffset.retryIntervalMs", "3000")
+    val kafka = reader.load()
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .as[(String, String)]
+    val query = startStream(kafka.map(kv => kv._2.toInt))
 
     val testTime = 1.minutes
     val startTime = System.currentTimeMillis()

http://git-wip-us.apache.org/repos/asf/spark/blob/16670578/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index e8d683a..b714a46 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
           ds = ds.asInstanceOf[DataSourceV2],
           conf = sparkSession.sessionState.conf)).asJava)
 
+      // Streaming also uses the data source V2 API. So it may be that the data source implements
+      // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading
+      // the dataframe as a v1 source.
       val reader = (ds, userSpecifiedSchema) match {
         case (ds: ReadSupportWithSchema, Some(schema)) =>
           ds.createReader(schema, options)
@@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
           }
           reader
 
-        case _ =>
-          throw new AnalysisException(s"$cls does not support data reading.")
+        case _ => null // fall back to v1
       }
 
-      Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
+      if (reader == null) {
+        loadV1Source(paths: _*)
+      } else {
+        Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
+      }
     } else {
-      // Code path for data source v1.
-      sparkSession.baseRelationToDataFrame(
-        DataSource.apply(
-          sparkSession,
-          paths = paths,
-          userSpecifiedSchema = userSpecifiedSchema,
-          className = source,
-          options = extraOptions.toMap).resolveRelation())
+      loadV1Source(paths: _*)
     }
   }
 
+  private def loadV1Source(paths: String*) = {
+    // Code path for data source v1.
+    sparkSession.baseRelationToDataFrame(
+      DataSource.apply(
+        sparkSession,
+        paths = paths,
+        userSpecifiedSchema = userSpecifiedSchema,
+        className = source,
+        options = extraOptions.toMap).resolveRelation())
+  }
+
   /**
    * Construct a `DataFrame` representing the database table accessible via JDBC URL
    * url named table and connection properties.

http://git-wip-us.apache.org/repos/asf/spark/blob/16670578/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 3304f36..97f12ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
             }
           }
 
-        case _ => throw new AnalysisException(s"$cls does not support data writing.")
+        // Streaming also uses the data source V2 API. So it may be that the data source implements
+        // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving
+        // as though it's a V1 source.
+        case _ => saveToV1Source()
       }
     } else {
-      // Code path for data source v1.
-      runCommand(df.sparkSession, "save") {
-        DataSource(
-          sparkSession = df.sparkSession,
-          className = source,
-          partitionColumns = partitioningColumns.getOrElse(Nil),
-          options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
-      }
+      saveToV1Source()
+    }
+  }
+
+  private def saveToV1Source(): Unit = {
+    // Code path for data source v1.
+    runCommand(df.sparkSession, "save") {
+      DataSource(
+        sparkSession = df.sparkSession,
+        className = source,
+        partitionColumns = partitioningColumns.getOrElse(Nil),
+        options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/16670578/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
index f0bdf84..a4a857f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
@@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan)
         (index, message: WriterCommitMessage) => messages(index) = message
       )
 
-      logInfo(s"Data source writer $writer is committing.")
-      writer.commit(messages)
-      logInfo(s"Data source writer $writer committed.")
+      if (!writer.isInstanceOf[ContinuousWriter]) {
+        logInfo(s"Data source writer $writer is committing.")
+        writer.commit(messages)
+        logInfo(s"Data source writer $writer committed.")
+      }
     } catch {
       case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] =>
         // Interruption is how continuous queries are ended, so accept and ignore the exception.

http://git-wip-us.apache.org/repos/asf/spark/blob/16670578/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 24a8b00..cf27e1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -142,7 +142,8 @@ abstract class StreamExecution(
 
   override val id: UUID = UUID.fromString(streamMetadata.id)
 
-  override val runId: UUID = UUID.randomUUID
+  override def runId: UUID = currentRunId
+  protected var currentRunId = UUID.randomUUID
 
   /**
    * Pretty identified string of printing in logs. Format is
@@ -418,11 +419,17 @@ abstract class StreamExecution(
    * Blocks the current thread until processing for data from the given `source` has reached at
    * least the given `Offset`. This method is intended for use primarily when writing tests.
    */
-  private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = {
+  private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = {
     assertAwaitThread()
     def notDone = {
       val localCommittedOffsets = committedOffsets
-      !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset
+      if (sources == null) {
+        // sources might not be initialized yet
+        false
+      } else {
+        val source = sources(sourceIndex)
+        !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset
+      }
     }
 
     while (notDone) {
@@ -436,7 +443,7 @@ abstract class StreamExecution(
         awaitProgressLock.unlock()
       }
     }
-    logDebug(s"Unblocked at $newOffset for $source")
+    logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}")
   }
 
   /** A flag to indicate that a batch has completed with no new data available. */

http://git-wip-us.apache.org/repos/asf/spark/blob/16670578/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
index b3f1a1a..66eb42d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
@@ -77,7 +77,6 @@ class ContinuousDataSourceRDD(
     dataReaderThread.start()
 
     context.addTaskCompletionListener(_ => {
-      reader.close()
       dataReaderThread.interrupt()
       epochPollExecutor.shutdown()
     })
@@ -177,6 +176,7 @@ class DataReaderThread(
   private[continuous] var failureReason: Throwable = _
 
   override def run(): Unit = {
+    TaskContext.setTaskContext(context)
     val baseReader = ContinuousDataSourceRDD.getBaseReader(reader)
     try {
       while (!context.isInterrupted && !context.isCompleted()) {
@@ -201,6 +201,8 @@ class DataReaderThread(
         failedFlag.set(true)
         // Don't rethrow the exception in this thread. It's not needed, and the default Spark
         // exception handler will kill the executor.
+    } finally {
+      reader.close()
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/16670578/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index 9657b5e..667410e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.execution.streaming.continuous
 
+import java.util.UUID
 import java.util.concurrent.TimeUnit
+import java.util.function.UnaryOperator
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}
@@ -52,7 +54,7 @@ class ContinuousExecution(
     sparkSession, name, checkpointRoot, analyzedPlan, sink,
     trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
 
-  @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty
+  @volatile protected var continuousSources: Seq[ContinuousReader] = _
   override protected def sources: Seq[BaseStreamingSource] = continuousSources
 
   override lazy val logicalPlan: LogicalPlan = {
@@ -78,15 +80,17 @@ class ContinuousExecution(
   }
 
   override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {
-    do {
-      try {
-        runContinuous(sparkSessionForStream)
-      } catch {
-        case _: InterruptedException if state.get().equals(RECONFIGURING) =>
-          // swallow exception and run again
-          state.set(ACTIVE)
+    val stateUpdate = new UnaryOperator[State] {
+      override def apply(s: State) = s match {
+        // If we ended the query to reconfigure, reset the state to active.
+        case RECONFIGURING => ACTIVE
+        case _ => s
       }
-    } while (state.get() == ACTIVE)
+    }
+
+    do {
+      runContinuous(sparkSessionForStream)
+    } while (state.updateAndGet(stateUpdate) == ACTIVE)
   }
 
   /**
@@ -120,12 +124,16 @@ class ContinuousExecution(
         }
         committedOffsets = nextOffsets.toStreamProgress(sources)
 
-        // Forcibly align commit and offset logs by slicing off any spurious offset logs from
-        // a previous run. We can't allow commits to an epoch that a previous run reached but
-        // this run has not.
-        offsetLog.purgeAfter(latestEpochId)
+        // Get to an epoch ID that has definitely never been sent to a sink before. Since sink
+        // commit happens between offset log write and commit log write, this means an epoch ID
+        // which is not in the offset log.
+        val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse {
+          throw new IllegalStateException(
+            s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" +
+              s"an element.")
+        }
+        currentBatchId = latestOffsetEpoch + 1
 
-        currentBatchId = latestEpochId + 1
         logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets")
         nextOffsets
       case None =>
@@ -141,6 +149,7 @@ class ContinuousExecution(
    * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with.
    */
   private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
+    currentRunId = UUID.randomUUID
     // A list of attributes that will need to be updated.
     val replacements = new ArrayBuffer[(Attribute, Attribute)]
     // Translate from continuous relation to the underlying data source.
@@ -225,13 +234,11 @@ class ContinuousExecution(
           triggerExecutor.execute(() => {
             startTrigger()
 
-            if (reader.needsReconfiguration()) {
-              state.set(RECONFIGURING)
+            if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) {
               stopSources()
               if (queryExecutionThread.isAlive) {
                 sparkSession.sparkContext.cancelJobGroup(runId.toString)
                 queryExecutionThread.interrupt()
-                // No need to join - this thread is about to end anyway.
               }
               false
             } else if (isActive) {
@@ -259,6 +266,7 @@ class ContinuousExecution(
           sparkSessionForQuery, lastExecution)(lastExecution.toRdd)
       }
     } finally {
+      epochEndpoint.askSync[Unit](StopContinuousExecutionWrites)
       SparkEnv.get.rpcEnv.stop(epochEndpoint)
 
       epochUpdateThread.interrupt()
@@ -273,17 +281,22 @@ class ContinuousExecution(
       epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = {
     assert(continuousSources.length == 1, "only one continuous source supported currently")
 
-    if (partitionOffsets.contains(null)) {
-      // If any offset is null, that means the corresponding partition hasn't seen any data yet, so
-      // there's nothing meaningful to add to the offset log.
-    }
     val globalOffset = reader.mergeOffsets(partitionOffsets.toArray)
-    synchronized {
-      if (queryExecutionThread.isAlive) {
-        offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
-      } else {
-        return
-      }
+    val oldOffset = synchronized {
+      offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
+      offsetLog.get(epoch - 1)
+    }
+
+    // If offset hasn't changed since last epoch, there's been no new data.
+    if (oldOffset.contains(OffsetSeq.fill(globalOffset))) {
+      noNewData = true
+    }
+
+    awaitProgressLock.lock()
+    try {
+      awaitProgressLockCondition.signalAll()
+    } finally {
+      awaitProgressLock.unlock()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/16670578/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
index 98017c3..40dcbec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
@@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable
  */
 private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage
 
+/**
+ * The RpcEndpoint stop() will wait to clear out the message queue before terminating the
+ * object. This can lead to a race condition where the query restarts at epoch n, a new
+ * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1.
+ * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous
+ * message to stop any writes to the ContinuousExecution object.
+ */
+private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage
+
 // Init messages
 /**
  * Set the reader and writer partition counts. Tasks may not be started until the coordinator
@@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator(
     override val rpcEnv: RpcEnv)
   extends ThreadSafeRpcEndpoint with Logging {
 
+  private var queryWritesStopped: Boolean = false
+
   private var numReaderPartitions: Int = _
   private var numWriterPartitions: Int = _
 
@@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator(
         partitionCommits.remove(k)
       }
       for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) {
-        partitionCommits.remove(k)
+        partitionOffsets.remove(k)
       }
     }
   }
 
   override def receive: PartialFunction[Any, Unit] = {
+    // If we just drop these messages, we won't do any writes to the query. The lame duck tasks
+    // won't shed errors or anything.
+    case _ if queryWritesStopped => ()
+
     case CommitPartitionEpoch(partitionId, epoch, message) =>
       logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message")
       if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
@@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator(
     case SetWriterPartitions(numPartitions) =>
       numWriterPartitions = numPartitions
       context.reply(())
+
+    case StopContinuousExecutionWrites =>
+      queryWritesStopped = true
+      context.reply(())
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/16670578/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index db588ae..b5b4a05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
 import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2}
+import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport
 
 /**
  * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
         useTempCheckpointLocation = true,
         trigger = trigger)
     } else {
-      val dataSource =
-        DataSource(
-          df.sparkSession,
-          className = source,
-          options = extraOptions.toMap,
-          partitionColumns = normalizedParCols.getOrElse(Nil))
+      val sink = trigger match {
+        case _: ContinuousTrigger =>
+          val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
+          ds.newInstance() match {
+            case w: ContinuousWriteSupport => w
+            case _ => throw new AnalysisException(
+              s"Data source $source does not support continuous writing")
+          }
+        case _ =>
+          val ds = DataSource(
+            df.sparkSession,
+            className = source,
+            options = extraOptions.toMap,
+            partitionColumns = normalizedParCols.getOrElse(Nil))
+          ds.createSink(outputMode)
+      }
+
       df.sparkSession.sessionState.streamingQueryManager.startQuery(
         extraOptions.get("queryName"),
         extraOptions.get("checkpointLocation"),
         df,
         extraOptions.toMap,
-        dataSource.createSink(outputMode),
+        sink,
         outputMode,
         useTempCheckpointLocation = source == "console",
         recoverFromCheckpointLocation = true,

http://git-wip-us.apache.org/repos/asf/spark/blob/16670578/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index d46461f..0762895 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row}
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch}
+import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch}
 import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.StreamingQueryListener._
@@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
     StateStore.stop() // stop the state store maintenance thread and unload store providers
   }
 
+  protected val defaultTrigger = Trigger.ProcessingTime(0)
+  protected val defaultUseV2Sink = false
+
   /** How long to wait for an active stream to catch up when checking a result. */
   val streamingTimeout = 10.seconds
 
@@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
 
   /** Starts the stream, resuming if data has already been processed. It must not be running. */
   case class StartStream(
-      trigger: Trigger = Trigger.ProcessingTime(0),
+      trigger: Trigger = defaultTrigger,
       triggerClock: Clock = new SystemClock,
       additionalConfs: Map[String, String] = Map.empty,
       checkpointLocation: String = null)
@@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
   def testStream(
       _stream: Dataset[_],
       outputMode: OutputMode = OutputMode.Append,
-      useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized {
+      useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized {
     import org.apache.spark.sql.streaming.util.StreamManualClock
 
     // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently
@@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
 
     def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = {
       verify(currentStream != null, "stream not running")
-      // Get the map of source index to the current source objects
-      val indexToSource = currentStream
-        .logicalPlan
-        .collect { case StreamingExecutionRelation(s, _) => s }
-        .zipWithIndex
-        .map(_.swap)
-        .toMap
 
       // Block until all data added has been processed for all the source
       awaiting.foreach { case (sourceIndex, offset) =>
         failAfter(streamingTimeout) {
-          currentStream.awaitOffset(indexToSource(sourceIndex), offset)
+          currentStream.awaitOffset(sourceIndex, offset)
         }
       }
 
@@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
             // after starting the query.
             try {
               currentStream.awaitInitialization(streamingTimeout.toMillis)
+              currentStream match {
+                case s: ContinuousExecution => eventually("IncrementalExecution was not created") {
+                    s.lastExecution.executedPlan // will fail if lastExecution is null
+                  }
+                case _ =>
+              }
             } catch {
               case _: StreamingQueryException =>
                 // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well.
@@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
 
               def findSourceIndex(plan: LogicalPlan): Option[Int] = {
                 plan
-                  .collect { case StreamingExecutionRelation(s, _) => s }
+                  .collect {
+                    case StreamingExecutionRelation(s, _) => s
+                    case DataSourceV2Relation(_, r) => r
+                  }
                   .zipWithIndex
                   .find(_._1 == source)
                   .map(_._2)
@@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
                   findSourceIndex(query.logicalPlan)
                 }.orElse {
                   findSourceIndex(stream.logicalPlan)
+                }.orElse {
+                  queryToUse.flatMap { q =>
+                    findSourceIndex(q.lastExecution.logical)
+                  }
                 }.getOrElse {
                   throw new IllegalArgumentException(
-                    "Could find index of the source to which data was added")
+                    "Could not find index of the source to which data was added")
                 }
 
               // Store the expected offset of added data to wait for it later


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message