spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-14451][SQL] Move encoder definition into Aggregator interface
Date Sat, 09 Apr 2016 07:00:43 GMT
Repository: spark
Updated Branches:
  refs/heads/master 2f0b882e5 -> 520dde48d


[SPARK-14451][SQL] Move encoder definition into Aggregator interface

## What changes were proposed in this pull request?
When we first introduced Aggregators, we required the user of Aggregators to (implicitly)
specify the encoders. It would actually make more sense to have the encoders be specified
by the implementation of Aggregators, since each implementation should have the most state
about how to encode its own data type.

Note that this simplifies the Java API because Java users no longer need to explicitly specify
encoders for aggregators.

## How was this patch tested?
Updated unit tests.

Author: Reynold Xin <rxin@databricks.com>

Closes #12231 from rxin/SPARK-14451.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/520dde48
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/520dde48
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/520dde48

Branch: refs/heads/master
Commit: 520dde48d0d52dbbbbe1710a3275fdd5355dd69d
Parents: 2f0b882
Author: Reynold Xin <rxin@databricks.com>
Authored: Sat Apr 9 00:00:39 2016 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Sat Apr 9 00:00:39 2016 -0700

----------------------------------------------------------------------
 project/MimaExcludes.scala                      |  5 ++
 .../scala/org/apache/spark/repl/ReplSuite.scala | 28 +-------
 .../scala/org/apache/spark/repl/ReplSuite.scala | 29 +-------
 .../execution/aggregate/typedaggregators.scala  | 47 ++++++------
 .../spark/sql/expressions/Aggregator.scala      | 39 ++++++----
 .../sql/sources/JavaDatasetAggregatorSuite.java | 17 +++--
 .../spark/sql/DatasetAggregatorSuite.scala      | 75 +++++++++++---------
 7 files changed, 113 insertions(+), 127 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/520dde48/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index f240c30..290de79 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -329,6 +329,11 @@ object MimaExcludes {
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"),
         ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"),
 
+        // [SPARK-14451][SQL] Move encoder definition into Aggregator interface
+        ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"),
+        ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"),
+        ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"),
+
         ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"),
         ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"),
         ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions")

http://git-wip-us.apache.org/repos/asf/spark/blob/520dde48/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
----------------------------------------------------------------------
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index c8b78bc..547da8f 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -285,7 +285,7 @@ class ReplSuite extends SparkFunSuite {
     val output = runInterpreter("local",
       """
         |import org.apache.spark.sql.functions._
-        |import org.apache.spark.sql.Encoder
+        |import org.apache.spark.sql.{Encoder, Encoders}
         |import org.apache.spark.sql.expressions.Aggregator
         |import org.apache.spark.sql.TypedColumn
         |val simpleSum = new Aggregator[Int, Int, Int] {
@@ -293,6 +293,8 @@ class ReplSuite extends SparkFunSuite {
         |  def reduce(b: Int, a: Int) = b + a    // Add an element to the running total
         |  def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
         |  def finish(b: Int) = b                // Return the final result.
+        |  def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+        |  def outputEncoder: Encoder[Int] = Encoders.scalaInt
         |}.toColumn
         |
         |val ds = Seq(1, 2, 3, 4).toDS()
@@ -339,30 +341,6 @@ class ReplSuite extends SparkFunSuite {
     }
   }
 
-  test("Datasets agg type-inference") {
-    val output = runInterpreter("local",
-      """
-        |import org.apache.spark.sql.functions._
-        |import org.apache.spark.sql.Encoder
-        |import org.apache.spark.sql.expressions.Aggregator
-        |import org.apache.spark.sql.TypedColumn
-        |/** An `Aggregator` that adds up any numeric type returned by the given function.
*/
-        |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
-        |  val numeric = implicitly[Numeric[N]]
-        |  override def zero: N = numeric.zero
-        |  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
-        |  override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
-        |  override def finish(reduction: N): N = reduction
-        |}
-        |
-        |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
-        |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
-        |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
-      """.stripMargin)
-    assertDoesNotContain("error:", output)
-    assertDoesNotContain("Exception", output)
-  }
-
   test("collecting objects of class defined in repl") {
     val output = runInterpreter("local[2]",
       """

http://git-wip-us.apache.org/repos/asf/spark/blob/520dde48/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
----------------------------------------------------------------------
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index dbfacba..7e10f15 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -267,7 +267,7 @@ class ReplSuite extends SparkFunSuite {
     val output = runInterpreter("local",
       """
         |import org.apache.spark.sql.functions._
-        |import org.apache.spark.sql.Encoder
+        |import org.apache.spark.sql.{Encoder, Encoders}
         |import org.apache.spark.sql.expressions.Aggregator
         |import org.apache.spark.sql.TypedColumn
         |val simpleSum = new Aggregator[Int, Int, Int] {
@@ -275,6 +275,8 @@ class ReplSuite extends SparkFunSuite {
         |  def reduce(b: Int, a: Int) = b + a    // Add an element to the running total
         |  def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
         |  def finish(b: Int) = b                // Return the final result.
+        |  def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+        |  def outputEncoder: Encoder[Int] = Encoders.scalaInt
         |}.toColumn
         |
         |val ds = Seq(1, 2, 3, 4).toDS()
@@ -321,31 +323,6 @@ class ReplSuite extends SparkFunSuite {
     }
   }
 
-  test("Datasets agg type-inference") {
-    val output = runInterpreter("local",
-      """
-        |import org.apache.spark.sql.functions._
-        |import org.apache.spark.sql.Encoder
-        |import org.apache.spark.sql.expressions.Aggregator
-        |import org.apache.spark.sql.TypedColumn
-        |/** An `Aggregator` that adds up any numeric type returned by the given function.
*/
-        |class SumOf[I, N : Numeric](f: I => N) extends
-        |  org.apache.spark.sql.expressions.Aggregator[I, N, N] {
-        |  val numeric = implicitly[Numeric[N]]
-        |  override def zero: N = numeric.zero
-        |  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
-        |  override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
-        |  override def finish(reduction: N): N = reduction
-        |}
-        |
-        |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
-        |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
-        |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
-      """.stripMargin)
-    assertDoesNotContain("error:", output)
-    assertDoesNotContain("Exception", output)
-  }
-
   test("collecting objects of class defined in repl") {
     val output = runInterpreter("local[2]",
       """

http://git-wip-us.apache.org/repos/asf/spark/blob/520dde48/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
index 7a18d0a..c39a78d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.aggregate
 
 import org.apache.spark.api.java.function.MapFunction
-import org.apache.spark.sql.TypedColumn
+import org.apache.spark.sql.{Encoder, TypedColumn}
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
 
@@ -27,28 +27,20 @@ import org.apache.spark.sql.expressions.Aggregator
 ////////////////////////////////////////////////////////////////////////////////////////////////////
 
 
-class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] {
-  val numeric = implicitly[Numeric[OUT]]
-  override def zero: OUT = numeric.zero
-  override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
-  override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
-  override def finish(reduction: OUT): OUT = reduction
-
-  // TODO(ekl) java api support once this is exposed in scala
-}
-
-
 class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] {
   override def zero: Double = 0.0
   override def reduce(b: Double, a: IN): Double = b + f(a)
   override def merge(b1: Double, b2: Double): Double = b1 + b2
   override def finish(reduction: Double): Double = reduction
 
+  override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+  override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+
   // Java api support
   def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
-  def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
-    toColumn(ExpressionEncoder(), ExpressionEncoder())
-      .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+
+  def toColumnJava: TypedColumn[IN, java.lang.Double] = {
+    toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
   }
 }
 
@@ -59,11 +51,14 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long,
Long] {
   override def merge(b1: Long, b2: Long): Long = b1 + b2
   override def finish(reduction: Long): Long = reduction
 
+  override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+  override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+
   // Java api support
   def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long])
