spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-11913][SQL] support typed aggregate with complex buffer schema
Date Mon, 23 Nov 2015 18:40:02 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 bd94793eb -> f10f1a1f7


[SPARK-11913][SQL] support typed aggregate with complex buffer schema

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9898 from cloud-fan/agg.

(cherry picked from commit 946b406519af58c79041217e6f93854b6cf80acd)
Signed-off-by: Michael Armbrust <michael@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f10f1a1f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f10f1a1f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f10f1a1f

Branch: refs/heads/branch-1.6
Commit: f10f1a1f7ee38c6880e33bb38c630a973ecc1798
Parents: bd94793
Author: Wenchen Fan <wenchen@databricks.com>
Authored: Mon Nov 23 10:39:33 2015 -0800
Committer: Michael Armbrust <michael@databricks.com>
Committed: Mon Nov 23 10:39:55 2015 -0800

----------------------------------------------------------------------
 .../aggregate/TypedAggregateExpression.scala    | 25 +++++++-----
 .../spark/sql/DatasetAggregatorSuite.scala      | 41 +++++++++++++++++++-
 2 files changed, 56 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f10f1a1f/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 6ce41aa..a971912 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -23,9 +23,8 @@ import org.apache.spark.Logging
 import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder}
 import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 
@@ -46,14 +45,12 @@ object TypedAggregateExpression {
 /**
  * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system.
 It has
  * the following limitations:
- *  - It assumes the aggregator reduces and returns a single column of type `long`.
- *  - It might only work when there is a single aggregator in the first column.
  *  - It assumes the aggregator has a zero, `0`.
  */
 case class TypedAggregateExpression(
     aggregator: Aggregator[Any, Any, Any],
     aEncoder: Option[ExpressionEncoder[Any]], // Should be bound.
-    bEncoder: ExpressionEncoder[Any], // Should be bound.
+    unresolvedBEncoder: ExpressionEncoder[Any],
     cEncoder: ExpressionEncoder[Any],
     children: Seq[Attribute],
     mutableAggBufferOffset: Int,
@@ -80,10 +77,14 @@ case class TypedAggregateExpression(
 
   override lazy val inputTypes: Seq[DataType] = Nil
 
-  override val aggBufferSchema: StructType = bEncoder.schema
+  override val aggBufferSchema: StructType = unresolvedBEncoder.schema
 
   override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes
 
+  val bEncoder = unresolvedBEncoder
+    .resolve(aggBufferAttributes, OuterScopes.outerScopes)
+    .bind(aggBufferAttributes)
+
   // Note: although this simply copies aggBufferAttributes, this common code can not be placed
   // in the superclass because that will lead to initialization ordering issues.
   override val inputAggBufferAttributes: Seq[AttributeReference] =
@@ -93,12 +94,18 @@ case class TypedAggregateExpression(
   lazy val boundA = aEncoder.get
 
   private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = {
-    // todo: need a more neat way to assign the value.
     var i = 0
     while (i < aggBufferAttributes.length) {
+      val offset = mutableAggBufferOffset + i
       aggBufferSchema(i).dataType match {
-        case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i))
-        case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i))
+        case BooleanType => buffer.setBoolean(offset, value.getBoolean(i))
+        case ByteType => buffer.setByte(offset, value.getByte(i))
+        case ShortType => buffer.setShort(offset, value.getShort(i))
+        case IntegerType => buffer.setInt(offset, value.getInt(i))
+        case LongType => buffer.setLong(offset, value.getLong(i))
+        case FloatType => buffer.setFloat(offset, value.getFloat(i))
+        case DoubleType => buffer.setDouble(offset, value.getDouble(i))
+        case other => buffer.update(offset, value.get(i, other))
       }
       i += 1
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/f10f1a1f/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 9377589..19dce5d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -67,7 +67,7 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long),
(Long, L
 }
 
 case class AggData(a: Int, b: String)
-object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable {
+object ClassInputAgg extends Aggregator[AggData, Int, Int] {
   /** A zero value for this aggregation. Should satisfy the property that any b + zero =
b */
   override def zero: Int = 0
 
@@ -88,6 +88,28 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable
{
   override def merge(b1: Int, b2: Int): Int = b1 + b2
 }
 
+object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
+  /** A zero value for this aggregation. Should satisfy the property that any b + zero =
b */
+  override def zero: (Int, AggData) = 0 -> AggData(0, "0")
+
+  /**
+   * Combine two values to produce a new value.  For performance, the function may modify
`b` and
+   * return it instead of constructing new object for b.
+   */
+  override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a)
+
+  /**
+   * Transform the output of the reduction.
+   */
+  override def finish(reduction: (Int, AggData)): Int = reduction._1
+
+  /**
+   * Merge two intermediate values
+   */
+  override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) =
+    (b1._1 + b2._1, b1._2)
+}
+
 class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
 
   import testImplicits._
@@ -168,4 +190,21 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext
{
       ds.groupBy(_.b).agg(ClassInputAgg.toColumn),
       ("one", 1))
   }
+
+  test("typed aggregation: complex input") {
+    val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
+
+    checkAnswer(
+      ds.select(ComplexBufferAgg.toColumn),
+      2
+    )
+
+    checkAnswer(
+      ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn),
+      (1.5, 2))
+
+    checkAnswer(
+      ds.groupBy(_.b).agg(ComplexBufferAgg.toColumn),
+      ("one", 1), ("two", 1))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message