spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From andrewo...@apache.org
Subject spark git commit: [SPARK-11078] Ensure spilling tests actually spill
Date Thu, 15 Oct 2015 21:50:04 GMT
Repository: spark
Updated Branches:
  refs/heads/master 2d000124b -> 3b364ff0a


[SPARK-11078] Ensure spilling tests actually spill

#9084 uncovered that many tests that test spilling don't actually spill. This is a follow-up patch to fix that to ensure our unit tests actually catch potential bugs in spilling. The size of this patch is inflated by the refactoring of `ExternalSorterSuite`, which had a lot of duplicate code and logic.

Author: Andrew Or <andrew@databricks.com>

Closes #9124 from andrewor14/spilling-tests.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3b364ff0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3b364ff0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3b364ff0

Branch: refs/heads/master
Commit: 3b364ff0a4f38c2b8023429a55623de32be5f329
Parents: 2d00012
Author: Andrew Or <andrew@databricks.com>
Authored: Thu Oct 15 14:50:01 2015 -0700
Committer: Andrew Or <andrew@databricks.com>
Committed: Thu Oct 15 14:50:01 2015 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/TestUtils.scala |  51 ++
 .../spark/shuffle/ShuffleMemoryManager.scala    |   6 +-
 .../util/collection/ExternalAppendOnlyMap.scala |   6 +
 .../spark/util/collection/Spillable.scala       |  37 +-
 .../org/apache/spark/DistributedSuite.scala     |  39 +-
 .../collection/ExternalAppendOnlyMapSuite.scala | 103 ++-
 .../util/collection/ExternalSorterSuite.scala   | 871 ++++++++-----------
 .../execution/TestShuffleMemoryManager.scala    |   2 +
 8 files changed, 534 insertions(+), 581 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3b364ff0/core/src/main/scala/org/apache/spark/TestUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index 888763a..acfe751 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -24,10 +24,14 @@ import java.util.Arrays
 import java.util.jar.{JarEntry, JarOutputStream}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 
 import com.google.common.io.{ByteStreams, Files}
 import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
 
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler._
 import org.apache.spark.util.Utils
 
 /**
@@ -154,4 +158,51 @@ private[spark] object TestUtils {
       "  @Override public String toString() { return \"" + toStringValue + "\"; }}")
     createCompiledClass(className, destDir, sourceFile, classpathUrls)
   }
+
+  /**
+   * Run some code involving jobs submitted to the given context and assert that the jobs spilled.
+   */
+  def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = {
+    val spillListener = new SpillListener
+    sc.addSparkListener(spillListener)
+    body
+    assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not")
+  }
+
+  /**
+   * Run some code involving jobs submitted to the given context and assert that the jobs
+   * did not spill.
+   */
+  def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = {
+    val spillListener = new SpillListener
+    sc.addSparkListener(spillListener)
+    body
+    assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did")
+  }
+
+}
+
+
+/**
+ * A [[SparkListener]] that detects whether spills have occurred in Spark jobs.
+ */
+private class SpillListener extends SparkListener {
+  private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]]
+  private val spilledStageIds = new mutable.HashSet[Int]
+
+  def numSpilledStages: Int = spilledStageIds.size
+
+  override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+    stageIdToTaskMetrics.getOrElseUpdate(
+      taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics
+  }
+
+  override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = {
+    val stageId = stageComplete.stageInfo.stageId
+    val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten
+    val spilled = metrics.map(_.memoryBytesSpilled).sum > 0
+    if (spilled) {
+      spilledStageIds += stageId
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3b364ff0/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index aaf543c..9bd18da 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -139,8 +139,10 @@ class ShuffleMemoryManager protected (
       throw new SparkException(
         s"Internal error: release called on $numBytes bytes but task only has $curMem")
     }
-    taskMemory(taskAttemptId) -= numBytes
-    memoryManager.releaseExecutionMemory(numBytes)
+    if (taskMemory.contains(taskAttemptId)) {
+      taskMemory(taskAttemptId) -= numBytes
+      memoryManager.releaseExecutionMemory(numBytes)
+    }
     memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3b364ff0/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 6a96b5d..cfa58f5 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -96,6 +96,12 @@ class ExternalAppendOnlyMap[K, V, C](
   private val ser = serializer.newInstance()
 
   /**
+   * Number of files this map has spilled so far.
+   * Exposed for testing.
+   */
+  private[collection] def numSpills: Int = spilledMaps.size
+
+  /**
    * Insert the given key and value into the map.
    */
   def insert(key: K, value: V): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/3b364ff0/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index 747ecf0..d2a68ca 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -43,10 +43,15 @@ private[spark] trait Spillable[C] extends Logging {
   private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
 
   // Initial threshold for the size of a collection before we start tracking its memory usage
-  // Exposed for testing
+  // For testing only
   private[this] val initialMemoryThreshold: Long =
     SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)
 
+  // Force this collection to spill when there are this many elements in memory
+  // For testing only
+  private[this] val numElementsForceSpillThreshold: Long =
+    SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue)
+
   // Threshold for this collection's size in bytes before we start tracking its memory usage
   // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
   private[this] var myMemoryThreshold = initialMemoryThreshold
@@ -69,27 +74,27 @@ private[spark] trait Spillable[C] extends Logging {
    * @return true if `collection` was spilled to disk; false otherwise
    */
   protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
+    var shouldSpill = false
     if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
       // Claim up to double our current memory from the shuffle memory pool
       val amountToRequest = 2 * currentMemory - myMemoryThreshold
       val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
       myMemoryThreshold += granted
-      if (myMemoryThreshold <= currentMemory) {
-        // We were granted too little memory to grow further (either tryToAcquire returned 0,
-        // or we already had more memory than myMemoryThreshold); spill the current collection
-        _spillCount += 1
-        logSpillage(currentMemory)
-
-        spill(collection)
-
-        _elementsRead = 0
-        // Keep track of spills, and release memory
-        _memoryBytesSpilled += currentMemory
-        releaseMemoryForThisThread()
-        return true
-      }
+      // If we were granted too little memory to grow further (either tryToAcquire returned 0,
+      // or we already had more memory than myMemoryThreshold), spill the current collection
+      shouldSpill = currentMemory >= myMemoryThreshold
+    }
+    shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
+    // Actually spill
+    if (shouldSpill) {
+      _spillCount += 1
+      logSpillage(currentMemory)
+      spill(collection)
+      _elementsRead = 0
+      _memoryBytesSpilled += currentMemory
+      releaseMemoryForThisThread()
     }
-    false
+    shouldSpill
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/3b364ff0/core/src/test/scala/org/apache/spark/DistributedSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 34a4bb9..1c3f2bc 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -203,22 +203,35 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
   }
 
   test("compute without caching when no partitions fit in memory") {
-    sc = new SparkContext(clusterUrl, "test")
-    // data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache
-    // to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory
-    val data = sc.parallelize(1 to 4000000, 2).persist(StorageLevel.MEMORY_ONLY_SER)
-    assert(data.count() === 4000000)
-    assert(data.count() === 4000000)
-    assert(data.count() === 4000000)
+    val size = 10000
+    val conf = new SparkConf()
+      .set("spark.storage.unrollMemoryThreshold", "1024")
+      .set("spark.testing.memory", (size / 2).toString)
+    sc = new SparkContext(clusterUrl, "test", conf)
+    val data = sc.parallelize(1 to size, 2).persist(StorageLevel.MEMORY_ONLY)
+    assert(data.count() === size)
+    assert(data.count() === size)
+    assert(data.count() === size)
+    // ensure only a subset of partitions were cached
+    val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true)
+    assert(rddBlocks.size === 0, s"expected no RDD blocks, found ${rddBlocks.size}")
   }
 
   test("compute when only some partitions fit in memory") {
-    sc = new SparkContext(clusterUrl, "test", new SparkConf)
-    // TODO: verify that only a subset of partitions fit in memory (SPARK-11078)
-    val data = sc.parallelize(1 to 4000000, 20).persist(StorageLevel.MEMORY_ONLY_SER)
-    assert(data.count() === 4000000)
-    assert(data.count() === 4000000)
-    assert(data.count() === 4000000)
+    val size = 10000
+    val numPartitions = 10
+    val conf = new SparkConf()
+      .set("spark.storage.unrollMemoryThreshold", "1024")
+      .set("spark.testing.memory", (size * numPartitions).toString)
+    sc = new SparkContext(clusterUrl, "test", conf)
+    val data = sc.parallelize(1 to size, numPartitions).persist(StorageLevel.MEMORY_ONLY)
+    assert(data.count() === size)
+    assert(data.count() === size)
+    assert(data.count() === size)
+    // ensure only a subset of partitions were cached
+    val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true)
+    assert(rddBlocks.size > 0, "no RDD blocks found")
+    assert(rddBlocks.size < numPartitions, s"too many RDD blocks found, expected <$numPartitions")
   }
 
   test("passing environment variables to cluster") {

http://git-wip-us.apache.org/repos/asf/spark/blob/3b364ff0/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 0a03c32..5cb506e 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark._
 import org.apache.spark.io.CompressionCodec
 
-// TODO: some of these spilling tests probably aren't actually spilling (SPARK-11078)
 
 class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
+  import TestUtils.{assertNotSpilled, assertSpilled}
+
   private val allCompressionCodecs = CompressionCodec.ALL_COMPRESSION_CODECS
   private def createCombiner[T](i: T) = ArrayBuffer[T](i)
   private def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i
@@ -244,54 +245,53 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
    * If a compression codec is provided, use it. Otherwise, do not compress spills.
    */
   private def testSimpleSpilling(codec: Option[String] = None): Unit = {
+    val size = 1000
     val conf = createSparkConf(loadDefaults = true, codec)  // Load defaults for Spark home
+    conf.set("spark.shuffle.manager", "hash") // avoid using external sorter
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
 
-    // reduceByKey - should spill ~8 times
-    val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
-    val resultA = rddA.reduceByKey(math.max).collect()
-    assert(resultA.length === 50000)
-    resultA.foreach { case (k, v) =>
-      assert(v === k * 2 + 1, s"Value for $k was wrong: expected ${k * 2 + 1}, got $v")
+    assertSpilled(sc, "reduceByKey") {
+      val result = sc.parallelize(0 until size)
+        .map { i => (i / 2, i) }.reduceByKey(math.max).collect()
+      assert(result.length === size / 2)
+      result.foreach { case (k, v) =>
+        val expected = k * 2 + 1
+        assert(v === expected, s"Value for $k was wrong: expected $expected, got $v")
+      }
     }
 
-    // groupByKey - should spill ~17 times
-    val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
-    val resultB = rddB.groupByKey().collect()
-    assert(resultB.length === 25000)
-    resultB.foreach { case (i, seq) =>
-      val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
-      assert(seq.toSet === expected,
-        s"Value for $i was wrong: expected $expected, got ${seq.toSet}")
+    assertSpilled(sc, "groupByKey") {
+      val result = sc.parallelize(0 until size).map { i => (i / 2, i) }.groupByKey().collect()
+      assert(result.length == size / 2)
+      result.foreach { case (i, seq) =>
+        val actual = seq.toSet
+        val expected = Set(i * 2, i * 2 + 1)
+        assert(actual === expected, s"Value for $i was wrong: expected $expected, got $actual")
+      }
     }
 
-    // cogroup - should spill ~7 times
-    val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
-    val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
-    val resultC = rddC1.cogroup(rddC2).collect()
-    assert(resultC.length === 10000)
-    resultC.foreach { case (i, (seq1, seq2)) =>
-      i match {
-        case 0 =>
-          assert(seq1.toSet === Set[Int](0))
-          assert(seq2.toSet === Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
-        case 1 =>
-          assert(seq1.toSet === Set[Int](1))
-          assert(seq2.toSet === Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
-        case 5000 =>
-          assert(seq1.toSet === Set[Int](5000))
-          assert(seq2.toSet === Set[Int]())
-        case 9999 =>
-          assert(seq1.toSet === Set[Int](9999))
-          assert(seq2.toSet === Set[Int]())
-        case _ =>
+    assertSpilled(sc, "cogroup") {
+      val rdd1 = sc.parallelize(0 until size).map { i => (i / 2, i) }
+      val rdd2 = sc.parallelize(0 until size).map { i => (i / 2, i) }
+      val result = rdd1.cogroup(rdd2).collect()
+      assert(result.length === size / 2)
+      result.foreach { case (i, (seq1, seq2)) =>
+        val actual1 = seq1.toSet
+        val actual2 = seq2.toSet
+        val expected = Set(i * 2, i * 2 + 1)
+        assert(actual1 === expected, s"Value 1 for $i was wrong: expected $expected, got $actual1")
+        assert(actual2 === expected, s"Value 2 for $i was wrong: expected $expected, got $actual2")
       }
     }
+
     sc.stop()
   }
 
   test("spilling with hash collisions") {
+    val size = 1000
     val conf = createSparkConf(loadDefaults = true)
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
     val map = createExternalMap[String]
 
@@ -315,11 +315,12 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
       assert(w1.hashCode === w2.hashCode)
     }
 
-    map.insertAll((1 to 100000).iterator.map(_.toString).map(i => (i, i)))
+    map.insertAll((1 to size).iterator.map(_.toString).map(i => (i, i)))
     collisionPairs.foreach { case (w1, w2) =>
       map.insert(w1, w2)
       map.insert(w2, w1)
     }
+    assert(map.numSpills > 0, "map did not spill")
 
     // A map of collision pairs in both directions
     val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap
@@ -334,22 +335,25 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
       assert(kv._2.equals(expectedValue))
       count += 1
     }
-    assert(count === 100000 + collisionPairs.size * 2)
+    assert(count === size + collisionPairs.size * 2)
     sc.stop()
   }
 
   test("spilling with many hash collisions") {
+    val size = 1000
     val conf = createSparkConf(loadDefaults = true)
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
     val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
 
     // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
     // problems if the map fails to group together the objects with the same code (SPARK-2043).
     for (i <- 1 to 10) {
-      for (j <- 1 to 10000) {
+      for (j <- 1 to size) {
         map.insert(FixedHashObject(j, j % 2), 1)
       }
     }
+    assert(map.numSpills > 0, "map did not spill")
 
     val it = map.iterator
     var count = 0
@@ -358,17 +362,20 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
       assert(kv._2 === 10)
       count += 1
     }
-    assert(count === 10000)
+    assert(count === size)
     sc.stop()
   }
 
   test("spilling with hash collisions using the Int.MaxValue key") {
+    val size = 1000
     val conf = createSparkConf(loadDefaults = true)
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
     val map = createExternalMap[Int]
 
-    (1 to 100000).foreach { i => map.insert(i, i) }
+    (1 to size).foreach { i => map.insert(i, i) }
     map.insert(Int.MaxValue, Int.MaxValue)
+    assert(map.numSpills > 0, "map did not spill")
 
     val it = map.iterator
     while (it.hasNext) {
@@ -379,14 +386,17 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
   }
 
   test("spilling with null keys and values") {
+    val size = 1000
     val conf = createSparkConf(loadDefaults = true)
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
     val map = createExternalMap[Int]
 
-    map.insertAll((1 to 100000).iterator.map(i => (i, i)))
+    map.insertAll((1 to size).iterator.map(i => (i, i)))
     map.insert(null.asInstanceOf[Int], 1)
     map.insert(1, null.asInstanceOf[Int])
     map.insert(null.asInstanceOf[Int], null.asInstanceOf[Int])
+    assert(map.numSpills > 0, "map did not spill")
 
     val it = map.iterator
     while (it.hasNext) {
@@ -397,17 +407,22 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
   }
 
   test("external aggregation updates peak execution memory") {
+    val spillThreshold = 1000
     val conf = createSparkConf(loadDefaults = false)
       .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter
-      .set("spark.testing.memory", (10 * 1024 * 1024).toString)
+      .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString)
     sc = new SparkContext("local", "test", conf)
     // No spilling
     AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map without spilling") {
-      sc.parallelize(1 to 10, 2).map { i => (i, i) }.reduceByKey(_ + _).count()
+      assertNotSpilled(sc, "verify peak memory") {
+        sc.parallelize(1 to spillThreshold / 2, 2).map { i => (i, i) }.reduceByKey(_ + _).count()
+      }
     }
     // With spilling
     AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external map with spilling") {
-      sc.parallelize(1 to 1000 * 1000, 2).map { i => (i, i) }.reduceByKey(_ + _).count()
+      assertSpilled(sc, "verify peak memory") {
+        sc.parallelize(1 to spillThreshold * 3, 2).map { i => (i, i) }.reduceByKey(_ + _).count()
+      }
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3b364ff0/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 651c7ea..e2cb791 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -18,535 +18,92 @@
 package org.apache.spark.util.collection
 
 import scala.collection.mutable.ArrayBuffer
-
 import scala.util.Random
 
 import org.apache.spark._
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
 
-// TODO: some of these spilling tests probably aren't actually spilling (SPARK-11078)
 
 class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
-  private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = {
-    val conf = new SparkConf(loadDefaults)
-    if (kryo) {
-      conf.set("spark.serializer", classOf[KryoSerializer].getName)
-    } else {
-      // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
-      // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
-      conf.set("spark.serializer.objectStreamReset", "1")
-      conf.set("spark.serializer", classOf[JavaSerializer].getName)
-    }
-    conf.set("spark.shuffle.sort.bypassMergeThreshold", "0")
-    // Ensure that we actually have multiple batches per spill file
-    conf.set("spark.shuffle.spill.batchSize", "10")
-    conf.set("spark.testing.memory", "2000000")
-    conf
-  }
-
-  test("empty data stream with kryo ser") {
-    emptyDataStream(createSparkConf(false, true))
-  }
-
-  test("empty data stream with java ser") {
-    emptyDataStream(createSparkConf(false, false))
-  }
-
-  def emptyDataStream(conf: SparkConf) {
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
-    val ord = implicitly[Ordering[Int]]
-
-    // Both aggregator and ordering
-    val sorter = new ExternalSorter[Int, Int, Int](
-      Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
-    assert(sorter.iterator.toSeq === Seq())
-    sorter.stop()
-
-    // Only aggregator
-    val sorter2 = new ExternalSorter[Int, Int, Int](
-      Some(agg), Some(new HashPartitioner(3)), None, None)
-    assert(sorter2.iterator.toSeq === Seq())
-    sorter2.stop()
-
-    // Only ordering
-    val sorter3 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(3)), Some(ord), None)
-    assert(sorter3.iterator.toSeq === Seq())
-    sorter3.stop()
-
-    // Neither aggregator nor ordering
-    val sorter4 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(3)), None, None)
-    assert(sorter4.iterator.toSeq === Seq())
-    sorter4.stop()
-  }
+  import TestUtils.{assertNotSpilled, assertSpilled}
 
-  test("few elements per partition with kryo ser") {
-    fewElementsPerPartition(createSparkConf(false, true))
-  }
+  testWithMultipleSer("empty data stream")(emptyDataStream)
 
-  test("few elements per partition with java ser") {
-    fewElementsPerPartition(createSparkConf(false, false))
-  }
+  testWithMultipleSer("few elements per partition")(fewElementsPerPartition)
 
-  def fewElementsPerPartition(conf: SparkConf) {
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
-    val ord = implicitly[Ordering[Int]]
-    val elements = Set((1, 1), (2, 2), (5, 5))
-    val expected = Set(
-      (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()),
-      (5, Set((5, 5))), (6, Set()))
-
-    // Both aggregator and ordering
-    val sorter = new ExternalSorter[Int, Int, Int](
-      Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
-    sorter.insertAll(elements.iterator)
-    assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
-    sorter.stop()
-
-    // Only aggregator
-    val sorter2 = new ExternalSorter[Int, Int, Int](
-      Some(agg), Some(new HashPartitioner(7)), None, None)
-    sorter2.insertAll(elements.iterator)
-    assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
-    sorter2.stop()
+  testWithMultipleSer("empty partitions with spilling")(emptyPartitionsWithSpilling)
 
-    // Only ordering
-    val sorter3 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(7)), Some(ord), None)
-    sorter3.insertAll(elements.iterator)
-    assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
-    sorter3.stop()
-
-    // Neither aggregator nor ordering
-    val sorter4 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(7)), None, None)
-    sorter4.insertAll(elements.iterator)
-    assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
-    sorter4.stop()
-  }
-
-  test("empty partitions with spilling with kryo ser") {
-    emptyPartitionsWithSpilling(createSparkConf(false, true))
+  // Load defaults, otherwise SPARK_HOME is not found
+  testWithMultipleSer("spilling in local cluster", loadDefaults = true) {
+    (conf: SparkConf) => testSpillingInLocalCluster(conf, 2)
   }
 
-  test("empty partitions with spilling with java ser") {
-    emptyPartitionsWithSpilling(createSparkConf(false, false))
-  }
-
-  def emptyPartitionsWithSpilling(conf: SparkConf) {
-    conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val ord = implicitly[Ordering[Int]]
-    val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
-
-    val sorter = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(7)), Some(ord), None)
-    sorter.insertAll(elements)
-    assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
-    val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
-    assert(iter.next() === (0, Nil))
-    assert(iter.next() === (1, List((1, 1))))
-    assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList))
-    assert(iter.next() === (3, Nil))
-    assert(iter.next() === (4, Nil))
-    assert(iter.next() === (5, List((5, 5))))
-    assert(iter.next() === (6, Nil))
-    sorter.stop()
-  }
-
-  test("spilling in local cluster with kryo ser") {
-    // Load defaults, otherwise SPARK_HOME is not found
-    testSpillingInLocalCluster(createSparkConf(true, true))
-  }
-
-  test("spilling in local cluster with java ser") {
-    // Load defaults, otherwise SPARK_HOME is not found
-    testSpillingInLocalCluster(createSparkConf(true, false))
-  }
-
-  def testSpillingInLocalCluster(conf: SparkConf) {
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
-
-    // reduceByKey - should spill ~8 times
-    val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
-    val resultA = rddA.reduceByKey(math.max).collect()
-    assert(resultA.length == 50000)
-    resultA.foreach { case(k, v) =>
-      if (v != k * 2 + 1) {
-        fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
-      }
-    }
-
-    // groupByKey - should spill ~17 times
-    val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
-    val resultB = rddB.groupByKey().collect()
-    assert(resultB.length == 25000)
-    resultB.foreach { case(i, seq) =>
-      val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
-      if (seq.toSet != expected) {
-        fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
-      }
-    }
-
-    // cogroup - should spill ~7 times
-    val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
-    val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
-    val resultC = rddC1.cogroup(rddC2).collect()
-    assert(resultC.length == 10000)
-    resultC.foreach { case(i, (seq1, seq2)) =>
-      i match {
-        case 0 =>
-          assert(seq1.toSet == Set[Int](0))
-          assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
-        case 1 =>
-          assert(seq1.toSet == Set[Int](1))
-          assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
-        case 5000 =>
-          assert(seq1.toSet == Set[Int](5000))
-          assert(seq2.toSet == Set[Int]())
-        case 9999 =>
-          assert(seq1.toSet == Set[Int](9999))
-          assert(seq2.toSet == Set[Int]())
-        case _ =>
-      }
-    }
-
-    // larger cogroup - should spill ~7 times
-    val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i))
-    val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i))
-    val resultD = rddD1.cogroup(rddD2).collect()
-    assert(resultD.length == 5000)
-    resultD.foreach { case(i, (seq1, seq2)) =>
-      val expected = Set(i * 2, i * 2 + 1)
-      if (seq1.toSet != expected) {
-        fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}")
-      }
-      if (seq2.toSet != expected) {
-        fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
-      }
-    }
-
-    // sortByKey - should spill ~17 times
-    val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i))
-    val resultE = rddE.sortByKey().collect().toSeq
-    assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq)
-  }
-
-  test("spilling in local cluster with many reduce tasks with kryo ser") {
-    spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, true))
-  }
-
-  test("spilling in local cluster with many reduce tasks with java ser") {
-    spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, false))
-  }
-
-  def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) {
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
-
-    // reduceByKey - should spill ~4 times per executor
-    val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
-    val resultA = rddA.reduceByKey(math.max _, 100).collect()
-    assert(resultA.length == 50000)
-    resultA.foreach { case(k, v) =>
-      if (v != k * 2 + 1) {
-        fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
-      }
-    }
-
-    // groupByKey - should spill ~8 times per executor
-    val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
-    val resultB = rddB.groupByKey(100).collect()
-    assert(resultB.length == 25000)
-    resultB.foreach { case(i, seq) =>
-      val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
-      if (seq.toSet != expected) {
-        fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
-      }
-    }
-
-    // cogroup - should spill ~4 times per executor
-    val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
-    val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
-    val resultC = rddC1.cogroup(rddC2, 100).collect()
-    assert(resultC.length == 10000)
-    resultC.foreach { case(i, (seq1, seq2)) =>
-      i match {
-        case 0 =>
-          assert(seq1.toSet == Set[Int](0))
-          assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
-        case 1 =>
-          assert(seq1.toSet == Set[Int](1))
-          assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
-        case 5000 =>
-          assert(seq1.toSet == Set[Int](5000))
-          assert(seq2.toSet == Set[Int]())
-        case 9999 =>
-          assert(seq1.toSet == Set[Int](9999))
-          assert(seq2.toSet == Set[Int]())
-        case _ =>
-      }
-    }
-
-    // larger cogroup - should spill ~4 times per executor
-    val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i))
-    val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i))
-    val resultD = rddD1.cogroup(rddD2).collect()
-    assert(resultD.length == 5000)
-    resultD.foreach { case(i, (seq1, seq2)) =>
-      val expected = Set(i * 2, i * 2 + 1)
-      if (seq1.toSet != expected) {
-        fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}")
-      }
-      if (seq2.toSet != expected) {
-        fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
-      }
-    }
-
-    // sortByKey - should spill ~8 times per executor
-    val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i))
-    val resultE = rddE.sortByKey().collect().toSeq
-    assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq)
+  testWithMultipleSer("spilling in local cluster with many reduce tasks", loadDefaults = true) {
+    (conf: SparkConf) => testSpillingInLocalCluster(conf, 100)
   }
 
   test("cleanup of intermediate files in sorter") {
-    val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
-    val ord = implicitly[Ordering[Int]]
-
-    val sorter = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(3)), Some(ord), None)
-    sorter.insertAll((0 until 120000).iterator.map(i => (i, i)))
-    assert(diskBlockManager.getAllFiles().length > 0)
-    sorter.stop()
-    assert(diskBlockManager.getAllBlocks().length === 0)
-
-    val sorter2 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(3)), Some(ord), None)
-    sorter2.insertAll((0 until 120000).iterator.map(i => (i, i)))
-    assert(diskBlockManager.getAllFiles().length > 0)
-    assert(sorter2.iterator.toSet === (0 until 120000).map(i => (i, i)).toSet)
-    sorter2.stop()
-    assert(diskBlockManager.getAllBlocks().length === 0)
+    cleanupIntermediateFilesInSorter(withFailures = false)
   }
 
