Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/20382#discussion_r164937540
--- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
---
@@ -47,130 +48,141 @@ object TextSocketSource {
* This source will *not* work in production applications due to multiple reasons, including
no
* support for fault recovery and keeping all of the text read in memory forever.
*/
-class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext:
SQLContext)
- extends Source with Logging {
-
- @GuardedBy("this")
- private var socket: Socket = null
-
- @GuardedBy("this")
- private var readThread: Thread = null
-
- /**
- * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive.
- * Stored in a ListBuffer to facilitate removing committed batches.
- */
- @GuardedBy("this")
- protected val batches = new ListBuffer[(String, Timestamp)]
-
- @GuardedBy("this")
- protected var currentOffset: LongOffset = new LongOffset(-1)
-
- @GuardedBy("this")
- protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
+class TextSocketSource(
+ protected val host: String,
+ protected val port: Int,
+ includeTimestamp: Boolean,
+ sqlContext: SQLContext)
+ extends Source with TextSocketReader with Logging {
initialize()
- private def initialize(): Unit = synchronized {
- socket = new Socket(host, port)
- val reader = new BufferedReader(new InputStreamReader(socket.getInputStream))
- readThread = new Thread(s"TextSocketSource($host, $port)") {
- setDaemon(true)
-
- override def run(): Unit = {
- try {
- while (true) {
- val line = reader.readLine()
- if (line == null) {
- // End of file reached
- logWarning(s"Stream closed by $host:$port")
- return
- }
- TextSocketSource.this.synchronized {
- val newData = (line,
- Timestamp.valueOf(
- TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime()))
- )
- currentOffset = currentOffset + 1
- batches.append(newData)
- }
- }
- } catch {
- case e: IOException =>
- }
- }
- }
- readThread.start()
- }
-
/** Returns the schema of the data from this source */
- override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP
- else TextSocketSource.SCHEMA_REGULAR
-
- override def getOffset: Option[Offset] = synchronized {
- if (currentOffset.offset == -1) {
- None
- } else {
- Some(currentOffset)
- }
- }
+ override def schema: StructType =
+ if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP else TextSocketSource.SCHEMA_REGULAR
+
+ override def getOffset: Option[Offset] = getOffsetInternal.map(LongOffset(_))
/** Returns the data that is between the offsets (`start`, `end`]. */
- override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized
{
- val startOrdinal =
- start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
- val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt +
1
-
- // Internal buffer only holds the batches after lastOffsetCommitted
- val rawList = synchronized {
- val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
- val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
- batches.slice(sliceStart, sliceEnd)
- }
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+ val rawList = getBatchInternal(start.flatMap(LongOffset.convert).map(_.offset),
+ LongOffset.convert(end).map(_.offset))
val rdd = sqlContext.sparkContext
.parallelize(rawList)
.map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) }
sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
}
- override def commit(end: Offset): Unit = synchronized {
+ override def commit(end: Offset): Unit = {
val newOffset = LongOffset.convert(end).getOrElse(
sys.error(s"TextSocketStream.commit() received an offset ($end) that did not "
+
s"originate with an instance of this class")
)
- val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
+ commitInternal(newOffset.offset)
+ }
- if (offsetDiff < 0) {
- sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end")
- }
+ override def toString: String = s"TextSocketSource[host: $host, port: $port]"
+}
+
+case class TextSocketOffset(offset: Long) extends V2Offset {
+ override def json(): String = offset.toString
+}
+
+class TextSocketMicroBatchReader(options: DataSourceV2Options)
+ extends MicroBatchReader with TextSocketReader with Logging {
+
+ private var startOffset: TextSocketOffset = _
+ private var endOffset: TextSocketOffset = _
+
+ protected val host: String = options.get("host").get()
+ protected val port: Int = options.get("port").get().toInt
- batches.trimStart(offsetDiff)
- lastOffsetCommitted = newOffset
+ initialize()
+
+ override def setOffsetRange(start: Optional[V2Offset], end: Optional[V2Offset]): Unit
= {
+ startOffset = start.orElse(TextSocketOffset(-1L)).asInstanceOf[TextSocketOffset]
+ endOffset = end.orElse(
+ TextSocketOffset(getOffsetInternal.getOrElse(-1L))).asInstanceOf[TextSocketOffset]
}
- /** Stop this source. */
- override def stop(): Unit = synchronized {
- if (socket != null) {
- try {
- // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only
way to
- // stop the readThread is to close the socket.
- socket.close()
- } catch {
- case e: IOException =>
- }
- socket = null
+ override def getStartOffset(): V2Offset = {
+ Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set"))
+ }
+
+ override def getEndOffset(): V2Offset = {
+ Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set"))
+ }
+
+ override def deserializeOffset(json: String): V2Offset = {
+ TextSocketOffset(json.toLong)
+ }
+
+ override def readSchema(): StructType = {
+ val includeTimestamp = options.getBoolean("includeTimestamp", false)
+ if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP else TextSocketSource.SCHEMA_REGULAR
+ }
+
+ override def createReadTasks(): JList[ReadTask[Row]] = {
+ val rawList = getBatchInternal(Option(startOffset.offset), Option(endOffset.offset))
+
+ assert(SparkSession.getActiveSession.isDefined)
+ val spark = SparkSession.getActiveSession.get
+ val numPartitions = spark.sparkContext.defaultParallelism
+
+ val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)])
+ rawList.zipWithIndex.foreach { case (r, idx) =>
+ slices(idx % numPartitions).append(r)
}
+
+ (0 until numPartitions).map { i =>
+ val slice = slices(i)
+ new ReadTask[Row] {
+ override def createDataReader(): DataReader[Row] = new DataReader[Row] {
+ private var currentIdx = -1
+
+ override def next(): Boolean = {
+ currentIdx += 1
+ currentIdx < slice.size
+ }
+
+ override def get(): Row = {
+ Row(slice(currentIdx)._1, slice(currentIdx)._2)
+ }
+
+ override def close(): Unit = {}
+ }
+ }
+ }.toList.asJava
}
- override def toString: String = s"TextSocketSource[host: $host, port: $port]"
+ override def commit(end: V2Offset): Unit = {
+ val newOffset = end.asInstanceOf[TextSocketOffset]
+ commitInternal(newOffset.offset)
+ }
+
+ override def toString: String = s"TextSocketMicroBatchReader[host: $host, port: $port]"
}
-class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with
Logging {
- private def parseIncludeTimestamp(params: Map[String, String]): Boolean = {
- Try(params.getOrElse("includeTimestamp", "false").toBoolean) match {
- case Success(bool) => bool
+class TextSocketSourceProvider extends DataSourceV2
+ with MicroBatchReadSupport with StreamSourceProvider with DataSourceRegister with Logging
{
--- End diff --
aah, i see earlier comments.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org
|