spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject [1/2] spark git commit: [SPARK-7081] Faster sort-based shuffle path using binary processing cache-aware sort
Date Thu, 14 May 2015 00:07:42 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.4 6c0644ae2 -> c53ebea9d


http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
new file mode 100644
index 0000000..f2bfef3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import java.util.Collections
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark._
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.sort.SortShuffleManager
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
+ */
+private[spark] class UnsafeShuffleHandle[K, V](
+    shuffleId: Int,
+    numMaps: Int,
+    dependency: ShuffleDependency[K, V, V])
+  extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
+
+private[spark] object UnsafeShuffleManager extends Logging {
+
+  /**
+   * The maximum number of shuffle output partitions that UnsafeShuffleManager supports.
+   */
+  val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
+
+  /**
+   * Helper method for determining whether a shuffle should use the optimized unsafe shuffle
+   * path or whether it should fall back to the original sort-based shuffle.
+   */
+  def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
+    val shufId = dependency.shuffleId
+    val serializer = Serializer.getSerializer(dependency.serializer)
+    if (!serializer.supportsRelocationOfSerializedObjects) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
+        s"${serializer.getClass.getName}, does not support object relocation")
+      false
+    } else if (dependency.aggregator.isDefined) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
+      false
+    } else if (dependency.keyOrdering.isDefined) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined")
+      false
+    } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " +
+        s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions")
+      false
+    } else {
+      log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
+      true
+    }
+  }
+}
+
+/**
+ * A shuffle implementation that uses directly-managed memory to implement several performance
+ * optimizations for certain types of shuffles. In cases where the new performance optimizations
+ * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those
+ * shuffles.
+ *
+ * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold:
+ *
+ *  - The shuffle dependency specifies no aggregation or output ordering.
+ *  - The shuffle serializer supports relocation of serialized values (this is currently supported
+ *    by KryoSerializer and Spark SQL's custom serializers).
+ *  - The shuffle produces fewer than 16777216 output partitions.
+ *  - No individual record is larger than 128 MB when serialized.
+ *
+ * In addition, extra spill-merging optimizations are automatically applied when the shuffle
+ * compression codec supports concatenation of serialized streams. This is currently supported by
+ * Spark's LZF serializer.
+ *
+ * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager.
+ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
+ * written to a single map output file. Reducers fetch contiguous regions of this file in order to
+ * read their portion of the map output. In cases where the map output data is too large to fit in
+ * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
+ * to produce the final output file.
+ *
+ * UnsafeShuffleManager optimizes this process in several ways:
+ *
+ *  - Its sort operates on serialized binary data rather than Java objects, which reduces memory
+ *    consumption and GC overheads. This optimization requires the record serializer to have certain
+ *    properties to allow serialized records to be re-ordered without requiring deserialization.
+ *    See SPARK-4550, where this optimization was first proposed and implemented, for more details.
+ *
+ *  - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts
+ *    arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
+ *    record in the sorting array, this fits more of the array into cache.
+ *
+ *  - The spill merging procedure operates on blocks of serialized records that belong to the same
+ *    partition and does not need to deserialize records during the merge.
+ *
+ *  - When the spill compression codec supports concatenation of compressed data, the spill merge
+ *    simply concatenates the serialized and compressed spill partitions to produce the final output
+ *    partition.  This allows efficient data copying methods, like NIO's `transferTo`, to be used
+ *    and avoids the need to allocate decompression or copying buffers during the merge.
+ *
+ * For more details on UnsafeShuffleManager's design, see SPARK-7081.
+ */
+private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
+
+  if (!conf.getBoolean("spark.shuffle.spill", true)) {
+    logWarning(
+      "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " +
+      "manager; its optimized shuffles will continue to spill to disk when necessary.")
+  }
+
+  private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
+  private[this] val shufflesThatFellBackToSortShuffle =
+    Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
+  private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()
+
+  /**
+   * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+   */
+  override def registerShuffle[K, V, C](
+      shuffleId: Int,
+      numMaps: Int,
+      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+    if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
+      new UnsafeShuffleHandle[K, V](
+        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+    } else {
+      new BaseShuffleHandle(shuffleId, numMaps, dependency)
+    }
+  }
+
+  /**
+   * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+   * Called on executors by reduce tasks.
+   */
+  override def getReader[K, C](
+      handle: ShuffleHandle,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext): ShuffleReader[K, C] = {
+    sortShuffleManager.getReader(handle, startPartition, endPartition, context)
+  }
+
+  /** Get a writer for a given partition. Called on executors by map tasks. */
+  override def getWriter[K, V](
+      handle: ShuffleHandle,
+      mapId: Int,
+      context: TaskContext): ShuffleWriter[K, V] = {
+    handle match {
+      case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] =>
+        numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
+        val env = SparkEnv.get
+        new UnsafeShuffleWriter(
+          env.blockManager,
+          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+          context.taskMemoryManager(),
+          env.shuffleMemoryManager,
+          unsafeShuffleHandle,
+          mapId,
+          context,
+          env.conf)
+      case other =>
+        shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
+        sortShuffleManager.getWriter(handle, mapId, context)
+    }
+  }
+
+  /** Remove a shuffle's metadata from the ShuffleManager. */
+  override def unregisterShuffle(shuffleId: Int): Boolean = {
+    if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
+      sortShuffleManager.unregisterShuffle(shuffleId)
+    } else {
+      Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
+        (0 until numMaps).foreach { mapId =>
+          shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
+        }
+      }
+      true
+    }
+  }
+
+  override val shuffleBlockResolver: IndexShuffleBlockResolver = {
+    sortShuffleManager.shuffleBlockResolver
+  }
+
+  /** Shut down this ShuffleManager. */
+  override def stop(): Unit = {
+    sortShuffleManager.stop()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 8bc4e20..a33f22e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -86,16 +86,6 @@ private[spark] class DiskBlockObjectWriter(
   extends BlockObjectWriter(blockId)
   with Logging
 {
-  /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
-  private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
-    override def write(i: Int): Unit = callWithTiming(out.write(i))
-    override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b))
-    override def write(b: Array[Byte], off: Int, len: Int): Unit = {
-      callWithTiming(out.write(b, off, len))
-    }
-    override def close(): Unit = out.close()
-    override def flush(): Unit = out.flush()
-  }
 
   /** The file channel, used for repositioning / truncating the file. */
   private var channel: FileChannel = null
