spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From t...@apache.org
Subject spark git commit: [SPARK-23097][SQL][SS] Migrate text socket source to V2
Date Fri, 02 Mar 2018 20:27:47 GMT
Repository: spark
Updated Branches:
  refs/heads/master 3a4d15e5d -> 707e6506d


[SPARK-23097][SQL][SS] Migrate text socket source to V2

## What changes were proposed in this pull request?

This PR moves structured streaming text socket source to V2.

Questions: do we need to remove old "socket" source?

## How was this patch tested?

Unit test and manual verification.

Author: jerryshao <sshao@hortonworks.com>

Closes #20382 from jerryshao/SPARK-23097.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/707e6506
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/707e6506
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/707e6506

Branch: refs/heads/master
Commit: 707e6506d0dbdb598a6c99d666f3c66746113b67
Parents: 3a4d15e
Author: jerryshao <sshao@hortonworks.com>
Authored: Fri Mar 2 12:27:42 2018 -0800
Committer: Tathagata Das <tathagata.das1565@gmail.com>
Committed: Fri Mar 2 12:27:42 2018 -0800

----------------------------------------------------------------------
 ....apache.spark.sql.sources.DataSourceRegister |   2 +-
 .../sql/execution/datasources/DataSource.scala  |   5 +-
 .../spark/sql/execution/streaming/socket.scala  | 219 -------------
 .../execution/streaming/sources/socket.scala    | 255 ++++++++++++++++
 .../spark/sql/streaming/DataStreamReader.scala  |  21 +-
 .../streaming/TextSocketStreamSuite.scala       | 231 --------------
 .../sources/TextSocketStreamSuite.scala         | 306 +++++++++++++++++++
 7 files changed, 582 insertions(+), 457 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/707e6506/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
----------------------------------------------------------------------
diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index 0259c77..1fe9c09 100644
--- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -5,6 +5,6 @@ org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
 org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 org.apache.spark.sql.execution.datasources.text.TextFileFormat
 org.apache.spark.sql.execution.streaming.ConsoleSinkProvider
-org.apache.spark.sql.execution.streaming.TextSocketSourceProvider
 org.apache.spark.sql.execution.streaming.RateSourceProvider
+org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
 org.apache.spark.sql.execution.streaming.sources.RateSourceProviderV2