-  def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
-    toColumn(ExpressionEncoder(), ExpressionEncoder())
-      .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+
+  def toColumnJava: TypedColumn[IN, java.lang.Long] = {
+    toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
   }
 }
 
@@ -76,11 +71,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long]
{
   override def merge(b1: Long, b2: Long): Long = b1 + b2
   override def finish(reduction: Long): Long = reduction
 
+  override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+  override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+
   // Java api support
   def this(f: MapFunction[IN, Object]) = this(x => f.call(x))
-  def toColumnJava(): TypedColumn[IN, java.lang.Long] = {
-    toColumn(ExpressionEncoder(), ExpressionEncoder())
-      .asInstanceOf[TypedColumn[IN, java.lang.Long]]
+  def toColumnJava: TypedColumn[IN, java.lang.Long] = {
+    toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
   }
 }
 
@@ -93,10 +90,12 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double,
Long), D
     (b1._1 + b2._1, b1._2 + b2._2)
   }
 
+  override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]()
+  override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+
   // Java api support
   def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
-  def toColumnJava(): TypedColumn[IN, java.lang.Double] = {
-    toColumn(ExpressionEncoder(), ExpressionEncoder())
-      .asInstanceOf[TypedColumn[IN, java.lang.Double]]
+  def toColumnJava: TypedColumn[IN, java.lang.Double] = {
+    toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/520dde48/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 9cb356f..7da8379 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -43,52 +43,65 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
  *
  * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird
  *
- * @tparam I The input type for the aggregation.
- * @tparam B The type of the intermediate value of the reduction.
- * @tparam O The type of the final output result.
+ * @tparam IN The input type for the aggregation.
+ * @tparam BUF The type of the intermediate value of the reduction.
+ * @tparam OUT The type of the final output result.
  * @since 1.6.0
  */
-abstract class Aggregator[-I, B, O] extends Serializable {
+abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
 
   /**
    * A zero value for this aggregation. Should satisfy the property that any b + zero = b.
    * @since 1.6.0
    */
-  def zero: B
+  def zero: BUF
 
   /**
    * Combine two values to produce a new value.  For performance, the function may modify
`b` and
    * return it instead of constructing new object for b.
    * @since 1.6.0
    */
-  def reduce(b: B, a: I): B
+  def reduce(b: BUF, a: IN): BUF
 
   /**
    * Merge two intermediate values.
    * @since 1.6.0
    */
-  def merge(b1: B, b2: B): B
+  def merge(b1: BUF, b2: BUF): BUF
 
   /**
    * Transform the output of the reduction.
    * @since 1.6.0
    */
-  def finish(reduction: B): O
+  def finish(reduction: BUF): OUT
 
   /**
-   * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]]
+   * Specifies the [[Encoder]] for the intermediate value type.
+   * @since 2.0.0
+   */
+  def bufferEncoder: Encoder[BUF]
+
+  /**
+   * Specifies the [[Encoder]] for the final ouput value type.
+   * @since 2.0.0
+   */
+  def outputEncoder: Encoder[OUT]
+
+  /**
+   * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]].
    * operations.
    * @since 1.6.0
    */
-  def toColumn(
-      implicit bEncoder: Encoder[B],
-      cEncoder: Encoder[O]): TypedColumn[I, O] = {
+  def toColumn: TypedColumn[IN, OUT] = {
+    implicit val bEncoder = bufferEncoder
+    implicit val cEncoder = outputEncoder
+
     val expr =
       AggregateExpression(
         TypedAggregateExpression(this),
         Complete,
         isDistinct = false)
 
-    new TypedColumn[I, O](expr, encoderFor[O])
+    new TypedColumn[IN, OUT](expr, encoderFor[OUT])
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/520dde48/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
index 8cb174b..0e49f87 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -26,6 +26,7 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.function.MapFunction;
 import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
 import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.KeyValueGroupedDataset;
 import org.apache.spark.sql.expressions.Aggregator;
@@ -39,12 +40,10 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase
{
   public void testTypedAggregationAnonClass() {
     KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
 
-    Dataset<Tuple2<String, Integer>> agged =
-      grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
+    Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn());
     Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
 
-    Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
-      new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
+    Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(new IntSumOf().toColumn())
       .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
     Assert.assertEquals(
       Arrays.asList(
@@ -73,6 +72,16 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase
{
     public Integer finish(Integer reduction) {
       return reduction;
     }
+
+    @Override
+    public Encoder<Integer> bufferEncoder() {
+      return Encoders.INT();
+    }
+
+    @Override
+    public Encoder<Integer> outputEncoder() {
+      return Encoders.INT();
+    }
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/spark/blob/520dde48/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 08b3389..3a7215e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import scala.language.postfixOps
 
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.expressions.scala.typed
 import org.apache.spark.sql.functions._
@@ -26,74 +27,65 @@ import org.apache.spark.sql.test.SharedSQLContext
 
 
 object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
-
   override def zero: (Long, Long) = (0, 0)
-
   override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
     (countAndSum._1 + 1, countAndSum._2 + input._2)
   }
-
   override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
     (b1._1 + b2._1, b1._2 + b2._2)
   }
-
   override def finish(reduction: (Long, Long)): (Long, Long) = reduction
+  override def bufferEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)]
+  override def outputEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)]
 }
 
+
 case class AggData(a: Int, b: String)
+
 object ClassInputAgg extends Aggregator[AggData, Int, Int] {
-  /** A zero value for this aggregation. Should satisfy the property that any b + zero =
b */
   override def zero: Int = 0
-
-  /**
-   * Combine two values to produce a new value.  For performance, the function may modify
`b` and
-   * return it instead of constructing new object for b.
-   */
   override def reduce(b: Int, a: AggData): Int = b + a.a
-
-  /**
-   * Transform the output of the reduction.
-   */
   override def finish(reduction: Int): Int = reduction
-
-  /**
-   * Merge two intermediate values
-   */
   override def merge(b1: Int, b2: Int): Int = b1 + b2
+  override def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+  override def outputEncoder: Encoder[Int] = Encoders.scalaInt
 }
 
+
 object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
-  /** A zero value for this aggregation. Should satisfy the property that any b + zero =
b */
   override def zero: (Int, AggData) = 0 -> AggData(0, "0")
-
-  /**
-   * Combine two values to produce a new value.  For performance, the function may modify
`b` and
-   * return it instead of constructing new object for b.
-   */
   override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a)
-
-  /**
-   * Transform the output of the reduction.
-   */
   override def finish(reduction: (Int, AggData)): Int = reduction._1
-
-  /**
-   * Merge two intermediate values
-   */
   override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) =
     (b1._1 + b2._1, b1._2)
