spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-11899][SQL] API audit for GroupedDataset.
Date Sat, 21 Nov 2015 23:00:50 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 05547183b -> 8c718a577


[SPARK-11899][SQL] API audit for GroupedDataset.

1. Renamed map to mapGroup, flatMap to flatMapGroup.
2. Renamed asKey -> keyAs.
3. Added more documentation.
4. Changed type parameter T to V on GroupedDataset.
5. Added since versions for all functions.

Author: Reynold Xin <rxin@databricks.com>

Closes #9880 from rxin/SPARK-11899.

(cherry picked from commit ff442bbcffd4f93cfcc2f76d160011e725d2fb3f)
Signed-off-by: Reynold Xin <rxin@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8c718a57
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8c718a57
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8c718a57

Branch: refs/heads/branch-1.6
Commit: 8c718a577e32d9f91dc4cacd58dab894e366d93d
Parents: 0554718
Author: Reynold Xin <rxin@databricks.com>
Authored: Sat Nov 21 15:00:37 2015 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Sat Nov 21 15:00:47 2015 -0800

----------------------------------------------------------------------
 .../api/java/function/MapGroupFunction.java     |   2 +-
 .../scala/org/apache/spark/sql/Encoder.scala    |   4 +
 .../spark/sql/catalyst/JavaTypeInference.scala  |   3 +-
 .../scala/org/apache/spark/sql/Column.scala     |   2 +
 .../scala/org/apache/spark/sql/DataFrame.scala  |   1 -
 .../org/apache/spark/sql/GroupedDataset.scala   | 132 +++++++++++++++----
 .../org/apache/spark/sql/JavaDatasetSuite.java  |   8 +-
 .../spark/sql/DatasetPrimitiveSuite.scala       |   4 +-
 .../org/apache/spark/sql/DatasetSuite.scala     |  20 +--
 9 files changed, 131 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8c718a57/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