http://git-wip-us.apache.org/repos/asf/spark/blob/707e6506/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 6e1b572..35fcff6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
 import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.sources.TextSocketSourceProvider
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.streaming.OutputMode
@@ -563,6 +564,7 @@ object DataSource extends Logging {
     val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat"
     val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat"
     val nativeOrc = classOf[OrcFileFormat].getCanonicalName
+    val socket = classOf[TextSocketSourceProvider].getCanonicalName
 
     Map(
       "org.apache.spark.sql.jdbc" -> jdbc,
@@ -583,7 +585,8 @@ object DataSource extends Logging {
       "org.apache.spark.sql.execution.datasources.orc" -> nativeOrc,
       "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm,
       "org.apache.spark.ml.source.libsvm" -> libsvm,
-      "com.databricks.spark.csv" -> csv
+      "com.databricks.spark.csv" -> csv,
+      "org.apache.spark.sql.execution.streaming.TextSocketSourceProvider" -> socket
     )
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/707e6506/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
deleted file mode 100644
index 0b22cbc..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
+++ /dev/null
@@ -1,219 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import java.io.{BufferedReader, InputStreamReader, IOException}
-import java.net.Socket
-import java.sql.Timestamp
-import java.text.SimpleDateFormat
-import java.util.{Calendar, Locale}
-import javax.annotation.concurrent.GuardedBy
-
-import scala.collection.mutable.ListBuffer
-import scala.util.{Failure, Success, Try}
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
-import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
-import org.apache.spark.unsafe.types.UTF8String
-
-
-object TextSocketSource {
-  val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil)
-  val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) ::
-    StructField("timestamp", TimestampType) :: Nil)
-  val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
-}
-
-/**
- * A source that reads text lines through a TCP socket, designed only for tutorials and debugging.
- * This source will *not* work in production applications due to multiple reasons, including
no
- * support for fault recovery and keeping all of the text read in memory forever.
- */
-class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext)
-  extends Source with Logging {
-
-  @GuardedBy("this")
-  private var socket: Socket = null
-
-  @GuardedBy("this")
-  private var readThread: Thread = null
-
-  /**
-   * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive.
-   * Stored in a ListBuffer to facilitate removing committed batches.
-   */
-  @GuardedBy("this")
-  protected val batches = new ListBuffer[(String, Timestamp)]
-
-  @GuardedBy("this")
-  protected var currentOffset: LongOffset = new LongOffset(-1)
-
-  @GuardedBy("this")
-  protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
-
-  initialize()
-
-  private def initialize(): Unit = synchronized {
-    socket = new Socket(host, port)
-    val reader = new BufferedReader(new InputStreamReader(socket.getInputStream))
-    readThread = new Thread(s"TextSocketSource($host, $port)") {
-      setDaemon(true)
-
-      override def run(): Unit = {
-        try {
-          while (true) {
-            val line = reader.readLine()
-            if (line == null) {
-              // End of file reached
-              logWarning(s"Stream closed by $host:$port")
-              return
-            }
-            TextSocketSource.this.synchronized {
-              val newData = (line,
-                Timestamp.valueOf(
-                  TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime()))
-                )
-              currentOffset = currentOffset + 1
-              batches.append(newData)
-            }
-          }
-        } catch {
-          case e: IOException =>
-        }
-      }
-    }
-    readThread.start()
-  }
-
-  /** Returns the schema of the data from this source */
-  override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP
-  else TextSocketSource.SCHEMA_REGULAR
-
-  override def getOffset: Option[Offset] = synchronized {
-    if (currentOffset.offset == -1) {
-      None
-    } else {
-      Some(currentOffset)
-    }
-  }
-
-  /** Returns the data that is between the offsets (`start`, `end`]. */
-  override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized {
-    val startOrdinal =
-      start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
-    val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1
-
-    // Internal buffer only holds the batches after lastOffsetCommitted
-    val rawList = synchronized {
-      val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
-      val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
-      batches.slice(sliceStart, sliceEnd)
-    }
-
-    val rdd = sqlContext.sparkContext
-      .parallelize(rawList)
-      .map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) }
-    sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
-  }
-
-  override def commit(end: Offset): Unit = synchronized {
-    val newOffset = LongOffset.convert(end).getOrElse(
-      sys.error(s"TextSocketStream.commit() received an offset ($end) that did not " +
-        s"originate with an instance of this class")
-    )
-
-    val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
-
-    if (offsetDiff < 0) {
-      sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end")
-    }
-
-    batches.trimStart(offsetDiff)
-    lastOffsetCommitted = newOffset
-  }
-
-  /** Stop this source. */
-  override def stop(): Unit = synchronized {
-    if (socket != null) {
-      try {
-        // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way
to
-        // stop the readThread is to close the socket.
-        socket.close()
-      } catch {
-        case e: IOException =>
-      }
-      socket = null
-    }
-  }
-
-  override def toString: String = s"TextSocketSource[host: $host, port: $port]"
-}
-
-class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with
Logging {
-  private def parseIncludeTimestamp(params: Map[String, String]): Boolean = {
-    Try(params.getOrElse("includeTimestamp", "false").toBoolean) match {
-      case Success(bool) => bool
-      case Failure(_) =>
-        throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"")
-    }
-  }
-
-  /** Returns the name and schema of the source that can be used to continually read data.
*/
-  override def sourceSchema(
-      sqlContext: SQLContext,
-      schema: Option[StructType],
-      providerName: String,
-      parameters: Map[String, String]): (String, StructType) = {
-    logWarning("The socket source should not be used for production applications! " +
-      "It does not support recovery.")
-    if (!parameters.contains("host")) {
-      throw new AnalysisException("Set a host to read from with option(\"host\", ...).")
-    }
-    if (!parameters.contains("port")) {
-      throw new AnalysisException("Set a port to read from with option(\"port\", ...).")
-    }
-    if (schema.nonEmpty) {
-      throw new AnalysisException("The socket source does not support a user-specified schema.")
-    }
-
-    val sourceSchema =
-      if (parseIncludeTimestamp(parameters)) {
-        TextSocketSource.SCHEMA_TIMESTAMP
-      } else {
-        TextSocketSource.SCHEMA_REGULAR
-      }
-    ("textSocket", sourceSchema)
-  }
-
-  override def createSource(
-      sqlContext: SQLContext,
-      metadataPath: String,
-      schema: Option[StructType],
-      providerName: String,
-      parameters: Map[String, String]): Source = {
-    val host = parameters("host")
-    val port = parameters("port").toInt
-    new TextSocketSource(host, port, parseIncludeTimestamp(parameters), sqlContext)
-  }
-
-  /** String that represents the format that this data source provider uses. */
-  override def shortName(): String = "socket"
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/707e6506/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
new file mode 100644
index 0000000..5aae46b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.io.{BufferedReader, InputStreamReader, IOException}
+import java.net.Socket
+import java.sql.Timestamp
+import java.text.SimpleDateFormat
+import java.util.{Calendar, List => JList, Locale, Optional}
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
+import scala.util.{Failure, Success, Try}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.streaming.LongOffset
+import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
+import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
+
+object TextSocketMicroBatchReader {
+  val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil)
+  val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) ::
+    StructField("timestamp", TimestampType) :: Nil)
+  val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
+}
+
+/**
+ * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials
and
+ * debugging. This MicroBatchReader will *not* work in production applications due to multiple
+ * reasons, including no support for fault recovery.
+ */
+class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with
Logging {
+
+  private var startOffset: Offset = _
+  private var endOffset: Offset = _
+
+  private val host: String = options.get("host").get()
+  private val port: Int = options.get("port").get().toInt
+
+  @GuardedBy("this")
+  private var socket: Socket = null
+
+  @GuardedBy("this")
+  private var readThread: Thread = null
+
+  /**
+   * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive.
+   * Stored in a ListBuffer to facilitate removing committed batches.
+   */
+  @GuardedBy("this")
+  private val batches = new ListBuffer[(String, Timestamp)]
+
+  @GuardedBy("this")
+  private var currentOffset: LongOffset = LongOffset(-1L)
+
+  @GuardedBy("this")
+  private var lastOffsetCommitted: LongOffset = LongOffset(-1L)
+
+  initialize()
+
+  /** This method is only used for unit test */
+  private[sources] def getCurrentOffset(): LongOffset = synchronized {
+    currentOffset.copy()
+  }
+
+  private def initialize(): Unit = synchronized {
+    socket = new Socket(host, port)
+    val reader = new BufferedReader(new InputStreamReader(socket.getInputStream))
+    readThread = new Thread(s"TextSocketSource($host, $port)") {
+      setDaemon(true)
+
+      override def run(): Unit = {
+        try {
+          while (true) {
+            val line = reader.readLine()
+            if (line == null) {
+              // End of file reached
+              logWarning(s"Stream closed by $host:$port")
+              return
+            }
+            TextSocketMicroBatchReader.this.synchronized {
+              val newData = (line,
+                Timestamp.valueOf(
+                  TextSocketMicroBatchReader.DATE_FORMAT.format(Calendar.getInstance().getTime()))
+              )
+              currentOffset += 1
+              batches.append(newData)
+            }
+          }
+        } catch {
+          case e: IOException =>
+        }
+      }
+    }
+    readThread.start()
+  }
+
+  override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized
{
+    startOffset = start.orElse(LongOffset(-1L))
+    endOffset = end.orElse(currentOffset)
+  }
+
+  override def getStartOffset(): Offset = {
+    Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set"))
+  }
+
+  override def getEndOffset(): Offset = {
+    Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set"))
+  }
+
+  override def deserializeOffset(json: String): Offset = {
+    LongOffset(json.toLong)
+  }
+
+  override def readSchema(): StructType = {
+    if (options.getBoolean("includeTimestamp", false)) {
+      TextSocketMicroBatchReader.SCHEMA_TIMESTAMP
+    } else {
+      TextSocketMicroBatchReader.SCHEMA_REGULAR
+    }
+  }
+
+  override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+    assert(startOffset != null && endOffset != null,
+      "start offset and end offset should already be set before create read tasks.")
+
+    val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1
+    val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1
+
+    // Internal buffer only holds the batches after lastOffsetCommitted
+    val rawList = synchronized {
+      val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
+      val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
+      batches.slice(sliceStart, sliceEnd)
+    }
+
+    assert(SparkSession.getActiveSession.isDefined)
+    val spark = SparkSession.getActiveSession.get
+    val numPartitions = spark.sparkContext.defaultParallelism
+
+    val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)])
+    rawList.zipWithIndex.foreach { case (r, idx) =>
+      slices(idx % numPartitions).append(r)
+    }
+
+    (0 until numPartitions).map { i =>
+      val slice = slices(i)
+      new DataReaderFactory[Row] {
+        override def createDataReader(): DataReader[Row] = new DataReader[Row] {
+          private var currentIdx = -1
+
+          override def next(): Boolean = {
+            currentIdx += 1
+            currentIdx < slice.size
+          }
+
+          override def get(): Row = {
+            Row(slice(currentIdx)._1, slice(currentIdx)._2)
+          }
+
+          override def close(): Unit = {}
+        }
+      }
+    }.toList.asJava
+  }
+
+  override def commit(end: Offset): Unit = synchronized {
+    val newOffset = LongOffset.convert(end).getOrElse(
+      sys.error(s"TextSocketStream.commit() received an offset ($end) that did not " +
+        s"originate with an instance of this class")
+    )
+
+    val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
+
+    if (offsetDiff < 0) {
+      sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end")
+    }
+
+    batches.trimStart(offsetDiff)
+    lastOffsetCommitted = newOffset
+  }
+
+  /** Stop this source. */
+  override def stop(): Unit = synchronized {
+    if (socket != null) {
+      try {
+        // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way
to
+        // stop the readThread is to close the socket.
+        socket.close()
+      } catch {
+        case e: IOException =>
+      }
+      socket = null
+    }
+  }
+
+  override def toString: String = s"TextSocket[host: $host, port: $port]"
+}
+
+class TextSocketSourceProvider extends DataSourceV2
+  with MicroBatchReadSupport with DataSourceRegister with Logging {
+
+  private def checkParameters(params: DataSourceOptions): Unit = {
+    logWarning("The socket source should not be used for production applications! " +
+      "It does not support recovery.")
+    if (!params.get("host").isPresent) {
+      throw new AnalysisException("Set a host to read from with option(\"host\", ...).")
+    }
+    if (!params.get("port").isPresent) {
+      throw new AnalysisException("Set a port to read from with option(\"port\", ...).")
+    }
+    Try {
+      params.get("includeTimestamp").orElse("false").toBoolean
+    } match {
+      case Success(_) =>
+      case Failure(_) =>
+        throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"")
+    }
+  }
+
+  override def createMicroBatchReader(
+      schema: Optional[StructType],
+      checkpointLocation: String,
+      options: DataSourceOptions): MicroBatchReader = {
+    checkParameters(options)
+    if (schema.isPresent) {
+      throw new AnalysisException("The socket source does not support a user-specified schema.")
+    }
+
+    new TextSocketMicroBatchReader(options)
+  }
+
+  /** String that represents the format that this data source provider uses. */
+  override def shortName(): String = "socket"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/707e6506/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 61e22fa..c393dcd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
 import org.apache.spark.sql.sources.StreamSourceProvider
 import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
 
