spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-14285][SQL] Implement common type-safe aggregate functions
Date Sat, 02 Apr 2016 05:46:58 GMT
Repository: spark
Updated Branches:
  refs/heads/master fa1af0aff -> f41415441


[SPARK-14285][SQL] Implement common type-safe aggregate functions

## What changes were proposed in this pull request?
In the Dataset API, it is fairly difficult for users to perform simple aggregations in a type-safe
way at the moment because there are no aggregators that have been implemented. This pull request
adds a few common aggregate functions in expressions.scala.typed package, and also creates
the expressions.java.typed package without implementation. The java implementation should
probably come as a separate pull request. One challenge there is to resolve the type difference
between Scala primitive types and Java boxed types.

## How was this patch tested?
Added unit tests for them.

Author: Reynold Xin <rxin@databricks.com>

Closes #12077 from rxin/SPARK-14285.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f4141544
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f4141544
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f4141544

Branch: refs/heads/master
Commit: f414154418c2291448954b9f0890d592b2d823ae
Parents: fa1af0a
Author: Reynold Xin <rxin@databricks.com>
Authored: Fri Apr 1 22:46:56 2016 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Fri Apr 1 22:46:56 2016 -0700

----------------------------------------------------------------------
 .../execution/aggregate/typedaggregators.scala  |  69 +++++++++++
 .../apache/spark/sql/expressions/Window.scala   |   8 +-
 .../spark/sql/expressions/WindowSpec.scala      |   8 +-
 .../spark/sql/expressions/java/typed.java       |  34 +++++
 .../spark/sql/expressions/scala/typed.scala     |  89 ++++++++++++++
 .../org/apache/spark/sql/expressions/udaf.scala |   4 +-
 .../org/apache/spark/sql/JavaDatasetSuite.java  |  54 --------
 .../sql/sources/JavaDatasetAggregatorSuite.java | 123 +++++++++++++++++++
 .../spark/sql/DatasetAggregatorSuite.scala      |  64 +++-------
 9 files changed, 342 insertions(+), 111 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f4141544/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
