spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-11370] [SQL] fix a bug in GroupedIterator and create unit test for it
Date Thu, 29 Oct 2015 10:49:56 GMT
Repository: spark
Updated Branches:
  refs/heads/master 87f28fc24 -> f79ebf2a9


[SPARK-11370] [SQL] fix a bug in GroupedIterator and create unit test for it

Before this PR, user has to consume the iterator of one group before process next group, or
we will get into infinite loops.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9330 from cloud-fan/group.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f79ebf2a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f79ebf2a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f79ebf2a

Branch: refs/heads/master
Commit: f79ebf2a9e99575908dad6f7a14c8cfcffdebd91
Parents: 87f28fc
Author: Wenchen Fan <wenchen@databricks.com>
Authored: Thu Oct 29 11:49:45 2015 +0100
Committer: Michael Armbrust <michael@databricks.com>
Committed: Thu Oct 29 11:49:45 2015 +0100

----------------------------------------------------------------------
 .../spark/sql/execution/GroupedIterator.scala   | 99 ++++++++++++--------
 .../sql/execution/GroupedIteratorSuite.scala    | 82 ++++++++++++++++
 2 files changed, 144 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f79ebf2a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
index 10742cf..6a88501 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
@@ -27,7 +27,7 @@ object GroupedIterator {
       keyExpressions: Seq[Expression],
       inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = {
     if (input.hasNext) {
-      new GroupedIterator(input, keyExpressions, inputSchema)
+      new GroupedIterator(input.buffered, keyExpressions, inputSchema)
     } else {
       Iterator.empty
     }
@@ -64,7 +64,7 @@ object GroupedIterator {
  * @param inputSchema The schema of the rows in the `input` iterator.
  */
 class GroupedIterator private(
-    input: Iterator[InternalRow],
+    input: BufferedIterator[InternalRow],
     groupingExpressions: Seq[Expression],
     inputSchema: Seq[Attribute])
   extends Iterator[(InternalRow, Iterator[InternalRow])] {
@@ -83,10 +83,17 @@ class GroupedIterator private(
 
   /** Holds a copy of an input row that is in the current group. */
   var currentGroup = currentRow.copy()
-  var currentIterator: Iterator[InternalRow] = null
+
   assert(keyOrdering.compare(currentGroup, currentRow) == 0)
+  var currentIterator = createGroupValuesIterator()
 
-  // Return true if we already have the next iterator or fetching a new iterator is successful.
+  /**
+   * Return true if we already have the next iterator or fetching a new iterator is successful.
+   *
+   * Note that, if we get the iterator by `next`, we should consume it before call `hasNext`,
+   * because we will consume the input data to skip to next group while fetching a new iterator,
+   * thus make the previous iterator empty.
+   */
   def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator
 
   def next(): (InternalRow, Iterator[InternalRow]) = {
@@ -96,46 +103,64 @@ class GroupedIterator private(
     ret
   }
 
-  def fetchNextGroupIterator(): Boolean = {
-    if (currentRow != null || input.hasNext) {
-      val inputIterator = new Iterator[InternalRow] {
-        // Return true if we have a row and it is in the current group, or if fetching a
new row is
-        // successful.
-        def hasNext = {
-          (currentRow != null && keyOrdering.compare(currentGroup, currentRow) ==
0) ||
-            fetchNextRowInGroup()
-        }
+  private def fetchNextGroupIterator(): Boolean = {
+    assert(currentIterator == null)
+
+    if (currentRow == null && input.hasNext) {
+      currentRow = input.next()
+    }
+
+    if (currentRow == null) {
+      // These is no data left, return false.
+      false
+    } else {
+      // Skip to next group.
+      while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0)
{
+        currentRow = input.next()
+      }
+
+      if (keyOrdering.compare(currentGroup, currentRow) == 0) {
+        // We are in the last group, there is no more groups, return false.
+        false
+      } else {
+        // Now the `currentRow` is the first row of next group.
+        currentGroup = currentRow.copy()
+        currentIterator = createGroupValuesIterator()
+        true
+      }
+    }
+  }
+
+  private def createGroupValuesIterator(): Iterator[InternalRow] = {
+    new Iterator[InternalRow] {
+      def hasNext: Boolean = currentRow != null || fetchNextRowInGroup()
+
+      def next(): InternalRow = {
+        assert(hasNext)
+        val res = currentRow
+        currentRow = null
+        res
+      }
 
-        def fetchNextRowInGroup(): Boolean = {
-          if (currentRow != null || input.hasNext) {
+      private def fetchNextRowInGroup(): Boolean = {
+        assert(currentRow == null)
+
+        if (input.hasNext) {
+          // The inner iterator should NOT consume the input into next group, here we use
`head` to
+          // peek the next input, to see if we should continue to process it.
+          if (keyOrdering.compare(currentGroup, input.head) == 0) {
+            // Next input is in the current group.  Continue the inner iterator.
             currentRow = input.next()
-            if (keyOrdering.compare(currentGroup, currentRow) == 0) {
-              // The row is in the current group.  Continue the inner iterator.
-              true
-            } else {
-              // We got a row, but its not in the right group.  End this inner iterator and
prepare
-              // for the next group.
-              currentIterator = null
-              currentGroup = currentRow.copy()
-              false
-            }
+            true
           } else {
-            // There is no more input so we are done.
+            // Next input is not in the right group.  End this inner iterator.
             false
           }
-        }
-
-        def next(): InternalRow = {
-          assert(hasNext) // Ensure we have fetched the next row.
-          val res = currentRow
-          currentRow = null
-          res
+        } else {
+          // There is no more data, return false.
+          false
         }
       }
-      currentIterator = inputIterator
-      true
-    } else {
-      false
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f79ebf2a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
new file mode 100644
index 0000000..e7a0848
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}
+
+class GroupedIteratorSuite extends SparkFunSuite {
+
+  test("basic") {
+    val schema = new StructType().add("i", IntegerType).add("s", StringType)
+    val encoder = RowEncoder(schema)
+    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
+    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
+      Seq('i.int.at(0)), schema.toAttributes)
+
+    val result = grouped.map {
+      case (key, data) =>
+        assert(key.numFields == 1)
+        key.getInt(0) -> data.map(encoder.fromRow).toSeq
+    }.toSeq
+
+    assert(result ==
+      1 -> Seq(input(0), input(1)) ::
+      2 -> Seq(input(2)) :: Nil)
+  }
+
+  test("group by 2 columns") {
+    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
+    val encoder = RowEncoder(schema)
+
+    val input = Seq(
+      Row(1, 2L, "a"),
+      Row(1, 2L, "b"),
+      Row(1, 3L, "c"),
+      Row(2, 1L, "d"),
+      Row(3, 2L, "e"))
+
+    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
+      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)
+
+    val result = grouped.map {
+      case (key, data) =>
+        assert(key.numFields == 2)
+        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
+    }.toSeq
+
+    assert(result ==
+      (1, 2L, Seq(input(0), input(1))) ::
+      (1, 3L, Seq(input(2))) ::
+      (2, 1L, Seq(input(3))) ::
+      (3, 2L, Seq(input(4))) :: Nil)
+  }
+
+  test("do nothing to the value iterator") {
+    val schema = new StructType().add("i", IntegerType).add("s", StringType)
+    val encoder = RowEncoder(schema)
+    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
+    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
+      Seq('i.int.at(0)), schema.toAttributes)
+
+    assert(grouped.length == 2)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message