spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-18932][SQL] Support partial aggregation for collect_set/collect_list
Date Tue, 03 Jan 2017 14:11:58 GMT
Repository: spark
Updated Branches:
  refs/heads/master e5c307c50 -> 52636226d


[SPARK-18932][SQL] Support partial aggregation for collect_set/collect_list

## What changes were proposed in this pull request?

Currently collect_set/collect_list aggregation expression don't support partial aggregation.
This patch is to enable partial aggregation for them.

## How was this patch tested?

Jenkins tests.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #16371 from viirya/collect-partial-support.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/52636226
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/52636226
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/52636226

Branch: refs/heads/master
Commit: 52636226dc8cb7fcf00381d65e280d651b25a382
Parents: e5c307c
Author: Liang-Chi Hsieh <viirya@gmail.com>
Authored: Tue Jan 3 22:11:54 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Tue Jan 3 22:11:54 2017 +0800

----------------------------------------------------------------------
 .../expressions/aggregate/Percentile.scala      |  7 +--
 .../expressions/aggregate/collect.scala         | 62 +++++++++++---------
 .../expressions/aggregate/interfaces.scala      |  4 +-
 .../RewriteDistinctAggregatesSuite.scala        |  9 ---
 4 files changed, 39 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 2f4d68d..eaeb010 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -33,10 +33,9 @@ import org.apache.spark.util.collection.OpenHashMap
  * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr`
at
  * the given percentage(s) with value range in [0.0, 1.0].
  *
- * The operator is bound to the slower sort based aggregation path because the number of
elements
- * and their partial order cannot be determined in advance. Therefore we have to store all
the
- * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory
- * Errors.
+ * Because the number of elements and their partial order cannot be determined in advance.
+ * Therefore we have to store all the elements in memory, and so notice that too many elements
can
+ * cause GC paused and eventually OutOfMemory Errors.
  *
  * @param child child expression that produce numeric column value with `child.eval(inputRow)`
  * @param percentageExpression Expression that represents a single percentage value or an
array of

http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index b176e2a..411f058 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
 import scala.collection.generic.Growable
 import scala.collection.mutable
 
@@ -27,14 +29,12 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 
 /**
- * The Collect aggregate function collects all seen expression values into a list of values.
+ * A base class for collect_list and collect_set aggregate functions.
  *
- * The operator is bound to the slower sort based aggregation path because the number of
- * elements (and their memory usage) can not be determined in advance. This also means that
the
- * collected elements are stored on heap, and that too many elements can cause GC pauses
and
- * eventually Out of Memory Errors.
+ * We have to store all the collected elements in memory, and so notice that too many elements
+ * can cause GC paused and eventually OutOfMemory Errors.
  */
-abstract class Collect extends ImperativeAggregate {
+abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T]
{
 
   val child: Expression
 
@@ -44,40 +44,44 @@ abstract class Collect extends ImperativeAggregate {
 
   override def dataType: DataType = ArrayType(child.dataType)
 
-  override def supportsPartial: Boolean = false
-
-  override def aggBufferAttributes: Seq[AttributeReference] = Nil
-
-  override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
-
-  override def inputAggBufferAttributes: Seq[AttributeReference] = Nil
-
   // Both `CollectList` and `CollectSet` are non-deterministic since their results depend
on the
   // actual order of input rows.
   override def deterministic: Boolean = false
 
-  protected[this] val buffer: Growable[Any] with Iterable[Any]
-
-  override def initialize(b: InternalRow): Unit = {
-    buffer.clear()
-  }
+  override def update(buffer: T, input: InternalRow): T = {
+    val value = child.eval(input)
 
-  override def update(b: InternalRow, input: InternalRow): Unit = {
     // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set
here.
     // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
-    val value = child.eval(input)
     if (value != null) {
       buffer += value
     }
+    buffer
   }
 
-  override def merge(buffer: InternalRow, input: InternalRow): Unit = {
-    sys.error("Collect cannot be used in partial aggregations.")
+  override def merge(buffer: T, other: T): T = {
+    buffer ++= other
   }
 
-  override def eval(input: InternalRow): Any = {
+  override def eval(buffer: T): Any = {
     new GenericArrayData(buffer.toArray)
   }
+
+  private lazy val projection = UnsafeProjection.create(
+    Array[DataType](ArrayType(elementType = child.dataType, containsNull = false)))
+  private lazy val row = new UnsafeRow(1)
+
+  override def serialize(obj: T): Array[Byte] = {
+    val array = new GenericArrayData(obj.toArray)
+    projection.apply(InternalRow.apply(array)).getBytes()
+  }
+
+  override def deserialize(bytes: Array[Byte]): T = {
+    val buffer = createAggregationBuffer()
+    row.pointTo(bytes, bytes.length)
+    row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x)
+    buffer
+  }
 }
 
 /**
@@ -88,7 +92,7 @@ abstract class Collect extends ImperativeAggregate {
 case class CollectList(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0) extends Collect {
+    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
 
   def this(child: Expression) = this(child, 0, 0)
 
@@ -98,9 +102,9 @@ case class CollectList(
   override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate
=
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  override def prettyName: String = "collect_list"
+  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
 
-  override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
+  override def prettyName: String = "collect_list"
 }
 
 /**
@@ -111,7 +115,7 @@ case class CollectList(
 case class CollectSet(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0) extends Collect {
+    inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {
 
   def this(child: Expression) = this(child, 0, 0)
 
@@ -131,5 +135,5 @@ case class CollectSet(
 
   override def prettyName: String = "collect_set"
 
-  override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty
+  override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 8e63fba..ccd4ae6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -458,7 +458,9 @@ abstract class DeclarativeAggregate
  * instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
  * buffer's storage format, which is not supported by hash based aggregation. Hash based
  * aggregation only support aggregation buffer of mutable types (like LongType, IntType that
have
- * fixed length and can be mutated in place in UnsafeRow)
+ * fixed length and can be mutated in place in UnsafeRow).
+ * NOTE: The newly added ObjectHashAggregateExec supports TypedImperativeAggregate functions
in
+ * hash based aggregation under some constraints.
  */
 abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
index 0b973c3..5c1faae 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
@@ -59,15 +59,6 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
     comparePlans(input, rewrite)
   }
 
-  test("single distinct group with non-partial aggregates") {
-    val input = testRelation
-      .groupBy('a, 'd)(
-        countDistinct('e, 'c).as('agg1),
-        CollectSet('b).toAggregateExpression().as('agg2))
-      .analyze
-    checkRewrite(RewriteDistinctAggregates(input))
-  }
-
   test("multiple distinct groups") {
     val input = testRelation
       .groupBy('a)(countDistinct('b, 'c), countDistinct('d))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message