@@ -136,7 +126,7 @@ private[spark] class DiskBlockObjectWriter(
       throw new IllegalStateException("Writer already closed. Cannot be reopened.")
     }
     fos = new FileOutputStream(file, true)
-    ts = new TimeTrackingOutputStream(fos)
+    ts = new TimeTrackingOutputStream(writeMetrics, fos)
     channel = fos.getChannel()
     bs = compressStream(new BufferedOutputStream(ts, bufferSize))
     objOut = serializerInstance.serializeStream(bs)
@@ -150,9 +140,9 @@ private[spark] class DiskBlockObjectWriter(
         if (syncWrites) {
           // Force outstanding writes to disk and track how long it takes
           objOut.flush()
-          callWithTiming {
-            fos.getFD.sync()
-          }
+          val start = System.nanoTime()
+          fos.getFD.sync()
+          writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
         }
       } {
         objOut.close()
@@ -251,12 +241,6 @@ private[spark] class DiskBlockObjectWriter(
     reportedPosition = pos
   }
 
-  private def callWithTiming(f: => Unit) = {
-    val start = System.nanoTime()
-    f
-    writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
-  }
-
   // For testing
   private[spark] override def flush() {
     objOut.flush()

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index b850973..df2d6ad 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -90,7 +90,7 @@ class ExternalAppendOnlyMap[K, V, C](
   // Number of bytes spilled in total
   private var _diskBytesSpilled = 0L
   
-  // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+  // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val fileBufferSize = 
     sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 7d5cf7b..3b9d14f 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -110,7 +110,7 @@ private[spark] class ExternalSorter[K, V, C](
   private val conf = SparkEnv.get.conf
   private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
   
-  // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+  // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
   private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
new file mode 100644
index 0000000..db9e827
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
+
+public class PackedRecordPointerSuite {
+
+  @Test
+  public void heap() {
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock page0 = memoryManager.allocatePage(100);
+    final MemoryBlock page1 = memoryManager.allocatePage(100);
+    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+      page1.getBaseOffset() + 42);
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+    assertEquals(360, packedPointer.getPartitionId());
+    final long recordPointer = packedPointer.getRecordPointer();
+    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+    assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+    assertEquals(addressInPage1, recordPointer);
+    memoryManager.cleanUpAllAllocatedMemory();
+  }
+
+  @Test
+  public void offHeap() {
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+    final MemoryBlock page0 = memoryManager.allocatePage(100);
+    final MemoryBlock page1 = memoryManager.allocatePage(100);
+    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+      page1.getBaseOffset() + 42);
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+    assertEquals(360, packedPointer.getPartitionId());
+    final long recordPointer = packedPointer.getRecordPointer();
+    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+    assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+    assertEquals(addressInPage1, recordPointer);
+    memoryManager.cleanUpAllAllocatedMemory();
+  }
+
+  @Test
+  public void maximumPartitionIdCanBeEncoded() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID));
+    assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
+  }
+
+  @Test
+  public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    try {
+      // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
+      packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1));
+      assertFalse(MAXIMUM_PARTITION_ID  + 1 == packedPointer.getPartitionId());
+    } catch (AssertionError e ) {
+      // pass
+    }
+  }
+
+  @Test
+  public void maximumOffsetInPageCanBeEncoded() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1);
+    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+    assertEquals(address, packedPointer.getRecordPointer());
+  }
+
+  @Test
+  public void offsetsPastMaxOffsetInPageWillOverflow() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES);
+    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+    assertEquals(0, packedPointer.getRecordPointer());
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
new file mode 100644
index 0000000..8fa7259
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class UnsafeShuffleInMemorySorterSuite {
+
+  private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
+    final byte[] strBytes = new byte[strLength];
+    PlatformDependent.copyMemory(
+      baseObject,
+      baseOffset,
+      strBytes,
+      PlatformDependent.BYTE_ARRAY_OFFSET, strLength);
+    return new String(strBytes);
+  }
+
+  @Test
+  public void testSortingEmptyInput() {
+    final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100);
+    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    assert(!iter.hasNext());
+  }
+
+  @Test
+  public void testBasicSorting() throws Exception {
+    final String[] dataToSort = new String[] {
+      "Boba",
+      "Pearls",
+      "Tapioca",
+      "Taho",
+      "Condensed Milk",
+      "Jasmine",
+      "Milk Tea",
+      "Lychee",
+      "Mango"
+    };
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+    final Object baseObject = dataPage.getBaseObject();
+    final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+    final HashPartitioner hashPartitioner = new HashPartitioner(4);
+
+    // Write the records into the data page and store pointers into the sorter
+    long position = dataPage.getBaseOffset();
+    for (String str : dataToSort) {
+      final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
+      final byte[] strBytes = str.getBytes("utf-8");
+      PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length);
+      position += 4;
+      PlatformDependent.copyMemory(
+        strBytes,
+        PlatformDependent.BYTE_ARRAY_OFFSET,
+        baseObject,
+        position,
+        strBytes.length);
+      position += strBytes.length;
+      sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
+    }
+
+    // Sort the records
+    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    int prevPartitionId = -1;
+    Arrays.sort(dataToSort);
+    for (int i = 0; i < dataToSort.length; i++) {
+      Assert.assertTrue(iter.hasNext());
+      iter.loadNext();
+      final int partitionId = iter.packedRecordPointer.getPartitionId();
+      Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
+      Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
+        partitionId >= prevPartitionId);
+      final long recordAddress = iter.packedRecordPointer.getRecordPointer();
+      final int recordLength = PlatformDependent.UNSAFE.getInt(
+        memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
+      final String str = getStringFromDataPage(
+        memoryManager.getPage(recordAddress),
+        memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length
+        recordLength);
+      Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
+    }
+    Assert.assertFalse(iter.hasNext());
+  }
+
+  @Test
+  public void testSortingManyNumbers() throws Exception {
+    UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+    int[] numbersToSort = new int[128000];
+    Random random = new Random(16);
+    for (int i = 0; i < numbersToSort.length; i++) {
+      numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
+      sorter.insertRecord(0, numbersToSort[i]);
+    }
+    Arrays.sort(numbersToSort);
+    int[] sorterResult = new int[numbersToSort.length];
+    UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    int j = 0;
+    while (iter.hasNext()) {
+      iter.loadNext();
+      sorterResult[j] = iter.packedRecordPointer.getPartitionId();
+      j += 1;
+    }
+    Assert.assertArrayEquals(numbersToSort, sorterResult);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
new file mode 100644
index 0000000..730d265
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.*;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import scala.*;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+import scala.runtime.AbstractFunction1;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.io.ByteStreams;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.*;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZ4CompressionCodec;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.io.SnappyCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.serializer.*;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeShuffleWriterSuite {
+
+  static final int NUM_PARTITITONS = 4;
+  final TaskMemoryManager taskMemoryManager =
+    new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+  final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
+  File mergedOutputFile;
+  File tempDir;
+  long[] partitionSizesInMergedFile;
+  final LinkedList<File> spillFilesCreated = new LinkedList<File>();
+  SparkConf conf;
+  final Serializer serializer = new KryoSerializer(new SparkConf());
+  TaskMetrics taskMetrics;
+
+  @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+  @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
+  @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+  @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
+
+  private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+    @Override
+    public OutputStream apply(OutputStream stream) {
+      if (conf.getBoolean("spark.shuffle.compress", true)) {
+        return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
+      } else {
+        return stream;
+      }
+    }
+  }
+
+  @After
+  public void tearDown() {
+    Utils.deleteRecursively(tempDir);
+    final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
+    if (leakedMemory != 0) {
+      fail("Test leaked " + leakedMemory + " bytes of managed memory");
+    }
+  }
+
+  @Before
+  @SuppressWarnings("unchecked")
+  public void setUp() throws IOException {
+    MockitoAnnotations.initMocks(this);
+    tempDir = Utils.createTempDir("test", "test");
+    mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
+    partitionSizesInMergedFile = null;
+    spillFilesCreated.clear();
+    conf = new SparkConf();
+    taskMetrics = new TaskMetrics();
+
+    when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+
+    when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+    when(blockManager.getDiskWriter(
+      any(BlockId.class),
+      any(File.class),
+      any(SerializerInstance.class),
+      anyInt(),
+      any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
+      @Override
+      public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+        Object[] args = invocationOnMock.getArguments();
+
+        return new DiskBlockObjectWriter(
+          (BlockId) args[0],
+          (File) args[1],
+          (SerializerInstance) args[2],
+          (Integer) args[3],
+          new CompressStream(),
+          false,
+          (ShuffleWriteMetrics) args[4]
+        );
+      }
+    });
+    when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer(
+      new Answer<InputStream>() {
+        @Override
+        public InputStream answer(InvocationOnMock invocation) throws Throwable {
+          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+          InputStream is = (InputStream) invocation.getArguments()[1];
+          if (conf.getBoolean("spark.shuffle.compress", true)) {
+            return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
+          } else {
+            return is;
+          }
+        }
+      }
+    );
+
+    when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer(
+      new Answer<OutputStream>() {
+        @Override
+        public OutputStream answer(InvocationOnMock invocation) throws Throwable {
+          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+          OutputStream os = (OutputStream) invocation.getArguments()[1];
+          if (conf.getBoolean("spark.shuffle.compress", true)) {
+            return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
+          } else {
+            return os;
+          }
+        }
+      }
+    );
+
+    when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
+    doAnswer(new Answer<Void>() {
+      @Override
+      public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+        partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
+        return null;
+      }
+    }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
+
+    when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+      new Answer<Tuple2<TempShuffleBlockId, File>>() {
+        @Override
+        public Tuple2<TempShuffleBlockId, File> answer(
+          InvocationOnMock invocationOnMock) throws Throwable {
+          TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
+          File file = File.createTempFile("spillFile", ".spill", tempDir);
+          spillFilesCreated.add(file);
+          return Tuple2$.MODULE$.apply(blockId, file);
+        }
+      });
+
+    when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+
+    when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
+    when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
+  }
+
+  private UnsafeShuffleWriter<Object, Object> createWriter(
+      boolean transferToEnabled) throws IOException {
+    conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
+    return new UnsafeShuffleWriter<Object, Object>(
+      blockManager,
+      shuffleBlockResolver,
+      taskMemoryManager,
+      shuffleMemoryManager,
+      new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
+      0, // map id
+      taskContext,
+      conf
+    );
+  }
+
+  private void assertSpillFilesWereCleanedUp() {
+    for (File spillFile : spillFilesCreated) {
+      assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+        spillFile.exists());
+    }
+  }
+
+  private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
+    final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>();
+    long startOffset = 0;
+    for (int i = 0; i < NUM_PARTITITONS; i++) {
+      final long partitionSize = partitionSizesInMergedFile[i];
+      if (partitionSize > 0) {
+        InputStream in = new FileInputStream(mergedOutputFile);
+        ByteStreams.skipFully(in, startOffset);
+        in = new LimitedInputStream(in, partitionSize);
+        if (conf.getBoolean("spark.shuffle.compress", true)) {
+          in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
+        }
+        DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in);
+        Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
+        while (records.hasNext()) {
+          Tuple2<Object, Object> record = records.next();
+          assertEquals(i, hashPartitioner.getPartition(record._1()));
+          recordsList.add(record);
+        }
+        recordsStream.close();
+        startOffset += partitionSize;
+      }
+    }
+    return recordsList;
+  }
+
+  @Test(expected=IllegalStateException.class)
+  public void mustCallWriteBeforeSuccessfulStop() throws IOException {
+    createWriter(false).stop(true);
+  }
+
+  @Test
+  public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
+    createWriter(false).stop(false);
+  }
+
+  @Test
+  public void writeEmptyIterator() throws Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(Collections.<Product2<Object, Object>>emptyIterator());
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+    assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
+    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten());
+    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten());
+    assertEquals(0, taskMetrics.diskBytesSpilled());
+    assertEquals(0, taskMetrics.memoryBytesSpilled());
+  }
+
+  @Test
+  public void writeWithoutSpilling() throws Exception {
+    // In this example, each partition should have exactly one record:
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    for (int i = 0; i < NUM_PARTITITONS; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(dataToWrite.iterator());
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+
+    long sumOfPartitionSizes = 0;
+    for (long size: partitionSizesInMergedFile) {
+      // All partitions should be the same size:
+      assertEquals(partitionSizesInMergedFile[0], size);
+      sumOfPartitionSizes += size;
+    }
+    assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertEquals(0, taskMetrics.diskBytesSpilled());
+    assertEquals(0, taskMetrics.memoryBytesSpilled());
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  private void testMergingSpills(
+      boolean transferToEnabled,
+      String compressionCodecName) throws IOException {
+    if (compressionCodecName != null) {
+      conf.set("spark.shuffle.compress", "true");
+      conf.set("spark.io.compression.codec", compressionCodecName);
+    } else {
+      conf.set("spark.shuffle.compress", "false");
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    writer.insertRecordIntoSorter(dataToWrite.get(0));
+    writer.insertRecordIntoSorter(dataToWrite.get(1));
+    writer.insertRecordIntoSorter(dataToWrite.get(2));
+    writer.insertRecordIntoSorter(dataToWrite.get(3));
+    writer.forceSorterToSpill();
+    writer.insertRecordIntoSorter(dataToWrite.get(4));
+    writer.insertRecordIntoSorter(dataToWrite.get(5));
+    writer.closeAndWriteOutput();
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+    assertEquals(2, spillFilesCreated.size());
+
+    long sumOfPartitionSizes = 0;
+    for (long size: partitionSizesInMergedFile) {
+      sumOfPartitionSizes += size;
+    }
+    assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
+
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndLZF() throws Exception {
+    testMergingSpills(true, LZFCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndLZF() throws Exception {
+    testMergingSpills(false, LZFCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndLZ4() throws Exception {
+    testMergingSpills(true, LZ4CompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
+    testMergingSpills(false, LZ4CompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndSnappy() throws Exception {
+    testMergingSpills(true, SnappyCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
+    testMergingSpills(false, SnappyCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
+    testMergingSpills(true, null);
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
+    testMergingSpills(false, null);
+  }
+
+  @Test
+  public void writeEnoughDataToTriggerSpill() throws Exception {
+    when(shuffleMemoryManager.tryToAcquire(anyLong()))
+      .then(returnsFirstArg()) // Allocate initial sort buffer
+      .then(returnsFirstArg()) // Allocate initial data page
+      .thenReturn(0L) // Deny request to allocate new data page
+      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
+    for (int i = 0; i < 128 + 1; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
+    }
+    writer.write(dataToWrite.iterator());
+    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    assertEquals(2, spillFilesCreated.size());
+    writer.stop(true);
+    readRecordsFromFile();
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
+    when(shuffleMemoryManager.tryToAcquire(anyLong()))
+      .then(returnsFirstArg()) // Allocate initial sort buffer
+      .then(returnsFirstArg()) // Allocate initial data page
+      .thenReturn(0L) // Deny request to grow sort buffer
+      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    writer.write(dataToWrite.iterator());
+    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    assertEquals(2, spillFilesCreated.size());
+    writer.stop(true);
+    readRecordsFromFile();
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
+    new Random(42).nextBytes(bytes);
+    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
+    writer.write(dataToWrite.iterator());
+    writer.stop(true);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
+    // Use a custom serializer so that we have exact control over the size of serialized data.
+    final Serializer byteArraySerializer = new Serializer() {
+      @Override
+      public SerializerInstance newInstance() {
+        return new SerializerInstance() {
+          @Override
+          public SerializationStream serializeStream(final OutputStream s) {
+            return new SerializationStream() {
+              @Override
+              public void flush() { }
+
+              @Override
+              public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
+                byte[] bytes = (byte[]) t;
+                try {
+                  s.write(bytes);
+                } catch (IOException e) {
+                  throw new RuntimeException(e);
+                }
+                return this;
+              }
+
+              @Override
+              public void close() { }
+            };
+          }
+          public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) { return null; }
+          public DeserializationStream deserializeStream(InputStream s) { return null; }
+          public <T> T deserialize(ByteBuffer b, ClassLoader l, ClassTag<T> ev1) { return null; }
+          public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) { return null; }
+        };
+      }
+    };
+    when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(byteArraySerializer));
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    // Insert a record and force a spill so that there's something to clean up:
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[1], new byte[1]));
+    writer.forceSorterToSpill();
+    // We should be able to write a record that's right _at_ the max record size
+    final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE];
+    new Random(42).nextBytes(atMaxRecordSize);
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[0], atMaxRecordSize));
+    writer.forceSorterToSpill();
+    // Inserting a record that's larger than the max record size should fail:
+    final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1];
+    new Random(42).nextBytes(exceedsMaxRecordSize);
+    Product2<Object, Object> hugeRecord =
+      new Tuple2<Object, Object>(new byte[0], exceedsMaxRecordSize);
+    try {
+      // Here, we write through the public `write()` interface instead of the test-only
+      // `insertRecordIntoSorter` interface:
+      writer.write(Collections.singletonList(hugeRecord).iterator());
+      fail("Expected exception to be thrown");
+    } catch (IOException e) {
+      // Pass
+    }
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+    writer.forceSorterToSpill();
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+    writer.stop(false);
+    assertSpillFilesWereCleanedUp();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
index 8c6035f..cf6a143 100644
--- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.io
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
 
+import com.google.common.io.ByteStreams
 import org.scalatest.FunSuite
 
 import org.apache.spark.SparkConf
@@ -62,6 +63,14 @@ class CompressionCodecSuite extends FunSuite {
     testCodec(codec)
   }
 
+  test("lz4 does not support concatenation of serialized streams") {
+    val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName)
+    assert(codec.getClass === classOf[LZ4CompressionCodec])
+    intercept[Exception] {
+      testConcatenationOfSerializedStreams(codec)
+    }
+  }
+
   test("lzf compression codec") {
     val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName)
     assert(codec.getClass === classOf[LZFCompressionCodec])
@@ -74,6 +83,12 @@ class CompressionCodecSuite extends FunSuite {
     testCodec(codec)
   }
 
+  test("lzf supports concatenation of serialized streams") {
+    val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName)
+    assert(codec.getClass === classOf[LZFCompressionCodec])
+    testConcatenationOfSerializedStreams(codec)
+  }
+
   test("snappy compression codec") {
     val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName)
     assert(codec.getClass === classOf[SnappyCompressionCodec])
@@ -86,9 +101,38 @@ class CompressionCodecSuite extends FunSuite {
     testCodec(codec)
   }
 
+  test("snappy does not support concatenation of serialized streams") {
+    val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName)
+    assert(codec.getClass === classOf[SnappyCompressionCodec])
+    intercept[Exception] {
+      testConcatenationOfSerializedStreams(codec)
+    }
+  }
+
   test("bad compression codec") {
     intercept[IllegalArgumentException] {
       CompressionCodec.createCodec(conf, "foobar")
     }
   }
+
+  private def testConcatenationOfSerializedStreams(codec: CompressionCodec): Unit = {
+    val bytes1: Array[Byte] = {
+      val baos = new ByteArrayOutputStream()
+      val out = codec.compressedOutputStream(baos)
+      (0 to 64).foreach(out.write)
+      out.close()
+      baos.toByteArray
+    }
+    val bytes2: Array[Byte] = {
+      val baos = new ByteArrayOutputStream()
+      val out = codec.compressedOutputStream(baos)
+      (65 to 127).foreach(out.write)
+      out.close()
+      baos.toByteArray
+    }
+    val concatenatedBytes = codec.compressedInputStream(new ByteArrayInputStream(bytes1 ++ bytes2))
+    val decompressed: Array[Byte] = new Array[Byte](128)
+    ByteStreams.readFully(concatenatedBytes, decompressed)
+    assert(decompressed.toSeq === (0 to 127))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
new file mode 100644
index 0000000..ed4d8ce
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import org.apache.spark.SparkConf
+import org.scalatest.FunSuite
+
+class JavaSerializerSuite extends FunSuite {
+  test("JavaSerializer instances are serializable") {
+    val serializer = new JavaSerializer(new SparkConf())
+    val instance = serializer.newInstance()
+    instance.deserialize[JavaSerializer](instance.serialize(serializer))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
new file mode 100644
index 0000000..49a04a2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{FunSuite, Matchers}
+
+import org.apache.spark._
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
+
+/**
+ * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
+ * performed in other suites.
+ */
+class UnsafeShuffleManagerSuite extends FunSuite with Matchers {
+
+  import UnsafeShuffleManager.canUseUnsafeShuffle
+
+  private class RuntimeExceptionAnswer extends Answer[Object] {
+    override def answer(invocation: InvocationOnMock): Object = {
+      throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName)
+    }
+  }
+
+  private def shuffleDep(
+      partitioner: Partitioner,
+      serializer: Option[Serializer],
+      keyOrdering: Option[Ordering[Any]],
+      aggregator: Option[Aggregator[Any, Any, Any]],
+      mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = {
+    val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer())
+    doReturn(0).when(dep).shuffleId
+    doReturn(partitioner).when(dep).partitioner
+    doReturn(serializer).when(dep).serializer
+    doReturn(keyOrdering).when(dep).keyOrdering
+    doReturn(aggregator).when(dep).aggregator
+    doReturn(mapSideCombine).when(dep).mapSideCombine
+    dep
+  }
+
+  test("supported shuffle dependencies") {
+    val kryo = Some(new KryoSerializer(new SparkConf()))
+
+    assert(canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
+    when(rangePartitioner.numPartitions).thenReturn(2)
+    assert(canUseUnsafeShuffle(shuffleDep(
+      partitioner = rangePartitioner,
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+  }
+
+  test("unsupported shuffle dependencies") {
+    val kryo = Some(new KryoSerializer(new SparkConf()))
+    val java = Some(new JavaSerializer(new SparkConf()))
+
+    // We only support serializers that support object relocation
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = java,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    // We do not support shuffles with more than 16 million output partitions
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    // We do not support shuffles that perform any kind of aggregation or sorting of keys
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = Some(mock(classOf[Ordering[Any]])),
+      aggregator = None,
+      mapSideCombine = false
+    )))
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+      mapSideCombine = false
+    )))
+    // We do not support shuffles that perform any kind of aggregation or sorting of keys
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = Some(mock(classOf[Ordering[Any]])),
+      aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+      mapSideCombine = true
+    )))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
new file mode 100644
index 0000000..6351539
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.io.filefilter.TrueFileFilter
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.util.Utils
+
+class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
+
+  // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
+
+  override def beforeAll() {
+    conf.set("spark.shuffle.manager", "tungsten-sort")
+    // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort
+    // shuffle records.
+    conf.set("spark.shuffle.memoryFraction", "0.5")
+  }
+
+  test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
+    val tmpDir = Utils.createTempDir()
+    try {
+      val myConf = conf.clone()
+        .set("spark.local.dir", tmpDir.getAbsolutePath)
+      sc = new SparkContext("local", "test", myConf)
+      // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
+      val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+      val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+        .setSerializer(new KryoSerializer(myConf))
+      val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+      assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
+      def getAllFiles: Set[File] =
+        FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+      val filesBeforeShuffle = getAllFiles
+      // Force the shuffle to be performed
+      shuffledRdd.count()
+      // Ensure that the shuffle actually created files that will need to be cleaned up
+      val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+      filesCreatedByShuffle.map(_.getName) should be
+        Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+      // Check that the cleanup actually removes the files
+      sc.env.blockManager.master.removeShuffle(0, blocking = true)
+      for (file <- filesCreatedByShuffle) {
+        assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+      }
+    } finally {
+      Utils.deleteRecursively(tmpDir)
+    }
+  }
+
+  test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
+    val tmpDir = Utils.createTempDir()
+    try {
+      val myConf = conf.clone()
+        .set("spark.local.dir", tmpDir.getAbsolutePath)
+      sc = new SparkContext("local", "test", myConf)
+      // Create a shuffled RDD and verify that it will actually use the old SortShuffle path
+      val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+      val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+        .setSerializer(new JavaSerializer(myConf))
+      val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+      assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
+      def getAllFiles: Set[File] =
+        FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+      val filesBeforeShuffle = getAllFiles
+      // Force the shuffle to be performed
+      shuffledRdd.count()
+      // Ensure that the shuffle actually created files that will need to be cleaned up
+      val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+      filesCreatedByShuffle.map(_.getName) should be
+        Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+      // Check that the cleanup actually removes the files
+      sc.env.blockManager.master.removeShuffle(0, blocking = true)
+      for (file <- filesCreatedByShuffle) {
+        assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+      }
+    } finally {
+      Utils.deleteRecursively(tmpDir)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index cf9279e..564a443 100644
--- a/pom.xml
+++ b/pom.xml
@@ -669,7 +669,7 @@
       <dependency>
         <groupId>org.mockito</groupId>
         <artifactId>mockito-all</artifactId>
-        <version>1.9.0</version>
+        <version>1.9.5</version>
         <scope>test</scope>
       </dependency>
       <dependency>
@@ -685,6 +685,18 @@
         <scope>test</scope>
       </dependency>
       <dependency>
+        <groupId>org.hamcrest</groupId>
+        <artifactId>hamcrest-core</artifactId>
+        <version>1.3</version>
+        <scope>test</scope>
+      </dependency>
+      <dependency>
+        <groupId>org.hamcrest</groupId>
+        <artifactId>hamcrest-library</artifactId>
+        <version>1.3</version>
+        <scope>test</scope>
+      </dependency>
+      <dependency>
         <groupId>com.novocode</groupId>
         <artifactId>junit-interface</artifactId>
         <version>0.10</version>

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index fba7290..487062a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -131,6 +131,12 @@ object MimaExcludes {
             // SPARK-7530 Added StreamingContext.getState()
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.streaming.StreamingContext.state_=")
+          ) ++ Seq(
+            // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some
+            // unnecessary type bounds in order to fix some compiler warnings that occurred when
+            // implementing this interface in Java. Note that ShuffleWriter is private[spark].
+            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
+              "org.apache.spark.shuffle.ShuffleWriter")
           )
 
         case v if v.startsWith("1.3") =>

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index c3d2c70..3e46596 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -17,17 +17,18 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.{RDD, ShuffledRDD}
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager
 import org.apache.spark.sql.catalyst.errors.attachTree
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.{SQLContext, Row}
 import org.apache.spark.util.MutablePair
 
 object Exchange {
@@ -85,7 +86,9 @@ case class Exchange(
     // corner-cases where a partitioner constructed with `numPartitions` partitions may output
     // fewer partitions (like RangePartitioner, for example).
     val conf = child.sqlContext.sparkContext.conf
-    val sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+    val shuffleManager = SparkEnv.get.shuffleManager
+    val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] ||
+      shuffleManager.isInstanceOf[UnsafeShuffleManager]
     val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
     val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true)
     if (newOrdering.nonEmpty) {
@@ -93,11 +96,11 @@ case class Exchange(
       // which requires a defensive copy.
       true
     } else if (sortBasedShuffleOn) {
-      // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory.
-      // However, there are two special cases where we can avoid the copy, described below:
-      if (partitioner.numPartitions <= bypassMergeThreshold) {
-        // If the number of output partitions is sufficiently small, then Spark will fall back to
-        // the old hash-based shuffle write path which doesn't buffer deserialized records.
+      val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+      if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) {
+        // If we're using the original SortShuffleManager and the number of output partitions is
+        // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which
+        // doesn't buffer deserialized records.
         // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
         false
       } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) {
@@ -105,9 +108,14 @@ case class Exchange(
         // them. This optimization is guarded by a feature-flag and is only applied in cases where
         // shuffle dependency does not specify an ordering and the record serializer has certain
         // properties. If this optimization is enabled, we can safely avoid the copy.
+        //
+        // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081).
         false
       } else {
-        // None of the special cases held, so we must copy.
+        // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code
+        // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls
+        // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In
+        // both cases, we must copy.
         true
       }
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/unsafe/pom.xml
----------------------------------------------------------------------
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index 5b07332..9e151fc 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -42,6 +42,10 @@
       <groupId>com.google.code.findbugs</groupId>
       <artifactId>jsr305</artifactId>
     </dependency>
+    <dependency>
+      <groupId>com.google.guava</groupId>
+      <artifactId>guava</artifactId>
+    </dependency>
 
     <!-- Provided dependencies -->
     <dependency>

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 9224988..2906ac8 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -19,6 +19,7 @@ package org.apache.spark.unsafe.memory;
 
 import java.util.*;
 
+import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -47,10 +48,18 @@ public final class TaskMemoryManager {
 
   private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
 
-  /**
-   * The number of entries in the page table.
-   */
-  private static final int PAGE_TABLE_SIZE = 1 << 13;
+  /** The number of bits used to address the page table. */
+  private static final int PAGE_NUMBER_BITS = 13;
+
+  /** The number of bits used to encode offsets in data pages. */
+  @VisibleForTesting
+  static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS;  // 51
+
+  /** The number of entries in the page table. */
+  private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
+
+  /** Maximum supported data page size */
+  private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS);
 
   /** Bit mask for the lower 51 bits of a long. */
   private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
@@ -101,11 +110,9 @@ public final class TaskMemoryManager {
    * intended for allocating large blocks of memory that will be shared between operators.
    */
   public MemoryBlock allocatePage(long size) {
-    if (logger.isTraceEnabled()) {
-      logger.trace("Allocating {} byte page", size);
-    }
-    if (size >= (1L << 51)) {
-      throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes");
+    if (size > MAXIMUM_PAGE_SIZE) {
+      throw new IllegalArgumentException(
+        "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes");
     }
 
     final int pageNumber;
@@ -120,8 +127,8 @@ public final class TaskMemoryManager {
     final MemoryBlock page = executorMemoryManager.allocate(size);
     page.pageNumber = pageNumber;
     pageTable[pageNumber] = page;
-    if (logger.isDebugEnabled()) {
-      logger.debug("Allocate page number {} ({} bytes)", pageNumber, size);
+    if (logger.isTraceEnabled()) {
+      logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
     }
     return page;
   }
@@ -130,9 +137,6 @@ public final class TaskMemoryManager {
    * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
    */
   public void freePage(MemoryBlock page) {
-    if (logger.isTraceEnabled()) {
-      logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size());
-    }
     assert (page.pageNumber != -1) :
       "Called freePage() on memory that wasn't allocated with allocatePage()";
     executorMemoryManager.free(page);
@@ -140,8 +144,8 @@ public final class TaskMemoryManager {
       allocatedPages.clear(page.pageNumber);
     }
     pageTable[page.pageNumber] = null;
-    if (logger.isDebugEnabled()) {
-      logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size());
+    if (logger.isTraceEnabled()) {
+      logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
     }
   }
 
@@ -173,14 +177,36 @@ public final class TaskMemoryManager {
   /**
    * Given a memory page and offset within that page, encode this address into a 64-bit long.
    * This address will remain valid as long as the corresponding page has not been freed.
+   *
+   * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}.
+   * @param offsetInPage an offset in this page which incorporates the base offset. In other words,
+   *                     this should be the value that you would pass as the base offset into an
+   *                     UNSAFE call (e.g. page.baseOffset() + something).
+   * @return an encoded page address.
    */
   public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
-    if (inHeap) {
-      assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
-      return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
-    } else {
-      return offsetInPage;
+    if (!inHeap) {
+      // In off-heap mode, an offset is an absolute address that may require a full 64 bits to
+      // encode. Due to our page size limitation, though, we can convert this into an offset that's
+      // relative to the page's base offset; this relative offset will fit in 51 bits.
+      offsetInPage -= page.getBaseOffset();
     }
+    return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
+  }
+
+  @VisibleForTesting
+  public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
+    assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
+    return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
+  }
+
+  @VisibleForTesting
+  public static int decodePageNumber(long pagePlusOffsetAddress) {
+    return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS);
+  }
+
+  private static long decodeOffset(long pagePlusOffsetAddress) {
+    return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
   }
 
   /**
@@ -189,7 +215,7 @@ public final class TaskMemoryManager {
    */
   public Object getPage(long pagePlusOffsetAddress) {
     if (inHeap) {
-      final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51);
+      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
       assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
       final Object page = pageTable[pageNumber].getBaseObject();
       assert (page != null);
@@ -204,10 +230,15 @@ public final class TaskMemoryManager {
    * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
    */
   public long getOffsetInPage(long pagePlusOffsetAddress) {
+    final long offsetInPage = decodeOffset(pagePlusOffsetAddress);
     if (inHeap) {
-      return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
+      return offsetInPage;
     } else {
-      return pagePlusOffsetAddress;
+      // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we
+      // converted the absolute address into a relative address. Here, we invert that operation:
+      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
+      assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+      return pageTable[pageNumber].getBaseOffset() + offsetInPage;
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c53ebea9/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
----------------------------------------------------------------------
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
index 932882f..06fb081 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
@@ -38,4 +38,27 @@ public class TaskMemoryManagerSuite {
     Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
   }
 
+  @Test
+  public void encodePageNumberAndOffsetOffHeap() {
+    final TaskMemoryManager manager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+    final MemoryBlock dataPage = manager.allocatePage(256);
+    // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
+    // encode. This test exercises that corner-case:
+    final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
+    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset);
+    Assert.assertEquals(null, manager.getPage(encodedAddress));
+    Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress));
+  }
+
+  @Test
+  public void encodePageNumberAndOffsetOnHeap() {
+    final TaskMemoryManager manager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock dataPage = manager.allocatePage(256);
+    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
+    Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
+    Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message