Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id CC3D6200BF1 for ; Tue, 3 Jan 2017 15:11:59 +0100 (CET) Received: by cust-asf.ponee.io (Postfix) id CAD1F160B48; Tue, 3 Jan 2017 14:11:59 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id EAF7F160B33 for ; Tue, 3 Jan 2017 15:11:58 +0100 (CET) Received: (qmail 95761 invoked by uid 500); 3 Jan 2017 14:11:58 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 95745 invoked by uid 99); 3 Jan 2017 14:11:58 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 03 Jan 2017 14:11:58 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 10844DFC64; Tue, 3 Jan 2017 14:11:58 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: wenchen@apache.org To: commits@spark.apache.org Message-Id: <385085752e4f4b13a27dc32a285022e2@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-18932][SQL] Support partial aggregation for collect_set/collect_list Date: Tue, 3 Jan 2017 14:11:58 +0000 (UTC) archived-at: Tue, 03 Jan 2017 14:12:00 -0000 Repository: spark Updated Branches: refs/heads/master e5c307c50 -> 52636226d [SPARK-18932][SQL] Support partial aggregation for collect_set/collect_list ## What changes were proposed in this pull request? Currently collect_set/collect_list aggregation expression don't support partial aggregation. This patch is to enable partial aggregation for them. ## How was this patch tested? Jenkins tests. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh Closes #16371 from viirya/collect-partial-support. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/52636226 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/52636226 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/52636226 Branch: refs/heads/master Commit: 52636226dc8cb7fcf00381d65e280d651b25a382 Parents: e5c307c Author: Liang-Chi Hsieh Authored: Tue Jan 3 22:11:54 2017 +0800 Committer: Wenchen Fan Committed: Tue Jan 3 22:11:54 2017 +0800 ---------------------------------------------------------------------- .../expressions/aggregate/Percentile.scala | 7 +-- .../expressions/aggregate/collect.scala | 62 +++++++++++--------- .../expressions/aggregate/interfaces.scala | 4 +- .../RewriteDistinctAggregatesSuite.scala | 9 --- 4 files changed, 39 insertions(+), 43 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 2f4d68d..eaeb010 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -33,10 +33,9 @@ import org.apache.spark.util.collection.OpenHashMap * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at * the given percentage(s) with value range in [0.0, 1.0]. * - * The operator is bound to the slower sort based aggregation path because the number of elements - * and their partial order cannot be determined in advance. Therefore we have to store all the - * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory - * Errors. + * Because the number of elements and their partial order cannot be determined in advance. + * Therefore we have to store all the elements in memory, and so notice that too many elements can + * cause GC paused and eventually OutOfMemory Errors. * * @param child child expression that produce numeric column value with `child.eval(inputRow)` * @param percentageExpression Expression that represents a single percentage value or an array of http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index b176e2a..411f058 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + import scala.collection.generic.Growable import scala.collection.mutable @@ -27,14 +29,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ /** - * The Collect aggregate function collects all seen expression values into a list of values. + * A base class for collect_list and collect_set aggregate functions. * - * The operator is bound to the slower sort based aggregation path because the number of - * elements (and their memory usage) can not be determined in advance. This also means that the - * collected elements are stored on heap, and that too many elements can cause GC pauses and - * eventually Out of Memory Errors. + * We have to store all the collected elements in memory, and so notice that too many elements + * can cause GC paused and eventually OutOfMemory Errors. */ -abstract class Collect extends ImperativeAggregate { +abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] { val child: Expression @@ -44,40 +44,44 @@ abstract class Collect extends ImperativeAggregate { override def dataType: DataType = ArrayType(child.dataType) - override def supportsPartial: Boolean = false - - override def aggBufferAttributes: Seq[AttributeReference] = Nil - - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - override def inputAggBufferAttributes: Seq[AttributeReference] = Nil - // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the // actual order of input rows. override def deterministic: Boolean = false - protected[this] val buffer: Growable[Any] with Iterable[Any] - - override def initialize(b: InternalRow): Unit = { - buffer.clear() - } + override def update(buffer: T, input: InternalRow): T = { + val value = child.eval(input) - override def update(b: InternalRow, input: InternalRow): Unit = { // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator - val value = child.eval(input) if (value != null) { buffer += value } + buffer } - override def merge(buffer: InternalRow, input: InternalRow): Unit = { - sys.error("Collect cannot be used in partial aggregations.") + override def merge(buffer: T, other: T): T = { + buffer ++= other } - override def eval(input: InternalRow): Any = { + override def eval(buffer: T): Any = { new GenericArrayData(buffer.toArray) } + + private lazy val projection = UnsafeProjection.create( + Array[DataType](ArrayType(elementType = child.dataType, containsNull = false))) + private lazy val row = new UnsafeRow(1) + + override def serialize(obj: T): Array[Byte] = { + val array = new GenericArrayData(obj.toArray) + projection.apply(InternalRow.apply(array)).getBytes() + } + + override def deserialize(bytes: Array[Byte]): T = { + val buffer = createAggregationBuffer() + row.pointTo(bytes, bytes.length) + row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x) + buffer + } } /** @@ -88,7 +92,7 @@ abstract class Collect extends ImperativeAggregate { case class CollectList( child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends Collect { + inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] { def this(child: Expression) = this(child, 0, 0) @@ -98,9 +102,9 @@ case class CollectList( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - override def prettyName: String = "collect_list" + override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty - override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty + override def prettyName: String = "collect_list" } /** @@ -111,7 +115,7 @@ case class CollectList( case class CollectSet( child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends Collect { + inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] { def this(child: Expression) = this(child, 0, 0) @@ -131,5 +135,5 @@ case class CollectSet( override def prettyName: String = "collect_set" - override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty + override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty } http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 8e63fba..ccd4ae6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -458,7 +458,9 @@ abstract class DeclarativeAggregate * instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation * buffer's storage format, which is not supported by hash based aggregation. Hash based * aggregation only support aggregation buffer of mutable types (like LongType, IntType that have - * fixed length and can be mutated in place in UnsafeRow) + * fixed length and can be mutated in place in UnsafeRow). + * NOTE: The newly added ObjectHashAggregateExec supports TypedImperativeAggregate functions in + * hash based aggregation under some constraints. */ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { http://git-wip-us.apache.org/repos/asf/spark/blob/52636226/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 0b973c3..5c1faae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -59,15 +59,6 @@ class RewriteDistinctAggregatesSuite extends PlanTest { comparePlans(input, rewrite) } - test("single distinct group with non-partial aggregates") { - val input = testRelation - .groupBy('a, 'd)( - countDistinct('e, 'c).as('agg1), - CollectSet('b).toAggregateExpression().as('agg2)) - .analyze - checkRewrite(RewriteDistinctAggregates(input)) - } - test("multiple distinct groups") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org