spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From hvanhov...@apache.org
Subject spark git commit: [SPARK-13450] Introduce ExternalAppendOnlyUnsafeRowArray. Change CartesianProductExec, SortMergeJoin, WindowExec to use it
Date Wed, 15 Mar 2017 19:18:57 GMT
Repository: spark
Updated Branches:
  refs/heads/master 7387126f8 -> 02c274eab


[SPARK-13450] Introduce ExternalAppendOnlyUnsafeRowArray. Change CartesianProductExec, SortMergeJoin, WindowExec to use it

## What issue does this PR address ?

Jira: https://issues.apache.org/jira/browse/SPARK-13450

In `SortMergeJoinExec`, rows of the right relation having the same value for a join key are buffered in-memory. In case of skew, this causes OOMs (see comments in SPARK-13450 for more details). Heap dump from a failed job confirms this : https://issues.apache.org/jira/secure/attachment/12846382/heap-dump-analysis.png . While its possible to increase the heap size to workaround, Spark should be resilient to such issues as skews can happen arbitrarily.

## Change proposed in this pull request

- Introduces `ExternalAppendOnlyUnsafeRowArray`
  - It holds `UnsafeRow`s in-memory upto a certain threshold.
  - After the threshold is hit, it switches to `UnsafeExternalSorter` which enables spilling of the rows to disk. It does NOT sort the data.
  - Allows iterating the array multiple times. However, any alteration to the array (using `add` or `clear`) will invalidate the existing iterator(s)
- `WindowExec` was already using `UnsafeExternalSorter` to support spilling. Changed it to use the new array
- Changed `SortMergeJoinExec` to use the new array implementation
  - NOTE: I have not changed FULL OUTER JOIN to use this new array implementation. Changing that will need more surgery and I will rather put up a separate PR for that once this gets in.
- Changed `CartesianProductExec` to use the new array implementation

#### Note for reviewers

The diff can be divided into 3 parts. My motive behind having all the changes in a single PR was to demonstrate that the API is sane and supports 2 use cases. If reviewing as 3 separate PRs would help, I am happy to make the split.

## How was this patch tested ?

#### Unit testing
- Added unit tests `ExternalAppendOnlyUnsafeRowArray` to validate all its APIs and access patterns
- Added unit test for `SortMergeExec`
  - with and without spill for inner join, left outer join, right outer join to confirm that the spill threshold config behaves as expected and output is as expected.
  - This PR touches the scanning logic in `SortMergeExec` for _all_ joins (except FULL OUTER JOIN). However, I expect existing test cases to cover that there is no regression in correctness.
- Added unit test for `WindowExec` to check behavior of spilling and correctness of results.

#### Stress testing
- Confirmed that OOM is gone by running against a production job which used to OOM
- Since I cannot share details about prod workload externally, created synthetic data to mimic the issue. Ran before and after the fix to demonstrate the issue and query success with this PR

Generating the synthetic data

```
./bin/spark-shell --driver-memory=6G

import org.apache.spark.sql._
val hc = SparkSession.builder.master("local").getOrCreate()

hc.sql("DROP TABLE IF EXISTS spark_13450_large_table").collect
hc.sql("DROP TABLE IF EXISTS spark_13450_one_row_table").collect

val df1 = (0 until 1).map(i => ("10", "100", i.toString, (i * 2).toString)).toDF("i", "j", "str1", "str2")
df1.write.format("org.apache.spark.sql.hive.orc.OrcFileFormat").bucketBy(100, "i", "j").sortBy("i", "j").saveAsTable("spark_13450_one_row_table")

val df2 = (0 until 3000000).map(i => ("10", "100", i.toString, (i * 2).toString)).toDF("i", "j", "str1", "str2")
df2.write.format("org.apache.spark.sql.hive.orc.OrcFileFormat").bucketBy(100, "i", "j").sortBy("i", "j").saveAsTable("spark_13450_large_table")
```

Ran this against trunk VS local build with this PR. OOM repros with trunk and with the fix this query runs fine.

```
./bin/spark-shell --driver-java-options="-XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=/tmp/spark.driver.heapdump.hprof"

import org.apache.spark.sql._
val hc = SparkSession.builder.master("local").getOrCreate()
hc.sql("SET spark.sql.autoBroadcastJoinThreshold=1")
hc.sql("SET spark.sql.sortMergeJoinExec.buffer.spill.threshold=10000")

hc.sql("DROP TABLE IF EXISTS spark_13450_result").collect
hc.sql("""
  CREATE TABLE spark_13450_result
  AS
  SELECT
    a.i AS a_i, a.j AS a_j, a.str1 AS a_str1, a.str2 AS a_str2,
    b.i AS b_i, b.j AS b_j, b.str1 AS b_str1, b.str2 AS b_str2
  FROM
    spark_13450_one_row_table a
  JOIN
    spark_13450_large_table b
  ON
    a.i=b.i AND
    a.j=b.j
""")
```

## Performance comparison

### Macro-benchmark

I ran a SMB join query over two real world tables (2 trillion rows (40 TB) and 6 million rows (120 GB)). Note that this dataset does not have skew so no spill happened. I saw improvement in CPU time by 2-4% over version without this PR. This did not add up as I was expected some regression. I think allocating array of capacity of 128 at the start (instead of starting with default size 16) is the sole reason for the perf. gain : https://github.com/tejasapatil/spark/blob/SPARK-13450_smb_buffer_oom/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala#L43 . I could remove that and rerun, but effectively the change will be deployed in this form and I wanted to see the effect of it over large workload.

### Micro-benchmark

Two types of benchmarking can be found in `ExternalAppendOnlyUnsafeRowArrayBenchmark`:

[A] Comparing `ExternalAppendOnlyUnsafeRowArray` against raw `ArrayBuffer` when all rows fit in-memory and there is no spill

```
Array with 1000 rows:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
ArrayBuffer                                   7821 / 7941         33.5          29.8       1.0X
ExternalAppendOnlyUnsafeRowArray              8798 / 8819         29.8          33.6       0.9X

Array with 30000 rows:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
ArrayBuffer                                 19200 / 19206         25.6          39.1       1.0X
ExternalAppendOnlyUnsafeRowArray            19558 / 19562         25.1          39.8       1.0X

Array with 100000 rows:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
ArrayBuffer                                   5949 / 6028         17.2          58.1       1.0X
ExternalAppendOnlyUnsafeRowArray              6078 / 6138         16.8          59.4       1.0X
```

[B] Comparing `ExternalAppendOnlyUnsafeRowArray` against raw `UnsafeExternalSorter` when there is spilling of data

```
Spilling with 1000 rows:                 Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
UnsafeExternalSorter                          9239 / 9470         28.4          35.2       1.0X
ExternalAppendOnlyUnsafeRowArray              8857 / 8909         29.6          33.8       1.0X

Spilling with 10000 rows:                Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
UnsafeExternalSorter                             4 /    5         39.3          25.5       1.0X
ExternalAppendOnlyUnsafeRowArray                 5 /    6         29.8          33.5       0.8X
```

Author: Tejas Patil <tejasp@fb.com>

Closes #16909 from tejasapatil/SPARK-13450_smb_buffer_oom.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/02c274ea
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/02c274ea
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/02c274ea

Branch: refs/heads/master
Commit: 02c274eaba0a8e7611226e0d4e93d3c36253f4ce
Parents: 7387126
Author: Tejas Patil <tejasp@fb.com>
Authored: Wed Mar 15 20:18:39 2017 +0100
Committer: Herman van Hovell <hvanhovell@databricks.com>
Committed: Wed Mar 15 20:18:39 2017 +0100

----------------------------------------------------------------------
 .../org/apache/spark/sql/internal/SQLConf.scala |  30 ++
 .../ExternalAppendOnlyUnsafeRowArray.scala      | 243 +++++++++++++
 .../execution/joins/CartesianProductExec.scala  |  52 +--
 .../sql/execution/joins/SortMergeJoinExec.scala | 117 ++++---
 .../spark/sql/execution/window/RowBuffer.scala  | 115 ------
 .../spark/sql/execution/window/WindowExec.scala |  72 ++--
 .../execution/window/WindowFunctionFrame.scala  |  97 +++--
 .../scala/org/apache/spark/sql/JoinSuite.scala  | 136 ++++++-
 ...ernalAppendOnlyUnsafeRowArrayBenchmark.scala | 233 ++++++++++++
 .../ExternalAppendOnlyUnsafeRowArraySuite.scala | 351 +++++++++++++++++++
 .../sql/execution/SQLWindowFunctionSuite.scala  |  33 ++
 11 files changed, 1187 insertions(+), 292 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/02c274ea/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 8f65672..a85f87a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
 import org.apache.spark.network.util.ByteUnit
 import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 // This file defines the configuration options for Spark SQL.
