spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From joshro...@apache.org
Subject [2/2] spark git commit: [SPARK-9451] [SQL] Support entries larger than default page size in BytesToBytesMap & integrate with ShuffleMemoryManager
Date Sat, 01 Aug 2015 02:19:35 GMT
[SPARK-9451] [SQL] Support entries larger than default page size in BytesToBytesMap & integrate with ShuffleMemoryManager

This patch adds support for entries larger than the default page size in BytesToBytesMap.  These large rows are handled by allocating special overflow pages to hold individual entries.

In addition, this patch integrates BytesToBytesMap with the ShuffleMemoryManager:

- Move BytesToBytesMap from `unsafe` to `core` so that it can import `ShuffleMemoryManager`.
- Before allocating new data pages, ask the ShuffleMemoryManager to reserve the memory:
  - `putNewKey()` now returns a boolean to indicate whether the insert succeeded or failed due to a lack of memory.  The caller can use this value to respond to the memory pressure (e.g. by spilling).
- `UnsafeFixedWidthAggregationMap. getAggregationBuffer()` now returns `null` to signal failure due to a lack of memory.
- Updated all uses of these classes to handle these error conditions.
- Added new tests for allocating large records and for allocations which fail due to memory pressure.
- Extended the `afterAll()` test teardown methods to detect ShuffleMemoryManager leaks.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #7762 from JoshRosen/large-rows and squashes the following commits:

ae7bc56 [Josh Rosen] Fix compilation
82fc657 [Josh Rosen] Merge remote-tracking branch 'origin/master' into large-rows
34ab943 [Josh Rosen] Remove semi
31a525a [Josh Rosen] Integrate BytesToBytesMap with ShuffleMemoryManager.
626b33c [Josh Rosen] Move code to sql/core and spark/core packages so that ShuffleMemoryManager can be integrated
ec4484c [Josh Rosen] Move BytesToBytesMap from unsafe package to core.
642ed69 [Josh Rosen] Rename size to numElements
bea1152 [Josh Rosen] Add basic test.
2cd3570 [Josh Rosen] Remove accidental duplicated code
07ff9ef [Josh Rosen] Basic support for large rows in BytesToBytesMap.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8cb415a4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8cb415a4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8cb415a4

Branch: refs/heads/master
Commit: 8cb415a4b9bc1f82127ccce4a5579d433f4e8f83
Parents: f51fd6f
Author: Josh Rosen <joshrosen@databricks.com>
Authored: Fri Jul 31 19:19:27 2015 -0700
Committer: Josh Rosen <joshrosen@databricks.com>
Committed: Fri Jul 31 19:19:27 2015 -0700

----------------------------------------------------------------------
 .../spark/unsafe/map/BytesToBytesMap.java       | 709 +++++++++++++++++++
 .../spark/unsafe/map/HashMapGrowthStrategy.java |  41 ++
 .../spark/shuffle/ShuffleMemoryManager.scala    |   8 +-
 .../map/AbstractBytesToBytesMapSuite.java       | 499 +++++++++++++
 .../unsafe/map/BytesToBytesMapOffHeapSuite.java |  29 +
 .../unsafe/map/BytesToBytesMapOnHeapSuite.java  |  29 +
 .../UnsafeFixedWidthAggregationMap.java         | 223 ------
 .../UnsafeFixedWidthAggregationMapSuite.scala   | 132 ----
 .../UnsafeFixedWidthAggregationMap.java         | 234 ++++++
 .../sql/execution/GeneratedAggregate.scala      |   6 +
 .../sql/execution/joins/HashedRelation.scala    |  27 +-
 .../UnsafeFixedWidthAggregationMapSuite.scala   | 140 ++++
 .../spark/unsafe/map/BytesToBytesMap.java       | 643 -----------------
 .../spark/unsafe/map/HashMapGrowthStrategy.java |  41 --
 .../apache/spark/unsafe/memory/MemoryBlock.java |   2 +-
 .../spark/unsafe/memory/TaskMemoryManager.java  |   1 +
 .../map/AbstractBytesToBytesMapSuite.java       | 385 ----------
 .../unsafe/map/BytesToBytesMapOffHeapSuite.java |  29 -
 .../unsafe/map/BytesToBytesMapOnHeapSuite.java  |  29 -
 19 files changed, 1717 insertions(+), 1490 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8cb415a4/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