@@ -172,15 +173,25 @@ final class DataStreamReader private[sql](sparkSession: SparkSession)
extends Lo
     }
     ds match {
       case s: MicroBatchReadSupport =>
-        val tempReader = s.createMicroBatchReader(
-          Optional.ofNullable(userSpecifiedSchema.orNull),
-          Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath,
-          options)
+        var tempReader: MicroBatchReader = null
+        val schema = try {
+          tempReader = s.createMicroBatchReader(
+            Optional.ofNullable(userSpecifiedSchema.orNull),
+            Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath,
+            options)
+          tempReader.readSchema()
+        } finally {
+          // Stop tempReader to avoid side-effect thing
+          if (tempReader != null) {
+            tempReader.stop()
+            tempReader = null
+          }
+        }
         Dataset.ofRows(
           sparkSession,
           StreamingRelationV2(
             s, source, extraOptions.toMap,
-            tempReader.readSchema().toAttributes, v1Relation)(sparkSession))
+            schema.toAttributes, v1Relation)(sparkSession))
       case s: ContinuousReadSupport =>
         val tempReader = s.createContinuousReader(
           Optional.ofNullable(userSpecifiedSchema.orNull),

http://git-wip-us.apache.org/repos/asf/spark/blob/707e6506/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala
deleted file mode 100644
index ec11549..0000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala
+++ /dev/null
@@ -1,231 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import java.io.{IOException, OutputStreamWriter}
-import java.net.ServerSocket
-import java.sql.Timestamp
-import java.util.concurrent.LinkedBlockingQueue
-
-import org.scalatest.BeforeAndAfterEach
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.streaming.StreamTest
-import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
-
-class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach
{
-  import testImplicits._
-
-  override def afterEach() {
-    sqlContext.streams.active.foreach(_.stop())
-    if (serverThread != null) {
-      serverThread.interrupt()
-      serverThread.join()
-      serverThread = null
-    }
-    if (source != null) {
-      source.stop()
-      source = null
-    }
-  }
-
-  private var serverThread: ServerThread = null
-  private var source: Source = null
-
-  test("basic usage") {
-    serverThread = new ServerThread()
-    serverThread.start()
-
-    val provider = new TextSocketSourceProvider
-    val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString)
-    val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2
-    assert(schema === StructType(StructField("value", StringType) :: Nil))
-
-    source = provider.createSource(sqlContext, "", None, "", parameters)
-
-    failAfter(streamingTimeout) {
-      serverThread.enqueue("hello")
-      while (source.getOffset.isEmpty) {
-        Thread.sleep(10)
-      }
-      withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
-        val offset1 = source.getOffset.get
-        val batch1 = source.getBatch(None, offset1)
-        assert(batch1.as[String].collect().toSeq === Seq("hello"))
-
-        serverThread.enqueue("world")
-        while (source.getOffset.get === offset1) {
-          Thread.sleep(10)
-        }
-        val offset2 = source.getOffset.get
-        val batch2 = source.getBatch(Some(offset1), offset2)
-        assert(batch2.as[String].collect().toSeq === Seq("world"))
-
-        val both = source.getBatch(None, offset2)
-        assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world"))
-      }
-
-      // Try stopping the source to make sure this does not block forever.
-      source.stop()
-      source = null
-    }
-  }
-
-  test("timestamped usage") {
-    serverThread = new ServerThread()
-    serverThread.start()
-
-    val provider = new TextSocketSourceProvider
-    val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString,
-      "includeTimestamp" -> "true")
-    val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2
-    assert(schema === StructType(StructField("value", StringType) ::
-      StructField("timestamp", TimestampType) :: Nil))
-
-    source = provider.createSource(sqlContext, "", None, "", parameters)
-
-    failAfter(streamingTimeout) {
-      serverThread.enqueue("hello")
-      while (source.getOffset.isEmpty) {
-        Thread.sleep(10)
-      }
-      withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
-        val offset1 = source.getOffset.get
-        val batch1 = source.getBatch(None, offset1)
-        val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq
-        assert(batch1Seq.map(_._1) === Seq("hello"))
-        val batch1Stamp = batch1Seq(0)._2
-
-        serverThread.enqueue("world")
-        while (source.getOffset.get === offset1) {
-          Thread.sleep(10)
-        }
-        val offset2 = source.getOffset.get
-        val batch2 = source.getBatch(Some(offset1), offset2)
-        val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq
-        assert(batch2Seq.map(_._1) === Seq("world"))
-        val batch2Stamp = batch2Seq(0)._2
-        assert(!batch2Stamp.before(batch1Stamp))
-      }
-
-      // Try stopping the source to make sure this does not block forever.
-      source.stop()
-      source = null
-    }
-  }
-
-  test("params not given") {
-    val provider = new TextSocketSourceProvider
-    intercept[AnalysisException] {
-      provider.sourceSchema(sqlContext, None, "", Map())
-    }
-    intercept[AnalysisException] {
-      provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost"))
-    }
-    intercept[AnalysisException] {
-      provider.sourceSchema(sqlContext, None, "", Map("port" -> "1234"))
-    }
-  }
-
-  test("non-boolean includeTimestamp") {
-    val provider = new TextSocketSourceProvider
-    intercept[AnalysisException] {
-      provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost",
-      "port" -> "1234", "includeTimestamp" -> "fasle"))
-    }
-  }
-
-  test("user-specified schema given") {
-    val provider = new TextSocketSourceProvider
-    val userSpecifiedSchema = StructType(
-      StructField("name", StringType) ::
-      StructField("area", StringType) :: Nil)
-    val exception = intercept[AnalysisException] {
-      provider.sourceSchema(
-        sqlContext, Some(userSpecifiedSchema),
-        "",
-        Map("host" -> "localhost", "port" -> "1234"))
-    }
-    assert(exception.getMessage.contains(
-      "socket source does not support a user-specified schema"))
-  }
-
-  test("no server up") {
-    val provider = new TextSocketSourceProvider
-    val parameters = Map("host" -> "localhost", "port" -> "0")
-    intercept[IOException] {
-      source = provider.createSource(sqlContext, "", None, "", parameters)
-    }
-  }
-
-  test("input row metrics") {
-    serverThread = new ServerThread()
-    serverThread.start()
-
-    val provider = new TextSocketSourceProvider
-    val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString)
-    source = provider.createSource(sqlContext, "", None, "", parameters)
-
-    failAfter(streamingTimeout) {
-      serverThread.enqueue("hello")
-      while (source.getOffset.isEmpty) {
-        Thread.sleep(10)
-      }
-      withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
-        val batch = source.getBatch(None, source.getOffset.get).as[String]
-        batch.collect()
-        val numRowsMetric =
-          batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows")
-        assert(numRowsMetric.nonEmpty)
-        assert(numRowsMetric.get.value === 1)
-      }
-      source.stop()
-      source = null
-    }
-  }
-
-  private class ServerThread extends Thread with Logging {
-    private val serverSocket = new ServerSocket(0)
-    private val messageQueue = new LinkedBlockingQueue[String]()
-
-    val port = serverSocket.getLocalPort
-
-    override def run(): Unit = {
-      try {
-        val clientSocket = serverSocket.accept()
-        clientSocket.setTcpNoDelay(true)
-        val out = new OutputStreamWriter(clientSocket.getOutputStream)
-        while (true) {
-          val line = messageQueue.take()
-          out.write(line + "\n")
-          out.flush()
-        }
-      } catch {
-        case e: InterruptedException =>
-      } finally {
-        serverSocket.close()
-      }
-    }
-
-    def enqueue(line: String): Unit = {
-      messageQueue.put(line)
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/707e6506/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
new file mode 100644
index 0000000..a15a980
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.io.IOException
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+import java.nio.channels.ServerSocketChannel
+import java.sql.Timestamp
+import java.util.Optional
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.JavaConverters._
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
+
+class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach
{
+
+  override def afterEach() {
+    sqlContext.streams.active.foreach(_.stop())
+    if (serverThread != null) {
+      serverThread.interrupt()
+      serverThread.join()
+      serverThread = null
+    }
+    if (batchReader != null) {
+      batchReader.stop()
+      batchReader = null
+    }
+  }
+
+  private var serverThread: ServerThread = null
+  private var batchReader: MicroBatchReader = null
+
+  case class AddSocketData(data: String*) extends AddData {
+    override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) =
{
+      require(
+        query.nonEmpty,
+        "Cannot add data when there is no query for finding the active socket source")
+
+      val sources = query.get.logicalPlan.collect {
+        case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source
+      }
+      if (sources.isEmpty) {
+        throw new Exception(
+          "Could not find socket source in the StreamExecution logical plan to add data to")
+      } else if (sources.size > 1) {
+        throw new Exception(
+          "Could not select the socket source in the StreamExecution logical plan as there"
+
+            "are multiple socket sources:\n\t" + sources.mkString("\n\t"))
+      }
+      val socketSource = sources.head
+
+      assert(serverThread != null && serverThread.port != 0)
+      val currOffset = socketSource.getCurrentOffset()
+      data.foreach(serverThread.enqueue)
+
+      val newOffset = LongOffset(currOffset.offset + data.size)
+      (socketSource, newOffset)
+    }
+
+    override def toString: String = s"AddSocketData(data = $data)"
+  }
+
+  test("backward compatibility with old path") {
+    DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider",
+      spark.sqlContext.conf).newInstance() match {
+      case ds: MicroBatchReadSupport =>
+        assert(ds.isInstanceOf[TextSocketSourceProvider])
+      case _ =>
+        throw new IllegalStateException("Could not find socket source")
+    }
+  }
+
+  test("basic usage") {
+    serverThread = new ServerThread()
+    serverThread.start()
+
+    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+      val ref = spark
+      import ref.implicits._
+
+      val socket = spark
+        .readStream
+        .format("socket")
+        .options(Map("host" -> "localhost", "port" -> serverThread.port.toString))
+        .load()
+        .as[String]
+
+      assert(socket.schema === StructType(StructField("value", StringType) :: Nil))
+
+      testStream(socket)(
+        StartStream(),
+        AddSocketData("hello"),
+        CheckAnswer("hello"),
+        AddSocketData("world"),
+        CheckLastBatch("world"),
+        CheckAnswer("hello", "world"),
+        StopStream
+      )
+    }
+  }
+
+  test("timestamped usage") {
+    serverThread = new ServerThread()
+    serverThread.start()
+
+    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+      val socket = spark
+        .readStream
+        .format("socket")
+        .options(Map(
+          "host" -> "localhost",
+          "port" -> serverThread.port.toString,
+          "includeTimestamp" -> "true"))
+        .load()
+
+      assert(socket.schema === StructType(StructField("value", StringType) ::
+        StructField("timestamp", TimestampType) :: Nil))
+
+      var batch1Stamp: Timestamp = null
+      var batch2Stamp: Timestamp = null
+
+      val curr = System.currentTimeMillis()
+      testStream(socket)(
+        StartStream(),
+        AddSocketData("hello"),
+        CheckAnswerRowsByFunc(
+          rows => {
+            assert(rows.size === 1)
+            assert(rows.head.getAs[String](0) === "hello")
+            batch1Stamp = rows.head.getAs[Timestamp](1)
+            Thread.sleep(10)
+          },
+          true),
+        AddSocketData("world"),
+        CheckAnswerRowsByFunc(
+          rows => {
+            assert(rows.size === 1)
+            assert(rows.head.getAs[String](0) === "world")
+            batch2Stamp = rows.head.getAs[Timestamp](1)
+          },
+          true),
+        StopStream
+      )
+
+      // Timestamp for rate stream is round to second which leads to milliseconds lost, that
will
+      // make batch1stamp smaller than current timestamp if both of them are in the same
second.
+      // Comparing by second to make sure the correct behavior.
+      assert(batch1Stamp.getTime >= curr / 1000 * 1000)
+      assert(!batch2Stamp.before(batch1Stamp))
+    }
+  }
+
+  test("params not given") {
+    val provider = new TextSocketSourceProvider
+    intercept[AnalysisException] {
+      provider.createMicroBatchReader(Optional.empty(), "",
+        new DataSourceOptions(Map.empty[String, String].asJava))
+    }
+    intercept[AnalysisException] {
+      provider.createMicroBatchReader(Optional.empty(), "",
+        new DataSourceOptions(Map("host" -> "localhost").asJava))
+    }
+    intercept[AnalysisException] {
+      provider.createMicroBatchReader(Optional.empty(), "",
+        new DataSourceOptions(Map("port" -> "1234").asJava))
+    }
+  }
+
+  test("non-boolean includeTimestamp") {
+    val provider = new TextSocketSourceProvider
+    val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" ->
"fasle")
+    intercept[AnalysisException] {
+      val a = new DataSourceOptions(params.asJava)
+      provider.createMicroBatchReader(Optional.empty(), "", a)
+    }
+  }
+
+  test("user-specified schema given") {
+    val provider = new TextSocketSourceProvider
+    val userSpecifiedSchema = StructType(
+      StructField("name", StringType) ::
+      StructField("area", StringType) :: Nil)
+    val params = Map("host" -> "localhost", "port" -> "1234")
+    val exception = intercept[AnalysisException] {
+      provider.createMicroBatchReader(
+        Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava))
+    }
+    assert(exception.getMessage.contains(
+      "socket source does not support a user-specified schema"))
+  }
+
+  test("no server up") {
+    val provider = new TextSocketSourceProvider
+    val parameters = Map("host" -> "localhost", "port" -> "0")
+    intercept[IOException] {
+      batchReader = provider.createMicroBatchReader(
+        Optional.empty(), "", new DataSourceOptions(parameters.asJava))
+    }
+  }
+
+  test("input row metrics") {
+    serverThread = new ServerThread()
+    serverThread.start()
+
+    withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+      val ref = spark
+      import ref.implicits._
+
+      val socket = spark
+        .readStream
+        .format("socket")
+        .options(Map("host" -> "localhost", "port" -> serverThread.port.toString))
+        .load()
+        .as[String]
+
+      assert(socket.schema === StructType(StructField("value", StringType) :: Nil))
+
+      testStream(socket)(
+        StartStream(),
+        AddSocketData("hello"),
+        CheckAnswer("hello"),
+        AssertOnQuery { q =>
+          val numRowMetric =
+            q.lastExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows")
+          numRowMetric.nonEmpty && numRowMetric.get.value == 1
+        },
+        StopStream
+      )
+    }
+  }
+
+  private class ServerThread extends Thread with Logging {
+    private val serverSocketChannel = ServerSocketChannel.open()
+    serverSocketChannel.bind(new InetSocketAddress(0))
+    private val messageQueue = new LinkedBlockingQueue[String]()
+
+    val port = serverSocketChannel.socket().getLocalPort
+
+    override def run(): Unit = {
+      try {
+        while (true) {
+          val clientSocketChannel = serverSocketChannel.accept()
+          clientSocketChannel.configureBlocking(false)
+          clientSocketChannel.socket().setTcpNoDelay(true)
+
+          // Check whether remote client is closed but still send data to this closed socket.
+          // This happens in DataStreamReader where a source will be created to get the schema.
+          var remoteIsClosed = false
+          var cnt = 0
+          while (cnt < 3 && !remoteIsClosed) {
+            if (clientSocketChannel.read(ByteBuffer.allocate(1)) != -1) {
+              cnt += 1
+              Thread.sleep(100)
+            } else {
+              remoteIsClosed = true
+            }
+          }
+
+          if (remoteIsClosed) {
+            logInfo(s"remote client ${clientSocketChannel.socket()} is closed")
+          } else {
+            while (true) {
+              val line = messageQueue.take() + "\n"
+              clientSocketChannel.write(ByteBuffer.wrap(line.getBytes("UTF-8")))
+            }
+          }
+        }
+      } catch {
+        case e: InterruptedException =>
+      } finally {
+        serverSocketChannel.close()
+      }
+    }
+
+    def enqueue(line: String): Unit = {
+      messageQueue.put(line)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message