spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From joshro...@apache.org
Subject spark git commit: [SPARK-7542][SQL] Support off-heap index/sort buffer
Date Fri, 06 Nov 2015 03:02:43 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 b655117d0 -> 1b43dd391


[SPARK-7542][SQL] Support off-heap index/sort buffer

This brings the support of off-heap memory for array inside BytesToBytesMap and InMemorySorter,
then we could allocate all the memory from off-heap for execution.

Closes #8068

Author: Davies Liu <davies@databricks.com>

Closes #9477 from davies/unsafe_timsort.

(cherry picked from commit eec74ba8bde7f9446cc38e687bda103e85669d35)
Signed-off-by: Josh Rosen <joshrosen@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1b43dd39
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1b43dd39
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1b43dd39

Branch: refs/heads/branch-1.6
Commit: 1b43dd391b2233c1b062df886628eb6fb92b2936
Parents: b655117
Author: Davies Liu <davies@databricks.com>
Authored: Thu Nov 5 19:02:18 2015 -0800
Committer: Josh Rosen <joshrosen@databricks.com>
Committed: Thu Nov 5 19:02:37 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/memory/MemoryConsumer.java | 36 ++++++-----
 .../apache/spark/memory/TaskMemoryManager.java  |  6 +-
 .../shuffle/sort/ShuffleExternalSorter.java     | 26 +++-----
 .../shuffle/sort/ShuffleInMemorySorter.java     | 67 ++++++++++----------
 .../shuffle/sort/ShuffleSortDataFormat.java     | 38 +++++++----
 .../spark/unsafe/map/BytesToBytesMap.java       | 18 +++---
 .../unsafe/sort/UnsafeExternalSorter.java       | 28 +++-----
 .../unsafe/sort/UnsafeInMemorySorter.java       | 66 +++++++++++--------
 .../unsafe/sort/UnsafeSortDataFormat.java       | 47 ++++++++------
 .../spark/memory/TaskMemoryManagerSuite.java    | 23 -------
 .../apache/spark/memory/TestMemoryConsumer.java | 45 +++++++++++++
 .../sort/ShuffleInMemorySorterSuite.java        | 16 +++--
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |  1 -
 .../unsafe/sort/UnsafeInMemorySorterSuite.java  | 12 ++--
 .../sql/execution/UnsafeKVExternalSorter.java   |  3 +-
 .../apache/spark/unsafe/array/LongArray.java    | 18 +++++-
 .../spark/unsafe/array/LongArraySuite.java      |  4 ++
 17 files changed, 265 insertions(+), 189 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
index 008799c..8fbdb72 100644
--- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
+++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
@@ -20,6 +20,7 @@ package org.apache.spark.memory;
 
 import java.io.IOException;
 
+import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 
 