new file mode 100644
index 0000000..9afc290
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.expressions.Aggregator
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// This file defines internal implementations for aggregators.
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+
+class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] {
+  val numeric = implicitly[Numeric[OUT]]
+  override def zero: OUT = numeric.zero
+  override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
+  override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
+  override def finish(reduction: OUT): OUT = reduction
+}
+
+
+class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] {
+  override def zero: Double = 0.0
+  override def reduce(b: Double, a: IN): Double = b + f(a)
+  override def merge(b1: Double, b2: Double): Double = b1 + b2
+  override def finish(reduction: Double): Double = reduction
+}
+
+
+class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
+  override def zero: Long = 0L
+  override def reduce(b: Long, a: IN): Long = b + f(a)
+  override def merge(b1: Long, b2: Long): Long = b1 + b2
+  override def finish(reduction: Long): Long = reduction
+}
+
+
+class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
+  override def zero: Long = 0
+  override def reduce(b: Long, a: IN): Long = {
+    if (f(a) == null) b else b + 1
+  }
+  override def merge(b1: Long, b2: Long): Long = b1 + b2
+  override def finish(reduction: Long): Long = reduction
+}
+
+
+class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double]
{
+  override def zero: (Double, Long) = (0.0, 0L)
+  override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2)
+  override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2
+  override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
+    (b1._1 + b2._1, b1._2 + b2._2)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f4141544/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
index e9b6084..350c283 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
@@ -42,7 +42,7 @@ object Window {
    * Creates a [[WindowSpec]] with the partitioning defined.
    * @since 1.4.0
    */
-  @scala.annotation.varargs
+  @_root_.scala.annotation.varargs
   def partitionBy(colName: String, colNames: String*): WindowSpec = {
     spec.partitionBy(colName, colNames : _*)
   }
@@ -51,7 +51,7 @@ object Window {
    * Creates a [[WindowSpec]] with the partitioning defined.
    * @since 1.4.0
    */
-  @scala.annotation.varargs
+  @_root_.scala.annotation.varargs
   def partitionBy(cols: Column*): WindowSpec = {
     spec.partitionBy(cols : _*)
   }
@@ -60,7 +60,7 @@ object Window {
    * Creates a [[WindowSpec]] with the ordering defined.
    * @since 1.4.0
    */
-  @scala.annotation.varargs
+  @_root_.scala.annotation.varargs
   def orderBy(colName: String, colNames: String*): WindowSpec = {
     spec.orderBy(colName, colNames : _*)
   }
@@ -69,7 +69,7 @@ object Window {
    * Creates a [[WindowSpec]] with the ordering defined.
    * @since 1.4.0
    */
-  @scala.annotation.varargs
+  @_root_.scala.annotation.varargs
   def orderBy(cols: Column*): WindowSpec = {
     spec.orderBy(cols : _*)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f4141544/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index 9e9c58c..d716da2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -39,7 +39,7 @@ class WindowSpec private[sql](
    * Defines the partitioning columns in a [[WindowSpec]].
    * @since 1.4.0
    */
-  @scala.annotation.varargs
+  @_root_.scala.annotation.varargs
   def partitionBy(colName: String, colNames: String*): WindowSpec = {
     partitionBy((colName +: colNames).map(Column(_)): _*)
   }
@@ -48,7 +48,7 @@ class WindowSpec private[sql](
    * Defines the partitioning columns in a [[WindowSpec]].
    * @since 1.4.0
    */
-  @scala.annotation.varargs
+  @_root_.scala.annotation.varargs
   def partitionBy(cols: Column*): WindowSpec = {
     new WindowSpec(cols.map(_.expr), orderSpec, frame)
   }
@@ -57,7 +57,7 @@ class WindowSpec private[sql](
    * Defines the ordering columns in a [[WindowSpec]].
    * @since 1.4.0
    */
-  @scala.annotation.varargs
+  @_root_.scala.annotation.varargs
   def orderBy(colName: String, colNames: String*): WindowSpec = {
     orderBy((colName +: colNames).map(Column(_)): _*)
   }
@@ -66,7 +66,7 @@ class WindowSpec private[sql](
    * Defines the ordering columns in a [[WindowSpec]].
    * @since 1.4.0
    */
-  @scala.annotation.varargs
+  @_root_.scala.annotation.varargs
   def orderBy(cols: Column*): WindowSpec = {
     val sortOrder: Seq[SortOrder] = cols.map { col =>
       col.expr match {

http://git-wip-us.apache.org/repos/asf/spark/blob/f4141544/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
new file mode 100644
index 0000000..cdba970
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.java;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.sql.Dataset;
+
+/**
+ * :: Experimental ::
+ * Type-safe functions available for {@link Dataset} operations in Java.
+ *
+ * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}.
+ *
+ * @since 2.0.0
+ */
+@Experimental
+public class typed {
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f4141544/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala
new file mode 100644
index 0000000..d0eb190
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.scala
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.aggregate._
+
+/**
+ * :: Experimental ::
+ * Type-safe functions available for [[Dataset]] operations in Scala.
+ *
+ * Java users should use [[org.apache.spark.sql.expressions.java.typed]].
+ *
+ * @since 2.0.0
+ */
+@Experimental
+// scalastyle:off
+object typed {
+  // scalastyle:on
+
+  // Note: whenever we update this file, we should update the corresponding Java version
too.
+  // The reason we have separate files for Java and Scala is because in the Scala version,
we can
+  // use tighter types (primitive types) for return types, whereas in the Java version we
can only
+  // use boxed primitive types.
+  // For example, avg in the Scala veresion returns Scala primitive Double, whose bytecode
+  // signature is just a java.lang.Object; avg in the Java version returns java.lang.Double.
+
+  // TODO: This is pretty hacky. Maybe we should have an object for implicit encoders.
+  private val implicits = new SQLImplicits {
+    override protected def _sqlContext: SQLContext = null
+  }
+
+  import implicits._
+
+  /**
+   * Average aggregate function.
+   *
+   * @since 2.0.0
+   */
+  def avg[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedAverage(f).toColumn
+
+  /**
+   * Count aggregate function.
+   *
+   * @since 2.0.0
+   */
+  def count[IN](f: IN => Any): TypedColumn[IN, Long] = new TypedCount(f).toColumn
+
+  /**
+   * Sum aggregate function for floating point (double) type.
+   *
+   * @since 2.0.0
+   */
+  def sum[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedSumDouble[IN](f).toColumn
+
+  /**
+   * Sum aggregate function for integral (long, i.e. 64 bit integer) type.
+   *
+   * @since 2.0.0
+   */
+  def sumLong[IN](f: IN => Long): TypedColumn[IN, Long] = new TypedSumLong[IN](f).toColumn
+
+  // TODO:
+  // stddevOf: Double
+  // varianceOf: Double
+  // approxCountDistinct: Long
+
+  // minOf: T
+  // maxOf: T
+
+  // firstOf: T
+  // lastOf: T
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f4141544/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index 8b355be..4892591 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -106,7 +106,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
   /**
    * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
    */
-  @scala.annotation.varargs
+  @_root_.scala.annotation.varargs
   def apply(exprs: Column*): Column = {
     val aggregateExpression =
       AggregateExpression(
@@ -120,7 +120,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
    * Creates a [[Column]] for this UDAF using the distinct values of the given
    * [[Column]]s as input arguments.
    */
-  @scala.annotation.varargs
+  @_root_.scala.annotation.varargs
   def distinct(exprs: Column*): Column = {
     val aggregateExpression =
       AggregateExpression(

http://git-wip-us.apache.org/repos/asf/spark/blob/f4141544/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a6c8193..a5ab446 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -37,7 +37,6 @@ import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.function.*;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.*;
-import org.apache.spark.sql.expressions.Aggregator;
 import org.apache.spark.sql.test.TestSQLContext;
 import org.apache.spark.sql.catalyst.encoders.OuterScopes;
 import org.apache.spark.sql.catalyst.expressions.GenericRow;
@@ -385,59 +384,6 @@ public class JavaDatasetSuite implements Serializable {
     Assert.assertEquals(data, ds.collectAsList());
   }
 
-  @Test
-  public void testTypedAggregation() {
-    Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(),
Encoders.INT());
-    List<Tuple2<String, Integer>> data =
-      Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
-    Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
-
-    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey(
-      new MapFunction<Tuple2<String, Integer>, String>() {
-        @Override
-        public String call(Tuple2<String, Integer> value) throws Exception {
-          return value._1();
-        }
-      },
-      Encoders.STRING());
-
-    Dataset<Tuple2<String, Integer>> agged =
-      grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
-    Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
-
-    Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
-      new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
-      .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
-    Assert.assertEquals(
-      Arrays.asList(
-        new Tuple2<>("a", 3),
-        new Tuple2<>("b", 3)),
-      agged2.collectAsList());
-  }
-
-  static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer>
{
-
-    @Override
-    public Integer zero() {
-      return 0;
-    }
-
-    @Override
-    public Integer reduce(Integer l, Tuple2<String, Integer> t) {
-      return l + t._2();
-    }
-
-    @Override
-    public Integer merge(Integer b1, Integer b2) {
-      return b1 + b2;
-    }
-
-    @Override
-    public Integer finish(Integer reduction) {
-      return reduction;
-    }
-  }
-
   public static class KryoSerializable {
     String value;
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f4141544/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
new file mode 100644
index 0000000..c4c455b
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import scala.Tuple2;
+
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.sql.test.TestSQLContext;
+
+/**
+ * Suite for testing the aggregate functionality of Datasets in Java.
+ */
+public class JavaDatasetAggregatorSuite implements Serializable {
+  private transient JavaSparkContext jsc;
+  private transient TestSQLContext context;
+
+  @Before
+  public void setUp() {
+    // Trigger static initializer of TestData
+    SparkContext sc = new SparkContext("local[*]", "testing");
+    jsc = new JavaSparkContext(sc);
+    context = new TestSQLContext(sc);
+    context.loadTestData();
+  }
+
+  @After
+  public void tearDown() {
+    context.sparkContext().stop();
+    context = null;
+    jsc = null;
+  }
+
+  private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
+    return new Tuple2<>(t1, t2);
+  }
+
+  private KeyValueGroupedDataset<String, Tuple2<String, Integer>> generateGroupedDataset()
{
+    Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(),
Encoders.INT());
+    List<Tuple2<String, Integer>> data =
+      Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
+    Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
+
+    return ds.groupByKey(
+      new MapFunction<Tuple2<String, Integer>, String>() {
+        @Override
+        public String call(Tuple2<String, Integer> value) throws Exception {
+          return value._1();
+        }
+      },
+      Encoders.STRING());
+  }
+
+  @Test
+  public void testTypedAggregationAnonClass() {
+    KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+
+    Dataset<Tuple2<String, Integer>> agged =
+      grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
+    Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+
+    Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
+      new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
+      .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
+    Assert.assertEquals(
+      Arrays.asList(
+        new Tuple2<>("a", 3),
+        new Tuple2<>("b", 3)),
+      agged2.collectAsList());
+  }
+
+  static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer>
{
+
+    @Override
+    public Integer zero() {
+      return 0;
+    }
+
+    @Override
+    public Integer reduce(Integer l, Tuple2<String, Integer> t) {
+      return l + t._2();
+    }
+
+    @Override
+    public Integer merge(Integer b1, Integer b2) {
+      return b1 + b2;
+    }
+
+    @Override
+    public Integer finish(Integer reduction) {
+      return reduction;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f4141544/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 8477016..5430aff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -20,35 +20,10 @@ package org.apache.spark.sql
 import scala.language.postfixOps
 
 import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.expressions.scala.typed
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
 
-/** An `Aggregator` that adds up any numeric type returned by the given function. */
-class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
-  val numeric = implicitly[Numeric[N]]
-
-  override def zero: N = numeric.zero
-
-  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
-
-  override def merge(b1: N, b2: N): N = numeric.plus(b1, b2)
-
-  override def finish(reduction: N): N = reduction
-}
-
-object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] {
-  override def zero: (Long, Long) = (0, 0)
-
-  override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
-    (countAndSum._1 + 1, countAndSum._2 + input._2)
-  }
-
-  override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
-    (b1._1 + b2._1, b1._2 + b2._2)
-  }
-
-  override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1
-}
 
 object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
 
@@ -113,15 +88,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext
{
 
   import testImplicits._
 
-  def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
-    new SumOf(f).toColumn
-
   test("typed aggregation: TypedAggregator") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
 
     checkDataset(
-      ds.groupByKey(_._1).agg(sum(_._2)),
-      ("a", 30), ("b", 3), ("c", 1))
+      ds.groupByKey(_._1).agg(typed.sum(_._2)),
+      ("a", 30.0), ("b", 3.0), ("c", 1.0))
   }
 
   test("typed aggregation: TypedAggregator, expr, expr") {
@@ -129,20 +101,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext
{
 
     checkDataset(
       ds.groupByKey(_._1).agg(
-        sum(_._2),
+        typed.sum(_._2),
         expr("sum(_2)").as[Long],
         count("*")),
-      ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L))
-  }
-
-  test("typed aggregation: complex case") {
-    val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
-
-    checkDataset(
-      ds.groupByKey(_._1).agg(
-        expr("avg(_2)").as[Double],
-        TypedAverage.toColumn),
-      ("a", 2.0, 2.0), ("b", 3.0, 3.0))
+      ("a", 30.0, 30L, 2L), ("b", 3.0, 3L, 2L), ("c", 1.0, 1L, 1L))
   }
 
   test("typed aggregation: complex result type") {
@@ -159,11 +121,11 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext
{
     val ds = Seq(1, 3, 2, 5).toDS()
 
     checkDataset(
-      ds.select(sum((i: Int) => i)),
-      11)
+      ds.select(typed.sum((i: Int) => i)),
+      11.0)
     checkDataset(
-      ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)),
-      11 -> 22)
+      ds.select(typed.sum((i: Int) => i), typed.sum((i: Int) => i * 2)),
+      11.0 -> 22.0)
   }
 
   test("typed aggregation: class input") {
@@ -206,4 +168,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext
{
       ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn),
       ("one", 1), ("two", 1))
   }
+
+  test("typed aggregate: avg, count, sum") {
+    val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+    checkDataset(
+      ds.groupByKey(_._1).agg(
+        typed.avg(_._2), typed.count(_._2), typed.sum(_._2), typed.sumLong(_._2)),
+      ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message