spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From lix...@apache.org
Subject spark git commit: [SPARK-19471][SQL] AggregationIterator does not initialize the generated result projection before using it
Date Mon, 14 Aug 2017 16:37:23 GMT
Repository: spark
Updated Branches:
  refs/heads/master 0326b69c9 -> fbc269252


[SPARK-19471][SQL] AggregationIterator does not initialize the generated result projection
before using it

## What changes were proposed in this pull request?

Recently, we have also encountered such NPE issues in our production environment as described
in:
https://issues.apache.org/jira/browse/SPARK-19471

This issue can be reproduced by the following examples:
` val df = spark.createDataFrame(Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4))).toDF("x", "y")

//HashAggregate, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key=false
df.groupBy("x").agg(rand(),sum("y")).show()

//ObjectHashAggregate, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key=false
df.groupBy("x").agg(rand(),collect_list("y")).show()

//SortAggregate, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key=false &&SQLConf.USE_OBJECT_HASH_AGG.key=false
df.groupBy("x").agg(rand(),collect_list("y")).show()`
`

This PR is based on PR-16820(https://github.com/apache/spark/pull/16820) with test cases for
all aggregation paths. We want to push it forward.

> When AggregationIterator generates result projection, it does not call the initialize
method of the Projection class. This will cause a runtime NullPointerException when the projection
involves nondeterministic expressions.

## How was this patch tested?

unit test
verified in production environment

Author: donnyzone <wellfengzhu@gmail.com>

Closes #18920 from DonnyZone/Branch-spark-19471.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fbc26925
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fbc26925
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fbc26925

Branch: refs/heads/master
Commit: fbc269252a1c99e04bd08906ad8404c031e9a097
Parents: 0326b69
Author: donnyzone <wellfengzhu@gmail.com>
Authored: Mon Aug 14 09:37:18 2017 -0700
Committer: gatorsmile <gatorsmile@gmail.com>
Committed: Mon Aug 14 09:37:18 2017 -0700

----------------------------------------------------------------------
 .../aggregate/AggregationIterator.scala         |  4 ++
 .../execution/aggregate/HashAggregateExec.scala |  3 +-
 .../aggregate/ObjectAggregationIterator.scala   |  2 +
 .../aggregate/ObjectHashAggregateExec.scala     |  3 +-
 .../execution/aggregate/SortAggregateExec.scala |  3 +-
 .../SortBasedAggregationIterator.scala          |  2 +
 .../aggregate/TungstenAggregationIterator.scala |  4 ++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 45 ++++++++++++++++++++
 8 files changed, 63 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fbc26925/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 7c11fdb..98c4a51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
  *    is used to generate result.
  */
 abstract class AggregationIterator(
+    partIndex: Int,
     groupingExpressions: Seq[NamedExpression],
     inputAttributes: Seq[Attribute],
     aggregateExpressions: Seq[AggregateExpression],
@@ -217,6 +218,7 @@ abstract class AggregationIterator(
 
       val resultProjection =
         UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes)
+      resultProjection.initialize(partIndex)
 
       (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
         // Generate results for all expression-based aggregate functions.
@@ -235,6 +237,7 @@ abstract class AggregationIterator(
       val resultProjection = UnsafeProjection.create(
         groupingAttributes ++ bufferAttributes,
         groupingAttributes ++ bufferAttributes)
+      resultProjection.initialize(partIndex)
 
       // TypedImperativeAggregate stores generic object in aggregation buffer, and requires
       // calling serialization before shuffling. See [[TypedImperativeAggregate]] for more
info.
@@ -256,6 +259,7 @@ abstract class AggregationIterator(
     } else {
       // Grouping-only: we only output values based on grouping expressions.
       val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
+      resultProjection.initialize(partIndex)
       (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
         resultProjection(currentGroupingKey)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc26925/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 56f61c3..80ea458 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -96,7 +96,7 @@ case class HashAggregateExec(
     val spillSize = longMetric("spillSize")
     val avgHashProbe = longMetric("avgHashProbe")
 
-    child.execute().mapPartitions { iter =>
+    child.execute().mapPartitionsWithIndex { (partIndex, iter) =>
 
       val hasInput = iter.hasNext
       if (!hasInput && groupingExpressions.nonEmpty) {
@@ -106,6 +106,7 @@ case class HashAggregateExec(
       } else {
         val aggregationIterator =
           new TungstenAggregationIterator(
+            partIndex,
             groupingExpressions,
             aggregateExpressions,
             aggregateAttributes,

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc26925/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
index eef2c4e..c68dbc7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
@@ -31,6 +31,7 @@ import org.apache.spark.unsafe.KVIterator
 import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 
 class ObjectAggregationIterator(
+    partIndex: Int,
     outputAttributes: Seq[Attribute],
     groupingExpressions: Seq[NamedExpression],
     aggregateExpressions: Seq[AggregateExpression],
@@ -43,6 +44,7 @@ class ObjectAggregationIterator(
     fallbackCountThreshold: Int,
     numOutputRows: SQLMetric)
   extends AggregationIterator(
+    partIndex,
     groupingExpressions,
     originalInputAttributes,
     aggregateExpressions,

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc26925/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
index b53521b..6316e06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
@@ -98,7 +98,7 @@ case class ObjectHashAggregateExec(
     val numOutputRows = longMetric("numOutputRows")
     val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold
 
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
       val hasInput = iter.hasNext
       if (!hasInput && groupingExpressions.nonEmpty) {
         // This is a grouped aggregate and the input kvIterator is empty,
@@ -107,6 +107,7 @@ case class ObjectHashAggregateExec(
       } else {
         val aggregationIterator =
           new ObjectAggregationIterator(
+            partIndex,
             child.output,
             groupingExpressions,
             aggregateExpressions,

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc26925/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
index be3198b..a432357 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala
@@ -74,7 +74,7 @@ case class SortAggregateExec(
 
   protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
     val numOutputRows = longMetric("numOutputRows")
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
       // Because the constructor of an aggregation iterator will read at least the first
row,
       // we need to get the value of iter.hasNext first.
       val hasInput = iter.hasNext
@@ -84,6 +84,7 @@ case class SortAggregateExec(
         Iterator[UnsafeRow]()
       } else {
         val outputIter = new SortBasedAggregationIterator(
+          partIndex,
           groupingExpressions,
           child.output,
           iter,

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc26925/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index a5a444b..492b0f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
  * sorted by values of [[groupingExpressions]].
  */
 class SortBasedAggregationIterator(
+    partIndex: Int,
     groupingExpressions: Seq[NamedExpression],
     valueAttributes: Seq[Attribute],
     inputIterator: Iterator[InternalRow],
@@ -37,6 +38,7 @@ class SortBasedAggregationIterator(
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
     numOutputRows: SQLMetric)
   extends AggregationIterator(
+    partIndex,
     groupingExpressions,
     valueAttributes,
     aggregateExpressions,

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc26925/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index cfa9306..756eeb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -60,6 +60,8 @@ import org.apache.spark.unsafe.KVIterator
  *  - Part 8: A utility function used to generate a result when there is no
  *            input and there is no grouping expression.
  *
+ * @param partIndex
+ *   index of the partition
  * @param groupingExpressions
  *   expressions for grouping keys
  * @param aggregateExpressions
@@ -77,6 +79,7 @@ import org.apache.spark.unsafe.KVIterator
  *   the iterator containing input [[UnsafeRow]]s.
  */
 class TungstenAggregationIterator(
+    partIndex: Int,
     groupingExpressions: Seq[NamedExpression],
     aggregateExpressions: Seq[AggregateExpression],
     aggregateAttributes: Seq[Attribute],
@@ -91,6 +94,7 @@ class TungstenAggregationIterator(
     spillSize: SQLMetric,
     avgHashProbe: SQLMetric)
   extends AggregationIterator(
+    partIndex,
     groupingExpressions,
     originalInputAttributes,
     aggregateExpressions,

http://git-wip-us.apache.org/repos/asf/spark/blob/fbc26925/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 0681b9c..fdb9f1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -24,6 +24,8 @@ import scala.util.Random
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.execution.WholeStageCodegenExec
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec,
SortAggregateExec}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
@@ -449,6 +451,49 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
     ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
   }
 
+  private def assertNoExceptions(c: Column): Unit = {
+    for ((wholeStage, useObjectHashAgg) <-
+         Seq((true, true), (true, false), (false, true), (false, false))) {
+      withSQLConf(
+        (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString),
+        (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) {
+
+        val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y")
+
+        // HashAggregate test case
+        val hashAggDF = df.groupBy("x").agg(c, sum("y"))
+        val hashAggPlan = hashAggDF.queryExecution.executedPlan
+        if (wholeStage) {
+          assert(hashAggPlan.find {
+            case WholeStageCodegenExec(_: HashAggregateExec) => true
+            case _ => false
+          }.isDefined)
+        } else {
+          assert(hashAggPlan.isInstanceOf[HashAggregateExec])
+        }
+        hashAggDF.collect()
+
+        // ObjectHashAggregate and SortAggregate test case
+        val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
+        val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan
+        if (useObjectHashAgg) {
+          assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
+        } else {
+          assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
+        }
+        objHashAggOrSortAggDF.collect()
+      }
+    }
+  }
+
+  test("SPARK-19471: AggregationIterator does not initialize the generated result projection"
+
+    " before using it") {
+    Seq(
+      monotonically_increasing_id(), spark_partition_id(),
+      rand(Random.nextLong()), randn(Random.nextLong())
+    ).foreach(assertNoExceptions)
+  }
+
   test("SPARK-21281 use string types by default if array and map have no argument") {
     val ds = spark.range(1)
     var expectedSchema = new StructType()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message