@@ -28,9 +29,9 @@ import org.apache.spark.unsafe.memory.MemoryBlock;
  */
 public abstract class MemoryConsumer {
 
-  private final TaskMemoryManager taskMemoryManager;
+  protected final TaskMemoryManager taskMemoryManager;
   private final long pageSize;
-  private long used;
+  protected long used;
 
   protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) {
     this.taskMemoryManager = taskMemoryManager;
@@ -74,26 +75,29 @@ public abstract class MemoryConsumer {
   public abstract long spill(long size, MemoryConsumer trigger) throws IOException;
 
   /**
-   * Acquire `size` bytes memory.
-   *
-   * If there is not enough memory, throws OutOfMemoryError.
+   * Allocates a LongArray of `size`.
    */
-  protected void acquireMemory(long size) {
-    long got = taskMemoryManager.acquireExecutionMemory(size, this);
-    if (got < size) {
-      taskMemoryManager.releaseExecutionMemory(got, this);
+  public LongArray allocateArray(long size) {
+    long required = size * 8L;
+    MemoryBlock page = taskMemoryManager.allocatePage(required, this);
+    if (page == null || page.size() < required) {
+      long got = 0;
+      if (page != null) {
+        got = page.size();
+        taskMemoryManager.freePage(page, this);
+      }
       taskMemoryManager.showMemoryUsage();
-      throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " +
got);
+      throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got
" + got);
     }
-    used += got;
+    used += required;
+    return new LongArray(page);
   }
 
   /**
-   * Release `size` bytes memory.
+   * Frees a LongArray.
    */
-  protected void releaseMemory(long size) {
-    used -= size;
-    taskMemoryManager.releaseExecutionMemory(size, this);
+  public void freeArray(LongArray array) {
+    freePage(array.memoryBlock());
   }
 
   /**
@@ -109,7 +113,7 @@ public abstract class MemoryConsumer {
       long got = 0;
       if (page != null) {
         got = page.size();
-        freePage(page);
+        taskMemoryManager.freePage(page, this);
       }
       taskMemoryManager.showMemoryUsage();
       throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got
" + got);

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index 4230575..6440f9c 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -137,7 +137,7 @@ public class TaskMemoryManager {
       if (got < required) {
         // Call spill() on other consumers to release memory
         for (MemoryConsumer c: consumers) {
-          if (c != null && c != consumer && c.getUsed() > 0) {
+          if (c != consumer && c.getUsed() > 0) {
             try {
               long released = c.spill(required - got, consumer);
               if (released > 0) {
@@ -173,7 +173,9 @@ public class TaskMemoryManager {
         }
       }
 
-      consumers.add(consumer);
+      if (consumer != null) {
+        consumers.add(consumer);
+      }
       logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got),
consumer);
       return got;
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 400d852..9affff8 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.DiskBlockObjectWriter;
 import org.apache.spark.storage.TempShuffleBlockId;
 import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.Utils;
 
@@ -114,8 +115,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
     this.numElementsForSpillThreshold =
       conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
     this.writeMetrics = writeMetrics;
-    acquireMemory(initialSize * 8L);
-    this.inMemSorter = new ShuffleInMemorySorter(initialSize);
+    this.inMemSorter = new ShuffleInMemorySorter(this, initialSize);
     this.peakMemoryUsedBytes = getMemoryUsage();
   }
 
@@ -301,9 +301,8 @@ final class ShuffleExternalSorter extends MemoryConsumer {
   public void cleanupResources() {
     freeMemory();
     if (inMemSorter != null) {
-      long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+      inMemSorter.free();
       inMemSorter = null;
-      releaseMemory(sorterMemoryUsage);
     }
     for (SpillInfo spill : spills) {
       if (spill.file.exists() && !spill.file.delete()) {
@@ -321,9 +320,10 @@ final class ShuffleExternalSorter extends MemoryConsumer {
     assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
       long used = inMemSorter.getMemoryUsage();
-      long needed = used + inMemSorter.getMemoryToExpand();
+      LongArray array;
       try {
-        acquireMemory(needed);  // could trigger spilling
+        // could trigger spilling
+        array = allocateArray(used / 8 * 2);
       } catch (OutOfMemoryError e) {
         // should have trigger spilling
         assert(inMemSorter.hasSpaceForAnotherRecord());
@@ -331,16 +331,9 @@ final class ShuffleExternalSorter extends MemoryConsumer {
       }
       // check if spilling is triggered or not
       if (inMemSorter.hasSpaceForAnotherRecord()) {
-        releaseMemory(needed);
+        freeArray(array);
       } else {
-        try {
-          inMemSorter.expandPointerArray();
-          releaseMemory(used);
-        } catch (OutOfMemoryError oom) {
-          // Just in case that JVM had run out of memory
-          releaseMemory(needed);
-          spill();
-        }
+        inMemSorter.expandPointerArray(array);
       }
     }
   }
@@ -404,9 +397,8 @@ final class ShuffleExternalSorter extends MemoryConsumer {
         // Do not count the final file towards the spill count.
         writeSortedFile(true);
         freeMemory();
-        long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+        inMemSorter.free();
         inMemSorter = null;
-        releaseMemory(sorterMemoryUsage);
       }
       return spills.toArray(new SpillInfo[spills.size()]);
     } catch (IOException e) {

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index e630575..58ad88e 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -19,11 +19,14 @@ package org.apache.spark.shuffle.sort;
 
 import java.util.Comparator;
 
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.util.collection.Sorter;
 
 final class ShuffleInMemorySorter {
 
-  private final Sorter<PackedRecordPointer, long[]> sorter;
+  private final Sorter<PackedRecordPointer, LongArray> sorter;
   private static final class SortComparator implements Comparator<PackedRecordPointer>
{
     @Override
     public int compare(PackedRecordPointer left, PackedRecordPointer right) {
@@ -32,24 +35,34 @@ final class ShuffleInMemorySorter {
   }
   private static final SortComparator SORT_COMPARATOR = new SortComparator();
 
+  private final MemoryConsumer consumer;
+
   /**
    * An array of record pointers and partition ids that have been encoded by
    * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
    * records.
    */
-  private long[] array;
+  private LongArray array;
 
   /**
    * The position in the pointer array where new records can be inserted.
    */
   private int pos = 0;
 
-  public ShuffleInMemorySorter(int initialSize) {
+  public ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) {
+    this.consumer = consumer;
     assert (initialSize > 0);
-    this.array = new long[initialSize];
+    this.array = consumer.allocateArray(initialSize);
     this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
   }
 
+  public void free() {
+    if (array != null) {
+      consumer.freeArray(array);
+      array = null;
+    }
+  }
+
   public int numRecords() {
     return pos;
   }
@@ -58,30 +71,25 @@ final class ShuffleInMemorySorter {
     pos = 0;
   }
 
-  private int newLength() {
-    // Guard against overflow:
-    return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
-  }
-
-  /**
-   * Returns the memory needed to expand
-   */
-  public long getMemoryToExpand() {
-    return ((long) (newLength() - array.length)) * 8;
-  }
-
-  public void expandPointerArray() {
-    final long[] oldArray = array;
-    array = new long[newLength()];
-    System.arraycopy(oldArray, 0, array, 0, oldArray.length);
+  public void expandPointerArray(LongArray newArray) {
+    assert(newArray.size() > array.size());
+    Platform.copyMemory(
+      array.getBaseObject(),
+      array.getBaseOffset(),
+      newArray.getBaseObject(),
+      newArray.getBaseOffset(),
+      array.size() * 8L
+    );
+    consumer.freeArray(array);
+    array = newArray;
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pos < array.length;
+    return pos < array.size();
   }
 
   public long getMemoryUsage() {
-    return array.length * 8L;
+    return array.size() * 8L;
   }
 
   /**
@@ -96,14 +104,9 @@ final class ShuffleInMemorySorter {
    */
   public void insertRecord(long recordPointer, int partitionId) {
     if (!hasSpaceForAnotherRecord()) {
-      if (array.length == Integer.MAX_VALUE) {
-        throw new IllegalStateException("Sort pointer array has reached maximum size");
-      } else {
-        expandPointerArray();
-      }
+      expandPointerArray(consumer.allocateArray(array.size() * 2));
     }
-    array[pos] =
-        PackedRecordPointer.packPointer(recordPointer, partitionId);
+    array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId));
     pos++;
   }
 
@@ -112,12 +115,12 @@ final class ShuffleInMemorySorter {
    */
   public static final class ShuffleSorterIterator {
 
-    private final long[] pointerArray;
+    private final LongArray pointerArray;
     private final int numRecords;
     final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
     private int position = 0;
 
-    public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
+    public ShuffleSorterIterator(int numRecords, LongArray pointerArray) {
       this.numRecords = numRecords;
       this.pointerArray = pointerArray;
     }
@@ -127,7 +130,7 @@ final class ShuffleInMemorySorter {
     }
 
     public void loadNext() {
-      packedRecordPointer.set(pointerArray[position]);
+      packedRecordPointer.set(pointerArray.get(position));
       position++;
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
index 8a1e5ae..8f4e322 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
@@ -17,16 +17,19 @@
 
 package org.apache.spark.shuffle.sort;
 
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.SortDataFormat;
 
-final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]>
{
+final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, LongArray>
{
 
   public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();
 
   private ShuffleSortDataFormat() { }
 
   @Override
-  public PackedRecordPointer getKey(long[] data, int pos) {
+  public PackedRecordPointer getKey(LongArray data, int pos) {
     // Since we re-use keys, this method shouldn't be called.
     throw new UnsupportedOperationException();
   }
@@ -37,31 +40,38 @@ final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer,
lo
   }
 
   @Override
-  public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
-    reuse.set(data[pos]);
+  public PackedRecordPointer getKey(LongArray data, int pos, PackedRecordPointer reuse) {
+    reuse.set(data.get(pos));
     return reuse;
   }
 
   @Override
-  public void swap(long[] data, int pos0, int pos1) {
-    final long temp = data[pos0];
-    data[pos0] = data[pos1];
-    data[pos1] = temp;
+  public void swap(LongArray data, int pos0, int pos1) {
+    final long temp = data.get(pos0);
+    data.set(pos0, data.get(pos1));
+    data.set(pos1, temp);
   }
 
   @Override
-  public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
-    dst[dstPos] = src[srcPos];
+  public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) {
+    dst.set(dstPos, src.get(srcPos));
   }
 
   @Override
-  public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
-    System.arraycopy(src, srcPos, dst, dstPos, length);
+  public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length)
{
+    Platform.copyMemory(
+      src.getBaseObject(),
+      src.getBaseOffset() + srcPos * 8,
+      dst.getBaseObject(),
+      dst.getBaseOffset() + dstPos * 8,
+      length * 8
+    );
   }
 
   @Override
-  public long[] allocate(int length) {
-    return new long[length];
+  public LongArray allocate(int length) {
+    // This buffer is used temporary (usually small), so it's fine to allocated from JVM
heap.
+    return new LongArray(MemoryBlock.fromLongArray(new long[length]));
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 6656fd1..04694dc 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -20,7 +20,6 @@ package org.apache.spark.unsafe.map;
 import javax.annotation.Nullable;
 import java.io.File;
 import java.io.IOException;
-import java.util.Arrays;
 import java.util.Iterator;
 import java.util.LinkedList;
 
@@ -724,11 +723,10 @@ public final class BytesToBytesMap extends MemoryConsumer {
    */
   private void allocate(int capacity) {
     assert (capacity >= 0);
-    // The capacity needs to be divisible by 64 so that our bit set can be sized properly
     capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)),
64);
     assert (capacity <= MAX_CAPACITY);
-    acquireMemory(capacity * 16);
-    longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
+    longArray = allocateArray(capacity * 2);
+    longArray.zeroOut();
 
     this.growthThreshold = (int) (capacity * loadFactor);
     this.mask = capacity - 1;
@@ -743,9 +741,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
   public void free() {
     updatePeakMemoryUsed();
     if (longArray != null) {
-      long used = longArray.memoryBlock().size();
+      freeArray(longArray);
       longArray = null;
-      releaseMemory(used);
     }
     Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
     while (dataPagesIterator.hasNext()) {
@@ -834,9 +831,9 @@ public final class BytesToBytesMap extends MemoryConsumer {
   /**
    * Returns the underline long[] of longArray.
    */
-  public long[] getArray() {
+  public LongArray getArray() {
     assert(longArray != null);
-    return (long[]) longArray.memoryBlock().getBaseObject();
+    return longArray;
   }
 
   /**
@@ -844,7 +841,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
    */
   public void reset() {
     numElements = 0;
-    Arrays.fill(getArray(), 0);
+    longArray.zeroOut();
+
     while (dataPages.size() > 0) {
       MemoryBlock dataPage = dataPages.removeLast();
       freePage(dataPage);
@@ -887,7 +885,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
       longArray.set(newPos * 2, keyPointer);
       longArray.set(newPos * 2 + 1, hashcode);
     }
-    releaseMemory(oldLongArray.memoryBlock().size());
+    freeArray(oldLongArray);
 
     if (enablePerfMetrics) {
       timeSpentResizingNs += System.nanoTime() - resizeStartTime;

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index cba043b..9a7b2ad 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -32,6 +32,7 @@ import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.TaskCompletionListener;
 import org.apache.spark.util.Utils;
@@ -123,9 +124,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
     this.writeMetrics = new ShuffleWriteMetrics();
 
     if (existingInMemorySorter == null) {
-      this.inMemSorter =
-        new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
-      acquireMemory(inMemSorter.getMemoryUsage());
+      this.inMemSorter = new UnsafeInMemorySorter(
+        this, taskMemoryManager, recordComparator, prefixComparator, initialSize);
     } else {
       this.inMemSorter = existingInMemorySorter;
     }
@@ -277,9 +277,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
       deleteSpillFiles();
       freeMemory();
       if (inMemSorter != null) {
-        long used = inMemSorter.getMemoryUsage();
+        inMemSorter.free();
         inMemSorter = null;
-        releaseMemory(used);
       }
     }
   }
@@ -293,9 +292,10 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
     assert(inMemSorter != null);
     if (!inMemSorter.hasSpaceForAnotherRecord()) {
       long used = inMemSorter.getMemoryUsage();
-      long needed = used + inMemSorter.getMemoryToExpand();
+      LongArray array;
       try {
-        acquireMemory(needed);  // could trigger spilling
+        // could trigger spilling
+        array = allocateArray(used / 8 * 2);
       } catch (OutOfMemoryError e) {
         // should have trigger spilling
         assert(inMemSorter.hasSpaceForAnotherRecord());
@@ -303,16 +303,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
       }
       // check if spilling is triggered or not
       if (inMemSorter.hasSpaceForAnotherRecord()) {
-        releaseMemory(needed);
+        freeArray(array);
       } else {
-        try {
-          inMemSorter.expandPointerArray();
-          releaseMemory(used);
-        } catch (OutOfMemoryError oom) {
-          // Just in case that JVM had run out of memory
-          releaseMemory(needed);
-          spill();
-        }
+        inMemSorter.expandPointerArray(array);
       }
     }
   }
@@ -498,9 +491,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
           nextUpstream = null;
 
           assert(inMemSorter != null);
-          long used = inMemSorter.getMemoryUsage();
+          inMemSorter.free();
           inMemSorter = null;
-          releaseMemory(used);
         }
         numRecords--;
         upstream.loadNext();

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index d57213b..a218ad4 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -19,8 +19,10 @@ package org.apache.spark.util.collection.unsafe.sort;
 
 import java.util.Comparator;
 
+import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.util.collection.Sorter;
 
 /**
@@ -62,15 +64,16 @@ public final class UnsafeInMemorySorter {
     }
   }
 
+  private final MemoryConsumer consumer;
   private final TaskMemoryManager memoryManager;
-  private final Sorter<RecordPointerAndKeyPrefix, long[]> sorter;
+  private final Sorter<RecordPointerAndKeyPrefix, LongArray> sorter;
   private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
 
   /**
    * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
    * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
    */
-  private long[] array;
+  private LongArray array;
 
   /**
    * The position in the sort buffer where new records can be inserted.
@@ -78,22 +81,33 @@ public final class UnsafeInMemorySorter {
   private int pos = 0;
 
   public UnsafeInMemorySorter(
+    final MemoryConsumer consumer,
     final TaskMemoryManager memoryManager,
     final RecordComparator recordComparator,
     final PrefixComparator prefixComparator,
     int initialSize) {
-    this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]);
+    this(consumer, memoryManager, recordComparator, prefixComparator,
+      consumer.allocateArray(initialSize * 2));
   }
 
   public UnsafeInMemorySorter(
+    final MemoryConsumer consumer,
       final TaskMemoryManager memoryManager,
       final RecordComparator recordComparator,
       final PrefixComparator prefixComparator,
-      long[] array) {
-    this.array = array;
+      LongArray array) {
+    this.consumer = consumer;
     this.memoryManager = memoryManager;
     this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
     this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+    this.array = array;
+  }
+
+  /**
+   * Free the memory used by pointer array.
+   */
+  public void free() {
+    consumer.freeArray(array);
   }
 
   public void reset() {
@@ -107,26 +121,26 @@ public final class UnsafeInMemorySorter {
     return pos / 2;
   }
 
-  private int newLength() {
-    return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
-  }
-
-  public long getMemoryToExpand() {
-    return (long) (newLength() - array.length) * 8L;
-  }
-
   public long getMemoryUsage() {
-    return array.length * 8L;
+    return array.size() * 8L;
   }
 
   public boolean hasSpaceForAnotherRecord() {
-    return pos + 2 <= array.length;
+    return pos + 2 <= array.size();
   }
 
-  public void expandPointerArray() {
-    final long[] oldArray = array;
-    array = new long[newLength()];
-    System.arraycopy(oldArray, 0, array, 0, oldArray.length);
+  public void expandPointerArray(LongArray newArray) {
+    if (newArray.size() < array.size()) {
+      throw new OutOfMemoryError("Not enough memory to grow pointer array");
+    }
+    Platform.copyMemory(
+      array.getBaseObject(),
+      array.getBaseOffset(),
+      newArray.getBaseObject(),
+      newArray.getBaseOffset(),
+      array.size() * 8L);
+    consumer.freeArray(array);
+    array = newArray;
   }
 
   /**
@@ -138,11 +152,11 @@ public final class UnsafeInMemorySorter {
    */
   public void insertRecord(long recordPointer, long keyPrefix) {
     if (!hasSpaceForAnotherRecord()) {
-      expandPointerArray();
+      expandPointerArray(consumer.allocateArray(array.size() * 2));
     }
-    array[pos] = recordPointer;
+    array.set(pos, recordPointer);
     pos++;
-    array[pos] = keyPrefix;
+    array.set(pos, keyPrefix);
     pos++;
   }
 
@@ -150,7 +164,7 @@ public final class UnsafeInMemorySorter {
 
     private final TaskMemoryManager memoryManager;
     private final int sortBufferInsertPosition;
-    private final long[] sortBuffer;
+    private final LongArray sortBuffer;
     private int position = 0;
     private Object baseObject;
     private long baseOffset;
@@ -160,7 +174,7 @@ public final class UnsafeInMemorySorter {
     private SortedIterator(
         TaskMemoryManager memoryManager,
         int sortBufferInsertPosition,
-        long[] sortBuffer) {
+        LongArray sortBuffer) {
       this.memoryManager = memoryManager;
       this.sortBufferInsertPosition = sortBufferInsertPosition;
       this.sortBuffer = sortBuffer;
@@ -188,11 +202,11 @@ public final class UnsafeInMemorySorter {
     @Override
     public void loadNext() {
       // This pointer points to a 4-byte record length, followed by the record's bytes
-      final long recordPointer = sortBuffer[position];
+      final long recordPointer = sortBuffer.get(position);
       baseObject = memoryManager.getPage(recordPointer);
       baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4;  // Skip over record
length
       recordLength = Platform.getInt(baseObject, baseOffset - 4);
-      keyPrefix = sortBuffer[position + 1];
+      keyPrefix = sortBuffer.get(position + 1);
       position += 2;
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
index d09c728..d3137f5 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -17,6 +17,9 @@
 
 package org.apache.spark.util.collection.unsafe.sort;
 
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.collection.SortDataFormat;
 
 /**
@@ -26,14 +29,14 @@ import org.apache.spark.util.collection.SortDataFormat;
  * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record
at
  * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
  */
-final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefix, long[]>
{
+final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefix, LongArray>
{
 
   public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
 
   private UnsafeSortDataFormat() { }
 
   @Override
-  public RecordPointerAndKeyPrefix getKey(long[] data, int pos) {
+  public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) {
     // Since we re-use keys, this method shouldn't be called.
     throw new UnsupportedOperationException();
   }
@@ -44,37 +47,43 @@ final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefi
   }
 
   @Override
-  public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix
reuse) {
-    reuse.recordPointer = data[pos * 2];
-    reuse.keyPrefix = data[pos * 2 + 1];
+  public RecordPointerAndKeyPrefix getKey(LongArray data, int pos, RecordPointerAndKeyPrefix
reuse) {
+    reuse.recordPointer = data.get(pos * 2);
+    reuse.keyPrefix = data.get(pos * 2 + 1);
     return reuse;
   }
 
   @Override
-  public void swap(long[] data, int pos0, int pos1) {
-    long tempPointer = data[pos0 * 2];
-    long tempKeyPrefix = data[pos0 * 2 + 1];
-    data[pos0 * 2] = data[pos1 * 2];
-    data[pos0 * 2 + 1] = data[pos1 * 2 + 1];
-    data[pos1 * 2] = tempPointer;
-    data[pos1 * 2 + 1] = tempKeyPrefix;
+  public void swap(LongArray data, int pos0, int pos1) {
+    long tempPointer = data.get(pos0 * 2);
+    long tempKeyPrefix = data.get(pos0 * 2 + 1);
+    data.set(pos0 * 2, data.get(pos1 * 2));
+    data.set(pos0 * 2 + 1, data.get(pos1 * 2 + 1));
+    data.set(pos1 * 2, tempPointer);
+    data.set(pos1 * 2 + 1, tempKeyPrefix);
   }
 
   @Override
-  public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
-    dst[dstPos * 2] = src[srcPos * 2];
-    dst[dstPos * 2 + 1] = src[srcPos * 2 + 1];
+  public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) {
+    dst.set(dstPos * 2, src.get(srcPos * 2));
+    dst.set(dstPos * 2 + 1, src.get(srcPos * 2 + 1));
   }
 
   @Override
-  public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
-    System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2);
+  public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length)
{
+    Platform.copyMemory(
+      src.getBaseObject(),
+      src.getBaseOffset() + srcPos * 16,
+      dst.getBaseObject(),
+      dst.getBaseOffset() + dstPos * 16,
+      length * 16);
   }
 
   @Override
-  public long[] allocate(int length) {
+  public LongArray allocate(int length) {
     assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
-    return new long[length * 2];
+    // This is used as temporary buffer, it's fine to allocate from JVM heap.
+    return new LongArray(MemoryBlock.fromLongArray(new long[length * 2]));
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index dab7b05..c731317 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -17,8 +17,6 @@
 
 package org.apache.spark.memory;
 
-import java.io.IOException;
-
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -27,27 +25,6 @@ import org.apache.spark.unsafe.memory.MemoryBlock;
 
 public class TaskMemoryManagerSuite {
 
-  class TestMemoryConsumer extends MemoryConsumer {
-    TestMemoryConsumer(TaskMemoryManager memoryManager) {
-      super(memoryManager);
-    }
-
-    @Override
-    public long spill(long size, MemoryConsumer trigger) throws IOException {
-      long used = getUsed();
-      releaseMemory(used);
-      return used;
-    }
-
-    void use(long size) {
-      acquireMemory(size);
-    }
-
-    void free(long size) {
-      releaseMemory(size);
-    }
-  }
-
   @Test
   public void leakedPageMemoryIsDetected() {
     final TaskMemoryManager manager = new TaskMemoryManager(

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
new file mode 100644
index 0000000..8ae3642
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory;
+
+import java.io.IOException;
+
+public class TestMemoryConsumer extends MemoryConsumer {
+  public TestMemoryConsumer(TaskMemoryManager memoryManager) {
+    super(memoryManager);
+  }
+
+  @Override
+  public long spill(long size, MemoryConsumer trigger) throws IOException {
+    long used = getUsed();
+    free(used);
+    return used;
+  }
+
+  void use(long size) {
+    long got = taskMemoryManager.acquireExecutionMemory(size, this);
+    used += got;
+  }
+
+  void free(long size) {
+    used -= size;
+    taskMemoryManager.releaseExecutionMemory(size, this);
+  }
+}
+
+

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 2293b1b..faa5a86 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -25,13 +25,19 @@ import org.junit.Test;
 
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.SparkConf;
-import org.apache.spark.unsafe.Platform;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.TestMemoryConsumer;
 import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
 
 public class ShuffleInMemorySorterSuite {
 
+  final TestMemoryManager memoryManager =
+    new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
+  final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
+  final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager);
+
   private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength)
{
     final byte[] strBytes = new byte[strLength];
     Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength);
@@ -40,7 +46,7 @@ public class ShuffleInMemorySorterSuite {
 
   @Test
   public void testSortingEmptyInput() {
-    final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100);
+    final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 100);
     final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
     assert(!iter.hasNext());
   }
@@ -63,7 +69,7 @@ public class ShuffleInMemorySorterSuite {
       new TaskMemoryManager(new TestMemoryManager(conf), 0);
     final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
     final Object baseObject = dataPage.getBaseObject();
-    final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
+    final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4);
     final HashPartitioner hashPartitioner = new HashPartitioner(4);
 
     // Write the records into the data page and store pointers into the sorter
@@ -104,7 +110,7 @@ public class ShuffleInMemorySorterSuite {
 
   @Test
   public void testSortingManyNumbers() throws Exception {
-    ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
+    ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4);
     int[] numbersToSort = new int[128000];
     Random random = new Random(16);
     for (int i = 0; i < numbersToSort.length; i++) {

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index cfead0e..11c3a7b 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -390,7 +390,6 @@ public class UnsafeExternalSorterSuite {
       for (int i = 0; i < numRecordsPerPage * 10; i++) {
         insertNumber(sorter, i);
         newPeakMemory = sorter.getPeakMemoryUsedBytes();
-        // The first page is pre-allocated on instantiation
         if (i % numRecordsPerPage == 0) {
           // We allocated a new page for this record, so peak memory should change
           assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 642f658..a203a09 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -23,6 +23,7 @@ import org.junit.Test;
 
 import org.apache.spark.HashPartitioner;
 import org.apache.spark.SparkConf;
+import org.apache.spark.memory.TestMemoryConsumer;
 import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
@@ -44,9 +45,11 @@ public class UnsafeInMemorySorterSuite {
 
   @Test
   public void testSortingEmptyInput() {
-    final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
-      new TaskMemoryManager(
-        new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
+    final TaskMemoryManager memoryManager = new TaskMemoryManager(
+      new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
+    final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer,
+      memoryManager,
       mock(RecordComparator.class),
       mock(PrefixComparator.class),
       100);
@@ -69,6 +72,7 @@ public class UnsafeInMemorySorterSuite {
     };
     final TaskMemoryManager memoryManager = new TaskMemoryManager(
       new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+    final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
     final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
     final Object baseObject = dataPage.getBaseObject();
     // Write the records into the data page:
@@ -102,7 +106,7 @@ public class UnsafeInMemorySorterSuite {
         return (int) prefix1 - (int) prefix2;
       }
     };
-    UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator,
+    UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, recordComparator,
       prefixComparator, dataToSort.length);
     // Given a page of records, insert those records into the sorter one-by-one:
     position = dataPage.getBaseOffset();

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index e2898ef..8c9b9c8 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -85,8 +85,9 @@ public final class UnsafeKVExternalSorter {
     } else {
       // During spilling, the array in map will not be used, so we can borrow that and use
it
       // as the underline array for in-memory sorter (it's always large enough).
+      // Since we will not grow the array, it's fine to pass `null` as consumer.
       final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
-        taskMemoryManager, recordComparator, prefixComparator, map.getArray());
+        null, taskMemoryManager, recordComparator, prefixComparator, map.getArray());
 
       // We cannot use the destructive iterator here because we are reusing the existing
memory
       // pages in BytesToBytesMap to hold records during sorting.

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java
index 7410505..1a3cdff 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java
@@ -39,7 +39,6 @@ public final class LongArray {
   private final long length;
 
   public LongArray(MemoryBlock memory) {
-    assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")";
     assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements";
     this.memory = memory;
     this.baseObj = memory.getBaseObject();
@@ -51,6 +50,14 @@ public final class LongArray {
     return memory;
   }
 
+  public Object getBaseObject() {
+    return baseObj;
+  }
+
+  public long getBaseOffset() {
+    return baseOffset;
+  }
+
   /**
    * Returns the number of elements this array can hold.
    */
@@ -59,6 +66,15 @@ public final class LongArray {
   }
 
   /**
+   * Fill this all with 0L.
+   */
+  public void zeroOut() {
+    for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) {
+      Platform.putLong(baseObj, off, 0);
+    }
+  }
+
+  /**
    * Sets the value at position {@code index}.
    */
   public void set(int index, long value) {

http://git-wip-us.apache.org/repos/asf/spark/blob/1b43dd39/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java
----------------------------------------------------------------------
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java
index 5974cf9..fb8e53b 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java
@@ -34,5 +34,9 @@ public class LongArraySuite {
     Assert.assertEquals(2, arr.size());
     Assert.assertEquals(1L, arr.get(0));
     Assert.assertEquals(3L, arr.get(1));
+
+    arr.zeroOut();
+    Assert.assertEquals(0L, arr.get(0));
+    Assert.assertEquals(0L, arr.get(1));
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message