spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From lix...@apache.org
Subject [4/7] spark git commit: [SPARK-24882][SQL] Revert [] improve data source v2 API from branch 2.4
Date Wed, 12 Sep 2018 18:25:32 GMT
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 2cac865..7a007b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.Optional
+
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{Map => MutableMap}
 
@@ -26,9 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp,
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
 import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
-import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport}
-import org.apache.spark.sql.sources.v2._
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2}
+import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
 import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
 import org.apache.spark.util.{Clock, Utils}
 
@@ -49,8 +51,8 @@ class MicroBatchExecution(
 
   @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty
 
-  private val readSupportToDataSourceMap =
-    MutableMap.empty[MicroBatchReadSupport, (DataSourceV2, Map[String, String])]
+  private val readerToDataSourceMap =
+    MutableMap.empty[MicroBatchReader, (DataSourceV2, Map[String, String])]
 
   private val triggerExecutor = trigger match {
     case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock)
@@ -89,19 +91,20 @@ class MicroBatchExecution(
           StreamingExecutionRelation(source, output)(sparkSession)
         })
       case s @ StreamingRelationV2(
-        dataSourceV2: MicroBatchReadSupportProvider, sourceName, options, output, _) if
+        dataSourceV2: MicroBatchReadSupport, sourceName, options, output, _) if
           !disabledSources.contains(dataSourceV2.getClass.getCanonicalName) =>
         v2ToExecutionRelationMap.getOrElseUpdate(s, {
           // Materialize source to avoid creating it in every batch
           val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
-          val readSupport = dataSourceV2.createMicroBatchReadSupport(
+          val reader = dataSourceV2.createMicroBatchReader(
+            Optional.empty(), // user specified schema
             metadataPath,
             new DataSourceOptions(options.asJava))
           nextSourceId += 1
-          readSupportToDataSourceMap(readSupport) = dataSourceV2 -> options
-          logInfo(s"Using MicroBatchReadSupport [$readSupport] from " +
+          readerToDataSourceMap(reader) = dataSourceV2 -> options
+          logInfo(s"Using MicroBatchReader [$reader] from " +
             s"DataSourceV2 named '$sourceName' [$dataSourceV2]")
-          StreamingExecutionRelation(readSupport, output)(sparkSession)
+          StreamingExecutionRelation(reader, output)(sparkSession)
         })
       case s @ StreamingRelationV2(dataSourceV2, sourceName, _, output, v1Relation) =>
         v2ToExecutionRelationMap.getOrElseUpdate(s, {
@@ -341,19 +344,19 @@ class MicroBatchExecution(
         reportTimeTaken("getOffset") {
           (s, s.getOffset)
         }
-      case s: RateControlMicroBatchReadSupport =>
-        updateStatusMessage(s"Getting offsets from $s")
-        reportTimeTaken("latestOffset") {
-          val startOffset = availableOffsets
-            .get(s).map(off => s.deserializeOffset(off.json))
-            .getOrElse(s.initialOffset())
-          (s, Option(s.latestOffset(startOffset)))
-        }
-      case s: MicroBatchReadSupport =>
+      case s: MicroBatchReader =>
         updateStatusMessage(s"Getting offsets from $s")
-        reportTimeTaken("latestOffset") {
-          (s, Option(s.latestOffset()))
+        reportTimeTaken("setOffsetRange") {
+          // Once v1 streaming source execution is gone, we can refactor this away.
+          // For now, we set the range here to get the source to infer the available end offset,
+          // get that offset, and then set the range again when we later execute.
+          s.setOffsetRange(
+            toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
+            Optional.empty())
         }
+
+        val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() }
+        (s, Option(currentOffset))
     }.toMap
     availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get)
 
@@ -393,8 +396,8 @@ class MicroBatchExecution(
           if (prevBatchOff.isDefined) {
             prevBatchOff.get.toStreamProgress(sources).foreach {
               case (src: Source, off) => src.commit(off)
-              case (readSupport: MicroBatchReadSupport, off) =>
-                readSupport.commit(readSupport.deserializeOffset(off.json))
+              case (reader: MicroBatchReader, off) =>
+                reader.commit(reader.deserializeOffset(off.json))
               case (src, _) =>
                 throw new IllegalArgumentException(
                   s"Unknown source is found at constructNextBatch: $src")
@@ -438,34 +441,30 @@ class MicroBatchExecution(
               s"${batch.queryExecution.logical}")
           logDebug(s"Retrieving data from $source: $current -> $available")
           Some(source -> batch.logicalPlan)
-
-        // TODO(cloud-fan): for data source v2, the new batch is just a new `ScanConfigBuilder`, but
-        // to be compatible with streaming source v1, we return a logical plan as a new batch here.
-        case (readSupport: MicroBatchReadSupport, available)
-          if committedOffsets.get(readSupport).map(_ != available).getOrElse(true) =>
-          val current = committedOffsets.get(readSupport).map {
-            off => readSupport.deserializeOffset(off.json)
-          }
-          val endOffset: OffsetV2 = available match {
-            case v1: SerializedOffset => readSupport.deserializeOffset(v1.json)
+        case (reader: MicroBatchReader, available)
+          if committedOffsets.get(reader).map(_ != available).getOrElse(true) =>
+          val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json))
+          val availableV2: OffsetV2 = available match {
+            case v1: SerializedOffset => reader.deserializeOffset(v1.json)
             case v2: OffsetV2 => v2
           }
-          val startOffset = current.getOrElse(readSupport.initialOffset)
-          val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset, endOffset)
-          logDebug(s"Retrieving data from $readSupport: $current -> $endOffset")
+          reader.setOffsetRange(
+            toJava(current),
+            Optional.of(availableV2))
+          logDebug(s"Retrieving data from $reader: $current -> $availableV2")
 
-          val (source, options) = readSupport match {
+          val (source, options) = reader match {
             // `MemoryStream` is special. It's for test only and doesn't have a `DataSourceV2`
             // implementation. We provide a fake one here for explain.
             case _: MemoryStream[_] => MemoryStreamDataSource -> Map.empty[String, String]
             // Provide a fake value here just in case something went wrong, e.g. the reader gives
             // a wrong `equals` implementation.
-            case _ => readSupportToDataSourceMap.getOrElse(readSupport, {
+            case _ => readerToDataSourceMap.getOrElse(reader, {
               FakeDataSourceV2 -> Map.empty[String, String]
             })
           }
-          Some(readSupport -> StreamingDataSourceV2Relation(
-            readSupport.fullSchema().toAttributes, source, options, readSupport, scanConfigBuilder))
+          Some(reader -> StreamingDataSourceV2Relation(
+            reader.readSchema().toAttributes, source, options, reader))
         case _ => None
       }
     }
@@ -499,13 +498,13 @@ class MicroBatchExecution(
 
     val triggerLogicalPlan = sink match {
       case _: Sink => newAttributePlan
-      case s: StreamingWriteSupportProvider =>
-        val writer = s.createStreamingWriteSupport(
+      case s: StreamWriteSupport =>
+        val writer = s.createStreamWriter(
           s"$runId",
           newAttributePlan.schema,
           outputMode,
           new DataSourceOptions(extraOptions.asJava))
-        WriteToDataSourceV2(new MicroBatchWritSupport(currentBatchId, writer), newAttributePlan)
+        WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan)
       case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
     }
 
@@ -533,7 +532,7 @@ class MicroBatchExecution(
       SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) {
         sink match {
           case s: Sink => s.addBatch(currentBatchId, nextBatch)
-          case _: StreamingWriteSupportProvider =>
+          case _: StreamWriteSupport =>
             // This doesn't accumulate any data - it just forces execution of the microbatch writer.
             nextBatch.collect()
         }
@@ -557,6 +556,10 @@ class MicroBatchExecution(
       awaitProgressLock.unlock()
     }
   }
+
+  private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = {
+    Optional.ofNullable(scalaOption.orNull)
+  }
 }
 
 object MicroBatchExecution {

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index d4b5065..6a380ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalP
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
-import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport
+import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader
 import org.apache.spark.sql.streaming._
 import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent
 import org.apache.spark.util.Clock
@@ -251,7 +251,7 @@ trait ProgressReporter extends Logging {
       // Check whether the streaming query's logical plan has only V2 data sources
       val allStreamingLeaves =
         logicalPlan.collect { case s: StreamingExecutionRelation => s }
-      allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReadSupport] }
+      allStreamingLeaves.forall { _.source.isInstanceOf[MicroBatchReader] }
     }
 
     if (onlyDataSourceV2Sources) {
@@ -278,7 +278,7 @@ trait ProgressReporter extends Logging {
         new IdentityHashMap[DataSourceV2ScanExec, DataSourceV2ScanExec]()
 
       lastExecution.executedPlan.collectLeaves().foreach {
-        case s: DataSourceV2ScanExec if s.readSupport.isInstanceOf[BaseStreamingSource] =>
+        case s: DataSourceV2ScanExec if s.reader.isInstanceOf[BaseStreamingSource] =>
           uniqueStreamingExecLeavesMap.put(s, s)
         case _ =>
       }
@@ -286,7 +286,7 @@ trait ProgressReporter extends Logging {
       val sourceToInputRowsTuples =
         uniqueStreamingExecLeavesMap.values.asScala.map { execLeaf =>
           val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L)
-          val source = execLeaf.readSupport.asInstanceOf[BaseStreamingSource]
+          val source = execLeaf.reader.asInstanceOf[BaseStreamingSource]
           source -> numRows
         }.toSeq
       logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t"))

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala
deleted file mode 100644
index 1be0716..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/SimpleStreamingScanConfigBuilder.scala
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import org.apache.spark.sql.sources.v2.reader.{ScanConfig, ScanConfigBuilder}
-import org.apache.spark.sql.types.StructType
-
-/**
- * A very simple [[ScanConfigBuilder]] implementation that creates a simple [[ScanConfig]] to
- * carry schema and offsets for streaming data sources.
- */
-class SimpleStreamingScanConfigBuilder(
-    schema: StructType,
-    start: Offset,
-    end: Option[Offset] = None)
-  extends ScanConfigBuilder {
-
-  override def build(): ScanConfig = SimpleStreamingScanConfig(schema, start, end)
-}
-
-case class SimpleStreamingScanConfig(
-    readSchema: StructType,
-    start: Offset,
-    end: Option[Offset])
-  extends ScanConfig

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
index 4b696df..24195b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
 import org.apache.spark.sql.execution.LeafExecNode
 import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceV2}
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2}
 
 object StreamingRelation {
   def apply(dataSource: DataSource): StreamingRelation = {
@@ -83,7 +83,7 @@ case class StreamingExecutionRelation(
 
 // We have to pack in the V1 data source as a shim, for the case when a source implements
 // continuous processing (which is always V2) but only has V1 microbatch support. We don't
-// know at read time whether the query is continuous or not, so we need to be able to
+// know at read time whether the query is conntinuous or not, so we need to be able to
 // swap a V1 relation back in.
 /**
  * Used to link a [[DataSourceV2]] into a streaming
@@ -113,7 +113,7 @@ case class StreamingRelationV2(
  * Used to link a [[DataSourceV2]] into a continuous processing execution.
  */
 case class ContinuousExecutionRelation(
-    source: ContinuousReadSupportProvider,
+    source: ContinuousReadSupport,
     extraOptions: Map[String, String],
     output: Seq[Attribute])(session: SparkSession)
   extends LeafNode with MultiInstanceRelation {

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
index 9c5c16f..cfba100 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
@@ -18,10 +18,10 @@
 package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.sql._
-import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport
+import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter
 import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider}
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
 
@@ -31,16 +31,16 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame)
 }
 
 class ConsoleSinkProvider extends DataSourceV2
-  with StreamingWriteSupportProvider
+  with StreamWriteSupport
   with DataSourceRegister
   with CreatableRelationProvider {
 
-  override def createStreamingWriteSupport(
+  override def createStreamWriter(
       queryId: String,
       schema: StructType,
       mode: OutputMode,
-      options: DataSourceOptions): StreamingWriteSupport = {
-    new ConsoleWriteSupport(schema, options)
+      options: DataSourceOptions): StreamWriter = {
+    new ConsoleWriter(schema, options)
   }
 
   def createRelation(

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
index b68f67e..554a0b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
@@ -21,13 +21,12 @@ import org.apache.spark._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousPartitionReaderFactory
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousInputPartitionReader
 import org.apache.spark.util.NextIterator
 
 class ContinuousDataSourceRDDPartition(
     val index: Int,
-    val inputPartition: InputPartition)
+    val inputPartition: InputPartition[InternalRow])
   extends Partition with Serializable {
 
   // This is semantically a lazy val - it's initialized once the first time a call to
@@ -50,22 +49,15 @@ class ContinuousDataSourceRDD(
     sc: SparkContext,
     dataQueueSize: Int,
     epochPollIntervalMs: Long,
-    private val inputPartitions: Seq[InputPartition],
-    schema: StructType,
-    partitionReaderFactory: ContinuousPartitionReaderFactory)
+    private val readerInputPartitions: Seq[InputPartition[InternalRow]])
   extends RDD[InternalRow](sc, Nil) {
 
   override protected def getPartitions: Array[Partition] = {
-    inputPartitions.zipWithIndex.map {
+    readerInputPartitions.zipWithIndex.map {
       case (inputPartition, index) => new ContinuousDataSourceRDDPartition(index, inputPartition)
     }.toArray
   }
 
-  private def castPartition(split: Partition): ContinuousDataSourceRDDPartition = split match {
-    case p: ContinuousDataSourceRDDPartition => p
-    case _ => throw new SparkException(s"[BUG] Not a ContinuousDataSourceRDDPartition: $split")
-  }
-
   /**
    * Initialize the shared reader for this partition if needed, then read rows from it until
    * it returns null to signal the end of the epoch.
@@ -77,12 +69,10 @@ class ContinuousDataSourceRDD(
     }
 
     val readerForPartition = {
-      val partition = castPartition(split)
+      val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition]
       if (partition.queueReader == null) {
-        val partitionReader = partitionReaderFactory.createReader(
-          partition.inputPartition)
-        partition.queueReader = new ContinuousQueuedDataReader(
-          partition.index, partitionReader, schema, context, dataQueueSize, epochPollIntervalMs)
+        partition.queueReader =
+          new ContinuousQueuedDataReader(partition, context, dataQueueSize, epochPollIntervalMs)
       }
 
       partition.queueReader
@@ -103,6 +93,17 @@ class ContinuousDataSourceRDD(
   }
 
   override def getPreferredLocations(split: Partition): Seq[String] = {
-    castPartition(split).inputPartition.preferredLocations()
+    split.asInstanceOf[ContinuousDataSourceRDDPartition].inputPartition.preferredLocations()
+  }
+}
+
+object ContinuousDataSourceRDD {
+  private[continuous] def getContinuousReader(
+      reader: InputPartitionReader[InternalRow]): ContinuousInputPartitionReader[_] = {
+    reader match {
+      case r: ContinuousInputPartitionReader[InternalRow] => r
+      case _ =>
+        throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}")
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index ccca726..f104422 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -29,12 +29,13 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentDate, CurrentTimestamp}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, StreamingDataSourceV2Relation}
+import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
 import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, StreamingRelationV2, _}
 import org.apache.spark.sql.sources.v2
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, StreamingWriteSupportProvider}
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset}
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset}
 import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.{Clock, Utils}
 
 class ContinuousExecution(
@@ -42,7 +43,7 @@ class ContinuousExecution(
     name: String,
     checkpointRoot: String,
     analyzedPlan: LogicalPlan,
-    sink: StreamingWriteSupportProvider,
+    sink: StreamWriteSupport,
     trigger: Trigger,
     triggerClock: Clock,
     outputMode: OutputMode,
@@ -52,7 +53,7 @@ class ContinuousExecution(
     sparkSession, name, checkpointRoot, analyzedPlan, sink,
     trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
 
-  @volatile protected var continuousSources: Seq[ContinuousReadSupport] = Seq()
+  @volatile protected var continuousSources: Seq[ContinuousReader] = Seq()
   override protected def sources: Seq[BaseStreamingSource] = continuousSources
 
   // For use only in test harnesses.
@@ -62,8 +63,7 @@ class ContinuousExecution(
     val toExecutionRelationMap = MutableMap[StreamingRelationV2, ContinuousExecutionRelation]()
     analyzedPlan.transform {
       case r @ StreamingRelationV2(
-          source: ContinuousReadSupportProvider, _, extraReaderOptions, output, _) =>
-        // TODO: shall we create `ContinuousReadSupport` here instead of each reconfiguration?
+          source: ContinuousReadSupport, _, extraReaderOptions, output, _) =>
         toExecutionRelationMap.getOrElseUpdate(r, {
           ContinuousExecutionRelation(source, extraReaderOptions, output)(sparkSession)
         })
@@ -148,7 +148,8 @@ class ContinuousExecution(
         val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
         nextSourceId += 1
 
-        dataSource.createContinuousReadSupport(
+        dataSource.createContinuousReader(
+          java.util.Optional.empty[StructType](),
           metadataPath,
           new DataSourceOptions(extraReaderOptions.asJava))
     }
@@ -159,9 +160,9 @@ class ContinuousExecution(
     var insertedSourceId = 0
     val withNewSources = logicalPlan transform {
       case ContinuousExecutionRelation(source, options, output) =>
-        val readSupport = continuousSources(insertedSourceId)
+        val reader = continuousSources(insertedSourceId)
         insertedSourceId += 1
-        val newOutput = readSupport.fullSchema().toAttributes
+        val newOutput = reader.readSchema().toAttributes
 
         assert(output.size == newOutput.size,
           s"Invalid reader: ${Utils.truncatedString(output, ",")} != " +
@@ -169,10 +170,9 @@ class ContinuousExecution(
         replacements ++= output.zip(newOutput)
 
         val loggedOffset = offsets.offsets(0)
-        val realOffset = loggedOffset.map(off => readSupport.deserializeOffset(off.json))
-        val startOffset = realOffset.getOrElse(readSupport.initialOffset)
-        val scanConfigBuilder = readSupport.newScanConfigBuilder(startOffset)
-        StreamingDataSourceV2Relation(newOutput, source, options, readSupport, scanConfigBuilder)
+        val realOffset = loggedOffset.map(off => reader.deserializeOffset(off.json))
+        reader.setStartOffset(java.util.Optional.ofNullable(realOffset.orNull))
+        StreamingDataSourceV2Relation(newOutput, source, options, reader)
     }
 
     // Rewire the plan to use the new attributes that were returned by the source.
@@ -185,13 +185,17 @@ class ContinuousExecution(
           "CurrentTimestamp and CurrentDate not yet supported for continuous processing")
     }
 
-    val writer = sink.createStreamingWriteSupport(
+    val writer = sink.createStreamWriter(
       s"$runId",
       triggerLogicalPlan.schema,
       outputMode,
       new DataSourceOptions(extraOptions.asJava))
     val withSink = WriteToContinuousDataSource(writer, triggerLogicalPlan)
 
+    val reader = withSink.collect {
+      case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r
+    }.head
+
     reportTimeTaken("queryPlanning") {
       lastExecution = new IncrementalExecution(
         sparkSessionForQuery,
@@ -204,11 +208,6 @@ class ContinuousExecution(
       lastExecution.executedPlan // Force the lazy generation of execution plan
     }
 
-    val (readSupport, scanConfig) = lastExecution.executedPlan.collect {
-      case scan: DataSourceV2ScanExec if scan.readSupport.isInstanceOf[ContinuousReadSupport] =>
-        scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig
-    }.head
-
     sparkSessionForQuery.sparkContext.setLocalProperty(
       StreamExecution.IS_CONTINUOUS_PROCESSING, true.toString)
     sparkSessionForQuery.sparkContext.setLocalProperty(
@@ -226,16 +225,14 @@ class ContinuousExecution(
     // Use the parent Spark session for the endpoint since it's where this query ID is registered.
     val epochEndpoint =
       EpochCoordinatorRef.create(
-        writer, readSupport, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get)
+        writer, reader, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get)
     val epochUpdateThread = new Thread(new Runnable {
       override def run: Unit = {
         try {
           triggerExecutor.execute(() => {
             startTrigger()
 
-            val shouldReconfigure = readSupport.needsReconfiguration(scanConfig) &&
-              state.compareAndSet(ACTIVE, RECONFIGURING)
-            if (shouldReconfigure) {
+            if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) {
               if (queryExecutionThread.isAlive) {
                 queryExecutionThread.interrupt()
               }
@@ -285,12 +282,10 @@ class ContinuousExecution(
    * Report ending partition offsets for the given reader at the given epoch.
    */
   def addOffset(
-      epoch: Long,
-      readSupport: ContinuousReadSupport,
-      partitionOffsets: Seq[PartitionOffset]): Unit = {
+      epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = {
     assert(continuousSources.length == 1, "only one continuous source supported currently")
 
-    val globalOffset = readSupport.mergeOffsets(partitionOffsets.toArray)
+    val globalOffset = reader.mergeOffsets(partitionOffsets.toArray)
     val oldOffset = synchronized {
       offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
       offsetLog.get(epoch - 1)

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
index 65c5fc6..ec1dabd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
@@ -25,9 +25,8 @@ import scala.util.control.NonFatal
 import org.apache.spark.{SparkEnv, SparkException, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, PartitionOffset}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
+import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset
 import org.apache.spark.util.ThreadUtils
 
 /**
@@ -38,14 +37,15 @@ import org.apache.spark.util.ThreadUtils
  * offsets across epochs. Each compute() should call the next() method here until null is returned.
  */
 class ContinuousQueuedDataReader(
-    partitionIndex: Int,
-    reader: ContinuousPartitionReader[InternalRow],
-    schema: StructType,
+    partition: ContinuousDataSourceRDDPartition,
     context: TaskContext,
     dataQueueSize: Int,
     epochPollIntervalMs: Long) extends Closeable {
+  private val reader = partition.inputPartition.createPartitionReader()
+
   // Important sequencing - we must get our starting point before the provider threads start running
-  private var currentOffset: PartitionOffset = reader.getOffset
+  private var currentOffset: PartitionOffset =
+    ContinuousDataSourceRDD.getContinuousReader(reader).getOffset
 
   /**
    * The record types in the read buffer.
@@ -66,7 +66,7 @@ class ContinuousQueuedDataReader(
   epochMarkerExecutor.scheduleWithFixedDelay(
     epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
 
-  private val dataReaderThread = new DataReaderThread(schema)
+  private val dataReaderThread = new DataReaderThread
   dataReaderThread.setDaemon(true)
   dataReaderThread.start()
 
@@ -113,7 +113,7 @@ class ContinuousQueuedDataReader(
     currentEntry match {
       case EpochMarker =>
         epochCoordEndpoint.send(ReportPartitionOffset(
-          partitionIndex, EpochTracker.getCurrentEpoch.get, currentOffset))
+          partition.index, EpochTracker.getCurrentEpoch.get, currentOffset))
         null
       case ContinuousRow(row, offset) =>
         currentOffset = offset
@@ -128,16 +128,16 @@ class ContinuousQueuedDataReader(
 
   /**
    * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when
-   * a new row arrives to the [[ContinuousPartitionReader]].
+   * a new row arrives to the [[InputPartitionReader]].
    */
-  class DataReaderThread(schema: StructType) extends Thread(
+  class DataReaderThread extends Thread(
       s"continuous-reader--${context.partitionId()}--" +
         s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging {
     @volatile private[continuous] var failureReason: Throwable = _
-    private val toUnsafe = UnsafeProjection.create(schema)
 
     override def run(): Unit = {
       TaskContext.setTaskContext(context)
+      val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader)
       try {
         while (!shouldStop()) {
           if (!reader.next()) {
@@ -149,9 +149,8 @@ class ContinuousQueuedDataReader(
               return
             }
           }
-          // `InternalRow#copy` may not be properly implemented, for safety we convert to unsafe row
-          // before copy here.
-          queue.put(ContinuousRow(toUnsafe(reader.get()).copy(), reader.getOffset))
+
+          queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset))
         }
       } catch {
         case _: InterruptedException =>

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
index a6cde2b..551e07c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
@@ -17,22 +17,24 @@
 
 package org.apache.spark.sql.execution.streaming.continuous
 
+import scala.collection.JavaConverters._
+
 import org.json4s.DefaultFormats
 import org.json4s.jackson.Serialization
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.streaming.{RateStreamOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder, ValueRunTimeMsPair}
+import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
 import org.apache.spark.sql.execution.streaming.sources.RateStreamProvider
 import org.apache.spark.sql.sources.v2.DataSourceOptions
 import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming._
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset}
 import org.apache.spark.sql.types.StructType
 
 case class RateStreamPartitionOffset(
    partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset
 
-class RateStreamContinuousReadSupport(options: DataSourceOptions) extends ContinuousReadSupport {
+class RateStreamContinuousReader(options: DataSourceOptions) extends ContinuousReader {
   implicit val defaultFormats: DefaultFormats = DefaultFormats
 
   val creationTime = System.currentTimeMillis()
@@ -54,18 +56,18 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin
     RateStreamOffset(Serialization.read[Map[Int, ValueRunTimeMsPair]](json))
   }
 
-  override def fullSchema(): StructType = RateStreamProvider.SCHEMA
+  override def readSchema(): StructType = RateStreamProvider.SCHEMA
 
-  override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = {
-    new SimpleStreamingScanConfigBuilder(fullSchema(), start)
-  }
+  private var offset: Offset = _
 
-  override def initialOffset: Offset = createInitialOffset(numPartitions, creationTime)
+  override def setStartOffset(offset: java.util.Optional[Offset]): Unit = {
+    this.offset = offset.orElse(createInitialOffset(numPartitions, creationTime))
+  }
 
-  override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
-    val startOffset = config.asInstanceOf[SimpleStreamingScanConfig].start
+  override def getStartOffset(): Offset = offset
 
-    val partitionStartMap = startOffset match {
+  override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = {
+    val partitionStartMap = offset match {
       case off: RateStreamOffset => off.partitionToValueAndRunTimeMs
       case off =>
         throw new IllegalArgumentException(
@@ -88,12 +90,8 @@ class RateStreamContinuousReadSupport(options: DataSourceOptions) extends Contin
         i,
         numPartitions,
         perPartitionRate)
-    }.toArray
-  }
-
-  override def createContinuousReaderFactory(
-      config: ScanConfig): ContinuousPartitionReaderFactory = {
-    RateStreamContinuousReaderFactory
+        .asInstanceOf[InputPartition[InternalRow]]
+    }.asJava
   }
 
   override def commit(end: Offset): Unit = {}
@@ -120,23 +118,33 @@ case class RateStreamContinuousInputPartition(
     partitionIndex: Int,
     increment: Long,
     rowsPerSecond: Double)
-  extends InputPartition
-
-object RateStreamContinuousReaderFactory extends ContinuousPartitionReaderFactory {
-  override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = {
-    val p = partition.asInstanceOf[RateStreamContinuousInputPartition]
-    new RateStreamContinuousPartitionReader(
-      p.startValue, p.startTimeMs, p.partitionIndex, p.increment, p.rowsPerSecond)
+  extends ContinuousInputPartition[InternalRow] {
+
+  override def createContinuousReader(
+      offset: PartitionOffset): InputPartitionReader[InternalRow] = {
+    val rateStreamOffset = offset.asInstanceOf[RateStreamPartitionOffset]
+    require(rateStreamOffset.partition == partitionIndex,
+      s"Expected partitionIndex: $partitionIndex, but got: ${rateStreamOffset.partition}")
+    new RateStreamContinuousInputPartitionReader(
+      rateStreamOffset.currentValue,
+      rateStreamOffset.currentTimeMs,
+      partitionIndex,
+      increment,
+      rowsPerSecond)
   }
+
+  override def createPartitionReader(): InputPartitionReader[InternalRow] =
+    new RateStreamContinuousInputPartitionReader(
+      startValue, startTimeMs, partitionIndex, increment, rowsPerSecond)
 }
 
-class RateStreamContinuousPartitionReader(
+class RateStreamContinuousInputPartitionReader(
     startValue: Long,
     startTimeMs: Long,
     partitionIndex: Int,
     increment: Long,
     rowsPerSecond: Double)
-  extends ContinuousPartitionReader[InternalRow] {
+  extends ContinuousInputPartitionReader[InternalRow] {
   private var nextReadTime: Long = startTimeMs
   private val readTimeIncrement: Long = (1000 / rowsPerSecond).toLong
 

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
index 28ab244..56bfefd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
@@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.streaming.continuous
 import java.io.{BufferedReader, InputStreamReader, IOException}
 import java.net.Socket
 import java.sql.Timestamp
-import java.util.Calendar
+import java.util.{Calendar, List => JList}
 import javax.annotation.concurrent.GuardedBy
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable.ListBuffer
 
 import org.json4s.{DefaultFormats, NoTypeHints}
@@ -33,26 +34,24 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.rpc.RpcEndpointRef
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.streaming.{Offset => _, _}
+import org.apache.spark.sql.execution.streaming.{ContinuousRecordEndpoint, ContinuousRecordPartitionOffset, GetRecord}
 import org.apache.spark.sql.execution.streaming.sources.TextSocketReader
 import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming._
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.RpcUtils
 
 
 /**
- * A ContinuousReadSupport that reads text lines through a TCP socket, designed only for tutorials
- * and debugging. This ContinuousReadSupport will *not* work in production applications due to
- * multiple reasons, including no support for fault recovery.
+ * A ContinuousReader that reads text lines through a TCP socket, designed only for tutorials and
+ * debugging. This ContinuousReader will *not* work in production applications due to multiple
+ * reasons, including no support for fault recovery.
  *
  * The driver maintains a socket connection to the host-port, keeps the received messages in
  * buckets and serves the messages to the executors via a RPC endpoint.
  */
-class TextSocketContinuousReadSupport(options: DataSourceOptions)
-  extends ContinuousReadSupport with Logging {
-
+class TextSocketContinuousReader(options: DataSourceOptions) extends ContinuousReader with Logging {
   implicit val defaultFormats: DefaultFormats = DefaultFormats
 
   private val host: String = options.get("host").get()
@@ -74,8 +73,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions)
   @GuardedBy("this")
   private var currentOffset: Int = -1
 
-  // Exposed for tests.
-  private[spark] var startOffset: TextSocketOffset = _
+  private var startOffset: TextSocketOffset = _
 
   private val recordEndpoint = new ContinuousRecordEndpoint(buckets, this)
   @volatile private var endpointRef: RpcEndpointRef = _
@@ -96,16 +94,16 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions)
     TextSocketOffset(Serialization.read[List[Int]](json))
   }
 
-  override def initialOffset(): Offset = {
-    startOffset = TextSocketOffset(List.fill(numPartitions)(0))
-    startOffset
+  override def setStartOffset(offset: java.util.Optional[Offset]): Unit = {
+    this.startOffset = offset
+      .orElse(TextSocketOffset(List.fill(numPartitions)(0)))
+      .asInstanceOf[TextSocketOffset]
+    recordEndpoint.setStartOffsets(startOffset.offsets)
   }
 
-  override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = {
-    new SimpleStreamingScanConfigBuilder(fullSchema(), start)
-  }
+  override def getStartOffset: Offset = startOffset
 
-  override def fullSchema(): StructType = {
+  override def readSchema(): StructType = {
     if (includeTimestamp) {
       TextSocketReader.SCHEMA_TIMESTAMP
     } else {
@@ -113,10 +111,8 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions)
     }
   }
 
-  override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
-    val startOffset = config.asInstanceOf[SimpleStreamingScanConfig]
-      .start.asInstanceOf[TextSocketOffset]
-    recordEndpoint.setStartOffsets(startOffset.offsets)
+  override def planInputPartitions(): JList[InputPartition[InternalRow]] = {
+
     val endpointName = s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}"
     endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint)
 
@@ -136,13 +132,10 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions)
 
     startOffset.offsets.zipWithIndex.map {
       case (offset, i) =>
-        TextSocketContinuousInputPartition(endpointName, i, offset, includeTimestamp)
-    }.toArray
-  }
+        TextSocketContinuousInputPartition(
+          endpointName, i, offset, includeTimestamp): InputPartition[InternalRow]
+    }.asJava
 
-  override def createContinuousReaderFactory(
-      config: ScanConfig): ContinuousPartitionReaderFactory = {
-    TextSocketReaderFactory
   }
 
   override def commit(end: Offset): Unit = synchronized {
@@ -197,7 +190,7 @@ class TextSocketContinuousReadSupport(options: DataSourceOptions)
               logWarning(s"Stream closed by $host:$port")
               return
             }
-            TextSocketContinuousReadSupport.this.synchronized {
+            TextSocketContinuousReader.this.synchronized {
               currentOffset += 1
               val newData = (line,
                 Timestamp.valueOf(
@@ -228,30 +221,25 @@ case class TextSocketContinuousInputPartition(
     driverEndpointName: String,
     partitionId: Int,
     startOffset: Int,
-    includeTimestamp: Boolean) extends InputPartition
-
-
-object TextSocketReaderFactory extends ContinuousPartitionReaderFactory {
+    includeTimestamp: Boolean)
+extends InputPartition[InternalRow] {
 
-  override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = {
-    val p = partition.asInstanceOf[TextSocketContinuousInputPartition]
-    new TextSocketContinuousPartitionReader(
-      p.driverEndpointName, p.partitionId, p.startOffset, p.includeTimestamp)
-  }
+  override def createPartitionReader(): InputPartitionReader[InternalRow] =
+    new TextSocketContinuousInputPartitionReader(driverEndpointName, partitionId, startOffset,
+      includeTimestamp)
 }
 
-
 /**
  * Continuous text socket input partition reader.
  *
  * Polls the driver endpoint for new records.
  */
-class TextSocketContinuousPartitionReader(
+class TextSocketContinuousInputPartitionReader(
     driverEndpointName: String,
     partitionId: Int,
     startOffset: Int,
     includeTimestamp: Boolean)
-  extends ContinuousPartitionReader[InternalRow] {
+  extends ContinuousInputPartitionReader[InternalRow] {
 
   private val endpoint = RpcUtils.makeDriverRef(
     driverEndpointName,

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
index a08411d..967dbe2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
@@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.streaming.continuous
 import org.apache.spark.{Partition, SparkEnv, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.sources.v2.writer.DataWriter
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory}
 import org.apache.spark.util.Utils
 
 /**
@@ -32,7 +31,7 @@ import org.apache.spark.util.Utils
  *
  * We keep repeating prev.compute() and writing new epochs until the query is shut down.
  */
-class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDataWriterFactory)
+class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow])
     extends RDD[Unit](prev) {
 
   override val partitioner = prev.partitioner
@@ -51,7 +50,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writerFactory: StreamingDat
       Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
         try {
           val dataIterator = prev.compute(split, context)
-          dataWriter = writerFactory.createWriter(
+          dataWriter = writeTask.createDataWriter(
             context.partitionId(),
             context.taskAttemptId(),
             EpochTracker.getCurrentEpoch.get)

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
index 2238ce2..8877ebe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
@@ -23,9 +23,9 @@ import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
 import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, PartitionOffset}
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, PartitionOffset}
 import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
 import org.apache.spark.util.RpcUtils
 
 private[continuous] sealed trait EpochCoordinatorMessage extends Serializable
@@ -82,15 +82,15 @@ private[sql] object EpochCoordinatorRef extends Logging {
    * Create a reference to a new [[EpochCoordinator]].
    */
   def create(
-      writeSupport: StreamingWriteSupport,
-      readSupport: ContinuousReadSupport,
+      writer: StreamWriter,
+      reader: ContinuousReader,
       query: ContinuousExecution,
       epochCoordinatorId: String,
       startEpoch: Long,
       session: SparkSession,
       env: SparkEnv): RpcEndpointRef = synchronized {
     val coordinator = new EpochCoordinator(
-      writeSupport, readSupport, query, startEpoch, session, env.rpcEnv)
+      writer, reader, query, startEpoch, session, env.rpcEnv)
     val ref = env.rpcEnv.setupEndpoint(endpointName(epochCoordinatorId), coordinator)
     logInfo("Registered EpochCoordinator endpoint")
     ref
@@ -115,8 +115,8 @@ private[sql] object EpochCoordinatorRef extends Logging {
  *   have both committed and reported an end offset for a given epoch.
  */
 private[continuous] class EpochCoordinator(
-    writeSupport: StreamingWriteSupport,
-    readSupport: ContinuousReadSupport,
+    writer: StreamWriter,
+    reader: ContinuousReader,
     query: ContinuousExecution,
     startEpoch: Long,
     session: SparkSession,
@@ -198,7 +198,7 @@ private[continuous] class EpochCoordinator(
       s"and is ready to be committed. Committing epoch $epoch.")
     // Sequencing is important here. We must commit to the writer before recording the commit
     // in the query, or we will end up dropping the commit if we restart in the middle.
-    writeSupport.commit(epoch, messages.toArray)
+    writer.commit(epoch, messages.toArray)
     query.commit(epoch)
   }
 
@@ -220,7 +220,7 @@ private[continuous] class EpochCoordinator(
         partitionOffsets.collect { case ((e, _), o) if e == epoch => o }
       if (thisEpochOffsets.size == numReaderPartitions) {
         logDebug(s"Epoch $epoch has offsets reported from all partitions: $thisEpochOffsets")
-        query.addOffset(epoch, readSupport, thisEpochOffsets.toSeq)
+        query.addOffset(epoch, reader, thisEpochOffsets.toSeq)
         resolveCommitsAtEpoch(epoch)
       }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
index 7ad21cc..943c731 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
@@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous
 
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
 
 /**
  * The logical plan for writing data in a continuous stream.
  */
 case class WriteToContinuousDataSource(
-    writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan {
+    writer: StreamWriter, query: LogicalPlan) extends LogicalPlan {
   override def children: Seq[LogicalPlan] = Seq(query)
   override def output: Seq[Attribute] = Nil
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
index c216b61..927d3a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
@@ -26,21 +26,21 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.streaming.StreamExecution
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
 
 /**
- * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]].
+ * The physical plan for writing data into a continuous processing [[StreamWriter]].
  */
-case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan)
+case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPlan)
     extends SparkPlan with Logging {
   override def children: Seq[SparkPlan] = Seq(query)
   override def output: Seq[Attribute] = Nil
 
   override protected def doExecute(): RDD[InternalRow] = {
-    val writerFactory = writeSupport.createStreamingWriterFactory()
+    val writerFactory = writer.createWriterFactory()
     val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)
 
-    logInfo(s"Start processing data source write support: $writeSupport. " +
+    logInfo(s"Start processing data source writer: $writer. " +
       s"The input RDD has ${rdd.partitions.length} partitions.")
     EpochCoordinatorRef.get(
       sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index adf52ab..f81abdc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -17,9 +17,12 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.{util => ju}
+import java.util.Optional
 import java.util.concurrent.atomic.AtomicInteger
 import javax.annotation.concurrent.GuardedBy
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, ListBuffer}
 import scala.util.control.NonFatal
 
@@ -31,8 +34,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
 import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset => OffsetV2}
+import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
@@ -64,7 +67,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas
     addData(data.toTraversable)
   }
 
-  def fullSchema(): StructType = encoder.schema
+  def readSchema(): StructType = encoder.schema
 
   protected def logicalPlan: LogicalPlan
 
@@ -77,7 +80,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas
  * available.
  */
 case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
-    extends MemoryStreamBase[A](sqlContext) with MicroBatchReadSupport with Logging {
+    extends MemoryStreamBase[A](sqlContext) with MicroBatchReader with Logging {
 
   protected val logicalPlan: LogicalPlan =
     StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
@@ -119,22 +122,24 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
 
   override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]"
 
-  override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)
+  override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = {
+    synchronized {
+      startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset]
+      endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset]
+    }
+  }
 
-  override def initialOffset: OffsetV2 = LongOffset(-1)
+  override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)
 
-  override def latestOffset(): OffsetV2 = {
-    if (currentOffset.offset == -1) null else currentOffset
+  override def getStartOffset: OffsetV2 = synchronized {
+    if (startOffset.offset == -1) null else startOffset
   }
 
-  override def newScanConfigBuilder(start: OffsetV2, end: OffsetV2): ScanConfigBuilder = {
-    new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end))
+  override def getEndOffset: OffsetV2 = synchronized {
+    if (endOffset.offset == -1) null else endOffset
   }
 
-  override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
-    val sc = config.asInstanceOf[SimpleStreamingScanConfig]
-    val startOffset = sc.start.asInstanceOf[LongOffset]
-    val endOffset = sc.end.get.asInstanceOf[LongOffset]
+  override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
     synchronized {
       // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
       val startOrdinal = startOffset.offset.toInt + 1
@@ -151,15 +156,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
       logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal))
 
       newBlocks.map { block =>
-        new MemoryStreamInputPartition(block)
-      }.toArray
+        new MemoryStreamInputPartition(block): InputPartition[InternalRow]
+      }.asJava
     }
   }
 
-  override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = {
-    MemoryStreamReaderFactory
-  }
-
   private def generateDebugString(
       rows: Seq[UnsafeRow],
       startOrdinal: Int,
@@ -200,12 +201,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
 }
 
 
-class MemoryStreamInputPartition(val records: Array[UnsafeRow]) extends InputPartition
-
-object MemoryStreamReaderFactory extends PartitionReaderFactory {
-  override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
-    val records = partition.asInstanceOf[MemoryStreamInputPartition].records
-    new PartitionReader[InternalRow] {
+class MemoryStreamInputPartition(records: Array[UnsafeRow])
+  extends InputPartition[InternalRow] {
+  override def createPartitionReader(): InputPartitionReader[InternalRow] = {
+    new InputPartitionReader[InternalRow] {
       private var currentIndex = -1
 
       override def next(): Boolean = {

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala
deleted file mode 100644
index 833e62f..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.sources
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{Dataset, SparkSession}
-import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
-import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport}
-import org.apache.spark.sql.types.StructType
-
-/** Common methods used to create writes for the the console sink */
-class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions)
-    extends StreamingWriteSupport with Logging {
-
-  // Number of rows to display, by default 20 rows
-  protected val numRowsToShow = options.getInt("numRows", 20)
-
-  // Truncate the displayed data if it is too long, by default it is true
-  protected val isTruncated = options.getBoolean("truncate", true)
-
-  assert(SparkSession.getActiveSession.isDefined)
-  protected val spark = SparkSession.getActiveSession.get
-
-  def createStreamingWriterFactory(): StreamingDataWriterFactory = PackedRowWriterFactory
-
-  override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
-    // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2
-    // behavior.
-    printRows(messages, schema, s"Batch: $epochId")
-  }
-
-  def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
-
-  protected def printRows(
-      commitMessages: Array[WriterCommitMessage],
-      schema: StructType,
-      printMessage: String): Unit = {
-    val rows = commitMessages.collect {
-      case PackedRowCommitMessage(rs) => rs
-    }.flatten
-
-    // scalastyle:off println
-    println("-------------------------------------------")
-    println(printMessage)
-    println("-------------------------------------------")
-    // scalastyle:off println
-    Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows))
-      .show(numRowsToShow, isTruncated)
-  }
-
-  override def toString(): String = {
-    s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]"
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala
new file mode 100644
index 0000000..fd45ba5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Dataset, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.sources.v2.DataSourceOptions
+import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.sql.types.StructType
+
+/** Common methods used to create writes for the the console sink */
+class ConsoleWriter(schema: StructType, options: DataSourceOptions)
+    extends StreamWriter with Logging {
+
+  // Number of rows to display, by default 20 rows
+  protected val numRowsToShow = options.getInt("numRows", 20)
+
+  // Truncate the displayed data if it is too long, by default it is true
+  protected val isTruncated = options.getBoolean("truncate", true)
+
+  assert(SparkSession.getActiveSession.isDefined)
+  protected val spark = SparkSession.getActiveSession.get
+
+  def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory
+
+  override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
+    // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2
+    // behavior.
+    printRows(messages, schema, s"Batch: $epochId")
+  }
+
+  def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+
+  protected def printRows(
+      commitMessages: Array[WriterCommitMessage],
+      schema: StructType,
+      printMessage: String): Unit = {
+    val rows = commitMessages.collect {
+      case PackedRowCommitMessage(rs) => rs
+    }.flatten
+
+    // scalastyle:off println
+    println("-------------------------------------------")
+    println(printMessage)
+    println("-------------------------------------------")
+    // scalastyle:off println
+    Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows))
+      .show(numRowsToShow, isTruncated)
+  }
+
+  override def toString(): String = {
+    s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]"
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
index dbcc448..4a32217 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -17,22 +17,26 @@
 
 package org.apache.spark.sql.execution.streaming.sources
 
+import java.{util => ju}
+import java.util.Optional
 import java.util.concurrent.atomic.AtomicInteger
 import javax.annotation.concurrent.GuardedBy
 
+import scala.collection.JavaConverters._
 import scala.collection.mutable.ListBuffer
 
 import org.json4s.NoTypeHints
 import org.json4s.jackson.Serialization
 
 import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
 import org.apache.spark.sql.{Encoder, SQLContext}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.streaming.{Offset => _, _}
-import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions}
-import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig, ScanConfigBuilder}
-import org.apache.spark.sql.sources.v2.reader.streaming._
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions}
+import org.apache.spark.sql.sources.v2.reader.InputPartition
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.RpcUtils
 
 /**
@@ -44,9 +48,7 @@ import org.apache.spark.util.RpcUtils
  *    the specified offset within the list, or null if that offset doesn't yet have a record.
  */
 class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2)
-  extends MemoryStreamBase[A](sqlContext)
-    with ContinuousReadSupportProvider with ContinuousReadSupport {
-
+  extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport {
   private implicit val formats = Serialization.formats(NoTypeHints)
 
   protected val logicalPlan =
@@ -57,6 +59,9 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
   @GuardedBy("this")
   private val records = Seq.fill(numPartitions)(new ListBuffer[A])
 
+  @GuardedBy("this")
+  private var startOffset: ContinuousMemoryStreamOffset = _
+
   private val recordEndpoint = new ContinuousRecordEndpoint(records, this)
   @volatile private var endpointRef: RpcEndpointRef = _
 
@@ -70,8 +75,15 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
     ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap)
   }
 
-  override def initialOffset(): Offset = {
-    ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap)
+  override def setStartOffset(start: Optional[Offset]): Unit = synchronized {
+    // Inferred initial offset is position 0 in each partition.
+    startOffset = start.orElse {
+      ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap)
+    }.asInstanceOf[ContinuousMemoryStreamOffset]
+  }
+
+  override def getStartOffset: Offset = synchronized {
+    startOffset
   }
 
   override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = {
@@ -86,40 +98,34 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
     )
   }
 
-  override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = {
-    new SimpleStreamingScanConfigBuilder(fullSchema(), start)
-  }
-
-  override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
-    val startOffset = config.asInstanceOf[SimpleStreamingScanConfig]
-      .start.asInstanceOf[ContinuousMemoryStreamOffset]
+  override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = {
     synchronized {
       val endpointName = s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id"
       endpointRef =
         recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint)
 
       startOffset.partitionNums.map {
-        case (part, index) => ContinuousMemoryStreamInputPartition(endpointName, part, index)
-      }.toArray
+        case (part, index) =>
+          new ContinuousMemoryStreamInputPartition(
+            endpointName, part, index): InputPartition[InternalRow]
+      }.toList.asJava
     }
   }
 
-  override def createContinuousReaderFactory(
-      config: ScanConfig): ContinuousPartitionReaderFactory = {
-    ContinuousMemoryStreamReaderFactory
-  }
-
   override def stop(): Unit = {
     if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef)
   }
 
   override def commit(end: Offset): Unit = {}
 
-  // ContinuousReadSupportProvider implementation
+  // ContinuousReadSupport implementation
   // This is necessary because of how StreamTest finds the source for AddDataMemory steps.
-  override def createContinuousReadSupport(
+  def createContinuousReader(
+      schema: Optional[StructType],
       checkpointLocation: String,
-      options: DataSourceOptions): ContinuousReadSupport = this
+      options: DataSourceOptions): ContinuousReader = {
+    this
+  }
 }
 
 object ContinuousMemoryStream {
@@ -135,16 +141,12 @@ object ContinuousMemoryStream {
 /**
  * An input partition for continuous memory stream.
  */
-case class ContinuousMemoryStreamInputPartition(
+class ContinuousMemoryStreamInputPartition(
     driverEndpointName: String,
     partition: Int,
-    startOffset: Int) extends InputPartition
-
-object ContinuousMemoryStreamReaderFactory extends ContinuousPartitionReaderFactory {
-  override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = {
-    val p = partition.asInstanceOf[ContinuousMemoryStreamInputPartition]
-    new ContinuousMemoryStreamPartitionReader(p.driverEndpointName, p.partition, p.startOffset)
-  }
+    startOffset: Int) extends InputPartition[InternalRow] {
+  override def createPartitionReader: ContinuousMemoryStreamInputPartitionReader =
+    new ContinuousMemoryStreamInputPartitionReader(driverEndpointName, partition, startOffset)
 }
 
 /**
@@ -152,10 +154,10 @@ object ContinuousMemoryStreamReaderFactory extends ContinuousPartitionReaderFact
  *
  * Polls the driver endpoint for new records.
  */
-class ContinuousMemoryStreamPartitionReader(
+class ContinuousMemoryStreamInputPartitionReader(
     driverEndpointName: String,
     partition: Int,
-    startOffset: Int) extends ContinuousPartitionReader[InternalRow] {
+    startOffset: Int) extends ContinuousInputPartitionReader[InternalRow] {
   private val endpoint = RpcUtils.makeDriverRef(
     driverEndpointName,
     SparkEnv.get.conf,

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala
deleted file mode 100644
index 4218fd5..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.sources
-
-import org.apache.spark.sql.{ForeachWriter, SparkSession}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.python.PythonForeachWriter
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider}
-import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage}
-import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport}
-import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.StructType
-
-/**
- * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified
- * [[ForeachWriter]].
- *
- * @param writer The [[ForeachWriter]] to process all data.
- * @param converter An object to convert internal rows to target type T. Either it can be
- *                  a [[ExpressionEncoder]] or a direct converter function.
- * @tparam T The expected type of the sink.
- */
-case class ForeachWriteSupportProvider[T](
-    writer: ForeachWriter[T],
-    converter: Either[ExpressionEncoder[T], InternalRow => T])
-  extends StreamingWriteSupportProvider {
-
-  override def createStreamingWriteSupport(
-      queryId: String,
-      schema: StructType,
-      mode: OutputMode,
-      options: DataSourceOptions): StreamingWriteSupport = {
-    new StreamingWriteSupport {
-      override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
-      override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
-
-      override def createStreamingWriterFactory(): StreamingDataWriterFactory = {
-        val rowConverter: InternalRow => T = converter match {
-          case Left(enc) =>
-            val boundEnc = enc.resolveAndBind(
-              schema.toAttributes,
-              SparkSession.getActiveSession.get.sessionState.analyzer)
-            boundEnc.fromRow
-          case Right(func) =>
-            func
-        }
-        ForeachWriterFactory(writer, rowConverter)
-      }
-
-      override def toString: String = "ForeachSink"
-    }
-  }
-}
-
-object ForeachWriteSupportProvider {
-  def apply[T](
-      writer: ForeachWriter[T],
-      encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = {
-    writer match {
-      case pythonWriter: PythonForeachWriter =>
-        new ForeachWriteSupportProvider[UnsafeRow](
-          pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow]))
-      case _ =>
-        new ForeachWriteSupportProvider[T](writer, Left(encoder))
-    }
-  }
-}
-
-case class ForeachWriterFactory[T](
-    writer: ForeachWriter[T],
-    rowConverter: InternalRow => T)
-  extends StreamingDataWriterFactory {
-  override def createWriter(
-      partitionId: Int,
-      taskId: Long,
-      epochId: Long): ForeachDataWriter[T] = {
-    new ForeachDataWriter(writer, rowConverter, partitionId, epochId)
-  }
-}
-
-/**
- * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]].
- *
- * @param writer The [[ForeachWriter]] to process all data.
- * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]]
- * @param partitionId
- * @param epochId
- * @tparam T The type expected by the writer.
- */
-class ForeachDataWriter[T](
-    writer: ForeachWriter[T],
-    rowConverter: InternalRow => T,
-    partitionId: Int,
-    epochId: Long)
-  extends DataWriter[InternalRow] {
-
-  // If open returns false, we should skip writing rows.
-  private val opened = writer.open(partitionId, epochId)
-
-  override def write(record: InternalRow): Unit = {
-    if (!opened) return
-
-    try {
-      writer.process(rowConverter(record))
-    } catch {
-      case t: Throwable =>
-        writer.close(t)
-        throw t
-    }
-  }
-
-  override def commit(): WriterCommitMessage = {
-    writer.close(null)
-    ForeachWriterCommitMessage
-  }
-
-  override def abort(): Unit = {}
-}
-
-/**
- * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination.
- */
-case object ForeachWriterCommitMessage extends WriterCommitMessage

http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
new file mode 100644
index 0000000..e8ce21c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import org.apache.spark.sql.{ForeachWriter, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.PythonForeachWriter
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified
+ * [[ForeachWriter]].
+ *
+ * @param writer The [[ForeachWriter]] to process all data.
+ * @param converter An object to convert internal rows to target type T. Either it can be
+ *                  a [[ExpressionEncoder]] or a direct converter function.
+ * @tparam T The expected type of the sink.
+ */
+case class ForeachWriterProvider[T](
+    writer: ForeachWriter[T],
+    converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport {
+
+  override def createStreamWriter(
+      queryId: String,
+      schema: StructType,
+      mode: OutputMode,
+      options: DataSourceOptions): StreamWriter = {
+    new StreamWriter {
+      override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+      override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+
+      override def createWriterFactory(): DataWriterFactory[InternalRow] = {
+        val rowConverter: InternalRow => T = converter match {
+          case Left(enc) =>
+            val boundEnc = enc.resolveAndBind(
+              schema.toAttributes,
+              SparkSession.getActiveSession.get.sessionState.analyzer)
+            boundEnc.fromRow
+          case Right(func) =>
+            func
+        }
+        ForeachWriterFactory(writer, rowConverter)
+      }
+
+      override def toString: String = "ForeachSink"
+    }
+  }
+}
+
+object ForeachWriterProvider {
+  def apply[T](
+      writer: ForeachWriter[T],
+      encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = {
+    writer match {
+      case pythonWriter: PythonForeachWriter =>
+        new ForeachWriterProvider[UnsafeRow](
+          pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow]))
+      case _ =>
+        new ForeachWriterProvider[T](writer, Left(encoder))
+    }
+  }
+}
+
+case class ForeachWriterFactory[T](
+    writer: ForeachWriter[T],
+    rowConverter: InternalRow => T)
+  extends DataWriterFactory[InternalRow] {
+  override def createDataWriter(
+      partitionId: Int,
+      taskId: Long,
+      epochId: Long): ForeachDataWriter[T] = {
+    new ForeachDataWriter(writer, rowConverter, partitionId, epochId)
+  }
+}
+
+/**
+ * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]].
+ *
+ * @param writer The [[ForeachWriter]] to process all data.
+ * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]]
+ * @param partitionId
+ * @param epochId
+ * @tparam T The type expected by the writer.
+ */
+class ForeachDataWriter[T](
+    writer: ForeachWriter[T],
+    rowConverter: InternalRow => T,
+    partitionId: Int,
+    epochId: Long)
+  extends DataWriter[InternalRow] {
+
+  // If open returns false, we should skip writing rows.
+  private val opened = writer.open(partitionId, epochId)
+
+  override def write(record: InternalRow): Unit = {
+    if (!opened) return
+
+    try {
+      writer.process(rowConverter(record))
+    } catch {
+      case t: Throwable =>
+        writer.close(t)
+        throw t
+    }
+  }
+
+  override def commit(): WriterCommitMessage = {
+    writer.close(null)
+    ForeachWriterCommitMessage
+  }
+
+  override def abort(): Unit = {}
+}
+
+/**
+ * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination.
+ */
+case object ForeachWriterCommitMessage extends WriterCommitMessage


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message