Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id E0135186F0 for ; Sun, 8 Nov 2015 21:00:13 +0000 (UTC) Received: (qmail 53065 invoked by uid 500); 8 Nov 2015 21:00:13 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 53031 invoked by uid 500); 8 Nov 2015 21:00:13 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 53022 invoked by uid 99); 8 Nov 2015 21:00:13 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Sun, 08 Nov 2015 21:00:13 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 9374FE00BD; Sun, 8 Nov 2015 21:00:13 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: marmbrus@apache.org To: commits@spark.apache.org Message-Id: <7e4d7c7de5e440d18656b03914a03bf6@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-11554][SQL] add map/flatMap to GroupedDataset Date: Sun, 8 Nov 2015 21:00:13 +0000 (UTC) Repository: spark Updated Branches: refs/heads/branch-1.6 27161f59e -> 6ade67e5f [SPARK-11554][SQL] add map/flatMap to GroupedDataset Author: Wenchen Fan Closes #9521 from cloud-fan/map. (cherry picked from commit b2d195e137fad88d567974659fa7023ff4da96cd) Signed-off-by: Michael Armbrust Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6ade67e5 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6ade67e5 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6ade67e5 Branch: refs/heads/branch-1.6 Commit: 6ade67e5f0afb212d979af836597d5df58de5b60 Parents: 27161f5 Author: Wenchen Fan Authored: Sun Nov 8 12:59:35 2015 -0800 Committer: Michael Armbrust Committed: Sun Nov 8 12:59:48 2015 -0800 ---------------------------------------------------------------------- .../catalyst/plans/logical/basicOperators.scala | 4 +- .../org/apache/spark/sql/GroupedDataset.scala | 29 ++++++++++++-- .../spark/sql/execution/basicOperators.scala | 2 +- .../org/apache/spark/sql/JavaDatasetSuite.java | 16 ++++---- .../spark/sql/DatasetPrimitiveSuite.scala | 16 ++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 40 ++++++++++---------- 6 files changed, 70 insertions(+), 37 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09aac00..e151ac0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -494,7 +494,7 @@ case class AppendColumn[T, U]( /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], groupingAttributes: Seq[Attribute], child: LogicalPlan): MapGroups[K, T, U] = { new MapGroups( @@ -514,7 +514,7 @@ object MapGroups { * object representation of all the rows with that key. */ case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index b2803d5..5c3f626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -102,16 +102,39 @@ class GroupedDataset[K, T] private[sql]( * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. */ - def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = { + def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups(f, groupingAttributes, logicalPlan)) } - def mapGroups[U]( + def flatMap[U]( f: JFunction2[K, JIterator[T], JIterator[U]], encoder: Encoder[U]): Dataset[U] = { - mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) + flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + */ + def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[T]) => Iterator(f(key, it)) + new Dataset[U]( + sqlContext, + MapGroups(func, groupingAttributes, logicalPlan)) + } + + def map[U]( + f: JFunction2[K, JIterator[T], U], + encoder: Encoder[U]): Dataset[U] = { + map((key, data) => f.call(key, data.asJava))(encoder) } // To ensure valid overloading. http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 799650a..2593b16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -356,7 +356,7 @@ case class AppendColumns[T, U]( * being output. */ case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a9493d5..0d3b1a5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,15 +170,15 @@ public class JavaDatasetSuite implements Serializable { } }, e.INT()); - Dataset mapped = grouped.mapGroups( - new Function2, Iterator>() { + Dataset mapped = grouped.map( + new Function2, String>() { @Override - public Iterator call(Integer key, Iterator data) throws Exception { + public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (data.hasNext()) { sb.append(data.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return sb.toString(); } }, e.STRING()); @@ -224,15 +224,15 @@ public class JavaDatasetSuite implements Serializable { Dataset ds = context.createDataset(data, e.STRING()); GroupedDataset grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); - Dataset mapped = grouped.mapGroups( - new Function2, Iterator>() { + Dataset mapped = grouped.map( + new Function2, String>() { @Override - public Iterator call(Integer key, Iterator data) throws Exception { + public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (data.hasNext()) { sb.append(data.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return sb.toString(); } }, e.STRING()); http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index e3b0346..fcf03f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -88,16 +88,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 0, 1) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.mapGroups { case (g, iter) => + val agged = grouped.map { case (g, iter) => val name = if (g == 0) "even" else "odd" - Iterator((name, iter.size)) + (name, iter.size) } checkAnswer( agged, ("even", 5), ("odd", 6)) } + + test("groupBy function, flatMap") { + val ds = Seq("a", "b", "c", "xyz", "hello").toDS() + val grouped = ds.groupBy(_.length) + val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) } + + checkAnswer( + agged, + "1", "abc", "3", "xyz", "5", "hello") + } } http://git-wip-us.apache.org/repos/asf/spark/blob/6ade67e5/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d61e17e..6f1174e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -198,60 +198,60 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (1, 1)) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g._1, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns, mapGroups") { + test("groupBy function, fatMap") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy(v => (v._1, "word")) + val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } + + checkAnswer( + agged, + "a", "30", "b", "3", "c", "1") + } + + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g.getString(0), iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns asKey, mapGroups") { + test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").asKey[String] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns asKey tuple, mapGroups") { + test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) } - test("groupBy columns asKey class, mapGroups") { + test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org