spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-12295] [SQL] external spilling for window functions
Date Thu, 07 Jan 2016 07:21:57 GMT
Repository: spark
Updated Branches:
  refs/heads/master 84e77a15d -> 6a1c864ab


[SPARK-12295] [SQL] external spilling for window functions

This PR manage the memory used by window functions (buffered rows), also enable external spilling.

After this PR, we can run window functions on a partition with hundreds of millions of rows
with only 1G.

Author: Davies Liu <davies@databricks.com>

Closes #10605 from davies/unsafe_window.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6a1c864a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6a1c864a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6a1c864a

Branch: refs/heads/master
Commit: 6a1c864ab6ee3e869a16ffdbaf6fead21c7aac6d
Parents: 84e77a1
Author: Davies Liu <davies@databricks.com>
Authored: Wed Jan 6 23:21:52 2016 -0800
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Wed Jan 6 23:21:52 2016 -0800

----------------------------------------------------------------------
 .../unsafe/sort/UnsafeExternalSorter.java       |  21 +-
 .../unsafe/sort/UnsafeInMemorySorter.java       |  18 +-
 .../unsafe/sort/UnsafeSorterIterator.java       |   2 +
 .../unsafe/sort/UnsafeSorterSpillMerger.java    |   7 +
 .../unsafe/sort/UnsafeSorterSpillReader.java    |   8 +-
 .../org/apache/spark/sql/execution/Window.scala | 314 ++++++++++++++-----
 6 files changed, 276 insertions(+), 94 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6a1c864a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 77d0b70..68dc0c6 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -45,7 +45,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
 
   private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
 
+  @Nullable
   private final PrefixComparator prefixComparator;
+  @Nullable
   private final RecordComparator recordComparator;
   private final TaskMemoryManager taskMemoryManager;
   private final BlockManager blockManager;
@@ -431,7 +433,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
 
     public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
       this.upstream = inMemIterator;
