spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From t...@apache.org
Subject spark git commit: [SPARK-24662][SQL][SS] Support limit in structured streaming
Date Tue, 10 Jul 2018 18:08:11 GMT
Repository: spark
Updated Branches:
  refs/heads/master e0559f238 -> 32cb50835


[SPARK-24662][SQL][SS] Support limit in structured streaming

## What changes were proposed in this pull request?

Support the LIMIT operator in structured streaming.

For streams in append or complete output mode, a stream with a LIMIT operator will return
no more than the specified number of rows. LIMIT is still unsupported for the update output
mode.

This change reverts https://github.com/apache/spark/commit/e4fee395ecd93ad4579d9afbf0861f82a303e563
as part of it because it is a better and more complete implementation.

## How was this patch tested?

New and existing unit tests.

Author: Mukul Murthy <mukul.murthy@gmail.com>

Closes #21662 from mukulmurthy/SPARK-24662.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/32cb5083
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/32cb5083
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/32cb5083

Branch: refs/heads/master
Commit: 32cb50835e7258625afff562939872be002232f2
Parents: e0559f2
Author: Mukul Murthy <mukul.murthy@gmail.com>
Authored: Tue Jul 10 11:08:04 2018 -0700
Committer: Tathagata Das <tathagata.das1565@gmail.com>
Committed: Tue Jul 10 11:08:04 2018 -0700

----------------------------------------------------------------------
 .../analysis/UnsupportedOperationChecker.scala  |   6 +-
 .../spark/sql/execution/SparkStrategies.scala   |  26 ++++-
 .../streaming/IncrementalExecution.scala        |  11 +-
 .../streaming/StreamingGlobalLimitExec.scala    | 102 ++++++++++++++++++
 .../spark/sql/execution/streaming/memory.scala  |  70 ++----------
 .../execution/streaming/sources/memoryV2.scala  |  44 ++------
 .../spark/sql/streaming/DataStreamWriter.scala  |   4 +-
 .../execution/streaming/MemorySinkSuite.scala   |  62 +----------
 .../execution/streaming/MemorySinkV2Suite.scala |  80 +-------------
 .../spark/sql/streaming/StreamSuite.scala       | 108 +++++++++++++++++++
 .../apache/spark/sql/streaming/StreamTest.scala |   4 +-
 11 files changed, 272 insertions(+), 245 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/32cb5083/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 5ced1ca..f68df5d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -315,8 +315,10 @@ object UnsupportedOperationChecker {
         case GroupingSets(_, _, child, _) if child.isStreaming =>
           throwError("GroupingSets is not supported on streaming DataFrames/Datasets")
 
-        case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming)
=>
-          throwError("Limits are not supported on streaming DataFrames/Datasets")
+        case GlobalLimit(_, _) | LocalLimit(_, _)
+            if subPlan.children.forall(_.isStreaming) && outputMode == InternalOutputModes.Update
=>
+          throwError("Limits are not supported on streaming DataFrames/Datasets in Update
" +
+            "output mode")
 
         case Sort(_, _, _) if !containsCompleteData(subPlan) =>
           throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it
is on " +

http://git-wip-us.apache.org/repos/asf/spark/blob/32cb5083/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index cfbcb9a..02e095b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
 import org.apache.spark.sql.execution.command._
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -34,7 +35,7 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.StreamingQuery
+import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -349,6 +350,29 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     }
   }
 
