spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-16391][SQL] Support partial aggregation for reduceGroups
Date Thu, 18 Aug 2016 08:37:35 GMT
Repository: spark
Updated Branches:
  refs/heads/master 3e6ef2e8a -> 1748f8241


[SPARK-16391][SQL] Support partial aggregation for reduceGroups

## What changes were proposed in this pull request?
This patch introduces a new private ReduceAggregator interface that is a subclass of Aggregator.
ReduceAggregator only requires a single associative and commutative reduce function. ReduceAggregator
is also used to implement KeyValueGroupedDataset.reduceGroups in order to support partial
aggregation.

Note that the pull request was initially done by viirya.

## How was this patch tested?
Covered by original tests for reduceGroups, as well as a new test suite for ReduceAggregator.

Author: Reynold Xin <rxin@databricks.com>
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>

Closes #14576 from rxin/reduceAggregator.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1748f824
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1748f824
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1748f824

Branch: refs/heads/master
Commit: 1748f824101870b845dbbd118763c6885744f98a
Parents: 3e6ef2e
Author: Reynold Xin <rxin@databricks.com>
Authored: Thu Aug 18 16:37:25 2016 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Thu Aug 18 16:37:25 2016 +0800

----------------------------------------------------------------------
 .../spark/sql/KeyValueGroupedDataset.scala      | 10 +--
 .../sql/expressions/ReduceAggregator.scala      | 68 ++++++++++++++++++
 .../sql/expressions/ReduceAggregatorSuite.scala | 73 ++++++++++++++++++++
 3 files changed, 146 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1748f824/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 65a725f..61a3e6e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -21,10 +21,11 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes}
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.expressions.ReduceAggregator
 
 /**
  * :: Experimental ::
@@ -177,10 +178,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
    * @since 1.6.0
    */
   def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
-    val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
-
-    implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc)
-    flatMapGroups(func)
+    val vEncoder = encoderFor[V]
+    val aggregator: TypedColumn[V, V] = new ReduceAggregator[V](f)(vEncoder).toColumn
+    agg(aggregator)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1748f824/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
new file mode 100644
index 0000000..1743783
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+
+/**
+ * An aggregator that uses a single associative and commutative reduce function. This reduce
+ * function can be used to go through all input values and reduces them to a single value.
+ * If there is no input, a null value is returned.
+ *
+ * This class currently assumes there is at least one input row.
+ */
+private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
+  extends Aggregator[T, (Boolean, T), T] {
+
+  private val encoder = implicitly[Encoder[T]]
+
+  override def zero: (Boolean, T) = (false, null.asInstanceOf[T])
+
+  override def bufferEncoder: Encoder[(Boolean, T)] =
+    ExpressionEncoder.tuple(
+      ExpressionEncoder[Boolean](),
+      encoder.asInstanceOf[ExpressionEncoder[T]])
+
+  override def outputEncoder: Encoder[T] = encoder
+
+  override def reduce(b: (Boolean, T), a: T): (Boolean, T) = {
+    if (b._1) {
+      (true, func(b._2, a))
+    } else {
+      (true, a)
+    }
+  }
+
+  override def merge(b1: (Boolean, T), b2: (Boolean, T)): (Boolean, T) = {
+    if (!b1._1) {
+      b2
+    } else if (!b2._1) {
+      b1
+    } else {
+      (true, func(b1._2, b2._2))
+    }
+  }
+
+  override def finish(reduction: (Boolean, T)): T = {
+    if (!reduction._1) {
+      throw new IllegalStateException("ReduceAggregator requires at least one input row")
+    }
+    reduction._2
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1748f824/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
new file mode 100644
index 0000000..d826d3f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+
+class ReduceAggregatorSuite extends SparkFunSuite {
+
+  test("zero value") {
+    val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+    val func = (v1: Int, v2: Int) => v1 + v2
+    val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+    assert(aggregator.zero == (false, null))
+  }
+
+  test("reduce, merge and finish") {
+    val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+    val func = (v1: Int, v2: Int) => v1 + v2
+    val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+
+    val firstReduce = aggregator.reduce(aggregator.zero, 1)
+    assert(firstReduce == (true, 1))
+
+    val secondReduce = aggregator.reduce(firstReduce, 2)
+    assert(secondReduce == (true, 3))
+
+    val thirdReduce = aggregator.reduce(secondReduce, 3)
+    assert(thirdReduce == (true, 6))
+
+    val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce)
+    assert(mergeWithZero1 == (true, 1))
+
+    val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero)
+    assert(mergeWithZero2 == (true, 3))
+
+    val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce)
+    assert(mergeTwoReduced == (true, 4))
+
+    assert(aggregator.finish(firstReduce)== 1)
+    assert(aggregator.finish(secondReduce) == 3)
+    assert(aggregator.finish(thirdReduce) == 6)
+    assert(aggregator.finish(mergeWithZero1) == 1)
+    assert(aggregator.finish(mergeWithZero2) == 3)
+    assert(aggregator.finish(mergeTwoReduced) == 4)
+  }
+
+  test("requires at least one input row") {
+    val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
+    val func = (v1: Int, v2: Int) => v1 + v2
+    val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
+
+    intercept[IllegalStateException] {
+      aggregator.finish(aggregator.zero)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message