@@ -715,6 +716,27 @@ object SQLConf {
       .stringConf
       .createWithDefault(TimeZone.getDefault().getID())
 
+  val WINDOW_EXEC_BUFFER_SPILL_THRESHOLD =
+    buildConf("spark.sql.windowExec.buffer.spill.threshold")
+      .internal()
+      .doc("Threshold for number of rows buffered in window operator")
+      .intConf
+      .createWithDefault(4096)
+
+  val SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD =
+    buildConf("spark.sql.sortMergeJoinExec.buffer.spill.threshold")
+      .internal()
+      .doc("Threshold for number of rows buffered in sort merge join operator")
+      .intConf
+      .createWithDefault(Int.MaxValue)
+
+  val CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD =
+    buildConf("spark.sql.cartesianProductExec.buffer.spill.threshold")
+      .internal()
+      .doc("Threshold for number of rows buffered in cartesian product operator")
+      .intConf
+      .createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)
+
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
   }
@@ -945,6 +967,14 @@ class SQLConf extends Serializable with Logging {
 
   def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD)
 
+  def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD)
+
+  def sortMergeJoinExecBufferSpillThreshold: Int =
+    getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD)
+
+  def cartesianProductExecBufferSpillThreshold: Int =
+    getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD)
+
   def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH)
 
   /** ********************** SQLConf functionality methods ************ */

http://git-wip-us.apache.org/repos/asf/spark/blob/02c274ea/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
new file mode 100644
index 0000000..458ac4b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.ConcurrentModificationException
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer
+import org.apache.spark.storage.BlockManager
+import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
+
+/**
+ * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined
+ * threshold of rows is reached.
+ *
+ * Setting spill threshold faces following trade-off:
+ *
+ * - If the spill threshold is too high, the in-memory array may occupy more memory than is
+ *   available, resulting in OOM.
+ * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes.
+ *   This may lead to a performance regression compared to the normal case of using an
+ *   [[ArrayBuffer]] or [[Array]].
+ */
+private[sql] class ExternalAppendOnlyUnsafeRowArray(
+    taskMemoryManager: TaskMemoryManager,
+    blockManager: BlockManager,
+    serializerManager: SerializerManager,
+    taskContext: TaskContext,
+    initialSize: Int,
+    pageSizeBytes: Long,
+    numRowsSpillThreshold: Int) extends Logging {
+
+  def this(numRowsSpillThreshold: Int) {
+    this(
+      TaskContext.get().taskMemoryManager(),
+      SparkEnv.get.blockManager,
+      SparkEnv.get.serializerManager,
+      TaskContext.get(),
+      1024,
+      SparkEnv.get.memoryManager.pageSizeBytes,
+      numRowsSpillThreshold)
+  }
+
+  private val initialSizeOfInMemoryBuffer =
+    Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold)
+
+  private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) {
+    new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer)
+  } else {
+    null
+  }
+
+  private var spillableArray: UnsafeExternalSorter = _
+  private var numRows = 0
+
+  // A counter to keep track of total modifications done to this array since its creation.
+  // This helps to invalidate iterators when there are changes done to the backing array.
+  private var modificationsCount: Long = 0
+
+  private var numFieldsPerRow = 0
+
+  def length: Int = numRows
+
+  def isEmpty: Boolean = numRows == 0
+
+  /**
+   * Clears up resources (eg. memory) held by the backing storage
+   */
+  def clear(): Unit = {
+    if (spillableArray != null) {
+      // The last `spillableArray` of this task will be cleaned up via task completion listener
+      // inside `UnsafeExternalSorter`
+      spillableArray.cleanupResources()
+      spillableArray = null
+    } else if (inMemoryBuffer != null) {
+      inMemoryBuffer.clear()
+    }
+    numFieldsPerRow = 0
+    numRows = 0
+    modificationsCount += 1
+  }
+
+  def add(unsafeRow: UnsafeRow): Unit = {
+    if (numRows < numRowsSpillThreshold) {
+      inMemoryBuffer += unsafeRow.copy()
+    } else {
+      if (spillableArray == null) {
+        logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " +
+          s"${classOf[UnsafeExternalSorter].getName}")
+
+        // We will not sort the rows, so prefixComparator and recordComparator are null
+        spillableArray = UnsafeExternalSorter.create(
+          taskMemoryManager,
+          blockManager,
+          serializerManager,
+          taskContext,
+          null,
+          null,
+          initialSize,
+          pageSizeBytes,
+          numRowsSpillThreshold,
+          false)
+
+        // populate with existing in-memory buffered rows
+        if (inMemoryBuffer != null) {
+          inMemoryBuffer.foreach(existingUnsafeRow =>
+            spillableArray.insertRecord(
+              existingUnsafeRow.getBaseObject,
+              existingUnsafeRow.getBaseOffset,
+              existingUnsafeRow.getSizeInBytes,
+              0,
+              false)
+          )
+          inMemoryBuffer.clear()
+        }
+        numFieldsPerRow = unsafeRow.numFields()
+      }
+
+      spillableArray.insertRecord(
+        unsafeRow.getBaseObject,
+        unsafeRow.getBaseOffset,
+        unsafeRow.getSizeInBytes,
+        0,
+        false)
+    }
+
+    numRows += 1
+    modificationsCount += 1
+  }
+
+  /**
+   * Creates an [[Iterator]] for the current rows in the array starting from a user provided index
+   *
+   * If there are subsequent [[add()]] or [[clear()]] calls made on this array after creation of
+   * the iterator, then the iterator is invalidated thus saving clients from thinking that they
+   * have read all the data while there were new rows added to this array.
+   */
+  def generateIterator(startIndex: Int): Iterator[UnsafeRow] = {
+    if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) {
+      throw new ArrayIndexOutOfBoundsException(
+        "Invalid `startIndex` provided for generating iterator over the array. " +
+          s"Total elements: $numRows, requested `startIndex`: $startIndex")
+    }
+
+    if (spillableArray == null) {
+      new InMemoryBufferIterator(startIndex)
+    } else {
+      new SpillableArrayIterator(spillableArray.getIterator, numFieldsPerRow, startIndex)
+    }
+  }
+
+  def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0)
+
+  private[this]
+  abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] {
+    private val expectedModificationsCount = modificationsCount
+
+    protected def isModified(): Boolean = expectedModificationsCount != modificationsCount
+
+    protected def throwExceptionIfModified(): Unit = {
+      if (expectedModificationsCount != modificationsCount) {
+        throw new ConcurrentModificationException(
+          s"The backing ${classOf[ExternalAppendOnlyUnsafeRowArray].getName} has been modified " +
+            s"since the creation of this Iterator")
+      }
+    }
+  }
+
+  private[this] class InMemoryBufferIterator(startIndex: Int)
+    extends ExternalAppendOnlyUnsafeRowArrayIterator {
+
+    private var currentIndex = startIndex
+
+    override def hasNext(): Boolean = !isModified() && currentIndex < numRows
+
+    override def next(): UnsafeRow = {
+      throwExceptionIfModified()
+      val result = inMemoryBuffer(currentIndex)
+      currentIndex += 1
+      result
+    }
+  }
+
+  private[this] class SpillableArrayIterator(
+      iterator: UnsafeSorterIterator,
+      numFieldPerRow: Int,
+      startIndex: Int)
+    extends ExternalAppendOnlyUnsafeRowArrayIterator {
+
+    private val currentRow = new UnsafeRow(numFieldPerRow)
+
+    def init(): Unit = {
+      var i = 0
+      while (i < startIndex) {
+        if (iterator.hasNext) {
+          iterator.loadNext()
+        } else {
+          throw new ArrayIndexOutOfBoundsException(
+            "Invalid `startIndex` provided for generating iterator over the array. " +
+              s"Total elements: $numRows, requested `startIndex`: $startIndex")
+        }
+        i += 1
+      }
+    }
+
+    // Traverse upto the given [[startIndex]]
+    init()
+
+    override def hasNext(): Boolean = !isModified() && iterator.hasNext
+
+    override def next(): UnsafeRow = {
+      throwExceptionIfModified()
+      iterator.loadNext()
+      currentRow.pointTo(iterator.getBaseObject, iterator.getBaseOffset, iterator.getRecordLength)
+      currentRow
+    }
+  }
+}
+
+private[sql] object ExternalAppendOnlyUnsafeRowArray {
+  val DefaultInitialSizeOfInMemoryBuffer = 128
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/02c274ea/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 8341fe2..f380986 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -19,65 +19,39 @@ package org.apache.spark.sql.execution.joins
 
 import org.apache.spark._
 import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
-import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
-import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.{BinaryExecNode, ExternalAppendOnlyUnsafeRowArray, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.CompletionIterator
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 
 /**
  * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
  * will be much faster than building the right partition for every row in left RDD, it also
  * materialize the right RDD (in case of the right RDD is nondeterministic).
  */
-class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
+class UnsafeCartesianRDD(
+    left : RDD[UnsafeRow],
+    right : RDD[UnsafeRow],
+    numFieldsOfRight: Int,
+    spillThreshold: Int)
   extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
 
   override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
-    // We will not sort the rows, so prefixComparator and recordComparator are null.
-    val sorter = UnsafeExternalSorter.create(
-      context.taskMemoryManager(),
-      SparkEnv.get.blockManager,
-      SparkEnv.get.serializerManager,
-      context,
-      null,
-      null,
-      1024,
-      SparkEnv.get.memoryManager.pageSizeBytes,
-      SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
-        UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
-      false)
+    val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
 
     val partition = split.asInstanceOf[CartesianPartition]
-    for (y <- rdd2.iterator(partition.s2, context)) {
-      sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false)
-    }
+    rdd2.iterator(partition.s2, context).foreach(rowArray.add)
 
-    // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
-    def createIter(): Iterator[UnsafeRow] = {
-      val iter = sorter.getIterator
-      val unsafeRow = new UnsafeRow(numFieldsOfRight)
-      new Iterator[UnsafeRow] {
-        override def hasNext: Boolean = {
-          iter.hasNext
-        }
-        override def next(): UnsafeRow = {
-          iter.loadNext()
-          unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
-          unsafeRow
-        }
-      }
-    }
+    // Create an iterator from rowArray
+    def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator()
 
     val resultIter =
       for (x <- rdd1.iterator(partition.s1, context);
            y <- createIter()) yield (x, y)
     CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
-      resultIter, sorter.cleanupResources())
+      resultIter, rowArray.clear())
   }
 }
 