+  override def bufferEncoder: Encoder[(Int, AggData)] = Encoders.product[(Int, AggData)]
+  override def outputEncoder: Encoder[Int] = Encoders.scalaInt
 }
 
+
 object NameAgg extends Aggregator[AggData, String, String] {
   def zero: String = ""
-
   def reduce(b: String, a: AggData): String = a.b + b
-
   def merge(b1: String, b2: String): String = b1 + b2
-
   def finish(r: String): String = r
+  override def bufferEncoder: Encoder[String] = Encoders.STRING
+  override def outputEncoder: Encoder[String] = Encoders.STRING
+}
+
+
+class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT)
+  extends Aggregator[IN, OUT, OUT] {
+
+  private val numeric = implicitly[Numeric[OUT]]
+  override def zero: OUT = numeric.zero
+  override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
+  override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
+  override def finish(reduction: OUT): OUT = reduction
+  override def bufferEncoder: Encoder[OUT] = implicitly[Encoder[OUT]]
+  override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]]
 }
 
+
 class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
 
   import testImplicits._
@@ -187,6 +179,19 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext
{
       ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L))
   }
 
+  test("generic typed sum") {
+    val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+    checkDataset(
+      ds.groupByKey(_._1)
+        .agg(new ParameterizedTypeSum[(String, Int), Double](_._2.toDouble).toColumn),
+      ("a", 4.0), ("b", 3.0))
+
+    checkDataset(
+      ds.groupByKey(_._1)
+        .agg(new ParameterizedTypeSum((x: (String, Int)) => x._2.toInt).toColumn),
+      ("a", 4), ("b", 3))
+  }
+
   test("SPARK-12555 - result should not be corrupted after input columns are reordered")
{
     val ds = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData]
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message