spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-21657][SQL] optimize explode quadratic memory consumpation
Date Fri, 29 Dec 2017 13:08:42 GMT
Repository: spark
Updated Branches:
  refs/heads/master cc30ef800 -> fcf66a327


[SPARK-21657][SQL] optimize explode quadratic memory consumpation

## What changes were proposed in this pull request?

The issue has been raised in two Jira tickets: [SPARK-21657](https://issues.apache.org/jira/browse/SPARK-21657),
[SPARK-16998](https://issues.apache.org/jira/browse/SPARK-16998). Basically, what happens
is that in collection generators like explode/inline we create many rows from each row. Currently
each exploded row contains also the column on which it was created. This causes, for example,
if we have a 10k array in one row that this array will get copy 10k times - to each of the
row. this results a qudratic memory consumption. However, it is a common case that the original
column gets projected out after the explode, so we can avoid duplicating it.
In this solution we propose to identify this situation in the optimizer and turn on a flag
for omitting the original column in the generation process.

## How was this patch tested?

1. We added a benchmark test to MiscBenchmark that shows x16 improvement in runtimes.
2. We ran some of the other tests in MiscBenchmark and they show 15% improvements.
3. We ran this code on a specific case from our production data with rows containing arrays
of size ~200k and it reduced the runtime from 6 hours to 3 mins.

Author: oraviv <oraviv@paypal.com>
Author: uzadude <ohad.raviv@gmail.com>
Author: uzadude <15645757+uzadude@users.noreply.github.com>

Closes #19683 from uzadude/optimize_explode.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fcf66a32
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fcf66a32
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fcf66a32

Branch: refs/heads/master
Commit: fcf66a32760c74e601acb537c51b2311ece6e9d5
Parents: cc30ef8
Author: oraviv <oraviv@paypal.com>
Authored: Fri Dec 29 21:08:34 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Fri Dec 29 21:08:34 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  6 +--
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  4 +-
 .../apache/spark/sql/catalyst/dsl/package.scala |  6 +--
 .../sql/catalyst/optimizer/Optimizer.scala      | 13 +++---
 .../spark/sql/catalyst/parser/AstBuilder.scala  |  2 +-
 .../plans/logical/basicLogicalOperators.scala   | 21 ++++++----
 .../catalyst/optimizer/ColumnPruningSuite.scala | 44 ++++++++++++--------
 .../optimizer/FilterPushdownSuite.scala         | 14 +++----
 .../sql/catalyst/parser/PlanParserSuite.scala   | 16 ++++---
 .../scala/org/apache/spark/sql/Dataset.scala    |  4 +-
 .../spark/sql/execution/GenerateExec.scala      | 29 ++++++-------
 .../spark/sql/execution/SparkStrategies.scala   |  6 +--
 .../sql/execution/benchmark/MiscBenchmark.scala | 37 ++++++++++++++++
 13 files changed, 128 insertions(+), 74 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7f2128e..1f7191c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -696,7 +696,7 @@ class Analyzer(
           (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
 
         case oldVersion: Generate
-            if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
+            if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
           val newOutput = oldVersion.generatorOutput.map(_.newInstance())
           (oldVersion, oldVersion.copy(generatorOutput = newOutput))
 
@@ -1138,7 +1138,7 @@ class Analyzer(
           case g: Generate =>
             val maybeResolvedExprs = exprs.map(resolveExpression(_, g))
             val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs,
g.child)
-            (newExprs, g.copy(join = true, child = newChild))
+            (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild))
 
           // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes
           // via its children.
@@ -1578,7 +1578,7 @@ class Analyzer(
             resolvedGenerator =
               Generate(
                 generator,
-                join = projectList.size > 1, // Only join if there are other expressions
in SELECT.
+                unrequiredChildIndex = Nil,
                 outer = outer,
                 qualifier = None,
                 generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names),

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 6894aed..bbcec56 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -608,8 +608,8 @@ trait CheckAnalysis extends PredicateHelper {
       // allows to have correlation under it
       // but must not host any outer references.
       // Note:
-      // Generator with join=false is treated as Category 4.
-      case g: Generate if g.join =>
+      // Generator with requiredChildOutput.isEmpty is treated as Category 4.
+      case g: Generate if g.requiredChildOutput.nonEmpty =>
         failOnInvalidOuterReference(g)
 
       // Category 4: Any other operators not in the above 3 categories

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 7c100af..59cb26d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -359,12 +359,12 @@ package object dsl {
 
       def generate(
         generator: Generator,
-        join: Boolean = false,
+        unrequiredChildIndex: Seq[Int] = Nil,
         outer: Boolean = false,
         alias: Option[String] = None,
         outputNames: Seq[String] = Nil): LogicalPlan =
-        Generate(generator, join = join, outer = outer, alias,
-          outputNames.map(UnresolvedAttribute(_)), logicalPlan)
+        Generate(generator, unrequiredChildIndex, outer,
+          alias, outputNames.map(UnresolvedAttribute(_)), logicalPlan)
 
       def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
         InsertIntoTable(

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 6a4d1e9..eeb1b13 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -456,12 +456,15 @@ object ColumnPruning extends Rule[LogicalPlan] {
       f.copy(child = prunedChild(child, f.references))
     case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
       e.copy(child = prunedChild(child, e.references))
-    case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
-      g.copy(child = prunedChild(g.child, g.references))
 
-    // Turn off `join` for Generate if no column from it's child is used
-    case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet)
=>
-      p.copy(child = g.copy(join = false))
+    // prune unrequired references
+    case p @ Project(_, g: Generate) if p.references != g.outputSet =>
+      val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references
+      val newChild = prunedChild(g.child, requiredAttrs)
+      val unrequired = g.generator.references -- p.references
+      val unrequiredIndices = newChild.output.zipWithIndex.filter(t => unrequired.contains(t._1))
+        .map(_._2)
+      p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices))
 
     // Eliminate unneeded attributes from right side of a Left Existence Join.
     case j @ Join(_, right, LeftExistence(_), _) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 7651d11..bdc357d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -623,7 +623,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with
Logging
     val expressions = expressionList(ctx.expression)
     Generate(
       UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions),
-      join = true,
+      unrequiredChildIndex = Nil,
       outer = ctx.OUTER != null,
       Some(ctx.tblName.getText.toLowerCase),
       ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply),

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index cd47455..95e099c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -73,8 +73,13 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
extend
  * their output.
  *
  * @param generator the generator expression
- * @param join  when true, each output row is implicitly joined with the input tuple that
produced
- *              it.
+ * @param unrequiredChildIndex this paramter starts as Nil and gets filled by the Optimizer.
+ *                             It's used as an optimization for omitting data generation
that will
+ *                             be discarded next by a projection.
+ *                             A common use case is when we explode(array(..)) and are interested
+ *                             only in the exploded data and not in the original array. before
this
+ *                             optimization the array got duplicated for each of its elements,
+ *                             causing O(n^^2) memory consumption. (see [SPARK-21657])
  * @param outer when true, each input row will be output at least once, even if the output
of the
  *              given `generator` is empty.
  * @param qualifier Qualifier for the attributes of generator(UDTF)
@@ -83,15 +88,17 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)
extend
  */
 case class Generate(
     generator: Generator,
-    join: Boolean,
+    unrequiredChildIndex: Seq[Int],
     outer: Boolean,
     qualifier: Option[String],
     generatorOutput: Seq[Attribute],
     child: LogicalPlan)
   extends UnaryNode {
 
-  /** The set of all attributes produced by this node. */
-  def generatedSet: AttributeSet = AttributeSet(generatorOutput)
+  lazy val requiredChildOutput: Seq[Attribute] = {
+    val unrequiredSet = unrequiredChildIndex.toSet
+    child.output.zipWithIndex.filterNot(t => unrequiredSet.contains(t._2)).map(_._1)
+  }
 
   override lazy val resolved: Boolean = {
     generator.resolved &&
@@ -114,9 +121,7 @@ case class Generate(
     nullableOutput
   }
 
-  def output: Seq[Attribute] = {
-    if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput
-  }
+  def output: Seq[Attribute] = requiredChildOutput ++ qualifiedGeneratorOutput
 }
 
 case class Filter(condition: Expression, child: LogicalPlan)

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 77e4eff..9f0f7e1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -38,54 +38,64 @@ class ColumnPruningSuite extends PlanTest {
       CollapseProject) :: Nil
   }
 
-  test("Column pruning for Generate when Generate.join = false") {
-    val input = LocalRelation('a.int, 'b.array(StringType))
+  test("Column pruning for Generate when Generate.unrequiredChildIndex = child.output") {
+    val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
 
-    val query = input.generate(Explode('b), join = false).analyze
+    val query =
+      input
+        .generate(Explode('c), outputNames = "explode" :: Nil)
+        .select('c, 'explode)
+        .analyze
 
     val optimized = Optimize.execute(query)
 
-    val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze
+    val correctAnswer =
+      input
+        .select('c)
+        .generate(Explode('c), outputNames = "explode" :: Nil)
+        .analyze
 
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Column pruning for Generate when Generate.join = true") {
-    val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
+  test("Fill Generate.unrequiredChildIndex if possible") {
+    val input = LocalRelation('b.array(StringType))
 
     val query =
       input
-        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
-        .select('a, 'explode)
+        .generate(Explode('b), outputNames = "explode" :: Nil)
+        .select(('explode + 1).as("result"))
         .analyze
 
     val optimized = Optimize.execute(query)
 
     val correctAnswer =
       input
-        .select('a, 'c)
-        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
-        .select('a, 'explode)
+        .generate(Explode('b), unrequiredChildIndex = input.output.zipWithIndex.map(_._2),
+          outputNames = "explode" :: Nil)
+         .select(('explode + 1).as("result"))
         .analyze
 
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Turn Generate.join to false if possible") {
-    val input = LocalRelation('b.array(StringType))
+  test("Another fill Generate.unrequiredChildIndex if possible") {
+    val input = LocalRelation('a.int, 'b.int, 'c1.string, 'c2.string)
 
     val query =
       input
-        .generate(Explode('b), join = true, outputNames = "explode" :: Nil)
-        .select(('explode + 1).as("result"))
+        .generate(Explode(CreateArray(Seq('c1, 'c2))), outputNames = "explode" :: Nil)
+        .select('a, 'c1, 'explode)
         .analyze
 
     val optimized = Optimize.execute(query)
 
     val correctAnswer =
       input
-        .generate(Explode('b), join = false, outputNames = "explode" :: Nil)
-        .select(('explode + 1).as("result"))
+        .select('a, 'c1, 'c2)
+        .generate(Explode(CreateArray(Seq('c1, 'c2))),
+          unrequiredChildIndex = Seq(2),
+          outputNames = "explode" :: Nil)
         .analyze
 
     comparePlans(optimized, correctAnswer)

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 641824e..4a23179 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -624,14 +624,14 @@ class FilterPushdownSuite extends PlanTest {
   test("generate: predicate referenced no generated column") {
     val originalQuery = {
       testRelationWithArrayType
-        .generate(Explode('c_arr), true, false, Some("arr"))
+        .generate(Explode('c_arr), alias = Some("arr"))
         .where(('b >= 5) && ('a > 6))
     }
     val optimized = Optimize.execute(originalQuery.analyze)
     val correctAnswer = {
       testRelationWithArrayType
         .where(('b >= 5) && ('a > 6))
-        .generate(Explode('c_arr), true, false, Some("arr")).analyze
+        .generate(Explode('c_arr), alias = Some("arr")).analyze
     }
 
     comparePlans(optimized, correctAnswer)
@@ -640,14 +640,14 @@ class FilterPushdownSuite extends PlanTest {
   test("generate: non-deterministic predicate referenced no generated column") {
     val originalQuery = {
       testRelationWithArrayType
-        .generate(Explode('c_arr), true, false, Some("arr"))
+        .generate(Explode('c_arr), alias = Some("arr"))
         .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('col
> 6))
     }
     val optimized = Optimize.execute(originalQuery.analyze)
     val correctAnswer = {
       testRelationWithArrayType
         .where('b >= 5)
-        .generate(Explode('c_arr), true, false, Some("arr"))
+        .generate(Explode('c_arr), alias = Some("arr"))
         .where('a + Rand(10).as("rnd") > 6 && 'col > 6)
         .analyze
     }
@@ -659,14 +659,14 @@ class FilterPushdownSuite extends PlanTest {
     val generator = Explode('c_arr)
     val originalQuery = {
       testRelationWithArrayType
-        .generate(generator, true, false, Some("arr"))
+        .generate(generator, alias = Some("arr"))
         .where(('b >= 5) && ('c > 6))
     }
     val optimized = Optimize.execute(originalQuery.analyze)
     val referenceResult = {
       testRelationWithArrayType
         .where('b >= 5)
-        .generate(generator, true, false, Some("arr"))
+        .generate(generator, alias = Some("arr"))
         .where('c > 6).analyze
     }
 
@@ -687,7 +687,7 @@ class FilterPushdownSuite extends PlanTest {
   test("generate: all conjuncts referenced generated column") {
     val originalQuery = {
       testRelationWithArrayType
-        .generate(Explode('c_arr), true, false, Some("arr"))
+        .generate(Explode('c_arr), alias = Some("arr"))
         .where(('col > 6) || ('b > 5)).analyze
     }
     val optimized = Optimize.execute(originalQuery)

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index d34a83c..812bfdd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -276,7 +276,7 @@ class PlanParserSuite extends AnalysisTest {
     assertEqual(
       "select * from t lateral view explode(x) expl as x",
       table("t")
-        .generate(explode, join = true, outer = false, Some("expl"), Seq("x"))
+        .generate(explode, alias = Some("expl"), outputNames = Seq("x"))
         .select(star()))
 
     // Multiple lateral views
@@ -286,12 +286,12 @@ class PlanParserSuite extends AnalysisTest {
         |lateral view explode(x) expl
         |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin,
       table("t")
-        .generate(explode, join = true, outer = false, Some("expl"), Seq.empty)
-        .generate(jsonTuple, join = true, outer = true, Some("jtup"), Seq("q", "z"))
+        .generate(explode, alias = Some("expl"))
+        .generate(jsonTuple, outer = true, alias = Some("jtup"), outputNames = Seq("q", "z"))
         .select(star()))
 
     // Multi-Insert lateral views.
-    val from = table("t1").generate(explode, join = true, outer = false, Some("expl"), Seq("x"))
+    val from = table("t1").generate(explode, alias = Some("expl"), outputNames = Seq("x"))
     assertEqual(
       """from t1
         |lateral view explode(x) expl as x
@@ -303,7 +303,7 @@ class PlanParserSuite extends AnalysisTest {
         |where s < 10
       """.stripMargin,
       Union(from
-        .generate(jsonTuple, join = true, outer = false, Some("jtup"), Seq("q", "z"))
+        .generate(jsonTuple, alias = Some("jtup"), outputNames = Seq("q", "z"))
         .select(star())
         .insertInto("t2"),
         from.where('s < 10).select(star()).insertInto("t3")))
@@ -312,10 +312,8 @@ class PlanParserSuite extends AnalysisTest {
     val expected = table("t")
       .generate(
         UnresolvedGenerator(FunctionIdentifier("posexplode"), Seq('x)),
-        join = true,
-        outer = false,
-        Some("posexpl"),
-        Seq("x", "y"))
+        alias = Some("posexpl"),
+        outputNames = Seq("x", "y"))
       .select(star())
     assertEqual(
       "select * from t lateral view posexplode(x) posexpl as x, y",

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 209b800..77e5712 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2095,7 +2095,7 @@ class Dataset[T] private[sql](
     val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr))
 
     withPlan {
-      Generate(generator, join = true, outer = false,
+      Generate(generator, unrequiredChildIndex = Nil, outer = false,
         qualifier = None, generatorOutput = Nil, planWithBarrier)
     }
   }
@@ -2136,7 +2136,7 @@ class Dataset[T] private[sql](
     val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr
:: Nil)
 
     withPlan {
-      Generate(generator, join = true, outer = false,
+      Generate(generator, unrequiredChildIndex = Nil, outer = false,
         qualifier = None, generatorOutput = Nil, planWithBarrier)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index e1562be..0c2c4a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -47,8 +47,7 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In
  * terminate().
  *
  * @param generator the generator expression
- * @param join  when true, each output row is implicitly joined with the input tuple that
produced
- *              it.
+ * @param requiredChildOutput required attributes from child's output
  * @param outer when true, each input row will be output at least once, even if the output
of the
  *              given `generator` is empty.
  * @param generatorOutput the qualified output attributes of the generator of this node,
which
@@ -57,19 +56,13 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In
  */
 case class GenerateExec(
     generator: Generator,
-    join: Boolean,
+    requiredChildOutput: Seq[Attribute],
     outer: Boolean,
     generatorOutput: Seq[Attribute],
     child: SparkPlan)
   extends UnaryExecNode with CodegenSupport {
 
-  override def output: Seq[Attribute] = {
-    if (join) {
-      child.output ++ generatorOutput
-    } else {
-      generatorOutput
-    }
-  }
+  override def output: Seq[Attribute] = requiredChildOutput ++ generatorOutput
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -85,11 +78,19 @@ case class GenerateExec(
     val numOutputRows = longMetric("numOutputRows")
     child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val generatorNullRow = new GenericInternalRow(generator.elementSchema.length)
-      val rows = if (join) {
+      val rows = if (requiredChildOutput.nonEmpty) {
+
+        val pruneChildForResult: InternalRow => InternalRow =
+          if (child.outputSet == AttributeSet(requiredChildOutput)) {
+            identity
+          } else {
+            UnsafeProjection.create(requiredChildOutput, child.output)
+          }
+
         val joinedRow = new JoinedRow
         iter.flatMap { row =>
-          // we should always set the left (child output)
-          joinedRow.withLeft(row)
+          // we should always set the left (required child output)
+          joinedRow.withLeft(pruneChildForResult(row))
           val outputRows = boundGenerator.eval(row)
           if (outer && outputRows.isEmpty) {
             joinedRow.withRight(generatorNullRow) :: Nil
@@ -136,7 +137,7 @@ case class GenerateExec(
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String
= {
     // Add input rows to the values when we are joining
-    val values = if (join) {
+    val values = if (requiredChildOutput.nonEmpty) {
       input
     } else {
       Seq.empty

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 0ed7c2f..9102948 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -499,10 +499,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.GlobalLimitExec(limit, planLater(child)) :: Nil
       case logical.Union(unionChildren) =>
         execution.UnionExec(unionChildren.map(planLater)) :: Nil
-      case g @ logical.Generate(generator, join, outer, _, _, child) =>
+      case g @ logical.Generate(generator, _, outer, _, _, child) =>
         execution.GenerateExec(
-          generator, join = join, outer = outer, g.qualifiedGeneratorOutput,
-          planLater(child)) :: Nil
+          generator, g.requiredChildOutput, outer,
+          g.qualifiedGeneratorOutput, planLater(child)) :: Nil
       case _: logical.OneRowRelation =>
         execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
       case r: logical.Range =>

http://git-wip-us.apache.org/repos/asf/spark/blob/fcf66a32/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
index 01773c2..f039aea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala
@@ -202,6 +202,42 @@ class MiscBenchmark extends BenchmarkBase {
     generate inline array wholestage off          6901 / 6928          2.4         411.3
      1.0X
     generate inline array wholestage on           1001 / 1010         16.8          59.7
      6.9X
      */
+
+    val M = 60000
+    runBenchmark("generate big struct array", M) {
+      import sparkSession.implicits._
+      val df = sparkSession.sparkContext.parallelize(Seq(("1",
+        Array.fill(M)({
+          val i = math.random
+          (i.toString, (i + 1).toString, (i + 2).toString, (i + 3).toString)
+        })))).toDF("col", "arr")
+
+      df.selectExpr("*", "expode(arr) as arr_col")
+        .select("col", "arr_col.*").count
+    }
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_151-b12 on Mac OS X 10.12.6
+    Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz
+
+    test the impact of adding the optimization of Generate.unrequiredChildIndex,
+    we can see enormous improvement of x250 in this case! and it grows O(n^2).
+
+    with Optimization ON:
+
+    generate big struct array:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)
  Relative
+    ------------------------------------------------------------------------------------------------
+    generate big struct array wholestage off       331 /  378          0.2        5524.9
      1.0X
+    generate big struct array wholestage on        205 /  232          0.3        3413.1
      1.6X
+
+    with Optimization OFF:
+
+    generate big struct array:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)
  Relative
+    ------------------------------------------------------------------------------------------------
+    generate big struct array wholestage off    49697 / 51496          0.0      828277.7
      1.0X
+    generate big struct array wholestage on     50558 / 51434          0.0      842641.6
      1.0X
+     */
+
   }
 
   ignore("generate regular generator") {
@@ -227,4 +263,5 @@ class MiscBenchmark extends BenchmarkBase {
     generate stack wholestage on                   836 /  847         20.1          49.8
     15.5X
      */
   }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message