-      this.numRecords = inMemIterator.numRecordsLeft();
+      this.numRecords = inMemIterator.getNumRecords();
+    }
+
+    public int getNumRecords() {
+      return numRecords;
     }
 
     public long spill() throws IOException {
@@ -558,14 +564,24 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
 
     private final Queue<UnsafeSorterIterator> iterators;
     private UnsafeSorterIterator current;
+    private int numRecords;
 
     public ChainedIterator(Queue<UnsafeSorterIterator> iterators) {
       assert iterators.size() > 0;
+      this.numRecords = 0;
+      for (UnsafeSorterIterator iter: iterators) {
+        this.numRecords += iter.getNumRecords();
+      }
       this.iterators = iterators;
       this.current = iterators.remove();
     }
 
     @Override
+    public int getNumRecords() {
+      return numRecords;
+    }
+
+    @Override
     public boolean hasNext() {
       while (!current.hasNext() && !iterators.isEmpty()) {
         current = iterators.remove();
@@ -575,6 +591,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
 
     @Override
     public void loadNext() throws IOException {
+      while (!current.hasNext() && !iterators.isEmpty()) {
+        current = iterators.remove();
+      }
       current.loadNext();
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6a1c864a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index b7ab456..f71b8d1 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -19,6 +19,8 @@ package org.apache.spark.util.collection.unsafe.sort;
 
 import java.util.Comparator;
 
+import org.apache.avro.reflect.Nullable;
+
 import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
@@ -66,7 +68,9 @@ public final class UnsafeInMemorySorter {
 
   private final MemoryConsumer consumer;
   private final TaskMemoryManager memoryManager;
+  @Nullable
   private final Sorter<RecordPointerAndKeyPrefix, LongArray> sorter;
+  @Nullable
   private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
 
   /**
@@ -98,10 +102,11 @@ public final class UnsafeInMemorySorter {
       LongArray array) {
     this.consumer = consumer;
     this.memoryManager = memoryManager;
-    this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
     if (recordComparator != null) {
+      this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
       this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
     } else {
+      this.sorter = null;
       this.sortComparator = null;
     }
     this.array = array;
@@ -190,12 +195,13 @@ public final class UnsafeInMemorySorter {
     }
 
     @Override
-    public boolean hasNext() {
-      return position / 2 < numRecords;
+    public int getNumRecords() {
+      return numRecords;
     }
 
-    public int numRecordsLeft() {
-      return numRecords - position / 2;
+    @Override
+    public boolean hasNext() {
+      return position / 2 < numRecords;
     }
 
     @Override
@@ -227,7 +233,7 @@ public final class UnsafeInMemorySorter {
    * {@code next()} will return the same mutable object.
    */
   public SortedIterator getSortedIterator() {
-    if (sortComparator != null) {
+    if (sorter != null) {
       sorter.sort(array, 0, pos / 2, sortComparator);
     }
     return new SortedIterator(pos / 2);

http://git-wip-us.apache.org/repos/asf/spark/blob/6a1c864a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
index 16ac2e8..1b3167f 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
@@ -32,4 +32,6 @@ public abstract class UnsafeSorterIterator {
   public abstract int getRecordLength();
 
   public abstract long getKeyPrefix();
+
+  public abstract int getNumRecords();
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6a1c864a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index 3874a9f..ceb5935 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -23,6 +23,7 @@ import java.util.PriorityQueue;
 
 final class UnsafeSorterSpillMerger {
 
+  private int numRecords = 0;
   private final PriorityQueue<UnsafeSorterIterator> priorityQueue;
 
   public UnsafeSorterSpillMerger(
@@ -59,6 +60,7 @@ final class UnsafeSorterSpillMerger {
       // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator.
       spillReader.loadNext();
       priorityQueue.add(spillReader);
+      numRecords += spillReader.getNumRecords();
     }
   }
 
@@ -68,6 +70,11 @@ final class UnsafeSorterSpillMerger {
       private UnsafeSorterIterator spillReader;
 
       @Override
+      public int getNumRecords() {
+        return numRecords;
+      }
+
+      @Override
       public boolean hasNext() {
         return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/6a1c864a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index dcb13e6..20ee1c8 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -38,6 +38,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator
implemen
   // Variables that change with every record read:
   private int recordLength;
   private long keyPrefix;
+  private int numRecords;
   private int numRecordsRemaining;
 
   private byte[] arr = new byte[1024 * 1024];
@@ -53,7 +54,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator
implemen
     try {
       this.in = blockManager.wrapForCompression(blockId, bs);
       this.din = new DataInputStream(this.in);
-      numRecordsRemaining = din.readInt();
+      numRecords = numRecordsRemaining = din.readInt();
     } catch (IOException e) {
       Closeables.close(bs, /* swallowIOException = */ true);
       throw e;
@@ -61,6 +62,11 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator
implemen
   }
 
   @Override
+  public int getNumRecords() {
+    return numRecords;
+  }
+
+  @Override
   public boolean hasNext() {
     return (numRecordsRemaining > 0);
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/6a1c864a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 89b17c8..be88539 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution
 
+import java.util
+
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
@@ -26,6 +28,8 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
+import org.apache.spark.{SparkEnv, TaskContext}
 
 /**
  * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted)
@@ -283,23 +287,26 @@ case class Window(
         val grouping = UnsafeProjection.create(partitionSpec, child.output)
 
         // Manage the stream and the grouping.
-        var nextRow: InternalRow = EmptyRow
-        var nextGroup: InternalRow = EmptyRow
+        var nextRow: UnsafeRow = null
+        var nextGroup: UnsafeRow = null
         var nextRowAvailable: Boolean = false
         private[this] def fetchNextRow() {
           nextRowAvailable = stream.hasNext
           if (nextRowAvailable) {
-            nextRow = stream.next()
+            nextRow = stream.next().asInstanceOf[UnsafeRow]
             nextGroup = grouping(nextRow)
           } else {
-            nextRow = EmptyRow
-            nextGroup = EmptyRow
+            nextRow = null
+            nextGroup = null
           }
         }
         fetchNextRow()
 
         // Manage the current partition.
-        val rows = ArrayBuffer.empty[InternalRow]
+        val rows = ArrayBuffer.empty[UnsafeRow]
+        val inputFields = child.output.length
+        var sorter: UnsafeExternalSorter = null
+        var rowBuffer: RowBuffer = null
         val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType))
         val frames = factories.map(_(windowFunctionResult))
         val numFrames = frames.length
@@ -307,27 +314,63 @@ case class Window(
           // Collect all the rows in the current partition.
           // Before we start to fetch new input rows, make a copy of nextGroup.
           val currentGroup = nextGroup.copy()
-          rows.clear()
+
+          // clear last partition
+          if (sorter != null) {
+            // the last sorter of this task will be cleaned up via task completion listener
+            sorter.cleanupResources()
+            sorter = null
+          } else {
+            rows.clear()
+          }
+
           while (nextRowAvailable && nextGroup == currentGroup) {
-            rows += nextRow.copy()
+            if (sorter == null) {
+              rows += nextRow.copy()
+
+              if (rows.length >= 4096) {
+                // We will not sort the rows, so prefixComparator and recordComparator are
null.
+                sorter = UnsafeExternalSorter.create(
+                  TaskContext.get().taskMemoryManager(),
+                  SparkEnv.get.blockManager,
+                  TaskContext.get(),
+                  null,
+                  null,
+                  1024,
+                  SparkEnv.get.memoryManager.pageSizeBytes)
+                rows.foreach { r =>
+                  sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes,
0)
+                }
+                rows.clear()
+              }
+            } else {
+              sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
+                nextRow.getSizeInBytes, 0)
+            }
             fetchNextRow()
           }
+          if (sorter != null) {
+            rowBuffer = new ExternalRowBuffer(sorter, inputFields)
+          } else {
+            rowBuffer = new ArrayRowBuffer(rows)
+          }
 
           // Setup the frames.
           var i = 0
           while (i < numFrames) {
-            frames(i).prepare(rows)
+            frames(i).prepare(rowBuffer.copy())
             i += 1
           }
 
           // Setup iteration
           rowIndex = 0
-          rowsSize = rows.size
+          rowsSize = rowBuffer.size()
         }
 
         // Iteration
         var rowIndex = 0
-        var rowsSize = 0
+        var rowsSize = 0L
+
         override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable
 
         val join = new JoinedRow
@@ -340,13 +383,14 @@ case class Window(
           if (rowIndex < rowsSize) {
             // Get the results for the window frames.
             var i = 0
+            val current = rowBuffer.next()
             while (i < numFrames) {
-              frames(i).write()
+              frames(i).write(rowIndex, current)
               i += 1
             }
 
             // 'Merge' the input row with the window function result
-            join(rows(rowIndex), windowFunctionResult)
+            join(current, windowFunctionResult)
             rowIndex += 1
 
             // Return the projection.
@@ -362,14 +406,18 @@ case class Window(
  * Function for comparing boundary values.
  */
 private[execution] abstract class BoundOrdering {
-  def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int
+  def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex:
Int): Int
 }
 
 /**
  * Compare the input index to the bound of the output index.
  */
 private[execution] final case class RowBoundOrdering(offset: Int) extends BoundOrdering {
-  override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int =
+  override def compare(
+      inputRow: InternalRow,
+      inputIndex: Int,
+      outputRow: InternalRow,
+      outputIndex: Int): Int =
     inputIndex - (outputIndex + offset)
 }
 
@@ -380,8 +428,100 @@ private[execution] final case class RangeBoundOrdering(
     ordering: Ordering[InternalRow],
     current: Projection,
     bound: Projection) extends BoundOrdering {
-  override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int =
-    ordering.compare(current(input(inputIndex)), bound(input(outputIndex)))
+  override def compare(
+      inputRow: InternalRow,
+      inputIndex: Int,
+      outputRow: InternalRow,
+      outputIndex: Int): Int =
+    ordering.compare(current(inputRow), bound(outputRow))
+}
+
+/**
+  * The interface of row buffer for a partition
+  */
+private[execution] abstract class RowBuffer {
+
+  /** Number of rows. */
+  def size(): Int
+
+  /** Return next row in the buffer, null if no more left. */
+  def next(): InternalRow
+
+  /** Skip the next `n` rows. */
+  def skip(n: Int): Unit
+
+  /** Return a new RowBuffer that has the same rows. */
+  def copy(): RowBuffer
+}
+
+/**
+  * A row buffer based on ArrayBuffer (the number of rows is limited)
+  */
+private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer
{
+
+  private[this] var cursor: Int = -1
+
+  /** Number of rows. */
+  def size(): Int = buffer.length
+
+  /** Return next row in the buffer, null if no more left. */
+  def next(): InternalRow = {
+    cursor += 1
+    if (cursor < buffer.length) {
+      buffer(cursor)
+    } else {
+      null
+    }
+  }
+
+  /** Skip the next `n` rows. */
+  def skip(n: Int): Unit = {
+    cursor += n
+  }
+
+  /** Return a new RowBuffer that has the same rows. */
+  def copy(): RowBuffer = {
+    new ArrayRowBuffer(buffer)
+  }
+}
+
+/**
+  * An external buffer of rows based on UnsafeExternalSorter
+  */
+private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int)
+  extends RowBuffer {
+
+  private[this] val iter: UnsafeSorterIterator = sorter.getIterator
+
+  private[this] val currentRow = new UnsafeRow(numFields)
+
+  /** Number of rows. */
+  def size(): Int = iter.getNumRecords()
+
+  /** Return next row in the buffer, null if no more left. */
+  def next(): InternalRow = {
+    if (iter.hasNext) {
+      iter.loadNext()
+      currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
+      currentRow
+    } else {
+      null
+    }
+  }
+
+  /** Skip the next `n` rows. */
+  def skip(n: Int): Unit = {
+    var i = 0
+    while (i < n && iter.hasNext) {
+      iter.loadNext()
+      i += 1
+    }
+  }
+
+  /** Return a new RowBuffer that has the same rows. */
+  def copy(): RowBuffer = {
+    new ExternalRowBuffer(sorter, numFields)
+  }
 }
 
 /**
@@ -395,12 +535,12 @@ private[execution] abstract class WindowFunctionFrame {
    *
    * @param rows to calculate the frame results for.
    */
-  def prepare(rows: ArrayBuffer[InternalRow]): Unit
+  def prepare(rows: RowBuffer): Unit
 
   /**
    * Write the current results to the target row.
    */
-  def write(): Unit
+  def write(index: Int, current: InternalRow): Unit
 }
 
 /**
@@ -421,14 +561,11 @@ private[execution] final class OffsetWindowFunctionFrame(
     offset: Int) extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: ArrayBuffer[InternalRow] = null
+  private[this] var input: RowBuffer = null
 
   /** Index of the input row currently used for output. */
   private[this] var inputIndex = 0
 
-  /** Index of the current output row. */
-  private[this] var outputIndex = 0
-
   /** Row used when there is no valid input. */
   private[this] val emptyRow = new GenericInternalRow(inputSchema.size)
 
@@ -463,22 +600,26 @@ private[execution] final class OffsetWindowFunctionFrame(
     newMutableProjection(boundExpressions, Nil)().target(target)
   }
 
-  override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
+  override def prepare(rows: RowBuffer): Unit = {
     input = rows
+    // drain the first few rows if offset is larger than zero
+    inputIndex = 0
+    while (inputIndex < offset) {
+      input.next()
+      inputIndex += 1
+    }
     inputIndex = offset
-    outputIndex = 0
   }
 
-  override def write(): Unit = {
-    val size = input.size
-    if (inputIndex >= 0 && inputIndex < size) {
-      join(input(inputIndex), input(outputIndex))
+  override def write(index: Int, current: InternalRow): Unit = {
+    if (inputIndex >= 0 && inputIndex < input.size) {
+      val r = input.next()
+      join(r, current)
     } else {
-      join(emptyRow, input(outputIndex))
+      join(emptyRow, current)
     }
     projection(join)
     inputIndex += 1
-    outputIndex += 1
   }
 }
 
@@ -498,7 +639,13 @@ private[execution] final class SlidingWindowFunctionFrame(
     ubound: BoundOrdering) extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: ArrayBuffer[InternalRow] = null
+  private[this] var input: RowBuffer = null
+
+  /** The next row from `input`. */
+  private[this] var nextRow: InternalRow = null
+
+  /** The rows within current sliding window. */
+  private[this] val buffer = new util.ArrayDeque[InternalRow]()
 
   /** Index of the first input row with a value greater than the upper bound of the current
     * output row. */
@@ -508,33 +655,32 @@ private[execution] final class SlidingWindowFunctionFrame(
     * current output row. */
   private[this] var inputLowIndex = 0
 
-  /** Index of the row we are currently writing. */
-  private[this] var outputIndex = 0
-
   /** Prepare the frame for calculating a new partition. Reset all variables. */
-  override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
+  override def prepare(rows: RowBuffer): Unit = {
     input = rows
+    nextRow = rows.next()
     inputHighIndex = 0
     inputLowIndex = 0
-    outputIndex = 0
+    buffer.clear()
   }
 
   /** Write the frame columns for the current row to the given target row. */
-  override def write(): Unit = {
-    var bufferUpdated = outputIndex == 0
+  override def write(index: Int, current: InternalRow): Unit = {
+    var bufferUpdated = index == 0
 
     // Add all rows to the buffer for which the input row value is equal to or less than
     // the output row upper bound.
-    while (inputHighIndex < input.size &&
-      ubound.compare(input, inputHighIndex, outputIndex) <= 0) {
+    while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index)
<= 0) {
+      buffer.add(nextRow.copy())
+      nextRow = input.next()
       inputHighIndex += 1
       bufferUpdated = true
     }
 
     // Drop all rows from the buffer for which the input row value is smaller than
     // the output row lower bound.
-    while (inputLowIndex < inputHighIndex &&
-      lbound.compare(input, inputLowIndex, outputIndex) < 0) {
+    while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current,
index) < 0) {
+      buffer.remove()
       inputLowIndex += 1
       bufferUpdated = true
     }
@@ -542,12 +688,12 @@ private[execution] final class SlidingWindowFunctionFrame(
     // Only recalculate and update when the buffer changes.
     if (bufferUpdated) {
       processor.initialize(input.size)
-      processor.update(input, inputLowIndex, inputHighIndex)
+      val iter = buffer.iterator()
+      while (iter.hasNext) {
+        processor.update(iter.next())
+      }
       processor.evaluate(target)
     }
-
-    // Move to the next row.
-    outputIndex += 1
   }
 }
 
@@ -567,13 +713,18 @@ private[execution] final class UnboundedWindowFunctionFrame(
     processor: AggregateProcessor) extends WindowFunctionFrame {
 
   /** Prepare the frame for calculating a new partition. Process all rows eagerly. */
-  override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
-    processor.initialize(rows.size)
-    processor.update(rows, 0, rows.size)
+  override def prepare(rows: RowBuffer): Unit = {
+    val size = rows.size()
+    processor.initialize(size)
+    var i = 0
+    while (i < size) {
+      processor.update(rows.next())
+      i += 1
+    }
   }
 
   /** Write the frame columns for the current row to the given target row. */
-  override def write(): Unit = {
+  override def write(index: Int, current: InternalRow): Unit = {
     // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate
     // for each row.
     processor.evaluate(target)
@@ -600,31 +751,32 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame(
     ubound: BoundOrdering) extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: ArrayBuffer[InternalRow] = null
+  private[this] var input: RowBuffer = null
+
+  /** The next row from `input`. */
+  private[this] var nextRow: InternalRow = null
 
   /** Index of the first input row with a value greater than the upper bound of the current
    * output row. */
   private[this] var inputIndex = 0
 
-  /** Index of the row we are currently writing. */
-  private[this] var outputIndex = 0
-
   /** Prepare the frame for calculating a new partition. */
-  override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
+  override def prepare(rows: RowBuffer): Unit = {
     input = rows
+    nextRow = rows.next()
     inputIndex = 0
-    outputIndex = 0
     processor.initialize(input.size)
   }
 
   /** Write the frame columns for the current row to the given target row. */
-  override def write(): Unit = {
-    var bufferUpdated = outputIndex == 0
+  override def write(index: Int, current: InternalRow): Unit = {
+    var bufferUpdated = index == 0
 
     // Add all rows to the aggregates for which the input row value is equal to or less than
     // the output row upper bound.
-    while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex)
<= 0) {
-      processor.update(input(inputIndex))
+    while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index)
<= 0) {
+      processor.update(nextRow)
+      nextRow = input.next()
       inputIndex += 1
       bufferUpdated = true
     }
@@ -633,9 +785,6 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame(
     if (bufferUpdated) {
       processor.evaluate(target)
     }
-
-    // Move to the next row.
-    outputIndex += 1
   }
 }
 
@@ -661,29 +810,31 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame(
     lbound: BoundOrdering) extends WindowFunctionFrame {
 
   /** Rows of the partition currently being processed. */
-  private[this] var input: ArrayBuffer[InternalRow] = null
+  private[this] var input: RowBuffer = null
 
   /** Index of the first input row with a value equal to or greater than the lower bound
of the
    * current output row. */
   private[this] var inputIndex = 0
 
-  /** Index of the row we are currently writing. */
-  private[this] var outputIndex = 0
-
   /** Prepare the frame for calculating a new partition. */
-  override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
+  override def prepare(rows: RowBuffer): Unit = {
     input = rows
     inputIndex = 0
-    outputIndex = 0
   }
 
   /** Write the frame columns for the current row to the given target row. */
-  override def write(): Unit = {
-    var bufferUpdated = outputIndex == 0
+  override def write(index: Int, current: InternalRow): Unit = {
+    var bufferUpdated = index == 0
+
+    // Duplicate the input to have a new iterator
+    val tmp = input.copy()
 
     // Drop all rows from the buffer for which the input row value is smaller than
     // the output row lower bound.
-    while (inputIndex < input.size && lbound.compare(input, inputIndex, outputIndex)
< 0) {
+    tmp.skip(inputIndex)
+    var nextRow = tmp.next()
+    while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index)
< 0) {
+      nextRow = tmp.next()
       inputIndex += 1
       bufferUpdated = true
     }
@@ -691,12 +842,12 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame(
     // Only recalculate and update when the buffer changes.
     if (bufferUpdated) {
       processor.initialize(input.size)
-      processor.update(input, inputIndex, input.size)
+      while (nextRow != null) {
+        processor.update(nextRow)
+        nextRow = tmp.next()
+      }
       processor.evaluate(target)
     }
-
-    // Move to the next row.
-    outputIndex += 1
   }
 }
 
@@ -825,15 +976,6 @@ private[execution] final class AggregateProcessor(
     }
   }
 
-  /** Bulk update the given buffer. */
-  def update(input: ArrayBuffer[InternalRow], begin: Int, end: Int): Unit = {
-    var i = begin
-    while (i < end) {
-      update(input(i))
-      i += 1
-    }
-  }
-
   /** Evaluate buffer. */
   def evaluate(target: MutableRow): Unit =
     evaluateProjection.target(target)(buffer)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message