-  test("cleanup of intermediate files in sorter if there are errors") {
-    val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
-    val ord = implicitly[Ordering[Int]]
-
-    val sorter = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(3)), Some(ord), None)
-    intercept[SparkException] {
-      sorter.insertAll((0 until 120000).iterator.map(i => {
-        if (i == 119990) {
-          throw new SparkException("Intentional failure")
-        }
-        (i, i)
-      }))
-    }
-    assert(diskBlockManager.getAllFiles().length > 0)
-    sorter.stop()
-    assert(diskBlockManager.getAllBlocks().length === 0)
+  test("cleanup of intermediate files in sorter with failures") {
+    cleanupIntermediateFilesInSorter(withFailures = true)
   }
 
   test("cleanup of intermediate files in shuffle") {
-    val conf = createSparkConf(false, false)
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
-    val data = sc.parallelize(0 until 100000, 2).map(i => (i, i))
-    assert(data.reduceByKey(_ + _).count() === 100000)
-
-    // After the shuffle, there should be only 4 files on disk: our two map output files and
-    // their index files. All other intermediate files should've been deleted.
-    assert(diskBlockManager.getAllFiles().length === 4)
-  }
-
-  test("cleanup of intermediate files in shuffle with errors") {
-    val conf = createSparkConf(false, false)
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
-    val data = sc.parallelize(0 until 100000, 2).map(i => {
-      if (i == 99990) {
-        throw new Exception("Intentional failure")
-      }
-      (i, i)
-    })
-    intercept[SparkException] {
-      data.reduceByKey(_ + _).count()
-    }
-
-    // After the shuffle, there should be only 2 files on disk: the output of task 1 and its index.
-    // All other files (map 2's output and intermediate merge files) should've been deleted.
-    assert(diskBlockManager.getAllFiles().length === 2)
-  }
-
-  test("no partial aggregation or sorting with kryo ser") {
-    noPartialAggregationOrSorting(createSparkConf(false, true))
-  }
-
-  test("no partial aggregation or sorting with java ser") {
-    noPartialAggregationOrSorting(createSparkConf(false, false))
-  }
-
-  def noPartialAggregationOrSorting(conf: SparkConf) {
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
-    sorter.insertAll((0 until 100000).iterator.map(i => (i / 4, i)))
-    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
-    val expected = (0 until 3).map(p => {
-      (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet)
-    }).toSet
-    assert(results === expected)
-  }
-
-  test("partial aggregation without spill with kryo ser") {
-    partialAggregationWithoutSpill(createSparkConf(false, true))
-  }
-
-  test("partial aggregation without spill with java ser") {
-    partialAggregationWithoutSpill(createSparkConf(false, false))
-  }
-
-  def partialAggregationWithoutSpill(conf: SparkConf) {
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
-    val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
-    sorter.insertAll((0 until 100).iterator.map(i => (i / 2, i)))
-    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
-    val expected = (0 until 3).map(p => {
-      (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
-    }).toSet
-    assert(results === expected)
+    cleanupIntermediateFilesInShuffle(withFailures = false)
   }
 
-  test("partial aggregation with spill, no ordering with kryo ser") {
-    partialAggregationWIthSpillNoOrdering(createSparkConf(false, true))
+  test("cleanup of intermediate files in shuffle with failures") {
+    cleanupIntermediateFilesInShuffle(withFailures = true)
   }
 
-  test("partial aggregation with spill, no ordering with java ser") {
-    partialAggregationWIthSpillNoOrdering(createSparkConf(false, false))
+  testWithMultipleSer("no sorting or partial aggregation") { (conf: SparkConf) =>
+    basicSorterTest(conf, withPartialAgg = false, withOrdering = false, withSpilling = false)
   }
 
-  def partialAggregationWIthSpillNoOrdering(conf: SparkConf) {
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
-    val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
-    sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i)))
-    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
-    val expected = (0 until 3).map(p => {
-      (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
-    }).toSet
-    assert(results === expected)
+  testWithMultipleSer("no sorting or partial aggregation with spilling") { (conf: SparkConf) =>
+    basicSorterTest(conf, withPartialAgg = false, withOrdering = false, withSpilling = true)
   }
 
-  test("partial aggregation with spill, with ordering with kryo ser") {
-    partialAggregationWithSpillWithOrdering(createSparkConf(false, true))
+  testWithMultipleSer("sorting, no partial aggregation") { (conf: SparkConf) =>
+    basicSorterTest(conf, withPartialAgg = false, withOrdering = true, withSpilling = false)
   }
 
-
-  test("partial aggregation with spill, with ordering with java ser") {
-    partialAggregationWithSpillWithOrdering(createSparkConf(false, false))
+  testWithMultipleSer("sorting, no partial aggregation with spilling") { (conf: SparkConf) =>
+    basicSorterTest(conf, withPartialAgg = false, withOrdering = true, withSpilling = true)
   }
 
-  def partialAggregationWithSpillWithOrdering(conf: SparkConf) {
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
-    val ord = implicitly[Ordering[Int]]
-    val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
-
-    // avoid combine before spill
-    sorter.insertAll((0 until 50000).iterator.map(i => (i , 2 * i)))
-    sorter.insertAll((0 until 50000).iterator.map(i => (i, 2 * i + 1)))
-    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
-    val expected = (0 until 3).map(p => {
-      (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
-    }).toSet
-    assert(results === expected)
+  testWithMultipleSer("partial aggregation, no sorting") { (conf: SparkConf) =>
+    basicSorterTest(conf, withPartialAgg = true, withOrdering = false, withSpilling = false)
   }
 
-  test("sorting without aggregation, no spill with kryo ser") {
-    sortingWithoutAggregationNoSpill(createSparkConf(false, true))
+  testWithMultipleSer("partial aggregation, no sorting with spilling") { (conf: SparkConf) =>
+    basicSorterTest(conf, withPartialAgg = true, withOrdering = false, withSpilling = true)
   }
 
-  test("sorting without aggregation, no spill with java ser") {
-    sortingWithoutAggregationNoSpill(createSparkConf(false, false))
+  testWithMultipleSer("partial aggregation and sorting") { (conf: SparkConf) =>
+    basicSorterTest(conf, withPartialAgg = true, withOrdering = true, withSpilling = false)
   }
 
-  def sortingWithoutAggregationNoSpill(conf: SparkConf) {
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val ord = implicitly[Ordering[Int]]
-    val sorter = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(3)), Some(ord), None)
-    sorter.insertAll((0 until 100).iterator.map(i => (i, i)))
-    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
-    val expected = (0 until 3).map(p => {
-      (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
-    }).toSeq
-    assert(results === expected)
-  }
-
-  test("sorting without aggregation, with spill with kryo ser") {
-    sortingWithoutAggregationWithSpill(createSparkConf(false, true))
-  }
-
-  test("sorting without aggregation, with spill with java ser") {
-    sortingWithoutAggregationWithSpill(createSparkConf(false, false))
+  testWithMultipleSer("partial aggregation and sorting with spilling") { (conf: SparkConf) =>
+    basicSorterTest(conf, withPartialAgg = true, withOrdering = true, withSpilling = true)
   }
 
-  def sortingWithoutAggregationWithSpill(conf: SparkConf) {
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val ord = implicitly[Ordering[Int]]
-    val sorter = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(3)), Some(ord), None)
-    sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
-    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
-    val expected = (0 until 3).map(p => {
-      (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
-    }).toSeq
-    assert(results === expected)
-  }
+  testWithMultipleSer("sort without breaking sorting contracts", loadDefaults = true)(
+    sortWithoutBreakingSortingContracts)
 
   test("spilling with hash collisions") {
-    val conf = createSparkConf(true, false)
+    val size = 1000
+    val conf = createSparkConf(loadDefaults = true, kryo = false)
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
 
     def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i)
     def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i
-    def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String])
-      : ArrayBuffer[String] = buffer1 ++= buffer2
+    def mergeCombiners(
+        buffer1: ArrayBuffer[String],
+        buffer2: ArrayBuffer[String]): ArrayBuffer[String] = buffer1 ++= buffer2
 
     val agg = new Aggregator[String, String, ArrayBuffer[String]](
       createCombiner _, mergeValue _, mergeCombiners _)
@@ -574,10 +131,11 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
       assert(w1.hashCode === w2.hashCode)
     }
 
-    val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++
+    val toInsert = (1 to size).iterator.map(_.toString).map(s => (s, s)) ++
       collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap)
 
     sorter.insertAll(toInsert)
+    assert(sorter.numSpills > 0, "sorter did not spill")
 
     // A map of collision pairs in both directions
     val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap
@@ -592,21 +150,21 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
       assert(kv._2.equals(expectedValue))
       count += 1
     }
-    assert(count === 100000 + collisionPairs.size * 2)
+    assert(count === size + collisionPairs.size * 2)
   }
 
   test("spilling with many hash collisions") {
-    val conf = createSparkConf(true, false)
+    val size = 1000
+    val conf = createSparkConf(loadDefaults = true, kryo = false)
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
-
     val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
     val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None)
-
     // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
     // problems if the map fails to group together the objects with the same code (SPARK-2043).
-    val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1)
+    val toInsert = for (i <- 1 to 10; j <- 1 to size) yield (FixedHashObject(j, j % 2), 1)
     sorter.insertAll(toInsert.iterator)