@@ -97,7 +71,9 @@ case class CartesianProductExec(
     val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
     val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
 
-    val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
+    val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold
+
+    val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold)
     pair.mapPartitionsWithIndexInternal { (index, iter) =>
       val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
       val filtered = if (condition.isDefined) {

http://git-wip-us.apache.org/repos/asf/spark/blob/02c274ea/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index ca9c0ed..bcdc4dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan}
+import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport,
+ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.util.collection.BitSet
 
@@ -95,9 +96,13 @@ case class SortMergeJoinExec(
   private def createRightKeyGenerator(): Projection =
     UnsafeProjection.create(rightKeys, right.output)
 
+  private def getSpillThreshold: Int = {
+    sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
-
+    val spillThreshold = getSpillThreshold
     left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
       val boundCondition: (InternalRow) => Boolean = {
         condition.map { cond =>
@@ -115,39 +120,39 @@ case class SortMergeJoinExec(
         case _: InnerLike =>
           new RowIterator {
             private[this] var currentLeftRow: InternalRow = _
-            private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
-            private[this] var currentMatchIdx: Int = -1
+            private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _
+            private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
             private[this] val smjScanner = new SortMergeJoinScanner(
               createLeftKeyGenerator(),
               createRightKeyGenerator(),
               keyOrdering,
               RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter)
+              RowIterator.fromScala(rightIter),
+              spillThreshold
             )
             private[this] val joinRow = new JoinedRow
 
             if (smjScanner.findNextInnerJoinRows()) {
               currentRightMatches = smjScanner.getBufferedMatches
               currentLeftRow = smjScanner.getStreamedRow
-              currentMatchIdx = 0
+              rightMatchesIterator = currentRightMatches.generateIterator()
             }
 
             override def advanceNext(): Boolean = {
-              while (currentMatchIdx >= 0) {
-                if (currentMatchIdx == currentRightMatches.length) {
+              while (rightMatchesIterator != null) {
+                if (!rightMatchesIterator.hasNext) {
                   if (smjScanner.findNextInnerJoinRows()) {
                     currentRightMatches = smjScanner.getBufferedMatches
                     currentLeftRow = smjScanner.getStreamedRow
-                    currentMatchIdx = 0
+                    rightMatchesIterator = currentRightMatches.generateIterator()
                   } else {
                     currentRightMatches = null
                     currentLeftRow = null
-                    currentMatchIdx = -1
+                    rightMatchesIterator = null
                     return false
                   }
                 }
-                joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
-                currentMatchIdx += 1
+                joinRow(currentLeftRow, rightMatchesIterator.next())
                 if (boundCondition(joinRow)) {
                   numOutputRows += 1
                   return true
@@ -165,7 +170,8 @@ case class SortMergeJoinExec(
             bufferedKeyGenerator = createRightKeyGenerator(),
             keyOrdering,
             streamedIter = RowIterator.fromScala(leftIter),
-            bufferedIter = RowIterator.fromScala(rightIter)
+            bufferedIter = RowIterator.fromScala(rightIter),
+            spillThreshold
           )
           val rightNullRow = new GenericInternalRow(right.output.length)
           new LeftOuterIterator(
@@ -177,7 +183,8 @@ case class SortMergeJoinExec(
             bufferedKeyGenerator = createLeftKeyGenerator(),
             keyOrdering,
             streamedIter = RowIterator.fromScala(rightIter),
-            bufferedIter = RowIterator.fromScala(leftIter)
+            bufferedIter = RowIterator.fromScala(leftIter),
+            spillThreshold
           )
           val leftNullRow = new GenericInternalRow(left.output.length)
           new RightOuterIterator(
@@ -209,7 +216,8 @@ case class SortMergeJoinExec(
               createRightKeyGenerator(),
               keyOrdering,
               RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter)
+              RowIterator.fromScala(rightIter),
+              spillThreshold
             )
             private[this] val joinRow = new JoinedRow
 
@@ -217,14 +225,15 @@ case class SortMergeJoinExec(
               while (smjScanner.findNextInnerJoinRows()) {
                 val currentRightMatches = smjScanner.getBufferedMatches
                 currentLeftRow = smjScanner.getStreamedRow
-                var i = 0
-                while (i < currentRightMatches.length) {
-                  joinRow(currentLeftRow, currentRightMatches(i))
-                  if (boundCondition(joinRow)) {
-                    numOutputRows += 1
-                    return true
+                if (currentRightMatches != null && currentRightMatches.length > 0) {
+                  val rightMatchesIterator = currentRightMatches.generateIterator()
+                  while (rightMatchesIterator.hasNext) {
+                    joinRow(currentLeftRow, rightMatchesIterator.next())
+                    if (boundCondition(joinRow)) {
+                      numOutputRows += 1
+                      return true
+                    }
                   }
-                  i += 1
                 }
               }
               false
@@ -241,7 +250,8 @@ case class SortMergeJoinExec(
               createRightKeyGenerator(),
               keyOrdering,
               RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter)
+              RowIterator.fromScala(rightIter),
+              spillThreshold
             )
             private[this] val joinRow = new JoinedRow
 
@@ -249,17 +259,16 @@ case class SortMergeJoinExec(
               while (smjScanner.findNextOuterJoinRows()) {
                 currentLeftRow = smjScanner.getStreamedRow
                 val currentRightMatches = smjScanner.getBufferedMatches
-                if (currentRightMatches == null) {
+                if (currentRightMatches == null || currentRightMatches.length == 0) {
                   return true
                 }
-                var i = 0
                 var found = false
-                while (!found && i < currentRightMatches.length) {
-                  joinRow(currentLeftRow, currentRightMatches(i))
+                val rightMatchesIterator = currentRightMatches.generateIterator()
+                while (!found && rightMatchesIterator.hasNext) {
+                  joinRow(currentLeftRow, rightMatchesIterator.next())
                   if (boundCondition(joinRow)) {
                     found = true
                   }
-                  i += 1
                 }
                 if (!found) {
                   numOutputRows += 1
@@ -281,7 +290,8 @@ case class SortMergeJoinExec(
               createRightKeyGenerator(),
               keyOrdering,
               RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter)
+              RowIterator.fromScala(rightIter),
+              spillThreshold
             )
             private[this] val joinRow = new JoinedRow
 
@@ -290,14 +300,13 @@ case class SortMergeJoinExec(
                 currentLeftRow = smjScanner.getStreamedRow
                 val currentRightMatches = smjScanner.getBufferedMatches
                 var found = false
-                if (currentRightMatches != null) {
-                  var i = 0
-                  while (!found && i < currentRightMatches.length) {
-                    joinRow(currentLeftRow, currentRightMatches(i))
+                if (currentRightMatches != null && currentRightMatches.length > 0) {
+                  val rightMatchesIterator = currentRightMatches.generateIterator()
+                  while (!found && rightMatchesIterator.hasNext) {
+                    joinRow(currentLeftRow, rightMatchesIterator.next())
                     if (boundCondition(joinRow)) {
                       found = true
                     }
-                    i += 1
                   }
                 }
                 result.setBoolean(0, found)
@@ -376,8 +385,11 @@ case class SortMergeJoinExec(
 
     // A list to hold all matched rows from right side.
     val matches = ctx.freshName("matches")
-    val clsName = classOf[java.util.ArrayList[InternalRow]].getName
-    ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
+    val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
+
+    val spillThreshold = getSpillThreshold
+
+    ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);")
     // Copy the left keys as class members so they could be used in next function call.
     val matchedKeyVars = copyKeys(ctx, leftKeyVars)
 
@@ -428,7 +440,7 @@ case class SortMergeJoinExec(
          |        }
          |        $leftRow = null;
          |      } else {
-         |        $matches.add($rightRow.copy());
+         |        $matches.add((UnsafeRow) $rightRow);
          |        $rightRow = null;;
          |      }
          |    } while ($leftRow != null);
@@ -517,8 +529,7 @@ case class SortMergeJoinExec(
     val rightRow = ctx.freshName("rightRow")
     val rightVars = createRightVar(ctx, rightRow)
 
-    val size = ctx.freshName("size")
-    val i = ctx.freshName("i")
+    val iterator = ctx.freshName("iterator")
     val numOutput = metricTerm(ctx, "numOutputRows")
     val (beforeLoop, condCheck) = if (condition.isDefined) {
       // Split the code of creating variables based on whether it's used by condition or not.
@@ -551,10 +562,10 @@ case class SortMergeJoinExec(
 
     s"""
        |while (findNextInnerJoinRows($leftInput, $rightInput)) {
-       |  int $size = $matches.size();
        |  ${beforeLoop.trim}
-       |  for (int $i = 0; $i < $size; $i ++) {
-       |    InternalRow $rightRow = (InternalRow) $matches.get($i);
+       |  scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
+       |  while ($iterator.hasNext()) {
+       |    InternalRow $rightRow = (InternalRow) $iterator.next();
        |    ${condCheck.trim}
        |    $numOutput.add(1);
        |    ${consume(ctx, leftVars ++ rightVars)}
@@ -589,7 +600,8 @@ private[joins] class SortMergeJoinScanner(
     bufferedKeyGenerator: Projection,
     keyOrdering: Ordering[InternalRow],
     streamedIter: RowIterator,
-    bufferedIter: RowIterator) {
+    bufferedIter: RowIterator,
+    bufferThreshold: Int) {
   private[this] var streamedRow: InternalRow = _
   private[this] var streamedRowKey: InternalRow = _
   private[this] var bufferedRow: InternalRow = _
@@ -600,7 +612,7 @@ private[joins] class SortMergeJoinScanner(
    */
   private[this] var matchJoinKey: InternalRow = _
   /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
-  private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+  private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold)
 
   // Initialization (note: do _not_ want to advance streamed here).
   advancedBufferedToRowWithNullFreeJoinKey()
@@ -609,7 +621,7 @@ private[joins] class SortMergeJoinScanner(
 
   def getStreamedRow: InternalRow = streamedRow
 
-  def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
+  def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches
 
   /**
    * Advances both input iterators, stopping when we have found rows with matching join keys.
@@ -755,7 +767,7 @@ private[joins] class SortMergeJoinScanner(
     matchJoinKey = streamedRowKey.copy()
     bufferedMatches.clear()
     do {
-      bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them
+      bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow])
       advancedBufferedToRowWithNullFreeJoinKey()
     } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
   }
@@ -819,7 +831,7 @@ private abstract class OneSideOuterIterator(
   protected[this] val joinedRow: JoinedRow = new JoinedRow()
 
   // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
-  private[this] var bufferIndex: Int = 0
+  private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
 
   // This iterator is initialized lazily so there should be no matches initially
   assert(smjScanner.getBufferedMatches.length == 0)
@@ -833,7 +845,7 @@ private abstract class OneSideOuterIterator(
    * @return whether there are more rows in the stream to consume.
    */
   private def advanceStream(): Boolean = {
-    bufferIndex = 0
+    rightMatchesIterator = null
     if (smjScanner.findNextOuterJoinRows()) {
       setStreamSideOutput(smjScanner.getStreamedRow)
       if (smjScanner.getBufferedMatches.isEmpty) {
@@ -858,10 +870,13 @@ private abstract class OneSideOuterIterator(
    */
   private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
     var foundMatch: Boolean = false
-    while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
-      setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
+    if (rightMatchesIterator == null) {
+      rightMatchesIterator = smjScanner.getBufferedMatches.generateIterator()
+    }
+
+    while (!foundMatch && rightMatchesIterator.hasNext) {
+      setBufferedSideOutput(rightMatchesIterator.next())
       foundMatch = boundCondition(joinedRow)
-      bufferIndex += 1
     }
     foundMatch
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/02c274ea/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala
deleted file mode 100644
index ee36c84..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.window
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
-
-
-/**
- * The interface of row buffer for a partition. In absence of a buffer pool (with locking), the
- * row buffer is used to materialize a partition of rows since we need to repeatedly scan these
- * rows in window function processing.
- */
-private[window] abstract class RowBuffer {
-
-  /** Number of rows. */
-  def size: Int
-
-  /** Return next row in the buffer, null if no more left. */
-  def next(): InternalRow
-
-  /** Skip the next `n` rows. */
-  def skip(n: Int): Unit
-
-  /** Return a new RowBuffer that has the same rows. */
-  def copy(): RowBuffer
-}
-
-/**
- * A row buffer based on ArrayBuffer (the number of rows is limited).
- */
-private[window] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer {
-
-  private[this] var cursor: Int = -1
-
-  /** Number of rows. */
-  override def size: Int = buffer.length
-
-  /** Return next row in the buffer, null if no more left. */
-  override def next(): InternalRow = {
-    cursor += 1
-    if (cursor < buffer.length) {
-      buffer(cursor)
-    } else {
-      null
-    }
-  }
-
-  /** Skip the next `n` rows. */
-  override def skip(n: Int): Unit = {
-    cursor += n
-  }
-
-  /** Return a new RowBuffer that has the same rows. */
-  override def copy(): RowBuffer = {
-    new ArrayRowBuffer(buffer)
-  }
-}
-
-/**
- * An external buffer of rows based on UnsafeExternalSorter.
- */
-private[window] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int)
-  extends RowBuffer {
-
-  private[this] val iter: UnsafeSorterIterator = sorter.getIterator
-
-  private[this] val currentRow = new UnsafeRow(numFields)
-
-  /** Number of rows. */
-  override def size: Int = iter.getNumRecords()
-
-  /** Return next row in the buffer, null if no more left. */
-  override def next(): InternalRow = {
-    if (iter.hasNext) {
-      iter.loadNext()
-      currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
-      currentRow
-    } else {
-      null
-    }
-  }
-
-  /** Skip the next `n` rows. */
-  override def skip(n: Int): Unit = {
-    var i = 0
-    while (i < n && iter.hasNext) {
-      iter.loadNext()
-      i += 1
-    }
-  }
-
-  /** Return a new RowBuffer that has the same rows. */
-  override def copy(): RowBuffer = {
-    new ExternalRowBuffer(sorter, numFields)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/02c274ea/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index 80b87d5..950a679 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -20,15 +20,13 @@ package org.apache.spark.sql.execution.window
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.types.IntegerType
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 
 /**
  * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted)
@@ -284,6 +282,7 @@ case class WindowExec(
     // Unwrap the expressions and factories from the map.
     val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
     val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
+    val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold
 
     // Start processing.
     child.execute().mapPartitions { stream =>
@@ -310,10 +309,12 @@ case class WindowExec(
         fetchNextRow()
 
         // Manage the current partition.
-        val rows = ArrayBuffer.empty[UnsafeRow]
         val inputFields = child.output.length
-        var sorter: UnsafeExternalSorter = null
-        var rowBuffer: RowBuffer = null
+
+        val buffer: ExternalAppendOnlyUnsafeRowArray =
+          new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
+        var bufferIterator: Iterator[UnsafeRow] = _
+
         val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType))
         val frames = factories.map(_(windowFunctionResult))
         val numFrames = frames.length
@@ -323,78 +324,43 @@ case class WindowExec(
           val currentGroup = nextGroup.copy()
 
           // clear last partition
-          if (sorter != null) {
-            // the last sorter of this task will be cleaned up via task completion listener
-            sorter.cleanupResources()
-            sorter = null
-          } else {
-            rows.clear()
-          }
+          buffer.clear()
 
           while (nextRowAvailable && nextGroup == currentGroup) {
-            if (sorter == null) {
-              rows += nextRow.copy()
-
-              if (rows.length >= 4096) {
-                // We will not sort the rows, so prefixComparator and recordComparator are null.
-                sorter = UnsafeExternalSorter.create(
-                  TaskContext.get().taskMemoryManager(),
-                  SparkEnv.get.blockManager,
-                  SparkEnv.get.serializerManager,
-                  TaskContext.get(),
-                  null,
-                  null,
-                  1024,
-                  SparkEnv.get.memoryManager.pageSizeBytes,
-                  SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
-                    UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
-                  false)
-                rows.foreach { r =>
-                  sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false)
-                }
-                rows.clear()
-              }
-            } else {
-              sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
-                nextRow.getSizeInBytes, 0, false)
-            }
+            buffer.add(nextRow)
             fetchNextRow()
           }
-          if (sorter != null) {
-            rowBuffer = new ExternalRowBuffer(sorter, inputFields)
-          } else {
-            rowBuffer = new ArrayRowBuffer(rows)
-          }
 
           // Setup the frames.
           var i = 0
           while (i < numFrames) {
-            frames(i).prepare(rowBuffer.copy())
+            frames(i).prepare(buffer)
             i += 1
           }
 
           // Setup iteration
           rowIndex = 0
-          rowsSize = rowBuffer.size
+          bufferIterator = buffer.generateIterator()
         }
 
         // Iteration
         var rowIndex = 0
-        var rowsSize = 0L
 
-        override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable
+        override final def hasNext: Boolean =
+          (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable
 
         val join = new JoinedRow
         override final def next(): InternalRow = {
           // Load the next partition if we need to.
-          if (rowIndex >= rowsSize && nextRowAvailable) {
+          if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) {
             fetchNextPartition()
           }
 
-          if (rowIndex < rowsSize) {
+          if (bufferIterator.hasNext) {
+            val current = bufferIterator.next()
+
             // Get the results for the window frames.
             var i = 0
-            val current = rowBuffer.next()
             while (i < numFrames) {
               frames(i).write(rowIndex, current)
               i += 1
@@ -406,7 +372,9 @@ case class WindowExec(
 
             // Return the projection.
             result(join)
-          } else throw new NoSuchElementException
+          } else {
+            throw new NoSuchElementException
+          }
         }
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/02c274ea/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 70efc0f..af2b4fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -22,6 +22,7 @@ import java.util
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
 
 
 /**
@@ -35,7 +36,7 @@ private[window] abstract class WindowFunctionFrame {
    *
    * @param rows to calculate the frame results for.
    */
-  def prepare(rows: RowBuffer): Unit
+  def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit
 
   /**
    * Write the current results to the target row.
@@ -43,6 +44,12 @@ private[window] abstract class WindowFunctionFrame {
   def write(index: Int, current: InternalRow): Unit
 }
 
+object WindowFunctionFrame {
+  def getNextOrNull(iterator: Iterator[UnsafeRow]): UnsafeRow = {
+    if (iterator.hasNext) iterator.next() else null
+  }
+}
+
 /**
  * The offset window frame calculates frames containing LEAD/LAG statements.
  *
@@ -65,7 +72,12 @@ private[window] final class OffsetWindowFunctionFrame(
   extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: RowBuffer = null
+  private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+  /**
+   * An iterator over the [[input]]
+   */
+  private[this] var inputIterator: Iterator[UnsafeRow] = _
 
   /** Index of the input row currently used for output. */
   private[this] var inputIndex = 0
@@ -103,20 +115,21 @@ private[window] final class OffsetWindowFunctionFrame(
     newMutableProjection(boundExpressions, Nil).target(target)
   }
 
-  override def prepare(rows: RowBuffer): Unit = {
+  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     input = rows
+    inputIterator = input.generateIterator()
     // drain the first few rows if offset is larger than zero
     inputIndex = 0
     while (inputIndex < offset) {
-      input.next()
+      if (inputIterator.hasNext) inputIterator.next()
       inputIndex += 1
     }
     inputIndex = offset
   }
 
   override def write(index: Int, current: InternalRow): Unit = {
-    if (inputIndex >= 0 && inputIndex < input.size) {
-      val r = input.next()
+    if (inputIndex >= 0 && inputIndex < input.length) {
+      val r = WindowFunctionFrame.getNextOrNull(inputIterator)
       projection(r)
     } else {
       // Use default values since the offset row does not exist.
@@ -143,7 +156,12 @@ private[window] final class SlidingWindowFunctionFrame(
   extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: RowBuffer = null
+  private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+  /**
+   * An iterator over the [[input]]
+   */
+  private[this] var inputIterator: Iterator[UnsafeRow] = _
 
   /** The next row from `input`. */
   private[this] var nextRow: InternalRow = null
@@ -164,9 +182,10 @@ private[window] final class SlidingWindowFunctionFrame(
   private[this] var inputLowIndex = 0
 
   /** Prepare the frame for calculating a new partition. Reset all variables. */
-  override def prepare(rows: RowBuffer): Unit = {
+  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     input = rows
-    nextRow = rows.next()
+    inputIterator = input.generateIterator()
+    nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
     inputHighIndex = 0
     inputLowIndex = 0
     buffer.clear()
@@ -180,7 +199,7 @@ private[window] final class SlidingWindowFunctionFrame(
     // the output row upper bound.
     while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
       buffer.add(nextRow.copy())
-      nextRow = input.next()
+      nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
       inputHighIndex += 1
       bufferUpdated = true
     }
@@ -195,7 +214,7 @@ private[window] final class SlidingWindowFunctionFrame(
 
     // Only recalculate and update when the buffer changes.
     if (bufferUpdated) {
-      processor.initialize(input.size)
+      processor.initialize(input.length)
       val iter = buffer.iterator()
       while (iter.hasNext) {
         processor.update(iter.next())
@@ -222,13 +241,12 @@ private[window] final class UnboundedWindowFunctionFrame(
   extends WindowFunctionFrame {
 
   /** Prepare the frame for calculating a new partition. Process all rows eagerly. */
-  override def prepare(rows: RowBuffer): Unit = {
-    val size = rows.size
-    processor.initialize(size)
-    var i = 0
-    while (i < size) {
-      processor.update(rows.next())
-      i += 1
+  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
+    processor.initialize(rows.length)
+
+    val iterator = rows.generateIterator()
+    while (iterator.hasNext) {
+      processor.update(iterator.next())
     }
   }
 
@@ -261,7 +279,12 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
   extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: RowBuffer = null
+  private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+  /**
+   * An iterator over the [[input]]
+   */
+  private[this] var inputIterator: Iterator[UnsafeRow] = _
 
   /** The next row from `input`. */
   private[this] var nextRow: InternalRow = null
@@ -273,11 +296,15 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
   private[this] var inputIndex = 0
 
   /** Prepare the frame for calculating a new partition. */
-  override def prepare(rows: RowBuffer): Unit = {
+  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     input = rows
-    nextRow = rows.next()
     inputIndex = 0
-    processor.initialize(input.size)
+    inputIterator = input.generateIterator()
+    if (inputIterator.hasNext) {
+      nextRow = inputIterator.next()
+    }
+
+    processor.initialize(input.length)
   }
 
   /** Write the frame columns for the current row to the given target row. */
@@ -288,7 +315,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
     // the output row upper bound.
     while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) {
       processor.update(nextRow)
-      nextRow = input.next()
+      nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
       inputIndex += 1
       bufferUpdated = true
     }
@@ -323,7 +350,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
   extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: RowBuffer = null
+  private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
 
   /**
    * Index of the first input row with a value equal to or greater than the lower bound of the
@@ -332,7 +359,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
   private[this] var inputIndex = 0
 
   /** Prepare the frame for calculating a new partition. */
-  override def prepare(rows: RowBuffer): Unit = {
+  override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     input = rows
     inputIndex = 0
   }
@@ -341,25 +368,25 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
   override def write(index: Int, current: InternalRow): Unit = {
     var bufferUpdated = index == 0
 
-    // Duplicate the input to have a new iterator
-    val tmp = input.copy()
-
-    // Drop all rows from the buffer for which the input row value is smaller than
+    // Ignore all the rows from the buffer for which the input row value is smaller than
     // the output row lower bound.
-    tmp.skip(inputIndex)
-    var nextRow = tmp.next()
+    val iterator = input.generateIterator(startIndex = inputIndex)
+
+    var nextRow = WindowFunctionFrame.getNextOrNull(iterator)
     while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) {
-      nextRow = tmp.next()
       inputIndex += 1
       bufferUpdated = true
+      nextRow = WindowFunctionFrame.getNextOrNull(iterator)
     }
 
     // Only recalculate and update when the buffer changes.
     if (bufferUpdated) {
-      processor.initialize(input.size)
-      while (nextRow != null) {
+      processor.initialize(input.length)
+      if (nextRow != null) {
         processor.update(nextRow)
-        nextRow = tmp.next()
+      }
+      while (iterator.hasNext) {
+        processor.update(iterator.next())
       }
       processor.evaluate(target)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/02c274ea/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 2e00673..1a66aa8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import scala.collection.mutable.ListBuffer
 import scala.language.existentials
 
 import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.execution.joins._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
-
+import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
 
 class JoinSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
@@ -604,4 +605,137 @@ class JoinSuite extends QueryTest with SharedSQLContext {
 
     cartesianQueries.foreach(checkCartesianDetection)
   }
+
+  test("test SortMergeJoin (without spill)") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
+      "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> Int.MaxValue.toString) {
+
+      assertNotSpilled(sparkContext, "inner join") {
+        checkAnswer(
+          sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"),
+          Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil
+        )
+      }
+
+      val expected = new ListBuffer[Row]()
+      expected.append(
+        Row(1, "1", 1, 1), Row(1, "1", 1, 2),
+        Row(2, "2", 2, 1), Row(2, "2", 2, 2),
+        Row(3, "3", 3, 1), Row(3, "3", 3, 2)
+      )
+      for (i <- 4 to 100) {
+        expected.append(Row(i, i.toString, null, null))
+      }
+
+      assertNotSpilled(sparkContext, "left outer join") {
+        checkAnswer(
+          sql(
+            """
+              |SELECT
+              |  big.key, big.value, small.a, small.b
+              |FROM
+              |  testData big
+              |LEFT OUTER JOIN
+              |  testData2 small
+              |ON
+              |  big.key = small.a
+            """.stripMargin),
+          expected
+        )
+      }
+
+      assertNotSpilled(sparkContext, "right outer join") {
+        checkAnswer(
+          sql(
+            """
+              |SELECT
+              |  big.key, big.value, small.a, small.b
+              |FROM
+              |  testData2 small
+              |RIGHT OUTER JOIN
+              |  testData big
+              |ON
+              |  big.key = small.a
+            """.stripMargin),
+          expected
+        )
+      }
+    }
+  }
+
+  test("test SortMergeJoin (with spill)") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
+      "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") {
+
+      assertSpilled(sparkContext, "inner join") {
+        checkAnswer(
+          sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"),
+          Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil
+        )
+      }
+
+      val expected = new ListBuffer[Row]()
+      expected.append(
+        Row(1, "1", 1, 1), Row(1, "1", 1, 2),
+        Row(2, "2", 2, 1), Row(2, "2", 2, 2),
+        Row(3, "3", 3, 1), Row(3, "3", 3, 2)
+      )
+      for (i <- 4 to 100) {
+        expected.append(Row(i, i.toString, null, null))
+      }
+
+      assertSpilled(sparkContext, "left outer join") {
+        checkAnswer(
+          sql(
+            """
+              |SELECT
+              |  big.key, big.value, small.a, small.b
+              |FROM
+              |  testData big
+              |LEFT OUTER JOIN
+              |  testData2 small
+              |ON
+              |  big.key = small.a
+            """.stripMargin),
+          expected
+        )
+      }
+
+      assertSpilled(sparkContext, "right outer join") {
+        checkAnswer(
+          sql(
+            """
+              |SELECT
+              |  big.key, big.value, small.a, small.b
+              |FROM
+              |  testData2 small
+              |RIGHT OUTER JOIN
+              |  testData big
+              |ON
+              |  big.key = small.a
+            """.stripMargin),
+          expected
+        )
+      }
+
+      // FULL OUTER JOIN still does not use [[ExternalAppendOnlyUnsafeRowArray]]
+      // so should not cause any spill
+      assertNotSpilled(sparkContext, "full outer join") {
+        checkAnswer(
+          sql(
+            """
+              |SELECT
+              |  big.key, big.value, small.a, small.b
+              |FROM
+              |  testData2 small
+              |FULL OUTER JOIN
+              |  testData big
+              |ON
+              |  big.key = small.a
+            """.stripMargin),
+          expected
+        )
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/02c274ea/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
new file mode 100644
index 0000000..00c5f25
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark.memory.MemoryTestingUtils
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.Benchmark
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
+
+object ExternalAppendOnlyUnsafeRowArrayBenchmark {
+
+  def testAgainstRawArrayBuffer(numSpillThreshold: Int, numRows: Int, iterations: Int): Unit = {
+    val random = new java.util.Random()
+    val rows = (1 to numRows).map(_ => {
+      val row = new UnsafeRow(1)
+      row.pointTo(new Array[Byte](64), 16)
+      row.setLong(0, random.nextLong())
+      row
+    })
+
+    val benchmark = new Benchmark(s"Array with $numRows rows", iterations * numRows)
+
+    // Internally, `ExternalAppendOnlyUnsafeRowArray` will create an
+    // in-memory buffer of size `numSpillThreshold`. This will mimic that
+    val initialSize =
+      Math.min(
+        ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer,
+        numSpillThreshold)
+
+    benchmark.addCase("ArrayBuffer") { _: Int =>
+      var sum = 0L
+      for (_ <- 0L until iterations) {
+        val array = new ArrayBuffer[UnsafeRow](initialSize)
+
+        // Internally, `ExternalAppendOnlyUnsafeRowArray` will create a
+        // copy of the row. This will mimic that
+        rows.foreach(x => array += x.copy())
+
+        var i = 0
+        val n = array.length
+        while (i < n) {
+          sum = sum + array(i).getLong(0)
+          i += 1
+        }
+        array.clear()
+      }
+    }
+
+    benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
+      var sum = 0L
+      for (_ <- 0L until iterations) {
+        val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+        rows.foreach(x => array.add(x))
+
+        val iterator = array.generateIterator()
+        while (iterator.hasNext) {
+          sum = sum + iterator.next().getLong(0)
+        }
+        array.clear()
+      }
+    }
+
+    val conf = new SparkConf(false)
+    // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
+    // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
+    conf.set("spark.serializer.objectStreamReset", "1")
+    conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+
+    val sc = new SparkContext("local", "test", conf)
+    val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+    TaskContext.setTaskContext(taskContext)
+    benchmark.run()
+    sc.stop()
+  }
+
+  def testAgainstRawUnsafeExternalSorter(
+      numSpillThreshold: Int,
+      numRows: Int,
+      iterations: Int): Unit = {
+
+    val random = new java.util.Random()
+    val rows = (1 to numRows).map(_ => {
+      val row = new UnsafeRow(1)
+      row.pointTo(new Array[Byte](64), 16)
+      row.setLong(0, random.nextLong())
+      row
+    })
+
+    val benchmark = new Benchmark(s"Spilling with $numRows rows", iterations * numRows)
+
+    benchmark.addCase("UnsafeExternalSorter") { _: Int =>
+      var sum = 0L
+      for (_ <- 0L until iterations) {
+        val array = UnsafeExternalSorter.create(
+          TaskContext.get().taskMemoryManager(),
+          SparkEnv.get.blockManager,
+          SparkEnv.get.serializerManager,
+          TaskContext.get(),
+          null,
+          null,
+          1024,
+          SparkEnv.get.memoryManager.pageSizeBytes,
+          numSpillThreshold,
+          false)
+
+        rows.foreach(x =>
+          array.insertRecord(
+            x.getBaseObject,
+            x.getBaseOffset,
+            x.getSizeInBytes,
+            0,
+            false))
+
+        val unsafeRow = new UnsafeRow(1)
+        val iter = array.getIterator
+        while (iter.hasNext) {
+          iter.loadNext()
+          unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
+          sum = sum + unsafeRow.getLong(0)
+        }
+        array.cleanupResources()
+      }
+    }
+
+    benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
+      var sum = 0L
+      for (_ <- 0L until iterations) {
+        val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+        rows.foreach(x => array.add(x))
+
+        val iterator = array.generateIterator()
+        while (iterator.hasNext) {
+          sum = sum + iterator.next().getLong(0)
+        }
+        array.clear()
+      }
+    }
+
+    val conf = new SparkConf(false)
+    // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
+    // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
+    conf.set("spark.serializer.objectStreamReset", "1")
+    conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+
+    val sc = new SparkContext("local", "test", conf)
+    val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+    TaskContext.setTaskContext(taskContext)
+    benchmark.run()
+    sc.stop()
+  }
+
+  def main(args: Array[String]): Unit = {
+
+    // ========================================================================================= //
+    // WITHOUT SPILL
+    // ========================================================================================= //
+
+    val spillThreshold = 100 * 1000
+
+    /*
+    Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+    Array with 1000 rows:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    ArrayBuffer                                   7821 / 7941         33.5          29.8       1.0X
+    ExternalAppendOnlyUnsafeRowArray              8798 / 8819         29.8          33.6       0.9X
+    */
+    testAgainstRawArrayBuffer(spillThreshold, 1000, 1 << 18)
+
+    /*
+    Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+    Array with 30000 rows:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    ArrayBuffer                                 19200 / 19206         25.6          39.1       1.0X
+    ExternalAppendOnlyUnsafeRowArray            19558 / 19562         25.1          39.8       1.0X
+    */
+    testAgainstRawArrayBuffer(spillThreshold, 30 * 1000, 1 << 14)
+
+    /*
+    Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+    Array with 100000 rows:                  Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    ArrayBuffer                                   5949 / 6028         17.2          58.1       1.0X
+    ExternalAppendOnlyUnsafeRowArray              6078 / 6138         16.8          59.4       1.0X
+    */
+    testAgainstRawArrayBuffer(spillThreshold, 100 * 1000, 1 << 10)
+
+    // ========================================================================================= //
+    // WITH SPILL
+    // ========================================================================================= //
+
+    /*
+    Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+    Spilling with 1000 rows:                 Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    UnsafeExternalSorter                          9239 / 9470         28.4          35.2       1.0X
+    ExternalAppendOnlyUnsafeRowArray              8857 / 8909         29.6          33.8       1.0X
+    */
+    testAgainstRawUnsafeExternalSorter(100 * 1000, 1000, 1 << 18)
+
+    /*
+    Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+    Spilling with 10000 rows:                Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    ------------------------------------------------------------------------------------------------
+    UnsafeExternalSorter                             4 /    5         39.3          25.5       1.0X
+    ExternalAppendOnlyUnsafeRowArray                 5 /    6         29.8          33.5       0.8X
+    */
+    testAgainstRawUnsafeExternalSorter(
+      UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt, 10 * 1000, 1 << 4)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/02c274ea/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
new file mode 100644
index 0000000..53c4163
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.ConcurrentModificationException
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark._
+import org.apache.spark.memory.MemoryTestingUtils
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSparkContext {
+  private val random = new java.util.Random()
+  private var taskContext: TaskContext = _
+
+  override def afterAll(): Unit = TaskContext.unset()
+
+  private def withExternalArray(spillThreshold: Int)
+                               (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = {
+    sc = new SparkContext("local", "test", new SparkConf(false))
+
+    taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+    TaskContext.setTaskContext(taskContext)
+
+    val array = new ExternalAppendOnlyUnsafeRowArray(
+      taskContext.taskMemoryManager(),
+      SparkEnv.get.blockManager,
+      SparkEnv.get.serializerManager,
+      taskContext,
+      1024,
+      SparkEnv.get.memoryManager.pageSizeBytes,
+      spillThreshold)
+    try f(array) finally {
+      array.clear()
+    }
+  }
+
+  private def insertRow(array: ExternalAppendOnlyUnsafeRowArray): Long = {
+    val valueInserted = random.nextLong()
+
+    val row = new UnsafeRow(1)
+    row.pointTo(new Array[Byte](64), 16)
+    row.setLong(0, valueInserted)
+    array.add(row)
+    valueInserted
+  }
+
+  private def checkIfValueExists(iterator: Iterator[UnsafeRow], expectedValue: Long): Unit = {
+    assert(iterator.hasNext)
+    val actualRow = iterator.next()
+    assert(actualRow.getLong(0) == expectedValue)
+    assert(actualRow.getSizeInBytes == 16)
+  }
+
+  private def validateData(
+      array: ExternalAppendOnlyUnsafeRowArray,
+      expectedValues: ArrayBuffer[Long]): Iterator[UnsafeRow] = {
+    val iterator = array.generateIterator()
+    for (value <- expectedValues) {
+      checkIfValueExists(iterator, value)
+    }
+
+    assert(!iterator.hasNext)
+    iterator
+  }
+
+  private def populateRows(
+      array: ExternalAppendOnlyUnsafeRowArray,
+      numRowsToBePopulated: Int): ArrayBuffer[Long] = {
+    val populatedValues = new ArrayBuffer[Long]
+    populateRows(array, numRowsToBePopulated, populatedValues)
+  }
+
+  private def populateRows(
+      array: ExternalAppendOnlyUnsafeRowArray,
+      numRowsToBePopulated: Int,
+      populatedValues: ArrayBuffer[Long]): ArrayBuffer[Long] = {
+    for (_ <- 0 until numRowsToBePopulated) {
+      populatedValues.append(insertRow(array))
+    }
+    populatedValues
+  }
+
+  private def getNumBytesSpilled: Long = {
+    TaskContext.get().taskMetrics().memoryBytesSpilled
+  }
+
+  private def assertNoSpill(): Unit = {
+    assert(getNumBytesSpilled == 0)
+  }
+
+  private def assertSpill(): Unit = {
+    assert(getNumBytesSpilled > 0)
+  }
+
+  test("insert rows less than the spillThreshold") {
+    val spillThreshold = 100
+    withExternalArray(spillThreshold) { array =>
+      assert(array.isEmpty)
+
+      val expectedValues = populateRows(array, 1)
+      assert(!array.isEmpty)
+      assert(array.length == 1)
+
+      val iterator1 = validateData(array, expectedValues)
+
+      // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]])
+      // Verify that NO spill has happened
+      populateRows(array, spillThreshold - 1, expectedValues)
+      assert(array.length == spillThreshold)
+      assertNoSpill()
+
+      val iterator2 = validateData(array, expectedValues)
+
+      assert(!iterator1.hasNext)
+      assert(!iterator2.hasNext)
+    }
+  }
+
+  test("insert rows more than the spillThreshold to force spill") {
+    val spillThreshold = 100
+    withExternalArray(spillThreshold) { array =>
+      val numValuesInserted = 20 * spillThreshold
+
+      assert(array.isEmpty)
+      val expectedValues = populateRows(array, 1)
+      assert(array.length == 1)
+
+      val iterator1 = validateData(array, expectedValues)
+
+      // Populate more rows to trigger spill. Verify that spill has happened
+      populateRows(array, numValuesInserted - 1, expectedValues)
+      assert(array.length == numValuesInserted)
+      assertSpill()
+
+      val iterator2 = validateData(array, expectedValues)
+      assert(!iterator2.hasNext)
+
+      assert(!iterator1.hasNext)
+      intercept[ConcurrentModificationException](iterator1.next())
+    }
+  }
+
+  test("iterator on an empty array should be empty") {
+    withExternalArray(spillThreshold = 10) { array =>
+      val iterator = array.generateIterator()
+      assert(array.isEmpty)
+      assert(array.length == 0)
+      assert(!iterator.hasNext)
+    }
+  }
+
+  test("generate iterator with negative start index") {
+    withExternalArray(spillThreshold = 2) { array =>
+      val exception =
+        intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10))
+
+      assert(exception.getMessage.contains(
+        "Invalid `startIndex` provided for generating iterator over the array")
+      )
+    }
+  }
+
+  test("generate iterator with start index exceeding array's size (without spill)") {
+    val spillThreshold = 2
+    withExternalArray(spillThreshold) { array =>
+      populateRows(array, spillThreshold / 2)
+
+      val exception =
+        intercept[ArrayIndexOutOfBoundsException](
+          array.generateIterator(startIndex = spillThreshold * 10))
+      assert(exception.getMessage.contains(
+        "Invalid `startIndex` provided for generating iterator over the array"))
+    }
+  }
+
+  test("generate iterator with start index exceeding array's size (with spill)") {
+    val spillThreshold = 2
+    withExternalArray(spillThreshold) { array =>
+      populateRows(array, spillThreshold * 2)
+
+      val exception =
+        intercept[ArrayIndexOutOfBoundsException](
+          array.generateIterator(startIndex = spillThreshold * 10))
+
+      assert(exception.getMessage.contains(
+        "Invalid `startIndex` provided for generating iterator over the array"))
+    }
+  }
+
+  test("generate iterator with custom start index (without spill)") {
+    val spillThreshold = 10
+    withExternalArray(spillThreshold) { array =>
+      val expectedValues = populateRows(array, spillThreshold)
+      val startIndex = spillThreshold / 2
+      val iterator = array.generateIterator(startIndex = startIndex)
+      for (i <- startIndex until expectedValues.length) {
+        checkIfValueExists(iterator, expectedValues(i))
+      }
+    }
+  }
+
+  test("generate iterator with custom start index (with spill)") {
+    val spillThreshold = 10
+    withExternalArray(spillThreshold) { array =>
+      val expectedValues = populateRows(array, spillThreshold * 10)
+      val startIndex = spillThreshold * 2
+      val iterator = array.generateIterator(startIndex = startIndex)
+      for (i <- startIndex until expectedValues.length) {
+        checkIfValueExists(iterator, expectedValues(i))
+      }
+    }
+  }
+
+  test("test iterator invalidation (without spill)") {
+    withExternalArray(spillThreshold = 10) { array =>
+      // insert 2 rows, iterate until the first row
+      populateRows(array, 2)
+
+      var iterator = array.generateIterator()
+      assert(iterator.hasNext)
+      iterator.next()
+
+      // Adding more row(s) should invalidate any old iterators
+      populateRows(array, 1)
+      assert(!iterator.hasNext)
+      intercept[ConcurrentModificationException](iterator.next())
+
+      // Clearing the array should also invalidate any old iterators
+      iterator = array.generateIterator()
+      assert(iterator.hasNext)
+      iterator.next()
+
+      array.clear()
+      assert(!iterator.hasNext)
+      intercept[ConcurrentModificationException](iterator.next())
+    }
+  }
+
+  test("test iterator invalidation (with spill)") {
+    val spillThreshold = 10
+    withExternalArray(spillThreshold) { array =>
+      // Populate enough rows so that spill has happens
+      populateRows(array, spillThreshold * 2)
+      assertSpill()
+
+      var iterator = array.generateIterator()
+      assert(iterator.hasNext)
+      iterator.next()
+
+      // Adding more row(s) should invalidate any old iterators
+      populateRows(array, 1)
+      assert(!iterator.hasNext)
+      intercept[ConcurrentModificationException](iterator.next())
+
+      // Clearing the array should also invalidate any old iterators
+      iterator = array.generateIterator()
+      assert(iterator.hasNext)
+      iterator.next()
+
+      array.clear()
+      assert(!iterator.hasNext)
+      intercept[ConcurrentModificationException](iterator.next())
+    }
+  }
+
+  test("clear on an empty the array") {
+    withExternalArray(spillThreshold = 2) { array =>
+      val iterator = array.generateIterator()
+      assert(!iterator.hasNext)
+
+      // multiple clear'ing should not have an side-effect
+      array.clear()
+      array.clear()
+      array.clear()
+      assert(array.isEmpty)
+      assert(array.length == 0)
+
+      // Clearing an empty array should also invalidate any old iterators
+      assert(!iterator.hasNext)
+      intercept[ConcurrentModificationException](iterator.next())
+    }
+  }
+
+  test("clear array (without spill)") {
+    val spillThreshold = 10
+    withExternalArray(spillThreshold) { array =>
+      // Populate rows ... but not enough to trigger spill
+      populateRows(array, spillThreshold / 2)
+      assertNoSpill()
+
+      // Clear the array
+      array.clear()
+      assert(array.isEmpty)
+
+      // Re-populate few rows so that there is no spill
+      // Verify the data. Verify that there was no spill
+      val expectedValues = populateRows(array, spillThreshold / 3)
+      validateData(array, expectedValues)
+      assertNoSpill()
+
+      // Populate more rows .. enough to not trigger a spill.
+      // Verify the data. Verify that there was no spill
+      populateRows(array, spillThreshold / 3, expectedValues)
+      validateData(array, expectedValues)
+      assertNoSpill()
+    }
+  }
+
+  test("clear array (with spill)") {
+    val spillThreshold = 10
+    withExternalArray(spillThreshold) { array =>
+      // Populate enough rows to trigger spill
+      populateRows(array, spillThreshold * 2)
+      val bytesSpilled = getNumBytesSpilled
+      assert(bytesSpilled > 0)
+
+      // Clear the array
+      array.clear()
+      assert(array.isEmpty)
+
+      // Re-populate the array ... but NOT upto the point that there is spill.
+      // Verify data. Verify that there was NO "extra" spill
+      val expectedValues = populateRows(array, spillThreshold / 2)
+      validateData(array, expectedValues)
+      assert(getNumBytesSpilled == bytesSpilled)
+
+      // Populate more rows to trigger spill
+      // Verify the data. Verify that there was "extra" spill
+      populateRows(array, spillThreshold * 2, expectedValues)
+      validateData(array, expectedValues)
+      assert(getNumBytesSpilled > bytesSpilled)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/02c274ea/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index afd4789..52e4f04 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.TestUtils.assertSpilled
 
 case class WindowData(month: Int, area: String, product: Int)
 
@@ -412,4 +413,36 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
       """.stripMargin),
       Row(1, 3, null) :: Row(2, null, 4) :: Nil)
   }
+
+  test("test with low buffer spill threshold") {
+    val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
+    nums.createOrReplaceTempView("nums")
+
+    val expected =
+      Row(1, 1, 1) ::
+        Row(0, 2, 3) ::
+        Row(1, 3, 6) ::
+        Row(0, 4, 10) ::
+        Row(1, 5, 15) ::
+        Row(0, 6, 21) ::
+        Row(1, 7, 28) ::
+        Row(0, 8, 36) ::
+        Row(1, 9, 45) ::
+        Row(0, 10, 55) :: Nil
+
+    val actual = sql(
+      """
+        |SELECT y, x, sum(x) OVER w1 AS running_sum
+        |FROM nums
+        |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW)
+      """.stripMargin)
+
+    withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") {
+      assertSpilled(sparkContext, "test with low buffer spill threshold") {
+        checkAnswer(actual, expected)
+      }
+    }
+
+    spark.catalog.dropTempView("nums")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message