+  /**
+   * Used to plan the streaming global limit operator for streams in append mode.
+   * We need to check for either a direct Limit or a Limit wrapped in a ReturnAnswer operator,
+   * following the example of the SpecialLimits Strategy above.
+   * Streams with limit in Append mode use the stateful StreamingGlobalLimitExec.
+   * Streams with limit in Complete mode use the stateless CollectLimitExec operator.
+   * Limit is unsupported for streams in Update mode.
+   */
+  case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy {
+    override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+      case ReturnAnswer(rootPlan) => rootPlan match {
+        case Limit(IntegerLiteral(limit), child)
+            if plan.isStreaming && outputMode == InternalOutputModes.Append =>
+          StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil
+        case _ => Nil
+      }
+      case Limit(IntegerLiteral(limit), child)
+          if plan.isStreaming && outputMode == InternalOutputModes.Append =>
+        StreamingGlobalLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil
+      case _ => Nil
+    }
+  }
+
   object StreamingJoinStrategy extends Strategy {
     override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
       plan match {

http://git-wip-us.apache.org/repos/asf/spark/blob/32cb5083/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index c480b96..6ae7f28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -59,7 +59,8 @@ class IncrementalExecution(
       StatefulAggregationStrategy ::
       FlatMapGroupsWithStateStrategy ::
       StreamingRelationStrategy ::
-      StreamingDeduplicationStrategy :: Nil
+      StreamingDeduplicationStrategy ::
+      StreamingGlobalLimitStrategy(outputMode) :: Nil
   }
 
   private[sql] val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
@@ -134,8 +135,12 @@ class IncrementalExecution(
           stateWatermarkPredicates =
             StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates(
               j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full,
-              Some(offsetSeqMetadata.batchWatermarkMs))
-        )
+              Some(offsetSeqMetadata.batchWatermarkMs)))
+
+      case l: StreamingGlobalLimitExec =>
+        l.copy(
+          stateInfo = Some(nextStatefulOperationStateInfo),
+          outputMode = Some(outputMode))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/32cb5083/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala
new file mode 100644
index 0000000..bf4af60
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingGlobalLimitExec.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent.TimeUnit.NANOSECONDS
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning}
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.streaming.state.StateStoreOps
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType}
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * A physical operator for executing a streaming limit, which makes sure no more than streamLimit
+ * rows are returned. This operator is meant for streams in Append mode only.
+ */
+case class StreamingGlobalLimitExec(
+    streamLimit: Long,
+    child: SparkPlan,
+    stateInfo: Option[StatefulOperatorStateInfo] = None,
+    outputMode: Option[OutputMode] = None)
+  extends UnaryExecNode with StateStoreWriter {
+
+  private val keySchema = StructType(Array(StructField("key", NullType)))
+  private val valueSchema = StructType(Array(StructField("value", LongType)))
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    metrics // force lazy init at driver
+
+    assert(outputMode.isDefined && outputMode.get == InternalOutputModes.Append,
+      "StreamingGlobalLimitExec is only valid for streams in Append output mode")
+
+    child.execute().mapPartitionsWithStateStore(
+        getStateInfo,
+        keySchema,
+        valueSchema,
+        indexOrdinal = None,
+        sqlContext.sessionState,
+        Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
+      val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null)))
+      val numOutputRows = longMetric("numOutputRows")
+      val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+      val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+      val commitTimeMs = longMetric("commitTimeMs")
+      val updatesStartTimeNs = System.nanoTime
+
+      val preBatchRowCount: Long = Option(store.get(key)).map(_.getLong(0)).getOrElse(0L)
+      var cumulativeRowCount = preBatchRowCount
+
+      val result = iter.filter { r =>
+        val x = cumulativeRowCount < streamLimit
+        if (x) {
+          cumulativeRowCount += 1
+        }
+        x
+      }
+
+      CompletionIterator[InternalRow, Iterator[InternalRow]](result, {
+        if (cumulativeRowCount > preBatchRowCount) {
+          numUpdatedStateRows += 1
+          numOutputRows += cumulativeRowCount - preBatchRowCount
+          store.put(key, getValueRow(cumulativeRowCount))
+        }
+        allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
+        commitTimeMs += timeTakenMs { store.commit() }
+        setStoreMetrics(store)
+      })
+    }
+  }
+
+  override def output: Seq[Attribute] = child.output
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  override def requiredChildDistribution: Seq[Distribution] = AllTuples :: Nil
+
+  private def getValueRow(value: Long): UnsafeRow = {
+    UnsafeProjection.create(valueSchema)(new GenericInternalRow(Array[Any](value)))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/32cb5083/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 7fa13c4..b137f98 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
 import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.sources.v2.DataSourceOptions
 import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader, SupportsScanUnsafeRow}
 import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
 import org.apache.spark.sql.streaming.OutputMode