index 2935f99..4f3f222 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
@@ -21,7 +21,7 @@ import java.io.Serializable;
 import java.util.Iterator;
 
 /**
- * Base interface for a map function used in GroupedDataset's map function.
+ * Base interface for a map function used in GroupedDataset's mapGroup function.
  */
 public interface MapGroupFunction<K, V, R> extends Serializable {
   R call(K key, Iterator<V> values) throws Exception;

http://git-wip-us.apache.org/repos/asf/spark/blob/8c718a57/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 5cb8edf..03aa25e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -30,6 +30,8 @@ import org.apache.spark.sql.types._
  *
  * Encoders are not intended to be thread-safe and thus they are allow to avoid internal
locking
  * and reuse internal buffers to improve performance.
+ *
+ * @since 1.6.0
  */
 trait Encoder[T] extends Serializable {
 
@@ -42,6 +44,8 @@ trait Encoder[T] extends Serializable {
 
 /**
  * Methods for creating encoders.
+ *
+ * @since 1.6.0
  */
 object Encoders {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8c718a57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 88a457f..7d4cfbe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.types._
 /**
  * Type-inference utilities for POJOs and Java collections.
  */
-private [sql] object JavaTypeInference {
+object JavaTypeInference {
 
   private val iterableType = TypeToken.of(classOf[JIterable[_]])
   private val mapType = TypeToken.of(classOf[JMap[_, _]])
@@ -53,7 +53,6 @@ private [sql] object JavaTypeInference {
    * @return (SQL data type, nullable)
    */
   private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
-    // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
     typeToken.getRawType match {
       case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
         (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

http://git-wip-us.apache.org/repos/asf/spark/blob/8c718a57/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 82e9cd7..30c554a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -46,6 +46,8 @@ private[sql] object Column {
  * @tparam T The input type expected for this expression.  Can be `Any` if the expression
is type
  *           checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
  * @tparam U The output type of this column.
+ *
+ * @since 1.6.0
  */
 class TypedColumn[-T, U](
     expr: Expression,

http://git-wip-us.apache.org/repos/asf/spark/blob/8c718a57/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 7abceca..5586fc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -110,7 +110,6 @@ private[sql] object DataFrame {
  * @groupname action Actions
  * @since 1.3.0
  */
-// TODO: Improve documentation.
 @Experimental
 class DataFrame private[sql](
     @transient val sqlContext: SQLContext,

http://git-wip-us.apache.org/repos/asf/spark/blob/8c718a57/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 263f049..7f43ce1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor,
Ou
 import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.expressions.Aggregator
 
 /**
  * :: Experimental ::
@@ -36,11 +37,13 @@ import org.apache.spark.sql.execution.QueryExecution
  * making this change to the class hierarchy would break some function signatures. As such,
this
  * class should be considered a preview of the final API.  Changes will be made to the interface
  * after Spark 1.6.
+ *
+ * @since 1.6.0
  */
 @Experimental
-class GroupedDataset[K, T] private[sql](
+class GroupedDataset[K, V] private[sql](
     kEncoder: Encoder[K],
-    tEncoder: Encoder[T],
+    tEncoder: Encoder[V],
     val queryExecution: QueryExecution,
     private val dataAttributes: Seq[Attribute],
     private val groupingAttributes: Seq[Attribute]) extends Serializable {
@@ -67,8 +70,10 @@ class GroupedDataset[K, T] private[sql](
   /**
    * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified
    * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]].
+   *
+   * @since 1.6.0
    */
-  def asKey[L : Encoder]: GroupedDataset[L, T] =
+  def keyAs[L : Encoder]: GroupedDataset[L, V] =
     new GroupedDataset(
       encoderFor[L],
       unresolvedTEncoder,
@@ -78,6 +83,8 @@ class GroupedDataset[K, T] private[sql](
 
   /**
    * Returns a [[Dataset]] that contains each unique key.
+   *
+   * @since 1.6.0
    */
   def keys: Dataset[K] = {
     new Dataset[K](
@@ -92,12 +99,18 @@ class GroupedDataset[K, T] private[sql](
    * function can return an iterator containing elements of an arbitrary type which will
be returned
    * as a new [[Dataset]].
    *
+   * This function does not support partial aggregation, and as a result requires shuffling
all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over
each
+   * key, it is best to use the reduce function or an [[Aggregator]].
+   *
    * Internally, the implementation will spill to disk if any given group is too large to
fit into
    * memory.  However, users must take care to avoid materializing the whole iterator for
a group
    * (for example, by calling `toList`) unless they are sure that this is possible given
the memory
    * constraints of their cluster.
+   *
+   * @since 1.6.0
    */
-  def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = {
+  def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U]
= {
     new Dataset[U](
       sqlContext,
       MapGroups(
@@ -108,8 +121,25 @@ class GroupedDataset[K, T] private[sql](
         logicalPlan))
   }
 
-  def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
-    flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder)
+  /**
+   * Applies the given function to each group of data.  For each unique group, the function
will
+   * be passed the group key and an iterator that contains all of the elements in the group.
The
+   * function can return an iterator containing elements of an arbitrary type which will
be returned
+   * as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling
all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over
each
+   * key, it is best to use the reduce function or an [[Aggregator]].
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to
fit into
+   * memory.  However, users must take care to avoid materializing the whole iterator for
a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given
the memory
+   * constraints of their cluster.
+   *
+   * @since 1.6.0
+   */
+  def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U]
= {
+    flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder)
   }
 
   /**
@@ -117,32 +147,62 @@ class GroupedDataset[K, T] private[sql](
    * be passed the group key and an iterator that contains all of the elements in the group.
The
    * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
    *
+   * This function does not support partial aggregation, and as a result requires shuffling
all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over
each
+   * key, it is best to use the reduce function or an [[Aggregator]].
+   *
    * Internally, the implementation will spill to disk if any given group is too large to
fit into
    * memory.  However, users must take care to avoid materializing the whole iterator for
a group
    * (for example, by calling `toList`) unless they are sure that this is possible given
the memory
    * constraints of their cluster.
+   *
+   * @since 1.6.0
    */
-  def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = {
-    val func = (key: K, it: Iterator[T]) => Iterator(f(key, it))
-    flatMap(func)
+  def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
+    val func = (key: K, it: Iterator[V]) => Iterator(f(key, it))
+    flatMapGroup(func)
   }
 
-  def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
-    map((key, data) => f.call(key, data.asJava))(encoder)
+  /**
+   * Applies the given function to each group of data.  For each unique group, the function
will
+   * be passed the group key and an iterator that contains all of the elements in the group.
The
+   * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * This function does not support partial aggregation, and as a result requires shuffling
all
+   * the data in the [[Dataset]]. If an application intends to perform an aggregation over
each
+   * key, it is best to use the reduce function or an [[Aggregator]].
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to
fit into
+   * memory.  However, users must take care to avoid materializing the whole iterator for
a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given
the memory
+   * constraints of their cluster.
+   *
+   * @since 1.6.0
+   */
+  def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
+    mapGroup((key, data) => f.call(key, data.asJava))(encoder)
   }
 
   /**
    * Reduces the elements of each group of data using the specified binary function.
    * The given function must be commutative and associative or the result may be non-deterministic.
+   *
+   * @since 1.6.0
    */
-  def reduce(f: (T, T) => T): Dataset[(K, T)] = {
-    val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f))
+  def reduce(f: (V, V) => V): Dataset[(K, V)] = {
+    val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
 
     implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder)
-    flatMap(func)
+    flatMapGroup(func)
   }
 
-  def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = {
+  /**
+   * Reduces the elements of each group of data using the specified binary function.
+   * The given function must be commutative and associative or the result may be non-deterministic.
+   *
+   * @since 1.6.0
+   */
+  def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = {
     reduce(f.call _)
   }
 
@@ -185,41 +245,51 @@ class GroupedDataset[K, T] private[sql](
   /**
    * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
    * and the result of computing this aggregation over all elements in the group.
+   *
+   * @since 1.6.0
    */
-  def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] =
+  def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] =
     aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
 
   /**
    * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
    * and the result of computing these aggregations over all elements in the group.
+   *
+   * @since 1.6.0
    */
-  def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)]
=
+  def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)]
=
     aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
 
   /**
    * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
    * and the result of computing these aggregations over all elements in the group.
+   *
+   * @since 1.6.0
    */
   def agg[U1, U2, U3](
-      col1: TypedColumn[T, U1],
-      col2: TypedColumn[T, U2],
-      col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] =
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] =
     aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
 
   /**
    * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
    * and the result of computing these aggregations over all elements in the group.
+   *
+   * @since 1.6.0
    */
   def agg[U1, U2, U3, U4](
-      col1: TypedColumn[T, U1],
-      col2: TypedColumn[T, U2],
-      col3: TypedColumn[T, U3],
-      col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] =
+      col1: TypedColumn[V, U1],
+      col2: TypedColumn[V, U2],
+      col3: TypedColumn[V, U3],
+      col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] =
     aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]
 
   /**
    * Returns a [[Dataset]] that contains a tuple with each key and the number of items present
    * for that key.
+   *
+   * @since 1.6.0
    */
   def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]))
 
@@ -228,10 +298,12 @@ class GroupedDataset[K, T] private[sql](
    * be passed the grouping key and 2 iterators containing all elements in the group from
    * [[Dataset]] `this` and `other`.  The function can return an iterator containing elements
of an
    * arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 1.6.0
    */
   def cogroup[U, R : Encoder](
       other: GroupedDataset[K, U])(
-      f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
+      f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
     implicit def uEnc: Encoder[U] = other.unresolvedTEncoder
     new Dataset[R](
       sqlContext,
@@ -243,9 +315,17 @@ class GroupedDataset[K, T] private[sql](
         other.logicalPlan))
   }
 
+  /**
+   * Applies the given function to each cogrouped data.  For each unique group, the function
will
+   * be passed the grouping key and 2 iterators containing all elements in the group from
+   * [[Dataset]] `this` and `other`.  The function can return an iterator containing elements
of an
+   * arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * @since 1.6.0
+   */
   def cogroup[U, R](
       other: GroupedDataset[K, U],
-      f: CoGroupFunction[K, T, U, R],
+      f: CoGroupFunction[K, V, U, R],
       encoder: Encoder[R]): Dataset[R] = {
     cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8c718a57/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index f32374b..cf335ef 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -170,7 +170,7 @@ public class JavaDatasetSuite implements Serializable {
       }
     }, Encoders.INT());
 
-    Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>()
{
+    Dataset<String> mapped = grouped.mapGroup(new MapGroupFunction<Integer, String,
String>() {
       @Override
       public String call(Integer key, Iterator<String> values) throws Exception {
         StringBuilder sb = new StringBuilder(key.toString());
@@ -183,7 +183,7 @@ public class JavaDatasetSuite implements Serializable {
 
     Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
 
-    Dataset<String> flatMapped = grouped.flatMap(
+    Dataset<String> flatMapped = grouped.flatMapGroup(
       new FlatMapGroupFunction<Integer, String, String>() {
         @Override
         public Iterable<String> call(Integer key, Iterator<String> values) throws
Exception {
@@ -247,9 +247,9 @@ public class JavaDatasetSuite implements Serializable {
     List<String> data = Arrays.asList("a", "foo", "bar");
     Dataset<String> ds = context.createDataset(data, Encoders.STRING());
     GroupedDataset<Integer, String> grouped =
-      ds.groupBy(length(col("value"))).asKey(Encoders.INT());
+      ds.groupBy(length(col("value"))).keyAs(Encoders.INT());
 
-    Dataset<String> mapped = grouped.map(
+    Dataset<String> mapped = grouped.mapGroup(
       new MapGroupFunction<Integer, String, String>() {
         @Override
         public String call(Integer key, Iterator<String> data) throws Exception {

http://git-wip-us.apache.org/repos/asf/spark/blob/8c718a57/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index 63b0097..d387710 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
   test("groupBy function, map") {
     val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
     val grouped = ds.groupBy(_ % 2)
-    val agged = grouped.map { case (g, iter) =>
+    val agged = grouped.mapGroup { case (g, iter) =>
       val name = if (g == 0) "even" else "odd"
       (name, iter.size)
     }
@@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
   test("groupBy function, flatMap") {
     val ds = Seq("a", "b", "c", "xyz", "hello").toDS()
     val grouped = ds.groupBy(_.length)
-    val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString)
}
+    val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString)
}
 
     checkAnswer(
       agged,

http://git-wip-us.apache.org/repos/asf/spark/blob/8c718a57/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 89d964a..9da0255 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -224,7 +224,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
   test("groupBy function, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy(v => (v._1, "word"))
-    val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) }
+    val agged = grouped.mapGroup { case (g, iter) => (g._1, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,
@@ -234,7 +234,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
   test("groupBy function, flatMap") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy(v => (v._1, "word"))
-    val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString)
}
+    val agged = grouped.flatMapGroup { case (g, iter) =>
+      Iterator(g._1, iter.map(_._2).sum.toString)
+    }
 
     checkAnswer(
       agged,
@@ -253,7 +255,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
   test("groupBy columns, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1")
-    val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) }
+    val agged = grouped.mapGroup { case (g, iter) => (g.getString(0), iter.map(_._2).sum)
}
 
     checkAnswer(
       agged,
@@ -262,8 +264,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
 
   test("groupBy columns asKey, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
-    val grouped = ds.groupBy($"_1").asKey[String]
-    val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
+    val grouped = ds.groupBy($"_1").keyAs[String]
+    val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,
@@ -272,8 +274,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
 
   test("groupBy columns asKey tuple, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
-    val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)]
-    val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
+    val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)]
+    val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,
@@ -282,8 +284,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
 
   test("groupBy columns asKey class, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
-    val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData]
-    val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
+    val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData]
+    val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message