new file mode 100644
index 0000000..0f42950
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -0,0 +1,709 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import java.io.IOException;
+import java.lang.Override;
+import java.lang.UnsupportedOperationException;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.unsafe.*;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.bitset.BitSet;
+import org.apache.spark.unsafe.hash.Murmur3_x86_32;
+import org.apache.spark.unsafe.memory.*;
+
+/**
+ * An append-only hash map where keys and values are contiguous regions of bytes.
+ * <p>
+ * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers,
+ * which is guaranteed to exhaust the space.
+ * <p>
+ * The map can support up to 2^29 keys. If the key cardinality is higher than this, you should
+ * probably be using sorting instead of hashing for better cache locality.
+ * <p>
+ * This class is not thread safe.
+ */
+public final class BytesToBytesMap {
+
+  private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
+
+  private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);
+
+  private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
+
+  /**
+   * Special record length that is placed after the last record in a data page.
+   */
+  private static final int END_OF_PAGE_MARKER = -1;
+
+  private final TaskMemoryManager taskMemoryManager;
+
+  private final ShuffleMemoryManager shuffleMemoryManager;
+
+  /**
+   * A linked list for tracking all allocated data pages so that we can free all of our memory.
+   */
+  private final List<MemoryBlock> dataPages = new LinkedList<MemoryBlock>();
+
+  /**
+   * The data page that will be used to store keys and values for new hashtable entries. When this
+   * page becomes full, a new page will be allocated and this pointer will change to point to that
+   * new page.
+   */
+  private MemoryBlock currentDataPage = null;
+
+  /**
+   * Offset into `currentDataPage` that points to the location where new data can be inserted into
+   * the page. This does not incorporate the page's base offset.
+   */
+  private long pageCursor = 0;
+
+  /**
+   * The maximum number of keys that BytesToBytesMap supports. The hash table has to be
+   * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since
+   * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array
+   * entries per key, giving us a maximum capacity of (1 << 29).
+   */
+  @VisibleForTesting
+  static final int MAX_CAPACITY = (1 << 29);
+
+  // This choice of page table size and page size means that we can address up to 500 gigabytes
+  // of memory.
+
+  /**
+   * A single array to store the key and value.
+   *
+   * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i},
+   * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode.
+   */
+  private LongArray longArray;
+  // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode
+  // and exploit word-alignment to use fewer bits to hold the address.  This might let us store
+  // only one long per map entry, increasing the chance that this array will fit in cache at the
+  // expense of maybe performing more lookups if we have hash collisions.  Say that we stored only
+  // 27 bits of the hashcode and 37 bits of the address.  37 bits is enough to address 1 terabyte
+  // of RAM given word-alignment.  If we use 13 bits of this for our page table, that gives us a
+  // maximum page size of 2^24 * 8 = ~134 megabytes per page. This change will require us to store
+  // full base addresses in the page table for off-heap mode so that we can reconstruct the full
+  // absolute memory addresses.
+
+  /**
+   * A {@link BitSet} used to track location of the map where the key is set.
+   * Size of the bitset should be half of the size of the long array.
+   */
+  private BitSet bitset;
+
+  private final double loadFactor;
+
+  /**
+   * The size of the data pages that hold key and value data. Map entries cannot span multiple
+   * pages, so this limits the maximum entry size.
+   */
+  private final long pageSizeBytes;
+
+  /**
+   * Number of keys defined in the map.
+   */
+  private int numElements;
+
+  /**
+   * The map will be expanded once the number of keys exceeds this threshold.
+   */
+  private int growthThreshold;
+
+  /**
+   * Mask for truncating hashcodes so that they do not exceed the long array's size.
+   * This is a strength reduction optimization; we're essentially performing a modulus operation,
+   * but doing so with a bitmask because this is a power-of-2-sized hash map.
+   */
+  private int mask;
+
+  /**
+   * Return value of {@link BytesToBytesMap#lookup(Object, long, int)}.
+   */
+  private final Location loc;
+
+  private final boolean enablePerfMetrics;
+
+  private long timeSpentResizingNs = 0;
+
+  private long numProbes = 0;
+
+  private long numKeyLookups = 0;
+
+  private long numHashCollisions = 0;
+
+  public BytesToBytesMap(
+      TaskMemoryManager taskMemoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      int initialCapacity,
+      double loadFactor,
+      long pageSizeBytes,
+      boolean enablePerfMetrics) {
+    this.taskMemoryManager = taskMemoryManager;
+    this.shuffleMemoryManager = shuffleMemoryManager;
+    this.loadFactor = loadFactor;
+    this.loc = new Location();
+    this.pageSizeBytes = pageSizeBytes;
+    this.enablePerfMetrics = enablePerfMetrics;
+    if (initialCapacity <= 0) {
+      throw new IllegalArgumentException("Initial capacity must be greater than 0");
+    }
+    if (initialCapacity > MAX_CAPACITY) {
+      throw new IllegalArgumentException(
+        "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY);
+    }
+    if (pageSizeBytes > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) {
+      throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " +
+        TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
+    }
+    allocate(initialCapacity);
+  }
+
+  public BytesToBytesMap(
+      TaskMemoryManager taskMemoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      int initialCapacity,
+      long pageSizeBytes) {
+    this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
+  }
+
+  public BytesToBytesMap(
+      TaskMemoryManager taskMemoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      int initialCapacity,
+      long pageSizeBytes,
+      boolean enablePerfMetrics) {
+    this(
+      taskMemoryManager,
+      shuffleMemoryManager,
+      initialCapacity,
+      0.70,
+      pageSizeBytes,
+      enablePerfMetrics);
+  }
+
+  /**
+   * Returns the number of keys defined in the map.
+   */
+  public int numElements() { return numElements; }
+
+  private static final class BytesToBytesMapIterator implements Iterator<Location> {
+
+    private final int numRecords;
+    private final Iterator<MemoryBlock> dataPagesIterator;
+    private final Location loc;
+
+    private int currentRecordNumber = 0;
+    private Object pageBaseObject;
+    private long offsetInPage;
+
+    BytesToBytesMapIterator(int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc) {
+      this.numRecords = numRecords;
+      this.dataPagesIterator = dataPagesIterator;
+      this.loc = loc;
+      if (dataPagesIterator.hasNext()) {
+        advanceToNextPage();
+      }
+    }
+
+    private void advanceToNextPage() {
+      final MemoryBlock currentPage = dataPagesIterator.next();
+      pageBaseObject = currentPage.getBaseObject();
+      offsetInPage = currentPage.getBaseOffset();
+    }
+
+    @Override
+    public boolean hasNext() {
+      return currentRecordNumber != numRecords;
+    }
+
+    @Override
+    public Location next() {
+      int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+      if (keyLength == END_OF_PAGE_MARKER) {
+        advanceToNextPage();
+        keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+      }
+      loc.with(pageBaseObject, offsetInPage);
+      offsetInPage += 8 + 8 + keyLength + loc.getValueLength();
+      currentRecordNumber++;
+      return loc;
+    }
+
+    @Override
+    public void remove() {
+      throw new UnsupportedOperationException();
+    }
+  }
+
+  /**
+   * Returns an iterator for iterating over the entries of this map.
+   *
+   * For efficiency, all calls to `next()` will return the same {@link Location} object.
+   *
+   * If any other lookups or operations are performed on this map while iterating over it, including
+   * `lookup()`, the behavior of the returned iterator is undefined.
+   */
+  public Iterator<Location> iterator() {
+    return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc);
+  }
+
+  /**
+   * Looks up a key, and return a {@link Location} handle that can be used to test existence
+   * and read/write values.
+   *
+   * This function always return the same {@link Location} instance to avoid object allocation.
+   */
+  public Location lookup(
+      Object keyBaseObject,
+      long keyBaseOffset,
+      int keyRowLengthBytes) {
+    if (enablePerfMetrics) {
+      numKeyLookups++;
+    }
+    final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes);
+    int pos = hashcode & mask;
+    int step = 1;
+    while (true) {
+      if (enablePerfMetrics) {
+        numProbes++;
+      }
+      if (!bitset.isSet(pos)) {
+        // This is a new key.
+        return loc.with(pos, hashcode, false);
+      } else {
+        long stored = longArray.get(pos * 2 + 1);
+        if ((int) (stored) == hashcode) {
+          // Full hash code matches.  Let's compare the keys for equality.
+          loc.with(pos, hashcode, true);
+          if (loc.getKeyLength() == keyRowLengthBytes) {
+            final MemoryLocation keyAddress = loc.getKeyAddress();
+            final Object storedKeyBaseObject = keyAddress.getBaseObject();
+            final long storedKeyBaseOffset = keyAddress.getBaseOffset();
+            final boolean areEqual = ByteArrayMethods.arrayEquals(
+              keyBaseObject,
+              keyBaseOffset,
+              storedKeyBaseObject,
+              storedKeyBaseOffset,
+              keyRowLengthBytes
+            );
+            if (areEqual) {
+              return loc;
+            } else {
+              if (enablePerfMetrics) {
+                numHashCollisions++;
+              }
+            }
+          }
+        }
+      }
+      pos = (pos + step) & mask;
+      step++;
+    }
+  }
+
+  /**
+   * Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function.
+   */
+  public final class Location {
+    /** An index into the hash map's Long array */
+    private int pos;
+    /** True if this location points to a position where a key is defined, false otherwise */
+    private boolean isDefined;
+    /**
+     * The hashcode of the most recent key passed to
+     * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to
+     * avoid re-hashing the key when storing a value for that key.
+     */
+    private int keyHashcode;
+    private final MemoryLocation keyMemoryLocation = new MemoryLocation();
+    private final MemoryLocation valueMemoryLocation = new MemoryLocation();
+    private int keyLength;
+    private int valueLength;
+
+    private void updateAddressesAndSizes(long fullKeyAddress) {
+      updateAddressesAndSizes(
+        taskMemoryManager.getPage(fullKeyAddress),
+        taskMemoryManager.getOffsetInPage(fullKeyAddress));
+    }
+
+    private void updateAddressesAndSizes(Object page, long keyOffsetInPage) {
+        long position = keyOffsetInPage;
+        keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
+        position += 8; // word used to store the key size
+        keyMemoryLocation.setObjAndOffset(page, position);
+        position += keyLength;
+        valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
+        position += 8; // word used to store the key size
+        valueMemoryLocation.setObjAndOffset(page, position);
+    }
+
+    Location with(int pos, int keyHashcode, boolean isDefined) {
+      this.pos = pos;
+      this.isDefined = isDefined;
+      this.keyHashcode = keyHashcode;
+      if (isDefined) {
+        final long fullKeyAddress = longArray.get(pos * 2);
+        updateAddressesAndSizes(fullKeyAddress);
+      }
+      return this;
+    }
+
+    Location with(Object page, long keyOffsetInPage) {
+      this.isDefined = true;
+      updateAddressesAndSizes(page, keyOffsetInPage);
+      return this;
+    }
+
+    /**
+     * Returns true if the key is defined at this position, and false otherwise.
+     */
+    public boolean isDefined() {
+      return isDefined;
+    }
+
+    /**
+     * Returns the address of the key defined at this position.
+     * This points to the first byte of the key data.
+     * Unspecified behavior if the key is not defined.
+     * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
+     */
+    public MemoryLocation getKeyAddress() {
+      assert (isDefined);
+      return keyMemoryLocation;
+    }
+
+    /**
+     * Returns the length of the key defined at this position.
+     * Unspecified behavior if the key is not defined.
+     */
+    public int getKeyLength() {
+      assert (isDefined);
+      return keyLength;
+    }
+
+    /**
+     * Returns the address of the value defined at this position.
+     * This points to the first byte of the value data.
+     * Unspecified behavior if the key is not defined.
+     * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
+     */
+    public MemoryLocation getValueAddress() {
+      assert (isDefined);
+      return valueMemoryLocation;
+    }
+
+    /**
+     * Returns the length of the value defined at this position.
+     * Unspecified behavior if the key is not defined.
+     */
+    public int getValueLength() {
+      assert (isDefined);
+      return valueLength;
+    }
+
+    /**
+     * Store a new key and value. This method may only be called once for a given key; if you want
+     * to update the value associated with a key, then you can directly manipulate the bytes stored
+     * at the value address. The return value indicates whether the put succeeded or whether it
+     * failed because additional memory could not be acquired.
+     * <p>
+     * It is only valid to call this method immediately after calling `lookup()` using the same key.
+     * </p>
+     * <p>
+     * The key and value must be word-aligned (that is, their sizes must multiples of 8).
+     * </p>
+     * <p>
+     * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length`
+     * will return information on the data stored by this `putNewKey` call.
+     * </p>
+     * <p>
+     * As an example usage, here's the proper way to store a new key:
+     * </p>
+     * <pre>
+     *   Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
+     *   if (!loc.isDefined()) {
+     *     if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+     *       // handle failure to grow map (by spilling, for example)
+     *     }
+     *   }
+     * </pre>
+     * <p>
+     * Unspecified behavior if the key is not defined.
+     * </p>
+     *
+     * @return true if the put() was successful and false if the put() failed because memory could
+     *         not be acquired.
+     */
+    public boolean putNewKey(
+        Object keyBaseObject,
+        long keyBaseOffset,
+        int keyLengthBytes,
+        Object valueBaseObject,
+        long valueBaseOffset,
+        int valueLengthBytes) {
+      assert (!isDefined) : "Can only set value once for a key";
+      assert (keyLengthBytes % 8 == 0);
+      assert (valueLengthBytes % 8 == 0);
+      if (numElements == MAX_CAPACITY) {
+        throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
+      }
+
+      // Here, we'll copy the data into our data pages. Because we only store a relative offset from
+      // the key address instead of storing the absolute address of the value, the key and value
+      // must be stored in the same memory page.
+      // (8 byte key length) (key) (8 byte value length) (value)
+      final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
+
+      // --- Figure out where to insert the new record ---------------------------------------------
+
+      final MemoryBlock dataPage;
+      final Object dataPageBaseObject;
+      final long dataPageInsertOffset;
+      boolean useOverflowPage = requiredSize > pageSizeBytes - 8;
+      if (useOverflowPage) {
+        // The record is larger than the page size, so allocate a special overflow page just to hold
+        // that record.
+        final long memoryRequested = requiredSize + 8;
+        final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested);
+        if (memoryGranted != memoryRequested) {
+          shuffleMemoryManager.release(memoryGranted);
+          logger.debug("Failed to acquire {} bytes of memory", memoryRequested);
+          return false;
+        }
+        MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested);
+        dataPages.add(overflowPage);
+        dataPage = overflowPage;
+        dataPageBaseObject = overflowPage.getBaseObject();
+        dataPageInsertOffset = overflowPage.getBaseOffset();
+      } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
+        // The record can fit in a data page, but either we have not allocated any pages yet or
+        // the current page does not have enough space.
+        if (currentDataPage != null) {
+          // There wasn't enough space in the current page, so write an end-of-page marker:
+          final Object pageBaseObject = currentDataPage.getBaseObject();
+          final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
+          PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
+        }
+        final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+        if (memoryGranted != pageSizeBytes) {
+          shuffleMemoryManager.release(memoryGranted);
+          logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+          return false;
+        }
+        MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
+        dataPages.add(newPage);
+        pageCursor = 0;
+        currentDataPage = newPage;
+        dataPage = currentDataPage;
+        dataPageBaseObject = currentDataPage.getBaseObject();
+        dataPageInsertOffset = currentDataPage.getBaseOffset();
+      } else {
+        // There is enough space in the current data page.
+        dataPage = currentDataPage;
+        dataPageBaseObject = currentDataPage.getBaseObject();
+        dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor;
+      }
+
+      // --- Append the key and value data to the current data page --------------------------------
+
+      long insertCursor = dataPageInsertOffset;
+
+      // Compute all of our offsets up-front:
+      final long keySizeOffsetInPage = insertCursor;
+      insertCursor += 8; // word used to store the key size
+      final long keyDataOffsetInPage = insertCursor;
+      insertCursor += keyLengthBytes;
+      final long valueSizeOffsetInPage = insertCursor;
+      insertCursor += 8; // word used to store the value size
+      final long valueDataOffsetInPage = insertCursor;
+      insertCursor += valueLengthBytes; // word used to store the value size
+
+      // Copy the key
+      PlatformDependent.UNSAFE.putLong(dataPageBaseObject, keySizeOffsetInPage, keyLengthBytes);
+      PlatformDependent.copyMemory(
+        keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
+      // Copy the value
+      PlatformDependent.UNSAFE.putLong(dataPageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
+      PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
+        valueDataOffsetInPage, valueLengthBytes);
+
+      // --- Update bookeeping data structures -----------------------------------------------------
+
+      if (useOverflowPage) {
+        // Store the end-of-page marker at the end of the data page
+        PlatformDependent.UNSAFE.putLong(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
+      } else {
+        pageCursor += requiredSize;
+      }
+
+      numElements++;
+      bitset.set(pos);
+      final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
+        dataPage, keySizeOffsetInPage);
+      longArray.set(pos * 2, storedKeyAddress);
+      longArray.set(pos * 2 + 1, keyHashcode);
+      updateAddressesAndSizes(storedKeyAddress);
+      isDefined = true;
+      if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
+        growAndRehash();
+      }
+      return true;
+    }
+  }
+
+  /**
+   * Allocate new data structures for this map. When calling this outside of the constructor,
+   * make sure to keep references to the old data structures so that you can free them.
+   *
+   * @param capacity the new map capacity
+   */
+  private void allocate(int capacity) {
+    assert (capacity >= 0);
+    // The capacity needs to be divisible by 64 so that our bit set can be sized properly
+    capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64);
+    assert (capacity <= MAX_CAPACITY);
+    longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
+    bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
+
+    this.growthThreshold = (int) (capacity * loadFactor);
+    this.mask = capacity - 1;
+  }
+
+  /**
+   * Free all allocated memory associated with this map, including the storage for keys and values
+   * as well as the hash map array itself.
+   *
+   * This method is idempotent.
+   */
+  public void free() {
+    longArray = null;
+    bitset = null;
+    Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
+    while (dataPagesIterator.hasNext()) {
+      MemoryBlock dataPage = dataPagesIterator.next();
+      dataPagesIterator.remove();
+      taskMemoryManager.freePage(dataPage);
+      shuffleMemoryManager.release(dataPage.size());
+    }
+    assert(dataPages.isEmpty());
+  }
+
+  /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
+  public long getTotalMemoryConsumption() {
+    long totalDataPagesSize = 0L;
+    for (MemoryBlock dataPage : dataPages) {
+      totalDataPagesSize += dataPage.size();
+    }
+    return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size();
+  }
+
+  /**
+   * Returns the total amount of time spent resizing this map (in nanoseconds).
+   */
+  public long getTimeSpentResizingNs() {
+    if (!enablePerfMetrics) {
+      throw new IllegalStateException();
+    }
+    return timeSpentResizingNs;
+  }
+
+
+  /**
+   * Returns the average number of probes per key lookup.
+   */
+  public double getAverageProbesPerLookup() {
+    if (!enablePerfMetrics) {
+      throw new IllegalStateException();
+    }
+    return (1.0 * numProbes) / numKeyLookups;
+  }
+
+  public long getNumHashCollisions() {
+    if (!enablePerfMetrics) {
+      throw new IllegalStateException();
+    }
+    return numHashCollisions;
+  }
+
+  @VisibleForTesting
+  int getNumDataPages() {
+    return dataPages.size();
+  }
+
+  /**
+   * Grows the size of the hash table and re-hash everything.
+   */
+  @VisibleForTesting
+  void growAndRehash() {
+    long resizeStartTime = -1;
+    if (enablePerfMetrics) {
+      resizeStartTime = System.nanoTime();
+    }
+    // Store references to the old data structures to be used when we re-hash
+    final LongArray oldLongArray = longArray;
+    final BitSet oldBitSet = bitset;
+    final int oldCapacity = (int) oldBitSet.capacity();
+
+    // Allocate the new data structures
+    allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
+
+    // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
+    for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
+      final long keyPointer = oldLongArray.get(pos * 2);
+      final int hashcode = (int) oldLongArray.get(pos * 2 + 1);
+      int newPos = hashcode & mask;
+      int step = 1;
+      boolean keepGoing = true;
+
+      // No need to check for equality here when we insert so this has one less if branch than
+      // the similar code path in addWithoutResize.
+      while (keepGoing) {
+        if (!bitset.isSet(newPos)) {
+          bitset.set(newPos);
+          longArray.set(newPos * 2, keyPointer);
+          longArray.set(newPos * 2 + 1, hashcode);
+          keepGoing = false;
+        } else {
+          newPos = (newPos + step) & mask;
+          step++;
+        }
+      }
+    }
+
+    if (enablePerfMetrics) {
+      timeSpentResizingNs += System.nanoTime() - resizeStartTime;
+    }
+  }
+
+  /** Returns the next number greater or equal num that is power of 2. */
+  private static long nextPowerOf2(long num) {
+    final long highBit = Long.highestOneBit(num);
+    return (highBit == num) ? num : highBit << 1;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8cb415a4/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
new file mode 100644
index 0000000..20654e4
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+/**
+ * Interface that defines how we can grow the size of a hash map when it is over a threshold.
+ */
+public interface HashMapGrowthStrategy {
+
+  int nextCapacity(int currentCapacity);
+
+  /**
+   * Double the size of the hash map every time.
+   */
+  HashMapGrowthStrategy DOUBLING = new Doubling();
+
+  class Doubling implements HashMapGrowthStrategy {
+    @Override
+    public int nextCapacity(int currentCapacity) {
+      assert (currentCapacity > 0);
+      // Guard against overflow
+      return (currentCapacity * 2 > 0) ? (currentCapacity * 2) : Integer.MAX_VALUE;
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8cb415a4/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index f038b72..00c1e07 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -85,7 +85,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
           return toGrant
         } else {
           logInfo(
-            s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
+            s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
           wait()
         }
       } else {
@@ -116,6 +116,12 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
     taskMemory.remove(taskAttemptId)
     notifyAll()  // Notify waiters who locked "this" in tryToAcquire that memory has been freed
   }
+
+  /** Returns the memory consumption, in bytes, for the current task */
+  def getMemoryConsumptionForThisTask(): Long = synchronized {
+    val taskAttemptId = currentTaskAttemptId()
+    taskMemory.getOrElse(taskAttemptId, 0L)
+  }
 }
 
 private object ShuffleMemoryManager {

http://git-wip-us.apache.org/repos/asf/spark/blob/8cb415a4/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
new file mode 100644
index 0000000..60f483a
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -0,0 +1,499 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import java.lang.Exception;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import org.junit.*;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.mockito.AdditionalMatchers.geq;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.memory.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET;
+import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET;
+
+
+public abstract class AbstractBytesToBytesMapSuite {
+
+  private final Random rand = new Random(42);
+
+  private ShuffleMemoryManager shuffleMemoryManager;
+  private TaskMemoryManager taskMemoryManager;
+  private TaskMemoryManager sizeLimitedTaskMemoryManager;
+  private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
+
+  @Before
+  public void setup() {
+    shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE);
+    taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
+    // Mocked memory manager for tests that check the maximum array size, since actually allocating
+    // such large arrays will cause us to run out of memory in our tests.
+    sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class);
+    when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer(
+      new Answer<MemoryBlock>() {
+        @Override
+        public MemoryBlock answer(InvocationOnMock invocation) throws Throwable {
+          if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) {
+            throw new OutOfMemoryError("Requested array size exceeds VM limit");
+          }
+          return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]);
+        }
+      }
+    );
+  }
+
+  @After
+  public void tearDown() {
+    if (taskMemoryManager != null) {
+      long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
+      Assert.assertEquals(0, taskMemoryManager.cleanUpAllAllocatedMemory());
+      Assert.assertEquals(0, leakedShuffleMemory);
+      shuffleMemoryManager = null;
+      taskMemoryManager = null;
+    }
+  }
+
+  protected abstract MemoryAllocator getMemoryAllocator();
+
+  private static byte[] getByteArray(MemoryLocation loc, int size) {
+    final byte[] arr = new byte[size];
+    PlatformDependent.copyMemory(
+      loc.getBaseObject(),
+      loc.getBaseOffset(),
+      arr,
+      BYTE_ARRAY_OFFSET,
+      size
+    );
+    return arr;
+  }
+
+  private byte[] getRandomByteArray(int numWords) {
+    Assert.assertTrue(numWords >= 0);
+    final int lengthInBytes = numWords * 8;
+    final byte[] bytes = new byte[lengthInBytes];
+    rand.nextBytes(bytes);
+    return bytes;
+  }
+
+  /**
+   * Fast equality checking for byte arrays, since these comparisons are a bottleneck
+   * in our stress tests.
+   */
+  private static boolean arrayEquals(
+      byte[] expected,
+      MemoryLocation actualAddr,
+      long actualLengthBytes) {
+    return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals(
+      expected,
+      BYTE_ARRAY_OFFSET,
+      actualAddr.getBaseObject(),
+      actualAddr.getBaseOffset(),
+      expected.length
+    );
+  }
+
+  @Test
+  public void emptyMap() {
+    BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
+    try {
+      Assert.assertEquals(0, map.numElements());
+      final int keyLengthInWords = 10;
+      final int keyLengthInBytes = keyLengthInWords * 8;
+      final byte[] key = getRandomByteArray(keyLengthInWords);
+      Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined());
+      Assert.assertFalse(map.iterator().hasNext());
+    } finally {
+      map.free();
+    }
+  }
+
+  @Test
+  public void setAndRetrieveAKey() {
+    BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
+    final int recordLengthWords = 10;
+    final int recordLengthBytes = recordLengthWords * 8;
+    final byte[] keyData = getRandomByteArray(recordLengthWords);
+    final byte[] valueData = getRandomByteArray(recordLengthWords);
+    try {
+      final BytesToBytesMap.Location loc =
+        map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes);
+      Assert.assertFalse(loc.isDefined());
+      Assert.assertTrue(loc.putNewKey(
+        keyData,
+        BYTE_ARRAY_OFFSET,
+        recordLengthBytes,
+        valueData,
+        BYTE_ARRAY_OFFSET,
+        recordLengthBytes
+      ));
+      // After storing the key and value, the other location methods should return results that
+      // reflect the result of this store without us having to call lookup() again on the same key.
+      Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
+      Assert.assertEquals(recordLengthBytes, loc.getValueLength());
+      Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
+      Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+
+      // After calling lookup() the location should still point to the correct data.
+      Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined());
+      Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
+      Assert.assertEquals(recordLengthBytes, loc.getValueLength());
+      Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
+      Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+
+      try {
+        Assert.assertTrue(loc.putNewKey(
+          keyData,
+          BYTE_ARRAY_OFFSET,
+          recordLengthBytes,
+          valueData,
+          BYTE_ARRAY_OFFSET,
+          recordLengthBytes
+        ));
+        Assert.fail("Should not be able to set a new value for a key");
+      } catch (AssertionError e) {
+        // Expected exception; do nothing.
+      }
+    } finally {
+      map.free();
+    }
+  }
+
+  @Test
+  public void iteratorTest() throws Exception {
+    final int size = 4096;
+    BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES);
+    try {
+      for (long i = 0; i < size; i++) {
+        final long[] value = new long[] { i };
+        final BytesToBytesMap.Location loc =
+          map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8);
+        Assert.assertFalse(loc.isDefined());
+        // Ensure that we store some zero-length keys
+        if (i % 5 == 0) {
+          Assert.assertTrue(loc.putNewKey(
+            null,
+            PlatformDependent.LONG_ARRAY_OFFSET,
+            0,
+            value,
+            PlatformDependent.LONG_ARRAY_OFFSET,
+            8
+          ));
+        } else {
+          Assert.assertTrue(loc.putNewKey(
+            value,
+            PlatformDependent.LONG_ARRAY_OFFSET,
+            8,
+            value,
+            PlatformDependent.LONG_ARRAY_OFFSET,
+            8
+          ));
+        }
+      }
+      final java.util.BitSet valuesSeen = new java.util.BitSet(size);
+      final Iterator<BytesToBytesMap.Location> iter = map.iterator();
+      while (iter.hasNext()) {
+        final BytesToBytesMap.Location loc = iter.next();
+        Assert.assertTrue(loc.isDefined());
+        final MemoryLocation keyAddress = loc.getKeyAddress();
+        final MemoryLocation valueAddress = loc.getValueAddress();
+        final long value = PlatformDependent.UNSAFE.getLong(
+          valueAddress.getBaseObject(), valueAddress.getBaseOffset());
+        final long keyLength = loc.getKeyLength();
+        if (keyLength == 0) {
+          Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
+        } else {
+        final long key = PlatformDependent.UNSAFE.getLong(
+          keyAddress.getBaseObject(), keyAddress.getBaseOffset());
+          Assert.assertEquals(value, key);
+        }
+        valuesSeen.set((int) value);
+      }
+      Assert.assertEquals(size, valuesSeen.cardinality());
+    } finally {
+      map.free();
+    }
+  }
+
+  @Test
+  public void iteratingOverDataPagesWithWastedSpace() throws Exception {
+    final int NUM_ENTRIES = 1000 * 1000;
+    final int KEY_LENGTH = 16;
+    final int VALUE_LENGTH = 40;
+    final BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
+    // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte
+    // pages won't be evenly-divisible by records of this size, which will cause us to waste some
+    // space at the end of the page. This is necessary in order for us to take the end-of-record
+    // handling branch in iterator().
+    try {
+      for (int i = 0; i < NUM_ENTRIES; i++) {
+        final long[] key = new long[] { i, i };  // 2 * 8 = 16 bytes
+        final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes
+        final BytesToBytesMap.Location loc = map.lookup(
+          key,
+          LONG_ARRAY_OFFSET,
+          KEY_LENGTH
+        );
+        Assert.assertFalse(loc.isDefined());
+        Assert.assertTrue(loc.putNewKey(
+          key,
+          LONG_ARRAY_OFFSET,
+          KEY_LENGTH,
+          value,
+          LONG_ARRAY_OFFSET,
+          VALUE_LENGTH
+        ));
+      }
+      Assert.assertEquals(2, map.getNumDataPages());
+
+      final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES);
+      final Iterator<BytesToBytesMap.Location> iter = map.iterator();
+      final long key[] = new long[KEY_LENGTH / 8];
+      final long value[] = new long[VALUE_LENGTH / 8];
+      while (iter.hasNext()) {
+        final BytesToBytesMap.Location loc = iter.next();
+        Assert.assertTrue(loc.isDefined());
+        Assert.assertEquals(KEY_LENGTH, loc.getKeyLength());
+        Assert.assertEquals(VALUE_LENGTH, loc.getValueLength());
+        PlatformDependent.copyMemory(
+          loc.getKeyAddress().getBaseObject(),
+          loc.getKeyAddress().getBaseOffset(),
+          key,
+          LONG_ARRAY_OFFSET,
+          KEY_LENGTH
+        );
+        PlatformDependent.copyMemory(
+          loc.getValueAddress().getBaseObject(),
+          loc.getValueAddress().getBaseOffset(),
+          value,
+          LONG_ARRAY_OFFSET,
+          VALUE_LENGTH
+        );
+        for (long j : key) {
+          Assert.assertEquals(key[0], j);
+        }
+        for (long j : value) {
+          Assert.assertEquals(key[0], j);
+        }
+        valuesSeen.set((int) key[0]);
+      }
+      Assert.assertEquals(NUM_ENTRIES, valuesSeen.cardinality());
+    } finally {
+      map.free();
+    }
+  }
+
+  @Test
+  public void randomizedStressTest() {
+    final int size = 65536;
+    // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
+    // into ByteBuffers in order to use them as keys here.
+    final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
+    final BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES);
+
+    try {
+      // Fill the map to 90% full so that we can trigger probing
+      for (int i = 0; i < size * 0.9; i++) {
+        final byte[] key = getRandomByteArray(rand.nextInt(256) + 1);
+        final byte[] value = getRandomByteArray(rand.nextInt(512) + 1);
+        if (!expected.containsKey(ByteBuffer.wrap(key))) {
+          expected.put(ByteBuffer.wrap(key), value);
+          final BytesToBytesMap.Location loc = map.lookup(
+            key,
+            BYTE_ARRAY_OFFSET,
+            key.length
+          );
+          Assert.assertFalse(loc.isDefined());
+          Assert.assertTrue(loc.putNewKey(
+            key,
+            BYTE_ARRAY_OFFSET,
+            key.length,
+            value,
+            BYTE_ARRAY_OFFSET,
+            value.length
+          ));
+          // After calling putNewKey, the following should be true, even before calling
+          // lookup():
+          Assert.assertTrue(loc.isDefined());
+          Assert.assertEquals(key.length, loc.getKeyLength());
+          Assert.assertEquals(value.length, loc.getValueLength());
+          Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
+          Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+        }
+      }
+
+      for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
+        final byte[] key = entry.getKey().array();
+        final byte[] value = entry.getValue();
+        final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length);
+        Assert.assertTrue(loc.isDefined());
+        Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
+        Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+      }
+    } finally {
+      map.free();
+    }
+  }
+
+  @Test
+  public void randomizedTestWithRecordsLargerThanPageSize() {
+    final long pageSizeBytes = 128;
+    final BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes);
+    // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
+    // into ByteBuffers in order to use them as keys here.
+    final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
+    try {
+      for (int i = 0; i < 1000; i++) {
+        final byte[] key = getRandomByteArray(rand.nextInt(128));
+        final byte[] value = getRandomByteArray(rand.nextInt(128));
+        if (!expected.containsKey(ByteBuffer.wrap(key))) {
+          expected.put(ByteBuffer.wrap(key), value);
+          final BytesToBytesMap.Location loc = map.lookup(
+            key,
+            BYTE_ARRAY_OFFSET,
+            key.length
+          );
+          Assert.assertFalse(loc.isDefined());
+          Assert.assertTrue(loc.putNewKey(
+            key,
+            BYTE_ARRAY_OFFSET,
+            key.length,
+            value,
+            BYTE_ARRAY_OFFSET,
+            value.length
+          ));
+          // After calling putNewKey, the following should be true, even before calling
+          // lookup():
+          Assert.assertTrue(loc.isDefined());
+          Assert.assertEquals(key.length, loc.getKeyLength());
+          Assert.assertEquals(value.length, loc.getValueLength());
+          Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
+          Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+        }
+      }
+      for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
+        final byte[] key = entry.getKey().array();
+        final byte[] value = entry.getValue();
+        final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length);
+        Assert.assertTrue(loc.isDefined());
+        Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
+        Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+      }
+    } finally {
+      map.free();
+    }
+  }
+
+  @Test
+  public void failureToAllocateFirstPage() {
+    shuffleMemoryManager = new ShuffleMemoryManager(1024);
+    BytesToBytesMap map =
+      new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+    try {
+      final long[] emptyArray = new long[0];
+      final BytesToBytesMap.Location loc =
+        map.lookup(emptyArray, PlatformDependent.LONG_ARRAY_OFFSET, 0);
+      Assert.assertFalse(loc.isDefined());
+      Assert.assertFalse(loc.putNewKey(
+        emptyArray, LONG_ARRAY_OFFSET, 0,
+        emptyArray, LONG_ARRAY_OFFSET, 0
+      ));
+    } finally {
+      map.free();
+    }
+  }
+
+
+  @Test
+  public void failureToGrow() {
+    shuffleMemoryManager = new ShuffleMemoryManager(1024 * 10);
+    BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024);
+    try {
+      boolean success = true;
+      int i;
+      for (i = 0; i < 1024; i++) {
+        final long[] arr = new long[]{i};
+        final BytesToBytesMap.Location loc = map.lookup(arr, PlatformDependent.LONG_ARRAY_OFFSET, 8);
+        success = loc.putNewKey(arr, LONG_ARRAY_OFFSET, 8, arr, LONG_ARRAY_OFFSET, 8);
+        if (!success) {
+          break;
+        }
+      }
+      Assert.assertThat(i, greaterThan(0));
+      Assert.assertFalse(success);
+    } finally {
+      map.free();
+    }
+  }
+
+  @Test
+  public void initialCapacityBoundsChecking() {
+    try {
+      new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES);
+      Assert.fail("Expected IllegalArgumentException to be thrown");
+    } catch (IllegalArgumentException e) {
+      // expected exception
+    }
+
+    try {
+      new BytesToBytesMap(
+        sizeLimitedTaskMemoryManager,
+        shuffleMemoryManager,
+        BytesToBytesMap.MAX_CAPACITY + 1,
+        PAGE_SIZE_BYTES);
+      Assert.fail("Expected IllegalArgumentException to be thrown");
+    } catch (IllegalArgumentException e) {
+      // expected exception
+    }
+
+    // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
+    // Can allocate _at_ the max capacity
+    //    BytesToBytesMap map = new BytesToBytesMap(
+    //      sizeLimitedTaskMemoryManager,
+    //      shuffleMemoryManager,
+    //      BytesToBytesMap.MAX_CAPACITY,
+    //      PAGE_SIZE_BYTES);
+    //    map.free();
+  }
+
+  // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
+  @Ignore
+  public void resizingLargeMap() {
+    // As long as a map's capacity is below the max, we should be able to resize up to the max
+    BytesToBytesMap map = new BytesToBytesMap(
+      sizeLimitedTaskMemoryManager,
+      shuffleMemoryManager,
+      BytesToBytesMap.MAX_CAPACITY - 64,
+      PAGE_SIZE_BYTES);
+    map.growAndRehash();
+    map.free();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8cb415a4/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
new file mode 100644
index 0000000..5a10de4
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+
+public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite {
+
+  @Override
+  protected MemoryAllocator getMemoryAllocator() {
+    return MemoryAllocator.UNSAFE;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8cb415a4/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
new file mode 100644
index 0000000..12cc9b2
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+
+public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite {
+
+  @Override
+  protected MemoryAllocator getMemoryAllocator() {
+    return MemoryAllocator.HEAP;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8cb415a4/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
deleted file mode 100644
index f3b4627..0000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ /dev/null
@@ -1,223 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions;
-
-import java.util.Iterator;
-
-import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.sql.types.DecimalType;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
-import org.apache.spark.unsafe.PlatformDependent;
-import org.apache.spark.unsafe.map.BytesToBytesMap;
-import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-
-/**
- * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
- *
- * This map supports a maximum of 2 billion keys.
- */
-public final class UnsafeFixedWidthAggregationMap {
-
-  /**
-   * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
-   * map, we copy this buffer and use it as the value.
-   */
-  private final byte[] emptyAggregationBuffer;
-
-  private final StructType aggregationBufferSchema;
-
-  private final StructType groupingKeySchema;
-
-  /**
-   * Encodes grouping keys as UnsafeRows.
-   */
-  private final UnsafeProjection groupingKeyProjection;
-
-  /**
-   * A hashmap which maps from opaque bytearray keys to bytearray values.
-   */
-  private final BytesToBytesMap map;
-
-  /**
-   * Re-used pointer to the current aggregation buffer
-   */
-  private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
-
-  private final boolean enablePerfMetrics;
-
-  /**
-   * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
-   *         schema, false otherwise.
-   */
-  public static boolean supportsAggregationBufferSchema(StructType schema) {
-    for (StructField field: schema.fields()) {
-      if (field.dataType() instanceof DecimalType) {
-        DecimalType dt = (DecimalType) field.dataType();
-        if (dt.precision() > Decimal.MAX_LONG_DIGITS()) {
-          return false;
-        }
-      } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  /**
-   * Create a new UnsafeFixedWidthAggregationMap.
-   *
-   * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
-   * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
-   * @param groupingKeySchema the schema of the grouping key, used for row conversion.
-   * @param memoryManager the memory manager used to allocate our Unsafe memory structures.
-   * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
-   * @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
-   * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
-   */
-  public UnsafeFixedWidthAggregationMap(
-      InternalRow emptyAggregationBuffer,
-      StructType aggregationBufferSchema,
-      StructType groupingKeySchema,
-      TaskMemoryManager memoryManager,
-      int initialCapacity,
-      long pageSizeBytes,
-      boolean enablePerfMetrics) {
-    this.aggregationBufferSchema = aggregationBufferSchema;
-    this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
-    this.groupingKeySchema = groupingKeySchema;
-    this.map =
-      new BytesToBytesMap(memoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
-    this.enablePerfMetrics = enablePerfMetrics;
-
-    // Initialize the buffer for aggregation value
-    final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
-    this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
-    assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 +
-      UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length()));
-  }
-
-  /**
-   * Return the aggregation buffer for the current group. For efficiency, all calls to this method
-   * return the same object.
-   */
-  public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
-    final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
-
-    // Probe our map using the serialized key
-    final BytesToBytesMap.Location loc = map.lookup(
-      unsafeGroupingKeyRow.getBaseObject(),
-      unsafeGroupingKeyRow.getBaseOffset(),
-      unsafeGroupingKeyRow.getSizeInBytes());
-    if (!loc.isDefined()) {
-      // This is the first time that we've seen this grouping key, so we'll insert a copy of the
-      // empty aggregation buffer into the map:
-      loc.putNewKey(
-        unsafeGroupingKeyRow.getBaseObject(),
-        unsafeGroupingKeyRow.getBaseOffset(),
-        unsafeGroupingKeyRow.getSizeInBytes(),
-        emptyAggregationBuffer,
-        PlatformDependent.BYTE_ARRAY_OFFSET,
-        emptyAggregationBuffer.length
-      );
-    }
-
-    // Reset the pointer to point to the value that we just stored or looked up:
-    final MemoryLocation address = loc.getValueAddress();
-    currentAggregationBuffer.pointTo(
-      address.getBaseObject(),
-      address.getBaseOffset(),
-      aggregationBufferSchema.length(),
-      loc.getValueLength()
-    );
-    return currentAggregationBuffer;
-  }
-
-  /**
-   * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}.
-   */
-  public static class MapEntry {
-    private MapEntry() { };
-    public final UnsafeRow key = new UnsafeRow();
-    public final UnsafeRow value = new UnsafeRow();
-  }
-
-  /**
-   * Returns an iterator over the keys and values in this map.
-   *
-   * For efficiency, each call returns the same object.
-   */
-  public Iterator<MapEntry> iterator() {
-    return new Iterator<MapEntry>() {
-
-      private final MapEntry entry = new MapEntry();
-      private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator();
-
-      @Override
-      public boolean hasNext() {
-        return mapLocationIterator.hasNext();
-      }
-
-      @Override
-      public MapEntry next() {
-        final BytesToBytesMap.Location loc = mapLocationIterator.next();
-        final MemoryLocation keyAddress = loc.getKeyAddress();
-        final MemoryLocation valueAddress = loc.getValueAddress();
-        entry.key.pointTo(
-          keyAddress.getBaseObject(),
-          keyAddress.getBaseOffset(),
-          groupingKeySchema.length(),
-          loc.getKeyLength()
-        );
-        entry.value.pointTo(
-          valueAddress.getBaseObject(),
-          valueAddress.getBaseOffset(),
-          aggregationBufferSchema.length(),
-          loc.getValueLength()
-        );
-        return entry;
-      }
-
-      @Override
-      public void remove() {
-        throw new UnsupportedOperationException();
-      }
-    };
-  }
-
-  /**
-   * Free the unsafe memory associated with this map.
-   */
-  public void free() {
-    map.free();
-  }
-
-  @SuppressWarnings("UseOfSystemOutOrSystemErr")
-  public void printPerfMetrics() {
-    if (!enablePerfMetrics) {
-      throw new IllegalStateException("Perf metrics not enabled");
-    }
-    System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup());
-    System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
-    System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
-    System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/8cb415a4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
deleted file mode 100644
index c6b4c72..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import scala.collection.JavaConverters._
-import scala.util.Random
-
-import org.scalatest.{BeforeAndAfterEach, Matchers}
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator}
-import org.apache.spark.unsafe.types.UTF8String
-
-
-class UnsafeFixedWidthAggregationMapSuite
-  extends SparkFunSuite
-  with Matchers
-  with BeforeAndAfterEach {
-
-  import UnsafeFixedWidthAggregationMap._
-
-  private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
-  private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
-  private def emptyAggregationBuffer: InternalRow = InternalRow(0)
-  private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
-
-  private var memoryManager: TaskMemoryManager = null
-
-  override def beforeEach(): Unit = {
-    memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
-  }
-
-  override def afterEach(): Unit = {
-    if (memoryManager != null) {
-      memoryManager.cleanUpAllAllocatedMemory()
-      memoryManager = null
-    }
-  }
-
-  test("supported schemas") {
-    assert(supportsAggregationBufferSchema(
-      StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil)))
-    assert(!supportsAggregationBufferSchema(
-      StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil)))
-    assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
-    assert(
-      !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
-  }
-
-  test("empty map") {
-    val map = new UnsafeFixedWidthAggregationMap(
-      emptyAggregationBuffer,
-      aggBufferSchema,
-      groupKeySchema,
-      memoryManager,
-      1024, // initial capacity,
-      PAGE_SIZE_BYTES,
-      false // disable perf metrics
-    )
-    assert(!map.iterator().hasNext)
-    map.free()
-  }
-
-  test("updating values for a single key") {
-    val map = new UnsafeFixedWidthAggregationMap(
-      emptyAggregationBuffer,
-      aggBufferSchema,
-      groupKeySchema,
-      memoryManager,
-      1024, // initial capacity
-      PAGE_SIZE_BYTES,
-      false // disable perf metrics
-    )
-    val groupKey = InternalRow(UTF8String.fromString("cats"))
-
-    // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
-    map.getAggregationBuffer(groupKey)
-    val iter = map.iterator()
-    val entry = iter.next()
-    assert(!iter.hasNext)
-    entry.key.getString(0) should be ("cats")
-    entry.value.getInt(0) should be (0)
-
-    // Modifications to rows retrieved from the map should update the values in the map
-    entry.value.setInt(0, 42)
-    map.getAggregationBuffer(groupKey).getInt(0) should be (42)
-
-    map.free()
-  }
-
-  test("inserting large random keys") {
-    val map = new UnsafeFixedWidthAggregationMap(
-      emptyAggregationBuffer,
-      aggBufferSchema,
-      groupKeySchema,
-      memoryManager,
-      128, // initial capacity
-      PAGE_SIZE_BYTES,
-      false // disable perf metrics
-    )
-    val rand = new Random(42)
-    val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet
-    groupKeys.foreach { keyString =>
-      map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
-    }
-    val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
-      entry.key.getString(0)
-    }.toSet
-    seenKeys.size should be (groupKeys.size)
-    seenKeys should be (groupKeys)
-
-    map.free()
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/8cb415a4/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
new file mode 100644
index 0000000..66012e3
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+/**
+ * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
+ *
+ * This map supports a maximum of 2 billion keys.
+ */
+public final class UnsafeFixedWidthAggregationMap {
+
+  /**
+   * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
+   * map, we copy this buffer and use it as the value.
+   */
+  private final byte[] emptyAggregationBuffer;
+
+  private final StructType aggregationBufferSchema;
+
+  private final StructType groupingKeySchema;
+
+  /**
+   * Encodes grouping keys as UnsafeRows.
+   */
+  private final UnsafeProjection groupingKeyProjection;
+
+  /**
+   * A hashmap which maps from opaque bytearray keys to bytearray values.
+   */
+  private final BytesToBytesMap map;
+
+  /**
+   * Re-used pointer to the current aggregation buffer
+   */
+  private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
+
+  private final boolean enablePerfMetrics;
+
+  /**
+   * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
+   *         schema, false otherwise.
+   */
+  public static boolean supportsAggregationBufferSchema(StructType schema) {
+    for (StructField field: schema.fields()) {
+      if (field.dataType() instanceof DecimalType) {
+        DecimalType dt = (DecimalType) field.dataType();
+        if (dt.precision() > Decimal.MAX_LONG_DIGITS()) {
+          return false;
+        }
+      } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  /**
+   * Create a new UnsafeFixedWidthAggregationMap.
+   *
+   * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
+   * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
+   * @param groupingKeySchema the schema of the grouping key, used for row conversion.
+   * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
+   * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with
+   *                             other tasks.
+   * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
+   * @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
+   * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
+   */
+  public UnsafeFixedWidthAggregationMap(
+      InternalRow emptyAggregationBuffer,
+      StructType aggregationBufferSchema,
+      StructType groupingKeySchema,
+      TaskMemoryManager taskMemoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      int initialCapacity,
+      long pageSizeBytes,
+      boolean enablePerfMetrics) {
+    this.aggregationBufferSchema = aggregationBufferSchema;
+    this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
+    this.groupingKeySchema = groupingKeySchema;
+    this.map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
+    this.enablePerfMetrics = enablePerfMetrics;
+
+    // Initialize the buffer for aggregation value
+    final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
+    this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
+    assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 +
+      UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length()));
+  }
+
+  /**
+   * Return the aggregation buffer for the current group. For efficiency, all calls to this method
+   * return the same object. If additional memory could not be allocated, then this method will
+   * signal an error by returning null.
+   */
+  public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
+    final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
+
+    // Probe our map using the serialized key
+    final BytesToBytesMap.Location loc = map.lookup(
+      unsafeGroupingKeyRow.getBaseObject(),
+      unsafeGroupingKeyRow.getBaseOffset(),
+      unsafeGroupingKeyRow.getSizeInBytes());
+    if (!loc.isDefined()) {
+      // This is the first time that we've seen this grouping key, so we'll insert a copy of the
+      // empty aggregation buffer into the map:
+      boolean putSucceeded = loc.putNewKey(
+        unsafeGroupingKeyRow.getBaseObject(),
+        unsafeGroupingKeyRow.getBaseOffset(),
+        unsafeGroupingKeyRow.getSizeInBytes(),
+        emptyAggregationBuffer,
+        PlatformDependent.BYTE_ARRAY_OFFSET,
+        emptyAggregationBuffer.length
+      );
+      if (!putSucceeded) {
+        return null;
+      }
+    }
+
+    // Reset the pointer to point to the value that we just stored or looked up:
+    final MemoryLocation address = loc.getValueAddress();
+    currentAggregationBuffer.pointTo(
+      address.getBaseObject(),
+      address.getBaseOffset(),
+      aggregationBufferSchema.length(),
+      loc.getValueLength()
+    );
+    return currentAggregationBuffer;
+  }
+
+  /**
+   * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}.
+   */
+  public static class MapEntry {
+    private MapEntry() { };
+    public final UnsafeRow key = new UnsafeRow();
+    public final UnsafeRow value = new UnsafeRow();
+  }
+
+  /**
+   * Returns an iterator over the keys and values in this map.
+   *
+   * For efficiency, each call returns the same object.
+   */
+  public Iterator<MapEntry> iterator() {
+    return new Iterator<MapEntry>() {
+
+      private final MapEntry entry = new MapEntry();
+      private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator();
+
+      @Override
+      public boolean hasNext() {
+        return mapLocationIterator.hasNext();
+      }
+
+      @Override
+      public MapEntry next() {
+        final BytesToBytesMap.Location loc = mapLocationIterator.next();
+        final MemoryLocation keyAddress = loc.getKeyAddress();
+        final MemoryLocation valueAddress = loc.getValueAddress();
+        entry.key.pointTo(
+          keyAddress.getBaseObject(),
+          keyAddress.getBaseOffset(),
+          groupingKeySchema.length(),
+          loc.getKeyLength()
+        );
+        entry.value.pointTo(
+          valueAddress.getBaseObject(),
+          valueAddress.getBaseOffset(),
+          aggregationBufferSchema.length(),
+          loc.getValueLength()
+        );
+        return entry;
+      }
+
+      @Override
+      public void remove() {
+        throw new UnsupportedOperationException();
+      }
+    };
+  }
+
+  /**
+   * Free the unsafe memory associated with this map.
+   */
+  public void free() {
+    map.free();
+  }
+
+  @SuppressWarnings("UseOfSystemOutOrSystemErr")
+  public void printPerfMetrics() {
+    if (!enablePerfMetrics) {
+      throw new IllegalStateException("Perf metrics not enabled");
+    }
+    System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup());
+    System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
+    System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
+    System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8cb415a4/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index d851eae..469de6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution
 
+import java.io.IOException
+
 import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
@@ -266,6 +268,7 @@ case class GeneratedAggregate(
           aggregationBufferSchema,
           groupKeySchema,
           TaskContext.get.taskMemoryManager(),
+          SparkEnv.get.shuffleMemoryManager,
           1024 * 16, // initial capacity
           pageSizeBytes,
           false // disable tracking of performance metrics
@@ -275,6 +278,9 @@ case class GeneratedAggregate(
           val currentRow: InternalRow = iter.next()
           val groupKey: InternalRow = groupProjection(currentRow)
           val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
+          if (aggregationBuffer == null) {
+            throw new IOException("Could not allocate memory to grow aggregation buffer")
+          }
           updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8cb415a4/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index f88a45f..cc8bbfd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.sql.execution.joins
 
-import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput}
 import java.nio.ByteOrder
 import java.util.{HashMap => JavaHashMap}
 
+import org.apache.spark.shuffle.ShuffleMemoryManager
 import org.apache.spark.{SparkConf, SparkEnv, TaskContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -28,6 +29,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer
 import org.apache.spark.unsafe.PlatformDependent
 import org.apache.spark.unsafe.map.BytesToBytesMap
 import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.CompactBuffer
 
 
@@ -217,7 +219,7 @@ private[joins] final class UnsafeHashedRelation(
     }
   }
 
-  override def writeExternal(out: ObjectOutput): Unit = {
+  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
     out.writeInt(hashTable.size())
 
     val iter = hashTable.entrySet().iterator()
@@ -256,16 +258,26 @@ private[joins] final class UnsafeHashedRelation(
     }
   }
 
-  override def readExternal(in: ObjectInput): Unit = {
+  override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
     val nKeys = in.readInt()
     // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
-    val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+    val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+
+    // Dummy shuffle memory manager which always grants all memory allocation requests.
+    // We use this because it doesn't make sense count shared broadcast variables' memory usage
+    // towards individual tasks' quotas. In the future, we should devise a better way of handling
+    // this.
+    val shuffleMemoryManager = new ShuffleMemoryManager(new SparkConf()) {
+      override def tryToAcquire(numBytes: Long): Long = numBytes
+      override def release(numBytes: Long): Unit = {}
+    }
 
     val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
       .getSizeAsBytes("spark.buffer.pageSize", "64m")
 
     binaryMap = new BytesToBytesMap(
-      memoryManager,
+      taskMemoryManager,
+      shuffleMemoryManager,
       nKeys * 2, // reduce hash collision
       pageSizeBytes)
 
@@ -287,8 +299,11 @@ private[joins] final class UnsafeHashedRelation(
       // put it into binary map
       val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize)
       assert(!loc.isDefined, "Duplicated key found!")
-      loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
+      val putSuceeded = loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
         valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize)
+      if (!putSuceeded) {
+        throw new IOException("Could not allocate memory to grow BytesToBytesMap")
+      }
       i += 1
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message