-
+    assert(sorter.numSpills > 0, "sorter did not spill")
     val it = sorter.iterator
     var count = 0
     while (it.hasNext) {
@@ -614,11 +172,13 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
       assert(kv._2 === 10)
       count += 1
     }
-    assert(count === 10000)
+    assert(count === size)
   }
 
   test("spilling with hash collisions using the Int.MaxValue key") {
-    val conf = createSparkConf(true, false)
+    val size = 1000
+    val conf = createSparkConf(loadDefaults = true, kryo = false)
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
 
     def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i)
@@ -629,10 +189,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
 
     val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners)
     val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None)
-
     sorter.insertAll(
-      (1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
-
+      (1 to size).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
+    assert(sorter.numSpills > 0, "sorter did not spill")
     val it = sorter.iterator
     while (it.hasNext) {
       // Should not throw NoSuchElementException
@@ -641,7 +200,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
   }
 
   test("spilling with null keys and values") {
-    val conf = createSparkConf(true, false)
+    val size = 1000
+    val conf = createSparkConf(loadDefaults = true, kryo = false)
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
 
     def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i)
@@ -655,12 +216,12 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
       Some(agg), None, None, None)
 
-    sorter.insertAll((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
+    sorter.insertAll((1 to size).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
       (null.asInstanceOf[String], "1"),
       ("1", null.asInstanceOf[String]),
       (null.asInstanceOf[String], null.asInstanceOf[String])
     ))
-
+    assert(sorter.numSpills > 0, "sorter did not spill")
     val it = sorter.iterator
     while (it.hasNext) {
       // Should not throw NullPointerException
@@ -668,16 +229,301 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     }
   }
 
-  test("sort without breaking sorting contracts with kryo ser") {
-    sortWithoutBreakingSortingContracts(createSparkConf(true, true))
+  /* ============================= *
+   |  Helper test utility methods  |
+   * ============================= */
+
+  private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = {
+    val conf = new SparkConf(loadDefaults)
+    if (kryo) {
+      conf.set("spark.serializer", classOf[KryoSerializer].getName)
+    } else {
+      // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
+      // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
+      conf.set("spark.serializer.objectStreamReset", "1")
+      conf.set("spark.serializer", classOf[JavaSerializer].getName)
+    }
+    conf.set("spark.shuffle.sort.bypassMergeThreshold", "0")
+    // Ensure that we actually have multiple batches per spill file
+    conf.set("spark.shuffle.spill.batchSize", "10")
+    conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
+    conf
+  }
+
+  /**
+   * Run a test multiple times, each time with a different serializer.
+   */
+  private def testWithMultipleSer(
+      name: String,
+      loadDefaults: Boolean = false)(body: (SparkConf => Unit)): Unit = {
+    test(name + " with kryo ser") {
+      body(createSparkConf(loadDefaults, kryo = true))
+    }
+    test(name + " with java ser") {
+      body(createSparkConf(loadDefaults, kryo = false))
+    }
   }
 
-  test("sort without breaking sorting contracts with java ser") {
-    sortWithoutBreakingSortingContracts(createSparkConf(true, false))
+  /* =========================================== *
+   |  Helper methods that contain the test body  |
+   * =========================================== */
+
+  private def emptyDataStream(conf: SparkConf) {
+    conf.set("spark.shuffle.manager", "sort")
+    sc = new SparkContext("local", "test", conf)
+
+    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+    val ord = implicitly[Ordering[Int]]
+
+    // Both aggregator and ordering
+    val sorter = new ExternalSorter[Int, Int, Int](
+      Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
+    assert(sorter.iterator.toSeq === Seq())
+    sorter.stop()
+
+    // Only aggregator
+    val sorter2 = new ExternalSorter[Int, Int, Int](
+      Some(agg), Some(new HashPartitioner(3)), None, None)
+    assert(sorter2.iterator.toSeq === Seq())
+    sorter2.stop()
+
+    // Only ordering
+    val sorter3 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), Some(ord), None)
+    assert(sorter3.iterator.toSeq === Seq())
+    sorter3.stop()
+
+    // Neither aggregator nor ordering
+    val sorter4 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), None, None)
+    assert(sorter4.iterator.toSeq === Seq())
+    sorter4.stop()
+  }
+
+  private def fewElementsPerPartition(conf: SparkConf) {
+    conf.set("spark.shuffle.manager", "sort")
+    sc = new SparkContext("local", "test", conf)
+
+    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+    val ord = implicitly[Ordering[Int]]
+    val elements = Set((1, 1), (2, 2), (5, 5))
+    val expected = Set(
+      (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()),
+      (5, Set((5, 5))), (6, Set()))
+
+    // Both aggregator and ordering
+    val sorter = new ExternalSorter[Int, Int, Int](
+      Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
+    sorter.insertAll(elements.iterator)
+    assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+    sorter.stop()
+
+    // Only aggregator
+    val sorter2 = new ExternalSorter[Int, Int, Int](
+      Some(agg), Some(new HashPartitioner(7)), None, None)
+    sorter2.insertAll(elements.iterator)
+    assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+    sorter2.stop()
+
+    // Only ordering
+    val sorter3 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(7)), Some(ord), None)
+    sorter3.insertAll(elements.iterator)
+    assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+    sorter3.stop()
+
+    // Neither aggregator nor ordering
+    val sorter4 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(7)), None, None)
+    sorter4.insertAll(elements.iterator)
+    assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+    sorter4.stop()
+  }
+
+  private def emptyPartitionsWithSpilling(conf: SparkConf) {
+    val size = 1000
+    conf.set("spark.shuffle.manager", "sort")
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
+    sc = new SparkContext("local", "test", conf)
+
+    val ord = implicitly[Ordering[Int]]
+    val elements = Iterator((1, 1), (5, 5)) ++ (0 until size).iterator.map(x => (2, 2))
+
+    val sorter = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(7)), Some(ord), None)
+    sorter.insertAll(elements)
+    assert(sorter.numSpills > 0, "sorter did not spill")
+    val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
+    assert(iter.next() === (0, Nil))
+    assert(iter.next() === (1, List((1, 1))))
+    assert(iter.next() === (2, (0 until 1000).map(x => (2, 2)).toList))
+    assert(iter.next() === (3, Nil))
+    assert(iter.next() === (4, Nil))
+    assert(iter.next() === (5, List((5, 5))))
+    assert(iter.next() === (6, Nil))
+    sorter.stop()
+  }
+
+  private def testSpillingInLocalCluster(conf: SparkConf, numReduceTasks: Int) {
+    val size = 5000
+    conf.set("spark.shuffle.manager", "sort")
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString)
+    sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
+
+    assertSpilled(sc, "reduceByKey") {
+      val result = sc.parallelize(0 until size)
+        .map { i => (i / 2, i) }
+        .reduceByKey(math.max _, numReduceTasks)
+        .collect()
+      assert(result.length === size / 2)
+      result.foreach { case (k, v) =>
+        val expected = k * 2 + 1
+        assert(v === expected, s"Value for $k was wrong: expected $expected, got $v")
+      }
+    }
+
+    assertSpilled(sc, "groupByKey") {
+      val result = sc.parallelize(0 until size)
+        .map { i => (i / 2, i) }
+        .groupByKey(numReduceTasks)
+        .collect()
+      assert(result.length == size / 2)
+      result.foreach { case (i, seq) =>
+        val actual = seq.toSet
+        val expected = Set(i * 2, i * 2 + 1)
+        assert(actual === expected, s"Value for $i was wrong: expected $expected, got $actual")
+      }
+    }
+
+    assertSpilled(sc, "cogroup") {
+      val rdd1 = sc.parallelize(0 until size).map { i => (i / 2, i) }
+      val rdd2 = sc.parallelize(0 until size).map { i => (i / 2, i) }
+      val result = rdd1.cogroup(rdd2, numReduceTasks).collect()
+      assert(result.length === size / 2)
+      result.foreach { case (i, (seq1, seq2)) =>
+        val actual1 = seq1.toSet
+        val actual2 = seq2.toSet
+        val expected = Set(i * 2, i * 2 + 1)
+        assert(actual1 === expected, s"Value 1 for $i was wrong: expected $expected, got $actual1")
+        assert(actual2 === expected, s"Value 2 for $i was wrong: expected $expected, got $actual2")
+      }
+    }
+
+    assertSpilled(sc, "sortByKey") {
+      val result = sc.parallelize(0 until size)
+        .map { i => (i / 2, i) }
+        .sortByKey(numPartitions = numReduceTasks)
+        .collect()
+      val expected = (0 until size).map { i => (i / 2, i) }.toArray
+      assert(result.length === size)
+      result.zipWithIndex.foreach { case ((k, _), i) =>
+        val (expectedKey, _) = expected(i)
+        assert(k === expectedKey, s"Value for $i was wrong: expected $expectedKey, got $k")
+      }
+    }
+  }
+
+  private def cleanupIntermediateFilesInSorter(withFailures: Boolean): Unit = {
+    val size = 1200
+    val conf = createSparkConf(loadDefaults = false, kryo = false)
+    conf.set("spark.shuffle.manager", "sort")
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString)
+    sc = new SparkContext("local", "test", conf)
+    val diskBlockManager = sc.env.blockManager.diskBlockManager
+    val ord = implicitly[Ordering[Int]]
+    val expectedSize = if (withFailures) size - 1 else size
+    val sorter = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), Some(ord), None)
+    if (withFailures) {
+      intercept[SparkException] {
+        sorter.insertAll((0 until size).iterator.map { i =>
+          if (i == size - 1) { throw new SparkException("intentional failure") }
+          (i, i)
+        })
+      }
+    } else {
+      sorter.insertAll((0 until size).iterator.map(i => (i, i)))
+    }
+    assert(sorter.iterator.toSet === (0 until expectedSize).map(i => (i, i)).toSet)
+    assert(sorter.numSpills > 0, "sorter did not spill")
+    assert(diskBlockManager.getAllFiles().nonEmpty, "sorter did not spill")
+    sorter.stop()
+    assert(diskBlockManager.getAllFiles().isEmpty, "spilled files were not cleaned up")
+  }
+
+  private def cleanupIntermediateFilesInShuffle(withFailures: Boolean): Unit = {
+    val size = 1200
+    val conf = createSparkConf(loadDefaults = false, kryo = false)
+    conf.set("spark.shuffle.manager", "sort")
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString)
+    sc = new SparkContext("local", "test", conf)
+    val diskBlockManager = sc.env.blockManager.diskBlockManager
+    val data = sc.parallelize(0 until size, 2).map { i =>
+      if (withFailures && i == size - 1) {
+        throw new SparkException("intentional failure")
+      }
+      (i, i)
+    }
+
+    assertSpilled(sc, "test shuffle cleanup") {
+      if (withFailures) {
+        intercept[SparkException] {
+          data.reduceByKey(_ + _).count()
+        }
+        // After the shuffle, there should be only 2 files on disk: the output of task 1 and
+        // its index. All other files (map 2's output and intermediate merge files) should
+        // have been deleted.
+        assert(diskBlockManager.getAllFiles().length === 2)
+      } else {
+        assert(data.reduceByKey(_ + _).count() === size)
+        // After the shuffle, there should be only 4 files on disk: the output of both tasks
+        // and their indices. All intermediate merge files should have been deleted.
+        assert(diskBlockManager.getAllFiles().length === 4)
+      }
+    }
+  }
+
+  private def basicSorterTest(
+      conf: SparkConf,
+      withPartialAgg: Boolean,
+      withOrdering: Boolean,
+      withSpilling: Boolean) {
+    val size = 1000
+    if (withSpilling) {
+      conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
+    }
+    conf.set("spark.shuffle.manager", "sort")
+    sc = new SparkContext("local", "test", conf)
+    val agg =
+      if (withPartialAgg) {
+        Some(new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j))
+      } else {
+        None
+      }
+    val ord = if (withOrdering) Some(implicitly[Ordering[Int]]) else None
+    val sorter = new ExternalSorter[Int, Int, Int](agg, Some(new HashPartitioner(3)), ord, None)
+    sorter.insertAll((0 until size).iterator.map { i => (i / 4, i) })
+    if (withSpilling) {
+      assert(sorter.numSpills > 0, "sorter did not spill")
+    } else {
+      assert(sorter.numSpills === 0, "sorter spilled")
+    }
+    val results = sorter.partitionedIterator.map { case (p, vs) => (p, vs.toSet) }.toSet
+    val expected = (0 until 3).map { p =>
+      var v = (0 until size).map { i => (i / 4, i) }.filter { case (k, _) => k % 3 == p }.toSet
+      if (withPartialAgg) {
+        v = v.groupBy(_._1).mapValues { s => s.map(_._2).sum }.toSet
+      }
+      (p, v.toSet)
+    }.toSet
+    assert(results === expected)
   }
 
   private def sortWithoutBreakingSortingContracts(conf: SparkConf) {
+    val size = 100000
+    val conf = createSparkConf(loadDefaults = true, kryo = false)
     conf.set("spark.shuffle.manager", "sort")
+    conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
 
     // Using wrongOrdering to show integer overflow introduced exception.
@@ -690,17 +536,18 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
       }
     }
 
