spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From joshro...@apache.org
Subject [2/2] spark git commit: [SPARK-8735] [SQL] Expose memory usage for shuffles, joins and aggregations
Date Mon, 03 Aug 2015 21:22:14 GMT
[SPARK-8735] [SQL] Expose memory usage for shuffles, joins and aggregations

This patch exposes the memory used by internal data structures on the SparkUI. This tracks memory used by all spilling operations and SQL operators backed by Tungsten, e.g. `BroadcastHashJoin`, `ExternalSort`, `GeneratedAggregate` etc. The metric exposed is "peak execution memory", which broadly refers to the peak in-memory sizes of each of these data structure.

A separate patch will extend this by linking the new information to the SQL operators themselves.

<img width="950" alt="screen shot 2015-07-29 at 7 43 17 pm" src="https://cloud.githubusercontent.com/assets/2133137/8974776/b90fc980-362a-11e5-9e2b-842da75b1641.png">
<img width="802" alt="screen shot 2015-07-29 at 7 43 05 pm" src="https://cloud.githubusercontent.com/assets/2133137/8974777/baa76492-362a-11e5-9b77-e364a6a6b64e.png">

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/7770)
<!-- Reviewable:end -->

Author: Andrew Or <andrew@databricks.com>

Closes #7770 from andrewor14/expose-memory-metrics and squashes the following commits:

9abecb9 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
f5b0d68 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
d7df332 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
8eefbc5 [Andrew Or] Fix non-failing tests
9de2a12 [Andrew Or] Fix tests due to another logical merge conflict
876bfa4 [Andrew Or] Fix failing test after logical merge conflict
361a359 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
40b4802 [Andrew Or] Fix style?
d0fef87 [Andrew Or] Fix tests?
b3b92f6 [Andrew Or] Address comments
0625d73 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
c00a197 [Andrew Or] Fix potential NPEs
10da1cd [Andrew Or] Fix compile
17f4c2d [Andrew Or] Fix compile?
a87b4d0 [Andrew Or] Fix compile?
d70874d [Andrew Or] Fix test compile + address comments
2840b7d [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
6aa2f7a [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
b889a68 [Andrew Or] Minor changes: comments, spacing, style
663a303 [Andrew Or] UnsafeShuffleWriter: update peak memory before close
d090a94 [Andrew Or] Fix style
2480d84 [Andrew Or] Expand test coverage
5f1235b [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
1ecf678 [Andrew Or] Minor changes: comments, style, unused imports
0b6926c [Andrew Or] Oops
111a05e [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
a7a39a5 [Andrew Or] Strengthen presence check for accumulator
a919eb7 [Andrew Or] Add tests for unsafe shuffle writer
23c845d [Andrew Or] Add tests for SQL operators
a757550 [Andrew Or] Address comments
b5c51c1 [Andrew Or] Re-enable test in JavaAPISuite
5107691 [Andrew Or] Add tests for internal accumulators
59231e4 [Andrew Or] Fix tests
9528d09 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
5b5e6f3 [Andrew Or] Add peak execution memory to summary table + tooltip
92b4b6b [Andrew Or] Display peak execution memory on the UI
eee5437 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
d9b9015 [Andrew Or] Track execution memory in unsafe shuffles
770ee54 [Andrew Or] Track execution memory in broadcast joins
9c605a4 [Andrew Or] Track execution memory in GeneratedAggregate
9e824f2 [Andrew Or] Add back execution memory tracking for *ExternalSort
4ef4cb1 [Andrew Or] Merge branch 'master' of github.com:apache/spark into expose-memory-metrics
e6c3e2f [Andrew Or] Move internal accumulators creation to Stage
a417592 [Andrew Or] Expose memory metrics in UnsafeExternalSorter
3c4f042 [Andrew Or] Track memory usage in ExternalAppendOnlyMap / ExternalSorter
bd7ab3f [Andrew Or] Add internal accumulators to TaskContext


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/702aa9d7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/702aa9d7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/702aa9d7

Branch: refs/heads/master
Commit: 702aa9d7fb16c98a50e046edfd76b8a7861d0391
Parents: e4765a4
Author: Andrew Or <andrew@databricks.com>
Authored: Mon Aug 3 14:22:07 2015 -0700
Committer: Josh Rosen <joshrosen@databricks.com>
Committed: Mon Aug 3 14:22:07 2015 -0700

----------------------------------------------------------------------
 .../unsafe/UnsafeShuffleExternalSorter.java     |  27 ++-
 .../shuffle/unsafe/UnsafeShuffleWriter.java     |  38 +++-
 .../spark/unsafe/map/BytesToBytesMap.java       |   8 +-
 .../unsafe/sort/UnsafeExternalSorter.java       |  29 ++-
 .../org/apache/spark/ui/static/webui.css        |   2 +-
 .../scala/org/apache/spark/Accumulators.scala   |  60 +++++-
 .../scala/org/apache/spark/Aggregator.scala     |  24 +--
 .../scala/org/apache/spark/TaskContext.scala    |  13 +-
 .../org/apache/spark/TaskContextImpl.scala      |   8 +
 .../org/apache/spark/rdd/CoGroupedRDD.scala     |   9 +-
 .../spark/scheduler/AccumulableInfo.scala       |   9 +-
 .../apache/spark/scheduler/DAGScheduler.scala   |  28 ++-
 .../org/apache/spark/scheduler/ResultTask.scala |   6 +-
 .../apache/spark/scheduler/ShuffleMapTask.scala |  10 +-
 .../org/apache/spark/scheduler/Stage.scala      |  16 ++
 .../scala/org/apache/spark/scheduler/Task.scala |  18 +-
 .../spark/shuffle/hash/HashShuffleReader.scala  |   8 +-
 .../scala/org/apache/spark/ui/ToolTips.scala    |   7 +
 .../org/apache/spark/ui/jobs/StagePage.scala    | 140 ++++++++++----
 .../spark/ui/jobs/TaskDetailsClassNames.scala   |   1 +
 .../util/collection/ExternalAppendOnlyMap.scala |  13 +-
 .../spark/util/collection/ExternalSorter.scala  |  20 +-
 .../java/org/apache/spark/JavaAPISuite.java     |   3 +-
 .../unsafe/UnsafeShuffleWriterSuite.java        |  54 ++++++
 .../map/AbstractBytesToBytesMapSuite.java       |  39 ++++
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |  46 +++++
 .../org/apache/spark/AccumulatorSuite.scala     | 193 ++++++++++++++++++-
 .../org/apache/spark/CacheManagerSuite.scala    |  10 +-
 .../org/apache/spark/rdd/PipedRDDSuite.scala    |   2 +-
 .../org/apache/spark/scheduler/FakeTask.scala   |   6 +-
 .../scheduler/NotSerializableFakeTask.scala     |   2 +-
 .../spark/scheduler/TaskContextSuite.scala      |   7 +-
 .../spark/scheduler/TaskSetManagerSuite.scala   |   2 +-
 .../shuffle/hash/HashShuffleReaderSuite.scala   |   2 +-
 .../ShuffleBlockFetcherIteratorSuite.scala      |   8 +-
 .../org/apache/spark/ui/StagePageSuite.scala    |  76 ++++++++
 .../collection/ExternalAppendOnlyMapSuite.scala |  15 ++
 .../util/collection/ExternalSorterSuite.scala   |  14 +-
 .../sql/execution/UnsafeExternalRowSorter.java  |   7 +
 .../UnsafeFixedWidthAggregationMap.java         |   8 +
 .../sql/execution/GeneratedAggregate.scala      |  11 +-
 .../sql/execution/joins/BroadcastHashJoin.scala |  10 +-
 .../joins/BroadcastHashOuterJoin.scala          |   8 +
 .../joins/BroadcastLeftSemiJoinHash.scala       |  10 +-
 .../sql/execution/joins/HashedRelation.scala    |  22 ++-
 .../org/apache/spark/sql/execution/sort.scala   |  12 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  60 ++++--
 .../spark/sql/execution/TungstenSortSuite.scala |  12 ++
 .../UnsafeFixedWidthAggregationMapSuite.scala   |   3 +-
 .../execution/UnsafeKVExternalSorterSuite.scala |   3 +-
 .../execution/joins/BroadcastJoinSuite.scala    |  94 +++++++++
 51 files changed, 1070 insertions(+), 163 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
index 1aa6ba4..bf4eaa5 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.unsafe;
 import java.io.File;
 import java.io.IOException;
 import java.util.LinkedList;
+import javax.annotation.Nullable;
 
 import scala.Tuple2;
 
@@ -86,9 +87,12 @@ final class UnsafeShuffleExternalSorter {
 
   private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
 
+  /** Peak memory used by this sorter so far, in bytes. **/
+  private long peakMemoryUsedBytes;
+
   // These variables are reset after spilling:
-  private UnsafeShuffleInMemorySorter sorter;
-  private MemoryBlock currentPage = null;
+  @Nullable private UnsafeShuffleInMemorySorter sorter;
+  @Nullable private MemoryBlock currentPage = null;
   private long currentPagePosition = -1;
   private long freeSpaceInCurrentPage = 0;
 
@@ -106,6 +110,7 @@ final class UnsafeShuffleExternalSorter {
     this.blockManager = blockManager;
     this.taskContext = taskContext;
     this.initialSize = initialSize;
+    this.peakMemoryUsedBytes = initialSize;
     this.numPartitions = numPartitions;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
     this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
@@ -279,10 +284,26 @@ final class UnsafeShuffleExternalSorter {
     for (MemoryBlock page : allocatedPages) {
       totalPageSize += page.size();
     }
-    return sorter.getMemoryUsage() + totalPageSize;
+    return ((sorter == null) ? 0 : sorter.getMemoryUsage()) + totalPageSize;
+  }
+
+  private void updatePeakMemoryUsed() {
+    long mem = getMemoryUsage();
+    if (mem > peakMemoryUsedBytes) {
+      peakMemoryUsedBytes = mem;
+    }
+  }
+
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  long getPeakMemoryUsedBytes() {
+    updatePeakMemoryUsed();
+    return peakMemoryUsedBytes;
   }
 
   private long freeMemory() {
+    updatePeakMemoryUsed();
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
       memoryManager.freePage(block);

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
index d47d6fc..6e2eeb3 100644
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -27,6 +27,7 @@ import scala.Product2;
 import scala.collection.JavaConversions;
 import scala.reflect.ClassTag;
 import scala.reflect.ClassTag$;
+import scala.collection.immutable.Map;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.io.ByteStreams;
@@ -78,8 +79,9 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   private final SparkConf sparkConf;
   private final boolean transferToEnabled;
 
-  private MapStatus mapStatus = null;
-  private UnsafeShuffleExternalSorter sorter = null;
+  @Nullable private MapStatus mapStatus;
+  @Nullable private UnsafeShuffleExternalSorter sorter;
+  private long peakMemoryUsedBytes = 0;
 
   /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
   private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
@@ -131,9 +133,28 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   @VisibleForTesting
   public int maxRecordSizeBytes() {
+    assert(sorter != null);
     return sorter.maxRecordSizeBytes;
   }
 
+  private void updatePeakMemoryUsed() {
+    // sorter can be null if this writer is closed
+    if (sorter != null) {
+      long mem = sorter.getPeakMemoryUsedBytes();
+      if (mem > peakMemoryUsedBytes) {
+        peakMemoryUsedBytes = mem;
+      }
+    }
+  }
+
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  public long getPeakMemoryUsedBytes() {
+    updatePeakMemoryUsed();
+    return peakMemoryUsedBytes;
+  }
+
   /**
    * This convenience method should only be called in test code.
    */
@@ -144,7 +165,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   @Override
   public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
-    // Keep track of success so we know if we ecountered an exception
+    // Keep track of success so we know if we encountered an exception
     // We do this rather than a standard try/catch/re-throw to handle
     // generic throwables.
     boolean success = false;
@@ -189,6 +210,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   @VisibleForTesting
   void closeAndWriteOutput() throws IOException {
+    assert(sorter != null);
+    updatePeakMemoryUsed();
     serBuffer = null;
     serOutputStream = null;
     final SpillInfo[] spills = sorter.closeAndGetSpills();
@@ -209,6 +232,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   @VisibleForTesting
   void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
+    assert(sorter != null);
     final K key = record._1();
     final int partitionId = partitioner.getPartition(key);
     serBuffer.reset();
@@ -431,6 +455,14 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   @Override
   public Option<MapStatus> stop(boolean success) {
     try {
+      // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite)
+      Map<String, Accumulator<Object>> internalAccumulators =
+        taskContext.internalMetricsToAccumulators();
+      if (internalAccumulators != null) {
+        internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY())
+          .add(getPeakMemoryUsedBytes());
+      }
+
       if (stopping) {
         return Option.apply(null);
       } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 01a6608..2034743 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -505,7 +505,7 @@ public final class BytesToBytesMap {
       // Here, we'll copy the data into our data pages. Because we only store a relative offset from
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
-      // (8 byte key length) (key) (8 byte value length) (value)
+      // (8 byte key length) (key) (value)
       final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
 
       // --- Figure out where to insert the new record ---------------------------------------------
@@ -655,7 +655,10 @@ public final class BytesToBytesMap {
     return pageSizeBytes;
   }
 
-  /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
+  /**
+   * Returns the total amount of memory, in bytes, consumed by this map's managed structures.
+   * Note that this is also the peak memory used by this map, since the map is append-only.
+   */
   public long getTotalMemoryConsumption() {
     long totalDataPagesSize = 0L;
     for (MemoryBlock dataPage : dataPages) {
@@ -674,7 +677,6 @@ public final class BytesToBytesMap {
     return timeSpentResizingNs;
   }
 
-
   /**
    * Returns the average number of probes per key lookup.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index b984301..bf5f965 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -70,13 +70,14 @@ public final class UnsafeExternalSorter {
   private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
 
   // These variables are reset after spilling:
-  private UnsafeInMemorySorter inMemSorter;
+  @Nullable private UnsafeInMemorySorter inMemSorter;
   // Whether the in-mem sorter is created internally, or passed in from outside.
   // If it is passed in from outside, we shouldn't release the in-mem sorter's memory.
   private boolean isInMemSorterExternal = false;
   private MemoryBlock currentPage = null;
   private long currentPagePosition = -1;
   private long freeSpaceInCurrentPage = 0;
+  private long peakMemoryUsedBytes = 0;
 
   public static UnsafeExternalSorter createWithExistingInMemorySorter(
       TaskMemoryManager taskMemoryManager,
@@ -183,6 +184,7 @@ public final class UnsafeExternalSorter {
    * Sort and spill the current records in response to memory pressure.
    */
   public void spill() throws IOException {
+    assert(inMemSorter != null);
     logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
       Thread.currentThread().getId(),
       Utils.bytesToString(getMemoryUsage()),
@@ -219,7 +221,22 @@ public final class UnsafeExternalSorter {
     for (MemoryBlock page : allocatedPages) {
       totalPageSize += page.size();
     }
-    return inMemSorter.getMemoryUsage() + totalPageSize;
+    return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
+  }
+
+  private void updatePeakMemoryUsed() {
+    long mem = getMemoryUsage();
+    if (mem > peakMemoryUsedBytes) {
+      peakMemoryUsedBytes = mem;
+    }
+  }
+
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  public long getPeakMemoryUsedBytes() {
+    updatePeakMemoryUsed();
+    return peakMemoryUsedBytes;
   }
 
   @VisibleForTesting
@@ -233,6 +250,7 @@ public final class UnsafeExternalSorter {
    * @return the number of bytes freed.
    */
   public long freeMemory() {
+    updatePeakMemoryUsed();
     long memoryFreed = 0;
     for (MemoryBlock block : allocatedPages) {
       taskMemoryManager.freePage(block);
@@ -277,7 +295,8 @@ public final class UnsafeExternalSorter {
    * @return true if the record can be inserted without requiring more allocations, false otherwise.
    */
   private boolean haveSpaceForRecord(int requiredSpace) {
-    assert (requiredSpace > 0);
+    assert(requiredSpace > 0);
+    assert(inMemSorter != null);
     return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
   }
 
@@ -290,6 +309,7 @@ public final class UnsafeExternalSorter {
    *                      the record size.
    */
   private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+    assert(inMemSorter != null);
     // TODO: merge these steps to first calculate total memory requirements for this insert,
     // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
     // data page.
@@ -350,6 +370,7 @@ public final class UnsafeExternalSorter {
     if (!haveSpaceForRecord(totalSpaceRequired)) {
       allocateSpaceForRecord(totalSpaceRequired);
     }
+    assert(inMemSorter != null);
 
     final long recordAddress =
       taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
@@ -382,6 +403,7 @@ public final class UnsafeExternalSorter {
     if (!haveSpaceForRecord(totalSpaceRequired)) {
       allocateSpaceForRecord(totalSpaceRequired);
     }
+    assert(inMemSorter != null);
 
     final long recordAddress =
       taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
@@ -405,6 +427,7 @@ public final class UnsafeExternalSorter {
   }
 
   public UnsafeSorterIterator getSortedIterator() throws IOException {
+    assert(inMemSorter != null);
     final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator();
     int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
     if (spillWriters.isEmpty()) {

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/resources/org/apache/spark/ui/static/webui.css
----------------------------------------------------------------------
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index b1cef47..648cd1b 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -207,7 +207,7 @@ span.additional-metric-title {
 /* Hide all additional metrics by default. This is done here rather than using JavaScript to
  * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
 .scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote,
-.serialization_time, .getting_result_time {
+.serialization_time, .getting_result_time, .peak_execution_memory {
   display: none;
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/Accumulators.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index eb75f26..b6a0119 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -152,8 +152,14 @@ class Accumulable[R, T] private[spark] (
     in.defaultReadObject()
     value_ = zero
     deserialized = true
-    val taskContext = TaskContext.get()
-    taskContext.registerAccumulator(this)
+    // Automatically register the accumulator when it is deserialized with the task closure.
+    // Note that internal accumulators are deserialized before the TaskContext is created and
+    // are registered in the TaskContext constructor.
+    if (!isInternal) {
+      val taskContext = TaskContext.get()
+      assume(taskContext != null, "Task context was null when deserializing user accumulators")
+      taskContext.registerAccumulator(this)
+    }
   }
 
   override def toString: String = if (value_ == null) "null" else value_.toString
@@ -248,10 +254,20 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
  * @param param helper object defining how to add elements of type `T`
  * @tparam T result type
  */
-class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String])
-  extends Accumulable[T, T](initialValue, param, name) {
+class Accumulator[T] private[spark] (
+    @transient initialValue: T,
+    param: AccumulatorParam[T],
+    name: Option[String],
+    internal: Boolean)
+  extends Accumulable[T, T](initialValue, param, name, internal) {
+
+  def this(initialValue: T, param: AccumulatorParam[T], name: Option[String]) = {
+    this(initialValue, param, name, false)
+  }
 
-  def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None)
+  def this(initialValue: T, param: AccumulatorParam[T]) = {
+    this(initialValue, param, None, false)
+  }
 }
 
 /**
@@ -342,3 +358,37 @@ private[spark] object Accumulators extends Logging {
   }
 
 }
+
+private[spark] object InternalAccumulator {
+  val PEAK_EXECUTION_MEMORY = "peakExecutionMemory"
+  val TEST_ACCUMULATOR = "testAccumulator"
+
+  // For testing only.
+  // This needs to be a def since we don't want to reuse the same accumulator across stages.
+  private def maybeTestAccumulator: Option[Accumulator[Long]] = {
+    if (sys.props.contains("spark.testing")) {
+      Some(new Accumulator(
+        0L, AccumulatorParam.LongAccumulatorParam, Some(TEST_ACCUMULATOR), internal = true))
+    } else {
+      None
+    }
+  }
+
+  /**
+   * Accumulators for tracking internal metrics.
+   *
+   * These accumulators are created with the stage such that all tasks in the stage will
+   * add to the same set of accumulators. We do this to report the distribution of accumulator
+   * values across all tasks within each stage.
+   */
+  def create(): Seq[Accumulator[Long]] = {
+    Seq(
+      // Execution memory refers to the memory used by internal data structures created
+      // during shuffles, aggregations and joins. The value of this accumulator should be
+      // approximately the sum of the peak sizes across all such data structures created
+      // in this task. For SQL jobs, this only tracks all unsafe operators and ExternalSort.
+      new Accumulator(
+        0L, AccumulatorParam.LongAccumulatorParam, Some(PEAK_EXECUTION_MEMORY), internal = true)
+    ) ++ maybeTestAccumulator.toSeq
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/Aggregator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index ceeb580..289aab9 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -58,12 +58,7 @@ case class Aggregator[K, V, C] (
     } else {
       val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
       combiners.insertAll(iter)
-      // Update task metrics if context is not null
-      // TODO: Make context non optional in a future release
-      Option(context).foreach { c =>
-        c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
-        c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
-      }
+      updateMetrics(context, combiners)
       combiners.iterator
     }
   }
@@ -89,13 +84,18 @@ case class Aggregator[K, V, C] (
     } else {
       val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
       combiners.insertAll(iter)
-      // Update task metrics if context is not null
-      // TODO: Make context non-optional in a future release
-      Option(context).foreach { c =>
-        c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
-        c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
-      }
+      updateMetrics(context, combiners)
       combiners.iterator
     }
   }
+
+  /** Update task metrics after populating the external map. */
+  private def updateMetrics(context: TaskContext, map: ExternalAppendOnlyMap[_, _, _]): Unit = {
+    Option(context).foreach { c =>
+      c.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
+      c.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
+      c.internalMetricsToAccumulators(
+        InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/TaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 5d2c551..63cca80 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -61,12 +61,12 @@ object TaskContext {
   protected[spark] def unset(): Unit = taskContext.remove()
 
   /**
-   * Return an empty task context that is not actually used.
-   * Internal use only.
+   * An empty task context that does not represent an actual task.
    */
-  private[spark] def empty(): TaskContext = {
-    new TaskContextImpl(0, 0, 0, 0, null, null)
+  private[spark] def empty(): TaskContextImpl = {
+    new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty)
   }
+
 }
 
 
@@ -187,4 +187,9 @@ abstract class TaskContext extends Serializable {
    * accumulator id and the value of the Map is the latest accumulator local value.
    */
   private[spark] def collectAccumulators(): Map[Long, Any]
+
+  /**
+   * Accumulators for tracking internal metrics indexed by the name.
+   */
+  private[spark] val internalMetricsToAccumulators: Map[String, Accumulator[Long]]
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 9ee168a..5df94c6 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -32,6 +32,7 @@ private[spark] class TaskContextImpl(
     override val attemptNumber: Int,
     override val taskMemoryManager: TaskMemoryManager,
     @transient private val metricsSystem: MetricsSystem,
+    internalAccumulators: Seq[Accumulator[Long]],
     val runningLocally: Boolean = false,
     val taskMetrics: TaskMetrics = TaskMetrics.empty)
   extends TaskContext
@@ -114,4 +115,11 @@ private[spark] class TaskContextImpl(
   private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
     accumulators.mapValues(_.localValue).toMap
   }
+
+  private[spark] override val internalMetricsToAccumulators: Map[String, Accumulator[Long]] = {
+    // Explicitly register internal accumulators here because these are
+    // not captured in the task closure and are already deserialized
+    internalAccumulators.foreach(registerAccumulator)
+    internalAccumulators.map { a => (a.name.get, a) }.toMap
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 130b588..9c617fc 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -23,8 +23,7 @@ import java.io.{IOException, ObjectOutputStream}
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
-import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import org.apache.spark._
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer}
 import org.apache.spark.util.Utils
@@ -169,8 +168,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
       for ((it, depNum) <- rddIterators) {
         map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
       }
-      context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled)
-      context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled)
+      context.taskMetrics().incMemoryBytesSpilled(map.memoryBytesSpilled)
+      context.taskMetrics().incDiskBytesSpilled(map.diskBytesSpilled)
+      context.internalMetricsToAccumulators(
+        InternalAccumulator.PEAK_EXECUTION_MEMORY).add(map.peakMemoryUsedBytes)
       new InterruptibleIterator(context,
         map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
index e0edd7d..11d123e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
@@ -24,11 +24,12 @@ import org.apache.spark.annotation.DeveloperApi
  * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage.
  */
 @DeveloperApi
-class AccumulableInfo (
+class AccumulableInfo private[spark] (
     val id: Long,
     val name: String,
     val update: Option[String], // represents a partial update within a task
-    val value: String) {
+    val value: String,
+    val internal: Boolean) {
 
   override def equals(other: Any): Boolean = other match {
     case acc: AccumulableInfo =>
@@ -40,10 +41,10 @@ class AccumulableInfo (
 
 object AccumulableInfo {
   def apply(id: Long, name: String, update: Option[String], value: String): AccumulableInfo = {
-    new AccumulableInfo(id, name, update, value)
+    new AccumulableInfo(id, name, update, value, internal = false)
   }
 
   def apply(id: Long, name: String, value: String): AccumulableInfo = {
-    new AccumulableInfo(id, name, None, value)
+    new AccumulableInfo(id, name, None, value, internal = false)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c4fa277..bb489c6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -773,16 +773,26 @@ class DAGScheduler(
     stage.pendingTasks.clear()
 
     // First figure out the indexes of partition ids to compute.
-    val partitionsToCompute: Seq[Int] = {
+    val (allPartitions: Seq[Int], partitionsToCompute: Seq[Int]) = {
       stage match {
         case stage: ShuffleMapStage =>
-          (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty)
+          val allPartitions = 0 until stage.numPartitions
+          val filteredPartitions = allPartitions.filter { id => stage.outputLocs(id).isEmpty }
+          (allPartitions, filteredPartitions)
         case stage: ResultStage =>
           val job = stage.resultOfJob.get
-          (0 until job.numPartitions).filter(id => !job.finished(id))
+          val allPartitions = 0 until job.numPartitions
+          val filteredPartitions = allPartitions.filter { id => !job.finished(id) }
+          (allPartitions, filteredPartitions)
       }
     }
 
+    // Reset internal accumulators only if this stage is not partially submitted
+    // Otherwise, we may override existing accumulator values from some tasks
+    if (allPartitions == partitionsToCompute) {
+      stage.resetInternalAccumulators()
+    }
+
     val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull
 
     runningStages += stage
@@ -852,7 +862,8 @@ class DAGScheduler(
           partitionsToCompute.map { id =>
             val locs = taskIdToLocations(id)
             val part = stage.rdd.partitions(id)
-            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
+            new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
+              taskBinary, part, locs, stage.internalAccumulators)
           }
 
         case stage: ResultStage =>
@@ -861,7 +872,8 @@ class DAGScheduler(
             val p: Int = job.partitions(id)
             val part = stage.rdd.partitions(p)
             val locs = taskIdToLocations(id)
-            new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
+            new ResultTask(stage.id, stage.latestInfo.attemptId,
+              taskBinary, part, locs, id, stage.internalAccumulators)
           }
       }
     } catch {
@@ -916,9 +928,11 @@ class DAGScheduler(
           // To avoid UI cruft, ignore cases where value wasn't updated
           if (acc.name.isDefined && partialValue != acc.zero) {
             val name = acc.name.get
-            stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}")
+            val value = s"${acc.value}"
+            stage.latestInfo.accumulables(id) =
+              new AccumulableInfo(id, name, None, value, acc.isInternal)
             event.taskInfo.accumulables +=
-              AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}")
+              new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal)
           }
         }
       } catch {

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 9c2606e..c4dc080 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -45,8 +45,10 @@ private[spark] class ResultTask[T, U](
     taskBinary: Broadcast[Array[Byte]],
     partition: Partition,
     @transient locs: Seq[TaskLocation],
-    val outputId: Int)
-  extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
+    val outputId: Int,
+    internalAccumulators: Seq[Accumulator[Long]])
+  extends Task[U](stageId, stageAttemptId, partition.index, internalAccumulators)
+  with Serializable {
 
   @transient private[this] val preferredLocs: Seq[TaskLocation] = {
     if (locs == null) Nil else locs.toSet.toSeq

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 14c8c00..f478f99 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -43,12 +43,14 @@ private[spark] class ShuffleMapTask(
     stageAttemptId: Int,
     taskBinary: Broadcast[Array[Byte]],
     partition: Partition,
-    @transient private var locs: Seq[TaskLocation])
-  extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
+    @transient private var locs: Seq[TaskLocation],
+    internalAccumulators: Seq[Accumulator[Long]])
+  extends Task[MapStatus](stageId, stageAttemptId, partition.index, internalAccumulators)
+  with Logging {
 
   /** A constructor used only in test suites. This does not require passing in an RDD. */
   def this(partitionId: Int) {
-    this(0, 0, null, new Partition { override def index: Int = 0 }, null)
+    this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
   }
 
   @transient private val preferredLocs: Seq[TaskLocation] = {
@@ -69,7 +71,7 @@ private[spark] class ShuffleMapTask(
       val manager = SparkEnv.get.shuffleManager
       writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
       writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
-      return writer.stop(success = true).get
+      writer.stop(success = true).get
     } catch {
       case e: Exception =>
         try {

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 40a333a..de05ee2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -68,6 +68,22 @@ private[spark] abstract class Stage(
   val name = callSite.shortForm
   val details = callSite.longForm
 
+  private var _internalAccumulators: Seq[Accumulator[Long]] = Seq.empty
+
+  /** Internal accumulators shared across all tasks in this stage. */
+  def internalAccumulators: Seq[Accumulator[Long]] = _internalAccumulators
+
+  /**
+   * Re-initialize the internal accumulators associated with this stage.
+   *
+   * This is called every time the stage is submitted, *except* when a subset of tasks
+   * belonging to this stage has already finished. Otherwise, reinitializing the internal
+   * accumulators here again will override partial values from the finished tasks.
+   */
+  def resetInternalAccumulators(): Unit = {
+    _internalAccumulators = InternalAccumulator.create()
+  }
+
   /**
    * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized
    * here, before any attempts have actually been created, because the DAGScheduler uses this

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 1978305..9edf9f0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -23,7 +23,7 @@ import java.nio.ByteBuffer
 import scala.collection.mutable.HashMap
 
 import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
+import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext}
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.serializer.SerializerInstance
 import org.apache.spark.unsafe.memory.TaskMemoryManager
@@ -47,7 +47,8 @@ import org.apache.spark.util.Utils
 private[spark] abstract class Task[T](
     val stageId: Int,
     val stageAttemptId: Int,
-    var partitionId: Int) extends Serializable {
+    val partitionId: Int,
+    internalAccumulators: Seq[Accumulator[Long]]) extends Serializable {
 
   /**
    * The key of the Map is the accumulator id and the value of the Map is the latest accumulator
@@ -68,12 +69,13 @@ private[spark] abstract class Task[T](
     metricsSystem: MetricsSystem)
   : (T, AccumulatorUpdates) = {
     context = new TaskContextImpl(
-      stageId = stageId,
-      partitionId = partitionId,
-      taskAttemptId = taskAttemptId,
-      attemptNumber = attemptNumber,
-      taskMemoryManager = taskMemoryManager,
-      metricsSystem = metricsSystem,
+      stageId,
+      partitionId,
+      taskAttemptId,
+      attemptNumber,
+      taskMemoryManager,
+      metricsSystem,
+      internalAccumulators,
       runningLocally = false)
     TaskContext.setTaskContext(context)
     context.taskMetrics.setHostname(Utils.localHostName())

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index de79fa5..0c8f08f 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.shuffle.hash
 
-import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark._
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
 import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
@@ -100,8 +100,10 @@ private[spark] class HashShuffleReader[K, C](
         // the ExternalSorter won't spill to disk.
         val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
         sorter.insertAll(aggregatedIter)
-        context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
-        context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
+        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+        context.internalMetricsToAccumulators(
+          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
         sorter.iterator
       case None =>
         aggregatedIter

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index e2d25e3..cb122ea 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -62,6 +62,13 @@ private[spark] object ToolTips {
     """Time that the executor spent paused for Java garbage collection while the task was
        running."""
 
+  val PEAK_EXECUTION_MEMORY =
+    """Execution memory refers to the memory used by internal data structures created during
+       shuffles, aggregations and joins when Tungsten is enabled. The value of this accumulator
+       should be approximately the sum of the peak sizes across all such data structures created
+       in this task. For SQL jobs, this only tracks all unsafe operators, broadcast joins, and
+       external sort."""
+
   val JOB_TIMELINE =
     """Shows when jobs started and ended and when executors joined or left. Drag to scroll.
        Click Enable Zooming and use mouse wheel to zoom in/out."""

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index cf04b5e..3954c3d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -26,6 +26,7 @@ import scala.xml.{Elem, Node, Unparsed}
 
 import org.apache.commons.lang3.StringEscapeUtils
 
+import org.apache.spark.{InternalAccumulator, SparkConf}
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
 import org.apache.spark.ui._
@@ -67,6 +68,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
   // if we find that it's okay.
   private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000)
 
+  private val displayPeakExecutionMemory =
+    parent.conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean)
 
   def render(request: HttpServletRequest): Seq[Node] = {
     progressListener.synchronized {
@@ -114,10 +117,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
 
       val stageData = stageDataOption.get
       val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime)
-
       val numCompleted = tasks.count(_.taskInfo.finished)
-      val accumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables
-      val hasAccumulators = accumulables.size > 0
+
+      val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables
+      val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal }
+      val hasAccumulators = externalAccumulables.size > 0
 
       val summary =
         <div>
@@ -221,6 +225,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
                   <span class="additional-metric-title">Getting Result Time</span>
                 </span>
               </li>
+              {if (displayPeakExecutionMemory) {
+                <li>
+                  <span data-toggle="tooltip"
+                        title={ToolTips.PEAK_EXECUTION_MEMORY} data-placement="right">
+                    <input type="checkbox" name={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}/>
+                    <span class="additional-metric-title">Peak Execution Memory</span>
+                  </span>
+                </li>
+              }}
             </ul>
           </div>
         </div>
@@ -241,11 +254,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
       val accumulableTable = UIUtils.listingTable(
         accumulableHeaders,
         accumulableRow,
-        accumulables.values.toSeq)
+        externalAccumulables.toSeq)
 
       val currentTime = System.currentTimeMillis()
       val (taskTable, taskTableHTML) = try {
         val _taskTable = new TaskPagedTable(
+          parent.conf,
           UIUtils.prependBaseUri(parent.basePath) +
             s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}",
           tasks,
@@ -294,12 +308,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
         else {
           def getDistributionQuantiles(data: Seq[Double]): IndexedSeq[Double] =
             Distribution(data).get.getQuantiles()
-
           def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = {
             getDistributionQuantiles(times).map { millis =>
               <td>{UIUtils.formatDuration(millis.toLong)}</td>
             }
           }
+          def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] = {
+            getDistributionQuantiles(data).map(d => <td>{Utils.bytesToString(d.toLong)}</td>)
+          }
 
           val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
             metrics.get.executorDeserializeTime.toDouble
@@ -349,6 +365,23 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
               </span>
             </td> +:
             getFormattedTimeQuantiles(gettingResultTimes)
+
+          val peakExecutionMemory = validTasks.map { case TaskUIData(info, _, _) =>
+            info.accumulables
+              .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
+              .map { acc => acc.value.toLong }
+              .getOrElse(0L)
+              .toDouble
+          }
+          val peakExecutionMemoryQuantiles = {
+            <td>
+              <span data-toggle="tooltip"
+                    title={ToolTips.PEAK_EXECUTION_MEMORY} data-placement="right">
+                Peak Execution Memory
+              </span>
+            </td> +: getFormattedSizeQuantiles(peakExecutionMemory)
+          }
+
           // The scheduler delay includes the network delay to send the task to the worker
           // machine and to send back the result (but not the time to fetch the task result,
           // if it needed to be fetched from the block manager on the worker).
@@ -359,10 +392,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
             title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay</span></td>
           val schedulerDelayQuantiles = schedulerDelayTitle +:
             getFormattedTimeQuantiles(schedulerDelays)
-
-          def getFormattedSizeQuantiles(data: Seq[Double]): Seq[Elem] =
-            getDistributionQuantiles(data).map(d => <td>{Utils.bytesToString(d.toLong)}</td>)
-
           def getFormattedSizeQuantilesWithRecords(data: Seq[Double], records: Seq[Double])
             : Seq[Elem] = {
             val recordDist = getDistributionQuantiles(records).iterator
@@ -466,6 +495,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
               {serializationQuantiles}
             </tr>,
             <tr class={TaskDetailsClassNames.GETTING_RESULT_TIME}>{gettingResultQuantiles}</tr>,
+            if (displayPeakExecutionMemory) {
+              <tr class={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}>
+                {peakExecutionMemoryQuantiles}
+              </tr>
+            } else {
+              Nil
+            },
             if (stageData.hasInput) <tr>{inputQuantiles}</tr> else Nil,
             if (stageData.hasOutput) <tr>{outputQuantiles}</tr> else Nil,
             if (stageData.hasShuffleRead) {
@@ -499,7 +535,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
       val executorTable = new ExecutorTable(stageId, stageAttemptId, parent)
 
       val maybeAccumulableTable: Seq[Node] =
-        if (accumulables.size > 0) { <h4>Accumulators</h4> ++ accumulableTable } else Seq()
+        if (hasAccumulators) { <h4>Accumulators</h4> ++ accumulableTable } else Seq()
 
       val content =
         summary ++
@@ -750,29 +786,30 @@ private[ui] case class TaskTableRowBytesSpilledData(
  * Contains all data that needs for sorting and generating HTML. Using this one rather than
  * TaskUIData to avoid creating duplicate contents during sorting the data.
  */
-private[ui] case class TaskTableRowData(
-    index: Int,
-    taskId: Long,
-    attempt: Int,
-    speculative: Boolean,
-    status: String,
-    taskLocality: String,
-    executorIdAndHost: String,
-    launchTime: Long,
-    duration: Long,
-    formatDuration: String,
-    schedulerDelay: Long,
-    taskDeserializationTime: Long,
-    gcTime: Long,
-    serializationTime: Long,
-    gettingResultTime: Long,
-    accumulators: Option[String], // HTML
-    input: Option[TaskTableRowInputData],
-    output: Option[TaskTableRowOutputData],
-    shuffleRead: Option[TaskTableRowShuffleReadData],
-    shuffleWrite: Option[TaskTableRowShuffleWriteData],
-    bytesSpilled: Option[TaskTableRowBytesSpilledData],
-    error: String)
+private[ui] class TaskTableRowData(
+    val index: Int,
+    val taskId: Long,
+    val attempt: Int,
+    val speculative: Boolean,
+    val status: String,
+    val taskLocality: String,
+    val executorIdAndHost: String,
+    val launchTime: Long,
+    val duration: Long,
+    val formatDuration: String,
+    val schedulerDelay: Long,
+    val taskDeserializationTime: Long,
+    val gcTime: Long,
+    val serializationTime: Long,
+    val gettingResultTime: Long,
+    val peakExecutionMemoryUsed: Long,
+    val accumulators: Option[String], // HTML
+    val input: Option[TaskTableRowInputData],
+    val output: Option[TaskTableRowOutputData],
+    val shuffleRead: Option[TaskTableRowShuffleReadData],
+    val shuffleWrite: Option[TaskTableRowShuffleWriteData],
+    val bytesSpilled: Option[TaskTableRowBytesSpilledData],
+    val error: String)
 
 private[ui] class TaskDataSource(
     tasks: Seq[TaskUIData],
@@ -816,10 +853,15 @@ private[ui] class TaskDataSource(
     val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
     val gettingResultTime = getGettingResultTime(info, currentTime)
 
-    val maybeAccumulators = info.accumulables
-    val accumulatorsReadable = maybeAccumulators.map { acc =>
+    val (taskInternalAccumulables, taskExternalAccumulables) =
+      info.accumulables.partition(_.internal)
+    val externalAccumulableReadable = taskExternalAccumulables.map { acc =>
       StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}")
     }
+    val peakExecutionMemoryUsed = taskInternalAccumulables
+      .find { acc => acc.name == InternalAccumulator.PEAK_EXECUTION_MEMORY }
+      .map { acc => acc.value.toLong }
+      .getOrElse(0L)
 
     val maybeInput = metrics.flatMap(_.inputMetrics)
     val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L)
@@ -923,7 +965,7 @@ private[ui] class TaskDataSource(
         None
       }
 
-    TaskTableRowData(
+    new TaskTableRowData(
       info.index,
       info.taskId,
       info.attempt,
@@ -939,7 +981,8 @@ private[ui] class TaskDataSource(
       gcTime,
       serializationTime,
       gettingResultTime,
-      if (hasAccumulators) Some(accumulatorsReadable.mkString("<br/>")) else None,
+      peakExecutionMemoryUsed,
+      if (hasAccumulators) Some(externalAccumulableReadable.mkString("<br/>")) else None,
       input,
       output,
       shuffleRead,
@@ -1006,6 +1049,10 @@ private[ui] class TaskDataSource(
         override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
           Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime)
       }
+      case "Peak Execution Memory" => new Ordering[TaskTableRowData] {
+        override def compare(x: TaskTableRowData, y: TaskTableRowData): Int =
+          Ordering.Long.compare(x.peakExecutionMemoryUsed, y.peakExecutionMemoryUsed)
+      }
       case "Accumulators" =>
         if (hasAccumulators) {
           new Ordering[TaskTableRowData] {
@@ -1132,6 +1179,7 @@ private[ui] class TaskDataSource(
 }
 
 private[ui] class TaskPagedTable(
+    conf: SparkConf,
     basePath: String,
     data: Seq[TaskUIData],
     hasAccumulators: Boolean,
@@ -1143,7 +1191,11 @@ private[ui] class TaskPagedTable(
     currentTime: Long,
     pageSize: Int,
     sortColumn: String,
-    desc: Boolean) extends PagedTable[TaskTableRowData]{
+    desc: Boolean) extends PagedTable[TaskTableRowData] {
+
+  // We only track peak memory used for unsafe operators
+  private val displayPeakExecutionMemory =
+    conf.getOption("spark.sql.unsafe.enabled").exists(_.toBoolean)
 
   override def tableId: String = ""
 
@@ -1195,6 +1247,13 @@ private[ui] class TaskPagedTable(
         ("GC Time", ""),
         ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME),
         ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++
+        {
+          if (displayPeakExecutionMemory) {
+            Seq(("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY))
+          } else {
+            Nil
+          }
+        } ++
         {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
         {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++
         {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++
@@ -1271,6 +1330,11 @@ private[ui] class TaskPagedTable(
       <td class={TaskDetailsClassNames.GETTING_RESULT_TIME}>
         {UIUtils.formatDuration(task.gettingResultTime)}
       </td>
+      {if (displayPeakExecutionMemory) {
+        <td class={TaskDetailsClassNames.PEAK_EXECUTION_MEMORY}>
+          {Utils.bytesToString(task.peakExecutionMemoryUsed)}
+        </td>
+      }}
       {if (task.accumulators.nonEmpty) {
         <td>{Unparsed(task.accumulators.get)}</td>
       }}

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
index 9bf67db..d2dfc5a 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
@@ -31,4 +31,5 @@ private[spark] object TaskDetailsClassNames {
   val SHUFFLE_READ_REMOTE_SIZE = "shuffle_read_remote"
   val RESULT_SERIALIZATION_TIME = "serialization_time"
   val GETTING_RESULT_TIME = "getting_result_time"
+  val PEAK_EXECUTION_MEMORY = "peak_execution_memory"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index d166037..f929b12 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -89,6 +89,7 @@ class ExternalAppendOnlyMap[K, V, C](
 
   // Number of bytes spilled in total
   private var _diskBytesSpilled = 0L
+  def diskBytesSpilled: Long = _diskBytesSpilled
 
   // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val fileBufferSize =
@@ -97,6 +98,10 @@ class ExternalAppendOnlyMap[K, V, C](
   // Write metrics for current spill
   private var curWriteMetrics: ShuffleWriteMetrics = _
 
+  // Peak size of the in-memory map observed so far, in bytes
+  private var _peakMemoryUsedBytes: Long = 0L
+  def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
+
   private val keyComparator = new HashComparator[K]
   private val ser = serializer.newInstance()
 
@@ -126,7 +131,11 @@ class ExternalAppendOnlyMap[K, V, C](
 
     while (entries.hasNext) {
       curEntry = entries.next()
-      if (maybeSpill(currentMap, currentMap.estimateSize())) {
+      val estimatedSize = currentMap.estimateSize()
+      if (estimatedSize > _peakMemoryUsedBytes) {
+        _peakMemoryUsedBytes = estimatedSize
+      }
+      if (maybeSpill(currentMap, estimatedSize)) {
         currentMap = new SizeTrackingAppendOnlyMap[K, C]
       }
       currentMap.changeValue(curEntry._1, update)
@@ -207,8 +216,6 @@ class ExternalAppendOnlyMap[K, V, C](
     spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
   }
 
-  def diskBytesSpilled: Long = _diskBytesSpilled
-
   /**
    * Return an iterator that merges the in-memory map with the spilled maps.
    * If no spill has occurred, simply return the in-memory map's iterator.

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index ba7ec83..19287ed 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -152,6 +152,9 @@ private[spark] class ExternalSorter[K, V, C](
   private var _diskBytesSpilled = 0L
   def diskBytesSpilled: Long = _diskBytesSpilled
 
+  // Peak size of the in-memory data structure observed so far, in bytes
+  private var _peakMemoryUsedBytes: Long = 0L
+  def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes
 
   // A comparator for keys K that orders them within a partition to allow aggregation or sorting.
   // Can be a partial ordering by hash code if a total ordering is not provided through by the
@@ -224,15 +227,22 @@ private[spark] class ExternalSorter[K, V, C](
       return
     }
 
+    var estimatedSize = 0L
     if (usingMap) {
-      if (maybeSpill(map, map.estimateSize())) {
+      estimatedSize = map.estimateSize()
+      if (maybeSpill(map, estimatedSize)) {
         map = new PartitionedAppendOnlyMap[K, C]
       }
     } else {
-      if (maybeSpill(buffer, buffer.estimateSize())) {
+      estimatedSize = buffer.estimateSize()
+      if (maybeSpill(buffer, estimatedSize)) {
         buffer = newBuffer()
       }
     }
+
+    if (estimatedSize > _peakMemoryUsedBytes) {
+      _peakMemoryUsedBytes = estimatedSize
+    }
   }
 
   /**
@@ -684,8 +694,10 @@ private[spark] class ExternalSorter[K, V, C](
       }
     }
 
-    context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
-    context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
+    context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
+    context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
+    context.internalMetricsToAccumulators(
+      InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemoryUsedBytes)
 
     lengths
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/test/java/org/apache/spark/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e948ca3..ffe4b4b 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -51,7 +51,6 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.*;
 import org.apache.spark.api.java.function.*;
-import org.apache.spark.executor.TaskMetrics;
 import org.apache.spark.input.PortableDataStream;
 import org.apache.spark.partial.BoundedDouble;
 import org.apache.spark.partial.PartialResult;
@@ -1011,7 +1010,7 @@ public class JavaAPISuite implements Serializable {
   @Test
   public void iterator() {
     JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
-    TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics());
+    TaskContext context = TaskContext$.MODULE$.empty();
     Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
index 04fc09b..98c32bb 100644
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -190,6 +190,7 @@ public class UnsafeShuffleWriterSuite {
       });
 
     when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+    when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
 
     when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
     when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
@@ -542,4 +543,57 @@ public class UnsafeShuffleWriterSuite {
     writer.stop(false);
     assertSpillFilesWereCleanedUp();
   }
+
+  @Test
+  public void testPeakMemoryUsed() throws Exception {
+    final long recordLengthBytes = 8;
+    final long pageSizeBytes = 256;
+    final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
+    final SparkConf conf = new SparkConf().set("spark.buffer.pageSize", pageSizeBytes + "b");
+    final UnsafeShuffleWriter<Object, Object> writer =
+      new UnsafeShuffleWriter<Object, Object>(
+        blockManager,
+        shuffleBlockResolver,
+        taskMemoryManager,
+        shuffleMemoryManager,
+        new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
+        0, // map id
+        taskContext,
+        conf);
+
+    // Peak memory should be monotonically increasing. More specifically, every time
+    // we allocate a new page it should increase by exactly the size of the page.
+    long previousPeakMemory = writer.getPeakMemoryUsedBytes();
+    long newPeakMemory;
+    try {
+      for (int i = 0; i < numRecordsPerPage * 10; i++) {
+        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+        newPeakMemory = writer.getPeakMemoryUsedBytes();
+        if (i % numRecordsPerPage == 0) {
+          // We allocated a new page for this record, so peak memory should change
+          assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
+        } else {
+          assertEquals(previousPeakMemory, newPeakMemory);
+        }
+        previousPeakMemory = newPeakMemory;
+      }
+
+      // Spilling should not change peak memory
+      writer.forceSorterToSpill();
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+      for (int i = 0; i < numRecordsPerPage; i++) {
+        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+      }
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+
+      // Closing the writer should not change peak memory
+      writer.closeAndWriteOutput();
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+    } finally {
+      writer.stop(false);
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index dbb7c66..0e23a64 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -25,6 +25,7 @@ import org.junit.*;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.*;
 import static org.mockito.AdditionalMatchers.geq;
 import static org.mockito.Mockito.*;
 
@@ -495,4 +496,42 @@ public abstract class AbstractBytesToBytesMapSuite {
     map.growAndRehash();
     map.free();
   }
+
+  @Test
+  public void testTotalMemoryConsumption() {
+    final long recordLengthBytes = 24;
+    final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
+    final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
+    final BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes);
+
+    // Since BytesToBytesMap is append-only, we expect the total memory consumption to be
+    // monotonically increasing. More specifically, every time we allocate a new page it
+    // should increase by exactly the size of the page. In this regard, the memory usage
+    // at any given time is also the peak memory used.
+    long previousMemory = map.getTotalMemoryConsumption();
+    long newMemory;
+    try {
+      for (long i = 0; i < numRecordsPerPage * 10; i++) {
+        final long[] value = new long[]{i};
+        map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8).putNewKey(
+          value,
+          PlatformDependent.LONG_ARRAY_OFFSET,
+          8,
+          value,
+          PlatformDependent.LONG_ARRAY_OFFSET,
+          8);
+        newMemory = map.getTotalMemoryConsumption();
+        if (i % numRecordsPerPage == 0) {
+          // We allocated a new page for this record, so peak memory should change
+          assertEquals(previousMemory + pageSizeBytes, newMemory);
+        } else {
+          assertEquals(previousMemory, newMemory);
+        }
+        previousMemory = newMemory;
+      }
+    } finally {
+      map.free();
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 52fa8bc..c11949d 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -247,4 +247,50 @@ public class UnsafeExternalSorterSuite {
     assertSpillFilesWereCleanedUp();
   }
 
+  @Test
+  public void testPeakMemoryUsed() throws Exception {
+    final long recordLengthBytes = 8;
+    final long pageSizeBytes = 256;
+    final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
+    final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+      taskMemoryManager,
+      shuffleMemoryManager,
+      blockManager,
+      taskContext,
+      recordComparator,
+      prefixComparator,
+      1024,
+      pageSizeBytes);
+
+    // Peak memory should be monotonically increasing. More specifically, every time
+    // we allocate a new page it should increase by exactly the size of the page.
+    long previousPeakMemory = sorter.getPeakMemoryUsedBytes();
+    long newPeakMemory;
+    try {
+      for (int i = 0; i < numRecordsPerPage * 10; i++) {
+        insertNumber(sorter, i);
+        newPeakMemory = sorter.getPeakMemoryUsedBytes();
+        if (i % numRecordsPerPage == 0) {
+          // We allocated a new page for this record, so peak memory should change
+          assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
+        } else {
+          assertEquals(previousPeakMemory, newPeakMemory);
+        }
+        previousPeakMemory = newPeakMemory;
+      }
+
+      // Spilling should not change peak memory
+      sorter.spill();
+      newPeakMemory = sorter.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+      for (int i = 0; i < numRecordsPerPage; i++) {
+        insertNumber(sorter, i);
+      }
+      newPeakMemory = sorter.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+    } finally {
+      sorter.freeMemory();
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index e942d65..48f5495 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -18,13 +18,17 @@
 package org.apache.spark
 
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 import scala.ref.WeakReference
 
 import org.scalatest.Matchers
+import org.scalatest.exceptions.TestFailedException
 
+import org.apache.spark.scheduler._
 
-class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
 
+class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
+  import InternalAccumulator._
 
   implicit def setAccum[A]: AccumulableParam[mutable.Set[A], A] =
     new AccumulableParam[mutable.Set[A], A] {
@@ -155,4 +159,191 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
     assert(!Accumulators.originals.get(accId).isDefined)
   }
 
+  test("internal accumulators in TaskContext") {
+    val accums = InternalAccumulator.create()
+    val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, accums)
+    val internalMetricsToAccums = taskContext.internalMetricsToAccumulators
+    val collectedInternalAccums = taskContext.collectInternalAccumulators()
+    val collectedAccums = taskContext.collectAccumulators()
+    assert(internalMetricsToAccums.size > 0)
+    assert(internalMetricsToAccums.values.forall(_.isInternal))
+    assert(internalMetricsToAccums.contains(TEST_ACCUMULATOR))
+    val testAccum = internalMetricsToAccums(TEST_ACCUMULATOR)
+    assert(collectedInternalAccums.size === internalMetricsToAccums.size)
+    assert(collectedInternalAccums.size === collectedAccums.size)
+    assert(collectedInternalAccums.contains(testAccum.id))
+    assert(collectedAccums.contains(testAccum.id))
+  }
+
+  test("internal accumulators in a stage") {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    sc = new SparkContext("local", "test")
+    sc.addSparkListener(listener)
+    // Have each task add 1 to the internal accumulator
+    sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
+      TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+      iter
+    }.count()
+    val stageInfos = listener.getCompletedStageInfos
+    val taskInfos = listener.getCompletedTaskInfos
+    assert(stageInfos.size === 1)
+    assert(taskInfos.size === numPartitions)
+    // The accumulator values should be merged in the stage
+    val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
+    assert(stageAccum.value.toLong === numPartitions)
+    // The accumulator should be updated locally on each task
+    val taskAccumValues = taskInfos.map { taskInfo =>
+      val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
+      assert(taskAccum.update.isDefined)
+      assert(taskAccum.update.get.toLong === 1)
+      taskAccum.value.toLong
+    }
+    // Each task should keep track of the partial value on the way, i.e. 1, 2, ... numPartitions
+    assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+  }
+
+  test("internal accumulators in multiple stages") {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    sc = new SparkContext("local", "test")
+    sc.addSparkListener(listener)
+    // Each stage creates its own set of internal accumulators so the
+    // values for the same metric should not be mixed up across stages
+    sc.parallelize(1 to 100, numPartitions)
+      .map { i => (i, i) }
+      .mapPartitions { iter =>
+        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+        iter
+      }
+      .reduceByKey { case (x, y) => x + y }
+      .mapPartitions { iter =>
+        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 10
+        iter
+      }
+      .repartition(numPartitions * 2)
+      .mapPartitions { iter =>
+        TaskContext.get().internalMetricsToAccumulators(TEST_ACCUMULATOR) += 100
+        iter
+      }
+      .count()
+    // We ran 3 stages, and the accumulator values should be distinct
+    val stageInfos = listener.getCompletedStageInfos
+    assert(stageInfos.size === 3)
+    val firstStageAccum = findAccumulableInfo(stageInfos(0).accumulables.values, TEST_ACCUMULATOR)
+    val secondStageAccum = findAccumulableInfo(stageInfos(1).accumulables.values, TEST_ACCUMULATOR)
+    val thirdStageAccum = findAccumulableInfo(stageInfos(2).accumulables.values, TEST_ACCUMULATOR)
+    assert(firstStageAccum.value.toLong === numPartitions)
+    assert(secondStageAccum.value.toLong === numPartitions * 10)
+    assert(thirdStageAccum.value.toLong === numPartitions * 2 * 100)
+  }
+
+  test("internal accumulators in fully resubmitted stages") {
+    testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks
+  }
+
+  test("internal accumulators in partially resubmitted stages") {
+    testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset
+  }
+
+  /**
+   * Return the accumulable info that matches the specified name.
+   */
+  private def findAccumulableInfo(
+      accums: Iterable[AccumulableInfo],
+      name: String): AccumulableInfo = {
+    accums.find { a => a.name == name }.getOrElse {
+      throw new TestFailedException(s"internal accumulator '$name' not found", 0)
+    }
+  }
+
+  /**
+   * Test whether internal accumulators are merged properly if some tasks fail.
+   */
+  private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = {
+    val listener = new SaveInfoListener
+    val numPartitions = 10
+    val numFailedPartitions = (0 until numPartitions).count(failCondition)
+    // This says use 1 core and retry tasks up to 2 times
+    sc = new SparkContext("local[1, 2]", "test")
+    sc.addSparkListener(listener)
+    sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) =>
+      val taskContext = TaskContext.get()
+      taskContext.internalMetricsToAccumulators(TEST_ACCUMULATOR) += 1
+      // Fail the first attempts of a subset of the tasks
+      if (failCondition(i) && taskContext.attemptNumber() == 0) {
+        throw new Exception("Failing a task intentionally.")
+      }
+      iter
+    }.count()
+    val stageInfos = listener.getCompletedStageInfos
+    val taskInfos = listener.getCompletedTaskInfos
+    assert(stageInfos.size === 1)
+    assert(taskInfos.size === numPartitions + numFailedPartitions)
+    val stageAccum = findAccumulableInfo(stageInfos.head.accumulables.values, TEST_ACCUMULATOR)
+    // We should not double count values in the merged accumulator
+    assert(stageAccum.value.toLong === numPartitions)
+    val taskAccumValues = taskInfos.flatMap { taskInfo =>
+      if (!taskInfo.failed) {
+        // If a task succeeded, its update value should always be 1
+        val taskAccum = findAccumulableInfo(taskInfo.accumulables, TEST_ACCUMULATOR)
+        assert(taskAccum.update.isDefined)
+        assert(taskAccum.update.get.toLong === 1)
+        Some(taskAccum.value.toLong)
+      } else {
+        // If a task failed, we should not get its accumulator values
+        assert(taskInfo.accumulables.isEmpty)
+        None
+      }
+    }
+    assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
+  }
+
+}
+
+private[spark] object AccumulatorSuite {
+
+  /**
+   * Run one or more Spark jobs and verify that the peak execution memory accumulator
+   * is updated afterwards.
+   */
+  def verifyPeakExecutionMemorySet(
+      sc: SparkContext,
+      testName: String)(testBody: => Unit): Unit = {
+    val listener = new SaveInfoListener
+    sc.addSparkListener(listener)
+    // Verify that the accumulator does not already exist
+    sc.parallelize(1 to 10).count()
+    val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
+    assert(!accums.exists(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY))
+    testBody
+    // Verify that peak execution memory is updated
+    val accum = listener.getCompletedStageInfos
+      .flatMap(_.accumulables.values)
+      .find(_.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)
+      .getOrElse {
+        throw new TestFailedException(
+          s"peak execution memory accumulator not set in '$testName'", 0)
+      }
+    assert(accum.value.toLong > 0)
+  }
+}
+
+/**
+ * A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs.
+ */
+private class SaveInfoListener extends SparkListener {
+  private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo]
+  private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo]
+
+  def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq
+  def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq
+
+  override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+    completedStageInfos += stageCompleted.stageInfo
+  }
+
+  override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+    completedTaskInfos += taskEnd.taskInfo
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 618a5fb..cb8bd04 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -21,7 +21,7 @@ import org.mockito.Mockito._
 import org.scalatest.BeforeAndAfter
 import org.scalatest.mock.MockitoSugar
 
-import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.executor.{DataReadMethod, TaskMetrics}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage._
 
@@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
     // in blockManager.put is a losing battle. You have been warned.
     blockManager = sc.env.blockManager
     cacheManager = sc.env.cacheManager
-    val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val context = TaskContext.empty()
     val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
     val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
     assert(computeValue.toList === List(1, 2, 3, 4))
@@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
     val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
     when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))
 
-    val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val context = TaskContext.empty()
     val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
     assert(value.toList === List(5, 6, 7))
   }
@@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before
     // Local computation should not persist the resulting value, so don't expect a put().
     when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)
 
-    val context = new TaskContextImpl(0, 0, 0, 0, null, null, true)
+    val context = new TaskContextImpl(0, 0, 0, 0, null, null, Seq.empty, runningLocally = true)
     val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
     assert(value.toList === List(1, 2, 3, 4))
   }
 
   test("verify task metrics updated correctly") {
     cacheManager = sc.env.cacheManager
-    val context = new TaskContextImpl(0, 0, 0, 0, null, null)
+    val context = TaskContext.empty()
     cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
     assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index 3e8816a..5f73ec8 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
       }
       val hadoopPart1 = generateFakeHadoopPartition()
       val pipedRdd = new PipedRDD(nums, "printenv " + varName)
-      val tContext = new TaskContextImpl(0, 0, 0, 0, null, null)
+      val tContext = TaskContext.empty()
       val rddIter = pipedRdd.compute(hadoopPart1, tContext)
       val arr = rddIter.toArray
       assert(arr(0) == "/some/path")

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index b3ca150..f7e16af 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -19,9 +19,11 @@ package org.apache.spark.scheduler
 
 import org.apache.spark.TaskContext
 
-class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
+class FakeTask(
+    stageId: Int,
+    prefLocs: Seq[TaskLocation] = Nil)
+  extends Task[Int](stageId, 0, 0, Seq.empty) {
   override def runTask(context: TaskContext): Int = 0
-
   override def preferredLocations: Seq[TaskLocation] = prefLocs
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/702aa9d7/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 383855c..f333247 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
  * A Task implementation that fails to serialize.
  */
 private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
-  extends Task[Array[Byte]](stageId, 0, 0) {
+  extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
 
   override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
   override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message