@@ -222,60 +221,19 @@ class MemoryStreamInputPartition(records: Array[UnsafeRow])
 }
 
 /** A common trait for MemorySinks with methods used for testing */
-trait MemorySinkBase extends BaseStreamingSink with Logging {
+trait MemorySinkBase extends BaseStreamingSink {
   def allData: Seq[Row]
   def latestBatchData: Seq[Row]
   def dataSinceBatch(sinceBatchId: Long): Seq[Row]
   def latestBatchId: Option[Long]
-
-  /**
-   * Truncates the given rows to return at most maxRows rows.
-   * @param rows The data that may need to be truncated.
-   * @param batchLimit Number of rows to keep in this batch; the rest will be truncated
-   * @param sinkLimit Total number of rows kept in this sink, for logging purposes.
-   * @param batchId The ID of the batch that sent these rows, for logging purposes.
-   * @return Truncated rows.
-   */
-  protected def truncateRowsIfNeeded(
-      rows: Array[Row],
-      batchLimit: Int,
-      sinkLimit: Int,
-      batchId: Long): Array[Row] = {
-    if (rows.length > batchLimit && batchLimit >= 0) {
-      logWarning(s"Truncating batch $batchId to $batchLimit rows because of sink limit $sinkLimit")
-      rows.take(batchLimit)
-    } else {
-      rows
-    }
-  }
-}
-
-/**
- * Companion object to MemorySinkBase.
- */
-object MemorySinkBase {
-  val MAX_MEMORY_SINK_ROWS = "maxRows"
-  val MAX_MEMORY_SINK_ROWS_DEFAULT = -1
-
-  /**
-   * Gets the max number of rows a MemorySink should store. This number is based on the memory
-   * sink row limit option if it is set. If not, we use a large value so that data truncates
-   * rather than causing out of memory errors.
-   * @param options Options for writing from which we get the max rows option
-   * @return The maximum number of rows a memorySink should store.
-   */
-  def getMemorySinkCapacity(options: DataSourceOptions): Int = {
-    val maxRows = options.getInt(MAX_MEMORY_SINK_ROWS, MAX_MEMORY_SINK_ROWS_DEFAULT)
-    if (maxRows >= 0) maxRows else Int.MaxValue - 10
-  }
 }
 
 /**
  * A sink that stores the results in memory. This [[Sink]] is primarily intended for use
in unit
  * tests and does not provide durability.
  */
-class MemorySink(val schema: StructType, outputMode: OutputMode, options: DataSourceOptions)
-  extends Sink with MemorySinkBase with Logging {
+class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
+  with MemorySinkBase with Logging {
 
   private case class AddedData(batchId: Long, data: Array[Row])
 
@@ -283,12 +241,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options:
DataSo
   @GuardedBy("this")
   private val batches = new ArrayBuffer[AddedData]()
 
-  /** The number of rows in this MemorySink. */
-  private var numRows = 0
-
-  /** The capacity in rows of this sink. */
-  val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
-
   /** Returns all rows that are stored in this [[Sink]]. */
   def allData: Seq[Row] = synchronized {
     batches.flatMap(_.data)
@@ -321,23 +273,14 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options:
DataSo
       logDebug(s"Committing batch $batchId to $this")
       outputMode match {
         case Append | Update =>
-          var rowsToAdd = data.collect()
-          synchronized {
-            rowsToAdd =
-              truncateRowsIfNeeded(rowsToAdd, sinkCapacity - numRows, sinkCapacity, batchId)
-            val rows = AddedData(batchId, rowsToAdd)
-            batches += rows
-            numRows += rowsToAdd.length
-          }
+          val rows = AddedData(batchId, data.collect())
+          synchronized { batches += rows }
 
         case Complete =>
-          var rowsToAdd = data.collect()
+          val rows = AddedData(batchId, data.collect())
           synchronized {
-            rowsToAdd = truncateRowsIfNeeded(rowsToAdd, sinkCapacity, sinkCapacity, batchId)
-            val rows = AddedData(batchId, rowsToAdd)
             batches.clear()
             batches += rows
-            numRows = rowsToAdd.length
           }
 
         case _ =>
@@ -351,7 +294,6 @@ class MemorySink(val schema: StructType, outputMode: OutputMode, options:
DataSo
 
   def clear(): Unit = synchronized {
     batches.clear()
-    numRows = 0
   }
 
   override def toString(): String = "MemorySink"

http://git-wip-us.apache.org/repos/asf/spark/blob/32cb5083/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
index 29f8cca..f2a35a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
@@ -46,7 +46,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
       schema: StructType,
       mode: OutputMode,
       options: DataSourceOptions): StreamWriter = {
-    new MemoryStreamWriter(this, mode, options)
+    new MemoryStreamWriter(this, mode)
   }
 
   private case class AddedData(batchId: Long, data: Array[Row])
@@ -55,9 +55,6 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
   @GuardedBy("this")
   private val batches = new ArrayBuffer[AddedData]()
 
-  /** The number of rows in this MemorySink. */
-  private var numRows = 0
-
   /** Returns all rows that are stored in this [[Sink]]. */
   def allData: Seq[Row] = synchronized {
     batches.flatMap(_.data)
@@ -84,11 +81,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
     }.mkString("\n")
   }
 
-  def write(
-      batchId: Long,
-      outputMode: OutputMode,
-      newRows: Array[Row],
-      sinkCapacity: Int): Unit = {
+  def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = {
     val notCommitted = synchronized {
       latestBatchId.isEmpty || batchId > latestBatchId.get
     }
@@ -96,21 +89,14 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
       logDebug(s"Committing batch $batchId to $this")
       outputMode match {
         case Append | Update =>
-          synchronized {
-            val rowsToAdd =
-              truncateRowsIfNeeded(newRows, sinkCapacity - numRows, sinkCapacity, batchId)
-            val rows = AddedData(batchId, rowsToAdd)
-            batches += rows
-            numRows += rowsToAdd.length
-          }
+          val rows = AddedData(batchId, newRows)
+          synchronized { batches += rows }
 
         case Complete =>
+          val rows = AddedData(batchId, newRows)
           synchronized {
-            val rowsToAdd = truncateRowsIfNeeded(newRows, sinkCapacity, sinkCapacity, batchId)
-            val rows = AddedData(batchId, rowsToAdd)
             batches.clear()
             batches += rows
-            numRows = rowsToAdd.length
           }
 
         case _ =>
@@ -124,7 +110,6 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
 
   def clear(): Unit = synchronized {
     batches.clear()
-    numRows = 0
   }
 
   override def toString(): String = "MemorySinkV2"
@@ -132,22 +117,16 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with
MemorySinkB
 
 case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage
{}
 
-class MemoryWriter(
-    sink: MemorySinkV2,
-    batchId: Long,
-    outputMode: OutputMode,
-    options: DataSourceOptions)
+class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode)
   extends DataSourceWriter with Logging {
 
-  val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
-
   override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
 
   def commit(messages: Array[WriterCommitMessage]): Unit = {
     val newRows = messages.flatMap {
       case message: MemoryWriterCommitMessage => message.data
     }
-    sink.write(batchId, outputMode, newRows, sinkCapacity)
+    sink.write(batchId, outputMode, newRows)
   }
 
   override def abort(messages: Array[WriterCommitMessage]): Unit = {
@@ -155,21 +134,16 @@ class MemoryWriter(
   }
 }
 
-class MemoryStreamWriter(
-    val sink: MemorySinkV2,
-    outputMode: OutputMode,
-    options: DataSourceOptions)
+class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode)
   extends StreamWriter {
 
-  val sinkCapacity: Int = MemorySinkBase.getMemorySinkCapacity(options)
-
   override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
 
   override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
     val newRows = messages.flatMap {
       case message: MemoryWriterCommitMessage => message.data
     }
-    sink.write(epochId, outputMode, newRows, sinkCapacity)
+    sink.write(epochId, outputMode, newRows)
   }
 
   override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/32cb5083/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 926c0b6..3b9a56f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
 import org.apache.spark.sql.execution.streaming.sources._
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
+import org.apache.spark.sql.sources.v2.StreamWriteSupport
 
 /**
  * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@@ -250,7 +250,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
           val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes))
           (s, r)
         case _ =>
-          val s = new MemorySink(df.schema, outputMode, new DataSourceOptions(extraOptions.asJava))
+          val s = new MemorySink(df.schema, outputMode)
           val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
           (s, r)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/32cb5083/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
index b2fd6ba..3bc36ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
@@ -17,13 +17,11 @@
 
 package org.apache.spark.sql.execution.streaming
 
-import scala.collection.JavaConverters._
 import scala.language.implicitConversions
 
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql._
-import org.apache.spark.sql.sources.v2.DataSourceOptions
 import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
 import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 import org.apache.spark.util.Utils
@@ -38,7 +36,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
 
   test("directly add data in Append output mode") {
     implicit val schema = new StructType().add(new StructField("value", IntegerType))
-    val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty())
+    val sink = new MemorySink(schema, OutputMode.Append)
 
     // Before adding data, check output
     assert(sink.latestBatchId === None)
@@ -70,35 +68,9 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
     checkAnswer(sink.allData, 1 to 9)
   }
 
-  test("directly add data in Append output mode with row limit") {
-    implicit val schema = new StructType().add(new StructField("value", IntegerType))
-
-    var optionsMap = new scala.collection.mutable.HashMap[String, String]
-    optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString())
-    var options = new DataSourceOptions(optionsMap.toMap.asJava)
-    val sink = new MemorySink(schema, OutputMode.Append, options)
-
-    // Before adding data, check output
-    assert(sink.latestBatchId === None)
-    checkAnswer(sink.latestBatchData, Seq.empty)
-    checkAnswer(sink.allData, Seq.empty)
-
-    // Add batch 0 and check outputs
-    sink.addBatch(0, 1 to 3)
-    assert(sink.latestBatchId === Some(0))
-    checkAnswer(sink.latestBatchData, 1 to 3)
-    checkAnswer(sink.allData, 1 to 3)
-
-    // Add batch 1 and check outputs
-    sink.addBatch(1, 4 to 6)
-    assert(sink.latestBatchId === Some(1))
-    checkAnswer(sink.latestBatchData, 4 to 5)
-    checkAnswer(sink.allData, 1 to 5)     // new data should not go over the limit
-  }
-
   test("directly add data in Update output mode") {
     implicit val schema = new StructType().add(new StructField("value", IntegerType))
-    val sink = new MemorySink(schema, OutputMode.Update, DataSourceOptions.empty())
+    val sink = new MemorySink(schema, OutputMode.Update)
 
     // Before adding data, check output
     assert(sink.latestBatchId === None)
@@ -132,7 +104,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
 
   test("directly add data in Complete output mode") {
     implicit val schema = new StructType().add(new StructField("value", IntegerType))
-    val sink = new MemorySink(schema, OutputMode.Complete, DataSourceOptions.empty())
+    val sink = new MemorySink(schema, OutputMode.Complete)
 
     // Before adding data, check output
     assert(sink.latestBatchId === None)
@@ -164,32 +136,6 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
     checkAnswer(sink.allData, 7 to 9)
   }
 
-  test("directly add data in Complete output mode with row limit") {
-    implicit val schema = new StructType().add(new StructField("value", IntegerType))
-
-    var optionsMap = new scala.collection.mutable.HashMap[String, String]
-    optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString())
-    var options = new DataSourceOptions(optionsMap.toMap.asJava)
-    val sink = new MemorySink(schema, OutputMode.Complete, options)
-
-    // Before adding data, check output
-    assert(sink.latestBatchId === None)
-    checkAnswer(sink.latestBatchData, Seq.empty)
-    checkAnswer(sink.allData, Seq.empty)
-
-    // Add batch 0 and check outputs
-    sink.addBatch(0, 1 to 3)
-    assert(sink.latestBatchId === Some(0))
-    checkAnswer(sink.latestBatchData, 1 to 3)
-    checkAnswer(sink.allData, 1 to 3)
-
-    // Add batch 1 and check outputs
-    sink.addBatch(1, 4 to 10)
-    assert(sink.latestBatchId === Some(1))
-    checkAnswer(sink.latestBatchData, 4 to 8)
-    checkAnswer(sink.allData, 4 to 8)     // new data should replace old data
-  }
-
 
   test("registering as a table in Append output mode") {
     val input = MemoryStream[Int]
@@ -265,7 +211,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
 
   test("MemoryPlan statistics") {
     implicit val schema = new StructType().add(new StructField("value", IntegerType))
-    val sink = new MemorySink(schema, OutputMode.Append, DataSourceOptions.empty())
+    val sink = new MemorySink(schema, OutputMode.Append)
     val plan = new MemoryPlan(sink)
 
     // Before adding data, check output

http://git-wip-us.apache.org/repos/asf/spark/blob/32cb5083/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
index e539510..9be22d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
@@ -17,16 +17,11 @@
 
 package org.apache.spark.sql.execution.streaming
 
-import scala.collection.JavaConverters._
-
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.execution.streaming.sources._
-import org.apache.spark.sql.sources.v2.DataSourceOptions
 import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
-import org.apache.spark.sql.types.IntegerType
-import org.apache.spark.sql.types.StructType
 
 class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
   test("data writer") {
@@ -45,7 +40,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
 
   test("continuous writer") {
     val sink = new MemorySinkV2
-    val writer = new MemoryStreamWriter(sink, OutputMode.Append(), DataSourceOptions.empty())
+    val writer = new MemoryStreamWriter(sink, OutputMode.Append())
     writer.commit(0,
       Array(
         MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
@@ -67,7 +62,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
 
   test("microbatch writer") {
     val sink = new MemorySinkV2
-    new MemoryWriter(sink, 0, OutputMode.Append(), DataSourceOptions.empty()).commit(
+    new MemoryWriter(sink, 0, OutputMode.Append()).commit(
       Array(
         MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
         MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
@@ -75,7 +70,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
       ))
     assert(sink.latestBatchId.contains(0))
     assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
-    new MemoryWriter(sink, 19, OutputMode.Append(), DataSourceOptions.empty()).commit(
+    new MemoryWriter(sink, 19, OutputMode.Append()).commit(
       Array(
         MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
         MemoryWriterCommitMessage(0, Seq(Row(33)))
@@ -85,73 +80,4 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
 
     assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33))
   }
-
-  test("continuous writer with row limit") {
-    val sink = new MemorySinkV2
-    val optionsMap = new scala.collection.mutable.HashMap[String, String]
-    optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 7.toString())
-    val options = new DataSourceOptions(optionsMap.toMap.asJava)
-    val appendWriter = new MemoryStreamWriter(sink, OutputMode.Append(), options)
-    appendWriter.commit(0, Array(
-        MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
-        MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
-        MemoryWriterCommitMessage(2, Seq(Row(6), Row(7)))))
-    assert(sink.latestBatchId.contains(0))
-    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
-    appendWriter.commit(19, Array(
-        MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
-        MemoryWriterCommitMessage(0, Seq(Row(33)))))
-    assert(sink.latestBatchId.contains(19))
-    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11))
-
-    assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11))
-
-    val completeWriter = new MemoryStreamWriter(sink, OutputMode.Complete(), options)
-    completeWriter.commit(20, Array(
-        MemoryWriterCommitMessage(4, Seq(Row(11), Row(22))),
-        MemoryWriterCommitMessage(5, Seq(Row(33)))))
-    assert(sink.latestBatchId.contains(20))
-    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))
-    completeWriter.commit(21, Array(
-      MemoryWriterCommitMessage(0, Seq(Row(1), Row(2), Row(3))),
-      MemoryWriterCommitMessage(1, Seq(Row(4), Row(5), Row(6))),
-      MemoryWriterCommitMessage(2, Seq(Row(7), Row(8), Row(9)))))
-    assert(sink.latestBatchId.contains(21))
-    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7))
-
-    assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5, 6, 7))
-  }
-
-  test("microbatch writer with row limit") {
-    val sink = new MemorySinkV2
-    val optionsMap = new scala.collection.mutable.HashMap[String, String]
-    optionsMap.put(MemorySinkBase.MAX_MEMORY_SINK_ROWS, 5.toString())
-    val options = new DataSourceOptions(optionsMap.toMap.asJava)
-
-    new MemoryWriter(sink, 25, OutputMode.Append(), options).commit(Array(
-      MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
-      MemoryWriterCommitMessage(1, Seq(Row(3), Row(4)))))
-    assert(sink.latestBatchId.contains(25))
-    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4))
-    assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4))
-    new MemoryWriter(sink, 26, OutputMode.Append(), options).commit(Array(
-      MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))),
-      MemoryWriterCommitMessage(3, Seq(Row(7), Row(8)))))
-    assert(sink.latestBatchId.contains(26))
-    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(5))
-    assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 5))
-
-    new MemoryWriter(sink, 27, OutputMode.Complete(), options).commit(Array(
-      MemoryWriterCommitMessage(4, Seq(Row(9), Row(10))),
-      MemoryWriterCommitMessage(5, Seq(Row(11), Row(12)))))
-    assert(sink.latestBatchId.contains(27))
-    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12))
-    assert(sink.allData.map(_.getInt(0)).sorted == Seq(9, 10, 11, 12))
-    new MemoryWriter(sink, 28, OutputMode.Complete(), options).commit(Array(
-      MemoryWriterCommitMessage(4, Seq(Row(13), Row(14), Row(15))),
-      MemoryWriterCommitMessage(5, Seq(Row(16), Row(17), Row(18)))))
-    assert(sink.latestBatchId.contains(28))
-    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17))
-    assert(sink.allData.map(_.getInt(0)).sorted == Seq(13, 14, 15, 16, 17))
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/32cb5083/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index c1ec1eb..ca38f04 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -805,6 +805,114 @@ class StreamSuite extends StreamTest {
     }
   }
 
+  test("streaming limit without state") {
+    val inputData1 = MemoryStream[Int]
+    testStream(inputData1.toDF().limit(0))(
+      AddData(inputData1, 1 to 8: _*),
+      CheckAnswer())
+
+    val inputData2 = MemoryStream[Int]
+    testStream(inputData2.toDF().limit(4))(
+      AddData(inputData2, 1 to 8: _*),
+      CheckAnswer(1 to 4: _*))
+  }
+
+  test("streaming limit with state") {
+    val inputData = MemoryStream[Int]
+    testStream(inputData.toDF().limit(4))(
+      AddData(inputData, 1 to 2: _*),
+      CheckAnswer(1 to 2: _*),
+      AddData(inputData, 3 to 6: _*),
+      CheckAnswer(1 to 4: _*),
+      AddData(inputData, 7 to 9: _*),
+      CheckAnswer(1 to 4: _*))
+  }
+
+  test("streaming limit with other operators") {
+    val inputData = MemoryStream[Int]
+    testStream(inputData.toDF().where("value % 2 = 1").limit(4))(
+      AddData(inputData, 1 to 5: _*),
+      CheckAnswer(1, 3, 5),
+      AddData(inputData, 6 to 9: _*),
+      CheckAnswer(1, 3, 5, 7),
+      AddData(inputData, 10 to 12: _*),
+      CheckAnswer(1, 3, 5, 7))
+  }
+
+  test("streaming limit with multiple limits") {
+    val inputData1 = MemoryStream[Int]
+    testStream(inputData1.toDF().limit(4).limit(2))(
+      AddData(inputData1, 1),
+      CheckAnswer(1),
+      AddData(inputData1, 2 to 8: _*),
+      CheckAnswer(1, 2))
+
+    val inputData2 = MemoryStream[Int]
+    testStream(inputData2.toDF().limit(4).limit(100).limit(3))(
+      AddData(inputData2, 1, 2),
+      CheckAnswer(1, 2),
+      AddData(inputData2, 3 to 8: _*),
+      CheckAnswer(1 to 3: _*))
+  }
+
+  test("streaming limit in complete mode") {
+    val inputData = MemoryStream[Int]
+    val limited = inputData.toDF().limit(5).groupBy("value").count()
+    testStream(limited, OutputMode.Complete())(
+      AddData(inputData, 1 to 3: _*),
+      CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)),
+      AddData(inputData, 1 to 9: _*),
+      CheckAnswer(Row(1, 2), Row(2, 2), Row(3, 2), Row(4, 1), Row(5, 1)))
+  }
+
+  test("streaming limits in complete mode") {
+    val inputData = MemoryStream[Int]
+    val limited = inputData.toDF().limit(4).groupBy("value").count().orderBy("value").limit(3)
+    testStream(limited, OutputMode.Complete())(
+      AddData(inputData, 1 to 9: _*),
+      CheckAnswer(Row(1, 1), Row(2, 1), Row(3, 1)),
+      AddData(inputData, 2 to 6: _*),
+      CheckAnswer(Row(1, 1), Row(2, 2), Row(3, 2)))
+  }
+
+  test("streaming limit in update mode") {
+    val inputData = MemoryStream[Int]
+    val e = intercept[AnalysisException] {
+      testStream(inputData.toDF().limit(5), OutputMode.Update())(
+        AddData(inputData, 1 to 3: _*)
+      )
+    }
+    assert(e.getMessage.contains(
+      "Limits are not supported on streaming DataFrames/Datasets in Update output mode"))
+  }
+
+  test("streaming limit in multiple partitions") {
+    val inputData = MemoryStream[Int]
+    testStream(inputData.toDF().repartition(2).limit(7))(
+      AddData(inputData, 1 to 10: _*),
+      CheckAnswerRowsByFunc(
+        rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <=
10)),
+        false),
+      AddData(inputData, 11 to 20: _*),
+      CheckAnswerRowsByFunc(
+        rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <=
10)),
+        false))
+  }
+
+  test("streaming limit in multiple partitions by column") {
+    val inputData = MemoryStream[(Int, Int)]
+    val df = inputData.toDF().repartition(2, $"_2").limit(7)
+    testStream(df)(
+      AddData(inputData, (1, 0), (2, 0), (3, 1), (4, 1)),
+      CheckAnswerRowsByFunc(
+        rows => assert(rows.size == 4 && rows.forall(r => r.getInt(0) <=
4)),
+        false),
+      AddData(inputData, (5, 0), (6, 0), (7, 1), (8, 1)),
+      CheckAnswerRowsByFunc(
+        rows => assert(rows.size == 7 && rows.forall(r => r.getInt(0) <=
8)),
+        false))
+  }
+
   for (e <- Seq(
     new InterruptedException,
     new InterruptedIOException,

http://git-wip-us.apache.org/repos/asf/spark/blob/32cb5083/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index e41b453..4c3fd58 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -45,7 +45,6 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution,
 import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.DataSourceOptions
 import org.apache.spark.sql.streaming.StreamingQueryListener._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.util.{Clock, SystemClock, Utils}
@@ -338,8 +337,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits
with Be
     var currentStream: StreamExecution = null
     var lastStream: StreamExecution = null
     val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait
for
-    val sink = if (useV2Sink) new MemorySinkV2
-      else new MemorySink(stream.schema, outputMode, DataSourceOptions.empty())
+    val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode)
     val resetConfValues = mutable.Map[String, Option[String]]()
     val defaultCheckpointLocation =
       Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message