spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-11554][SQL] add map/flatMap to GroupedDataset
Date Sun, 08 Nov 2015 21:00:13 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 27161f59e -> 6ade67e5f


[SPARK-11554][SQL] add map/flatMap to GroupedDataset

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9521 from cloud-fan/map.

(cherry picked from commit b2d195e137fad88d567974659fa7023ff4da96cd)
Signed-off-by: Michael Armbrust <michael@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6ade67e5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6ade67e5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6ade67e5

Branch: refs/heads/branch-1.6
Commit: 6ade67e5f0afb212d979af836597d5df58de5b60
Parents: 27161f5
Author: Wenchen Fan <wenchen@databricks.com>
Authored: Sun Nov 8 12:59:35 2015 -0800
Committer: Michael Armbrust <michael@databricks.com>
Committed: Sun Nov 8 12:59:48 2015 -0800

----------------------------------------------------------------------
 .../catalyst/plans/logical/basicOperators.scala |  4 +-
 .../org/apache/spark/sql/GroupedDataset.scala   | 29 ++++++++++++--
 .../spark/sql/execution/basicOperators.scala    |  2 +-
 .../org/apache/spark/sql/JavaDatasetSuite.java  | 16 ++++----
 .../spark/sql/DatasetPrimitiveSuite.scala       | 16 ++++++--
 .../org/apache/spark/sql/DatasetSuite.scala     | 40 ++++++++++----------
 6 files changed, 70 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 09aac00..e151ac0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -494,7 +494,7 @@ case class AppendColumn[T, U](
 /** Factory for constructing new `MapGroups` nodes. */
 object MapGroups {
   def apply[K : Encoder, T : Encoder, U : Encoder](
-      func: (K, Iterator[T]) => Iterator[U],
+      func: (K, Iterator[T]) => TraversableOnce[U],
       groupingAttributes: Seq[Attribute],
       child: LogicalPlan): MapGroups[K, T, U] = {
     new MapGroups(
@@ -514,7 +514,7 @@ object MapGroups {
  * object representation of all the rows with that key.
  */
 case class MapGroups[K, T, U](
-    func: (K, Iterator[T]) => Iterator[U],
+    func: (K, Iterator[T]) => TraversableOnce[U],
     kEncoder: ExpressionEncoder[K],
     tEncoder: ExpressionEncoder[T],
     uEncoder: ExpressionEncoder[U],

http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index b2803d5..5c3f626 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -102,16 +102,39 @@ class GroupedDataset[K, T] private[sql](
    * (for example, by calling `toList`) unless they are sure that this is possible given
the memory
    * constraints of their cluster.
    */
-  def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = {
+  def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = {
     new Dataset[U](
       sqlContext,
       MapGroups(f, groupingAttributes, logicalPlan))
   }
 
-  def mapGroups[U](
+  def flatMap[U](
       f: JFunction2[K, JIterator[T], JIterator[U]],
       encoder: Encoder[U]): Dataset[U] = {
-    mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
+    flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder)
+  }
+
+  /**
+   * Applies the given function to each group of data.  For each unique group, the function
will
+   * be passed the group key and an iterator that contains all of the elements in the group.
The
+   * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to
fit into
+   * memory.  However, users must take care to avoid materializing the whole iterator for
a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given
the memory
+   * constraints of their cluster.
+   */
+  def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = {
+    val func = (key: K, it: Iterator[T]) => Iterator(f(key, it))
+    new Dataset[U](
+      sqlContext,
+      MapGroups(func, groupingAttributes, logicalPlan))
+  }
+
+  def map[U](
+      f: JFunction2[K, JIterator[T], U],
+      encoder: Encoder[U]): Dataset[U] = {
+    map((key, data) => f.call(key, data.asJava))(encoder)
   }
 
   // To ensure valid overloading.

http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 799650a..2593b16 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -356,7 +356,7 @@ case class AppendColumns[T, U](
  * being output.
  */
 case class MapGroups[K, T, U](
-    func: (K, Iterator[T]) => Iterator[U],
+    func: (K, Iterator[T]) => TraversableOnce[U],
     kEncoder: ExpressionEncoder[K],
     tEncoder: ExpressionEncoder[T],
     uEncoder: ExpressionEncoder[U],

http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a9493d5..0d3b1a5 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -170,15 +170,15 @@ public class JavaDatasetSuite implements Serializable {
       }
     }, e.INT());
 
-    Dataset<String> mapped = grouped.mapGroups(
-      new Function2<Integer, Iterator<String>, Iterator<String>>() {
+    Dataset<String> mapped = grouped.map(
+      new Function2<Integer, Iterator<String>, String>() {
         @Override
-        public Iterator<String> call(Integer key, Iterator<String> data) throws
Exception {
+        public String call(Integer key, Iterator<String> data) throws Exception {
           StringBuilder sb = new StringBuilder(key.toString());
           while (data.hasNext()) {
             sb.append(data.next());
           }
-          return Collections.singletonList(sb.toString()).iterator();
+          return sb.toString();
         }
       },
       e.STRING());
@@ -224,15 +224,15 @@ public class JavaDatasetSuite implements Serializable {
     Dataset<String> ds = context.createDataset(data, e.STRING());
     GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT());
 
-    Dataset<String> mapped = grouped.mapGroups(
-      new Function2<Integer, Iterator<String>, Iterator<String>>() {
+    Dataset<String> mapped = grouped.map(
+      new Function2<Integer, Iterator<String>, String>() {
         @Override
-        public Iterator<String> call(Integer key, Iterator<String> data) throws
Exception {
+        public String call(Integer key, Iterator<String> data) throws Exception {
           StringBuilder sb = new StringBuilder(key.toString());
           while (data.hasNext()) {
             sb.append(data.next());
           }
-          return Collections.singletonList(sb.toString()).iterator();
+          return sb.toString();
         }
       },
       e.STRING());

http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index e3b0346..fcf03f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -88,16 +88,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
       0, 1)
   }
 
-  test("groupBy function, mapGroups") {
+  test("groupBy function, map") {
     val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
     val grouped = ds.groupBy(_ % 2)
-    val agged = grouped.mapGroups { case (g, iter) =>
+    val agged = grouped.map { case (g, iter) =>
       val name = if (g == 0) "even" else "odd"
-      Iterator((name, iter.size))
+      (name, iter.size)
     }
 
     checkAnswer(
       agged,
       ("even", 5), ("odd", 6))
   }
+
+  test("groupBy function, flatMap") {
+    val ds = Seq("a", "b", "c", "xyz", "hello").toDS()
+    val grouped = ds.groupBy(_.length)
+    val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString)
}
+
+    checkAnswer(
+      agged,
+      "1", "abc", "3", "xyz", "5", "hello")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index d61e17e..6f1174e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -198,60 +198,60 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       (1, 1))
   }
 
-  test("groupBy function, mapGroups") {
+  test("groupBy function, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy(v => (v._1, "word"))
-    val agged = grouped.mapGroups { case (g, iter) =>
-      Iterator((g._1, iter.map(_._2).sum))
-    }
+    val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,
       ("a", 30), ("b", 3), ("c", 1))
   }
 
-  test("groupBy columns, mapGroups") {
+  test("groupBy function, fatMap") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    val grouped = ds.groupBy(v => (v._1, "word"))
+    val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString)
}
+
+    checkAnswer(
+      agged,
+      "a", "30", "b", "3", "c", "1")
+  }
+
+  test("groupBy columns, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1")
-    val agged = grouped.mapGroups { case (g, iter) =>
-      Iterator((g.getString(0), iter.map(_._2).sum))
-    }
+    val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) }
 
     checkAnswer(
       agged,
       ("a", 30), ("b", 3), ("c", 1))
   }
 
-  test("groupBy columns asKey, mapGroups") {
+  test("groupBy columns asKey, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1").asKey[String]
-    val agged = grouped.mapGroups { case (g, iter) =>
-      Iterator((g, iter.map(_._2).sum))
-    }
+    val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,
       ("a", 30), ("b", 3), ("c", 1))
   }
 
-  test("groupBy columns asKey tuple, mapGroups") {
+  test("groupBy columns asKey tuple, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)]
-    val agged = grouped.mapGroups { case (g, iter) =>
-      Iterator((g, iter.map(_._2).sum))
-    }
+    val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,
       (("a", 1), 30), (("b", 1), 3), (("c", 1), 1))
   }
 
-  test("groupBy columns asKey class, mapGroups") {
+  test("groupBy columns asKey class, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData]
-    val agged = grouped.mapGroups { case (g, iter) =>
-      Iterator((g, iter.map(_._2).sum))
-    }
+    val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message