-    val testData = Array.tabulate(100000) { _ => rand.nextInt().toString }
+    val testData = Array.tabulate(size) { _ => rand.nextInt().toString }
 
     val sorter1 = new ExternalSorter[String, String, String](
       None, None, Some(wrongOrdering), None)
     val thrown = intercept[IllegalArgumentException] {
       sorter1.insertAll(testData.iterator.map(i => (i, i)))
+      assert(sorter1.numSpills > 0, "sorter did not spill")
       sorter1.iterator
     }
 
-    assert(thrown.getClass() === classOf[IllegalArgumentException])
-    assert(thrown.getMessage().contains("Comparison method violates its general contract"))
+    assert(thrown.getClass === classOf[IllegalArgumentException])
+    assert(thrown.getMessage.contains("Comparison method violates its general contract"))
     sorter1.stop()
 
     // Using aggregation and external spill to make sure ExternalSorter using
@@ -716,6 +563,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
     val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]](
       Some(agg), None, None, None)
     sorter2.insertAll(testData.iterator.map(i => (i, i)))
+    assert(sorter2.numSpills > 0, "sorter did not spill")
 
     // To validate the hash ordering of key
     var minKey = Int.MinValue
@@ -729,12 +577,23 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
   }
 
   test("sorting updates peak execution memory") {
+    val spillThreshold = 1000
     val conf = createSparkConf(loadDefaults = false, kryo = false)
       .set("spark.shuffle.manager", "sort")
+      .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString)
     sc = new SparkContext("local", "test", conf)
     // Avoid aggregating here to make sure we're not also using ExternalAppendOnlyMap
-    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter") {
-      sc.parallelize(1 to 1000, 2).repartition(100).count()
+    // No spilling
+    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter without spilling") {
+      assertNotSpilled(sc, "verify peak memory") {
+        sc.parallelize(1 to spillThreshold / 2, 2).repartition(100).count()
+      }
+    }
+    // With spilling
+    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "external sorter with spilling") {
+      assertSpilled(sc, "verify peak memory") {
+        sc.parallelize(1 to spillThreshold * 3, 2).repartition(100).count()
+      }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3b364ff0/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
index 835f52f..c4358f4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
@@ -68,6 +68,8 @@ private class GrantEverythingMemoryManager extends MemoryManager {
       blockId: BlockId,
       numBytes: Long,
       evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
+  override def releaseExecutionMemory(numBytes: Long): Unit = { }
+  override def releaseStorageMemory(numBytes: Long): Unit = { }
   override def maxExecutionMemory: Long = Long.MaxValue
   override def maxStorageMemory: Long = Long.MaxValue
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message