spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-13031][SQL] cleanup codegen and improve test coverage
Date Fri, 29 Jan 2016 10:00:02 GMT
Repository: spark
Updated Branches:
  refs/heads/master 8d3cc3de7 -> 55561e769


[SPARK-13031][SQL] cleanup codegen and improve test coverage

1. enable whole stage codegen during tests even there is only one operator supports that.
2. split doProduce() into two APIs: upstream() and doProduce()
3. generate prefix for fresh names of each operator
4. pass UnsafeRow to parent directly (avoid getters and create UnsafeRow again)
5. fix bugs and tests.

This PR re-open #10944 and fix the bug.

Author: Davies Liu <davies@databricks.com>

Closes #10977 from davies/gen_refactor.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/55561e76
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/55561e76
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/55561e76

Branch: refs/heads/master
Commit: 55561e7693dd2a5bf3c7f8026c725421801fd0ec
Parents: 8d3cc3d
Author: Davies Liu <davies@databricks.com>
Authored: Fri Jan 29 01:59:59 2016 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Fri Jan 29 01:59:59 2016 -0800

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     |  13 +-
 .../codegen/GenerateMutableProjection.scala     |   2 +-
 .../spark/sql/execution/WholeStageCodegen.scala | 188 +++++++++++++------
 .../aggregate/AggregationIterator.scala         |   2 +-
 .../execution/aggregate/TungstenAggregate.scala |  98 +++++++---
 .../spark/sql/execution/basicOperators.scala    |  96 +++++-----
 .../spark/sql/DataFrameAggregateSuite.scala     |   7 +
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 103 +++++-----
 .../sql/execution/metric/SQLMetricsSuite.scala  |  34 ++--
 .../apache/spark/sql/test/SQLTestUtils.scala    |   2 +-
 .../spark/sql/util/DataFrameCallbackSuite.scala |  10 +-
 11 files changed, 350 insertions(+), 205 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/55561e76/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 2747c31..e6704cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -145,13 +145,22 @@ class CodegenContext {
   private val curId = new java.util.concurrent.atomic.AtomicInteger()
 
   /**
+    * A prefix used to generate fresh name.
+    */
+  var freshNamePrefix = ""
+
+  /**
    * Returns a term name that is unique within this instance of a `CodeGenerator`.
    *
    * (Since we aren't in a macro context we do not seem to have access to the built in `freshName`
    * function.)
    */
-  def freshName(prefix: String): String = {
-    s"$prefix${curId.getAndIncrement}"
+  def freshName(name: String): String = {
+    if (freshNamePrefix == "") {
+      s"$name${curId.getAndIncrement}"
+    } else {
+      s"${freshNamePrefix}_$name${curId.getAndIncrement}"
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/55561e76/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index d9fe761..ec31db1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -93,7 +93,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression],
() => Mu
             // Can't call setNullAt on DecimalType, because we need to keep the offset
             s"""
               if (this.isNull_$i) {
-                ${ctx.setColumn("mutableRow", e.dataType, i, null)};
+                ${ctx.setColumn("mutableRow", e.dataType, i, "null")};
               } else {
                 ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
               }

http://git-wip-us.apache.org/repos/asf/spark/blob/55561e76/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 57f4945..ef81ba6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -22,9 +22,11 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression,
LeafExpression}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.util.Utils
 
 /**
   * An interface for those physical operators that support codegen.
@@ -42,10 +44,16 @@ trait CodegenSupport extends SparkPlan {
   private var parent: CodegenSupport = null
 
   /**
-    * Returns an input RDD of InternalRow and Java source code to process them.
+    * Returns the RDD of InternalRow which generates the input rows.
     */
-  def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], String) =
{
+  def upstream(): RDD[InternalRow]
+
+  /**
+    * Returns Java source code to process the rows from upstream.
+    */
+  def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
     this.parent = parent
+    ctx.freshNamePrefix = nodeName
     doProduce(ctx)
   }
 
@@ -66,16 +74,41 @@ trait CodegenSupport extends SparkPlan {
     *     # call consume(), wich will call parent.doConsume()
     *   }
     */
-  protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
+  protected def doProduce(ctx: CodegenContext): String
 
   /**
-    * Consume the columns generated from current SparkPlan, call it's parent or create an
iterator.
+    * Consume the columns generated from current SparkPlan, call it's parent.
     */
-  protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = {
-    assert(columns.length == output.length)
-    parent.doConsume(ctx, this, columns)
+  def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): String = {
+    if (input != null) {
+      assert(input.length == output.length)
+    }
+    parent.consumeChild(ctx, this, input, row)
   }
 
+  /**
+    * Consume the columns generated from it's child, call doConsume() or emit the rows.
+    */
+  def consumeChild(
+      ctx: CodegenContext,
+      child: SparkPlan,
+      input: Seq[ExprCode],
+      row: String = null): String = {
+    ctx.freshNamePrefix = nodeName
+    if (row != null) {
+      ctx.currentVars = null
+      ctx.INPUT_ROW = row
+      val evals = child.output.zipWithIndex.map { case (attr, i) =>
+        BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+      }
+      s"""
+         | ${evals.map(_.code).mkString("\n")}
+         | ${doConsume(ctx, evals)}
+       """.stripMargin
+    } else {
+      doConsume(ctx, input)
+    }
+  }
 
   /**
     * Generate the Java source code to process the rows from child SparkPlan.
@@ -89,7 +122,9 @@ trait CodegenSupport extends SparkPlan {
     *     # call consume(), which will call parent.doConsume()
     *   }
     */
-  def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
+  protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    throw new UnsupportedOperationException
+  }
 }
 
 
@@ -102,31 +137,36 @@ trait CodegenSupport extends SparkPlan {
 case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
 
   override def output: Seq[Attribute] = child.output
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+  override def doPrepare(): Unit = {
+    child.prepare()
+  }
 
-  override def supportCodegen: Boolean = true
+  override def doExecute(): RDD[InternalRow] = {
+    child.execute()
+  }
 
-  override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+  override def supportCodegen: Boolean = false
+
+  override def upstream(): RDD[InternalRow] = {
+    child.execute()
+  }
+
+  override def doProduce(ctx: CodegenContext): String = {
     val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
     val row = ctx.freshName("row")
     ctx.INPUT_ROW = row
     ctx.currentVars = null
     val columns = exprs.map(_.gen(ctx))
-    val code = s"""
-       |  while (input.hasNext()) {
+    s"""
+       | while (input.hasNext()) {
        |   InternalRow $row = (InternalRow) input.next();
        |   ${columns.map(_.code).mkString("\n")}
        |   ${consume(ctx, columns)}
        | }
      """.stripMargin
-    (child.execute(), code)
-  }
-
-  def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
-    throw new UnsupportedOperationException
-  }
-
-  override def doExecute(): RDD[InternalRow] = {
-    throw new UnsupportedOperationException
   }
 
   override def simpleString: String = "INPUT"
@@ -143,16 +183,20 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport
{
   *
   * -> execute()
   *     |
-  *  doExecute() -------->   produce()
+  *  doExecute() --------->   upstream() -------> upstream() ------> execute()
+  *     |
+  *      ----------------->   produce()
   *                             |
   *                          doProduce()  -------> produce()
   *                                                   |
-  *                                                doProduce() ---> execute()
+  *                                                doProduce()
   *                                                   |
   *                                                consume()
-  *                          doConsume()  ------------|
+  *                        consumeChild() <-----------|
   *                             |
-  *  doConsume()  <-----    consume()
+  *                          doConsume()
+  *                             |
+  *  consumeChild()  <-----  consume()
   *
   * SparkPlan A should override doProduce() and doConsume().
   *
@@ -162,37 +206,48 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport
{
 case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
   extends SparkPlan with CodegenSupport {
 
+  override def supportCodegen: Boolean = false
+
   override def output: Seq[Attribute] = plan.output
+  override def outputPartitioning: Partitioning = plan.outputPartitioning
+  override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
+
+  override def doPrepare(): Unit = {
+    plan.prepare()
+  }
 
   override def doExecute(): RDD[InternalRow] = {
     val ctx = new CodegenContext
-    val (rdd, code) = plan.produce(ctx, this)
+    val code = plan.produce(ctx, this)
     val references = ctx.references.toArray
     val source = s"""
       public Object generate(Object[] references) {
-       return new GeneratedIterator(references);
+        return new GeneratedIterator(references);
       }
 
       class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator
{
 
-       private Object[] references;
-       ${ctx.declareMutableStates()}
+        private Object[] references;
+        ${ctx.declareMutableStates()}
+        ${ctx.declareAddedFunctions()}
 
-       public GeneratedIterator(Object[] references) {
+        public GeneratedIterator(Object[] references) {
          this.references = references;
          ${ctx.initMutableStates()}
-       }
+        }
 
-       protected void processNext() {
+        protected void processNext() throws java.io.IOException {
          $code
-       }
+        }
       }
-     """
+      """
+
     // try to compile, helpful for debug
     // println(s"${CodeFormatter.format(source)}")
     CodeGenerator.compile(source)
 
-    rdd.mapPartitions { iter =>
+    plan.upstream().mapPartitions { iter =>
+
       val clazz = CodeGenerator.compile(source)
       val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
       buffer.setInput(iter)
@@ -203,29 +258,47 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
     }
   }
 
-  override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+  override def upstream(): RDD[InternalRow] = {
     throw new UnsupportedOperationException
   }
 
-  override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
= {
-    if (input.nonEmpty) {
-      val colExprs = output.zipWithIndex.map { case (attr, i) =>
-        BoundReference(i, attr.dataType, attr.nullable)
-      }
-      // generate the code to create a UnsafeRow
-      ctx.currentVars = input
-      val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
-      s"""
-         | ${code.code.trim}
-         | currentRow = ${code.value};
-         | return;
-     """.stripMargin
-    } else {
-      // There is no columns
+  override def doProduce(ctx: CodegenContext): String = {
+    throw new UnsupportedOperationException
+  }
+
+  override def consumeChild(
+      ctx: CodegenContext,
+      child: SparkPlan,
+      input: Seq[ExprCode],
+      row: String = null): String = {
+
+    if (row != null) {
+      // There is an UnsafeRow already
       s"""
-         | currentRow = unsafeRow;
+         | currentRow = $row;
          | return;
        """.stripMargin
+    } else {
+      assert(input != null)
+      if (input.nonEmpty) {
+        val colExprs = output.zipWithIndex.map { case (attr, i) =>
+          BoundReference(i, attr.dataType, attr.nullable)
+        }
+        // generate the code to create a UnsafeRow
+        ctx.currentVars = input
+        val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
+        s"""
+           | ${code.code.trim}
+           | currentRow = ${code.value};
+           | return;
+         """.stripMargin
+      } else {
+        // There is no columns
+        s"""
+           | currentRow = unsafeRow;
+           | return;
+         """.stripMargin
+      }
     }
   }
 
@@ -246,7 +319,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
     builder.append(simpleString)
     builder.append("\n")
 
-    plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ true, builder)
+    plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
     if (children.nonEmpty) {
       children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
       children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
@@ -286,13 +359,14 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext)
extends Ru
         case plan: CodegenSupport if supportCodegen(plan) &&
           // Whole stage codegen is only useful when there are at least two levels of operators
that
           // support it (save at least one projection/iterator).
-          plan.children.exists(supportCodegen) =>
+          (Utils.isTesting || plan.children.exists(supportCodegen)) =>
 
           var inputs = ArrayBuffer[SparkPlan]()
           val combined = plan.transform {
             case p if !supportCodegen(p) =>
-              inputs += p
-              InputAdapter(p)
+              val input = apply(p)  // collapse them recursively
+              inputs += input
+              InputAdapter(input)
           }.asInstanceOf[CodegenSupport]
           WholeStageCodegen(combined, inputs)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/55561e76/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 0c74df0..38da82c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -238,7 +238,7 @@ abstract class AggregationIterator(
         resultProjection(joinedRow(currentGroupingKey, currentBuffer))
       }
     } else {
-      // Grouping-only: we only output values of grouping expressions.
+      // Grouping-only: we only output values based on grouping expressions.
       val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
       (currentGroupingKey: UnsafeRow, currentBuffer: MutableRow) => {
         resultProjection(currentGroupingKey)

http://git-wip-us.apache.org/repos/asf/spark/blob/55561e76/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 23e54f3..ff2f38b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -117,9 +117,7 @@ case class TungstenAggregate(
   override def supportCodegen: Boolean = {
     groupingExpressions.isEmpty &&
       // ImperativeAggregate is not supported right now
-      !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
&&
-      // final aggregation only have one row, do not need to codegen
-      !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
+      !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
   }
 
   // The variables used as aggregation buffer
@@ -127,7 +125,11 @@ case class TungstenAggregate(
 
   private val modes = aggregateExpressions.map(_.mode).distinct
 
-  protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+  override def upstream(): RDD[InternalRow] = {
+    child.asInstanceOf[CodegenSupport].upstream()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
     val initAgg = ctx.freshName("initAgg")
     ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
 
@@ -137,60 +139,96 @@ case class TungstenAggregate(
     bufVars = initExpr.map { e =>
       val isNull = ctx.freshName("bufIsNull")
       val value = ctx.freshName("bufValue")
+      ctx.addMutableState("boolean", isNull, "")
+      ctx.addMutableState(ctx.javaType(e.dataType), value, "")
       // The initial expression should not access any column
       val ev = e.gen(ctx)
       val initVars = s"""
-         | boolean $isNull = ${ev.isNull};
-         | ${ctx.javaType(e.dataType)} $value = ${ev.value};
+         | $isNull = ${ev.isNull};
+         | $value = ${ev.value};
        """.stripMargin
       ExprCode(ev.code + initVars, isNull, value)
     }
 
-    val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this)
-    val source =
+    // generate variables for output
+    val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
+    val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete))
{
+      // evaluate aggregate results
+      ctx.currentVars = bufVars
+      val aggResults = functions.map(_.evaluateExpression).map { e =>
+        BindReferences.bindReference(e, bufferAttrs).gen(ctx)
+      }
+      // evaluate result expressions
+      ctx.currentVars = aggResults
+      val resultVars = resultExpressions.map { e =>
+        BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
+      }
+      (resultVars, s"""
+        | ${aggResults.map(_.code).mkString("\n")}
+        | ${resultVars.map(_.code).mkString("\n")}
+       """.stripMargin)
+    } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
+      // output the aggregate buffer directly
+      (bufVars, "")
+    } else {
+      // no aggregate function, the result should be literals
+      val resultVars = resultExpressions.map(_.gen(ctx))
+      (resultVars, resultVars.map(_.code).mkString("\n"))
+    }
+
+    val doAgg = ctx.freshName("doAgg")
+    ctx.addNewFunction(doAgg,
       s"""
-         | if (!$initAgg) {
-         |   $initAgg = true;
-         |
+         | private void $doAgg() {
          |   // initialize aggregation buffer
          |   ${bufVars.map(_.code).mkString("\n")}
          |
-         |   $childSource
-         |
-         |   // output the result
-         |   ${consume(ctx, bufVars)}
+         |   ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
          | }
-       """.stripMargin
+       """.stripMargin)
 
-    (rdd, source)
+    s"""
+       | if (!$initAgg) {
+       |   $initAgg = true;
+       |   $doAgg();
+       |
+       |   // output the result
+       |   $genResult
+       |
+       |   ${consume(ctx, resultVars)}
+       | }
+     """.stripMargin
   }
 
-  override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
= {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     // only have DeclarativeAggregate
     val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
-    // the mode could be only Partial or PartialMerge
-    val updateExpr = if (modes.contains(Partial)) {
-      functions.flatMap(_.updateExpressions)
-    } else {
-      functions.flatMap(_.mergeExpressions)
+    val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
+    val updateExpr = aggregateExpressions.flatMap { e =>
+      e.mode match {
+        case Partial | Complete =>
+          e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
+        case PartialMerge | Final =>
+          e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
+      }
     }
 
-    val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
-    val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr))
     ctx.currentVars = bufVars ++ input
     // TODO: support subexpression elimination
-    val codes = boundExpr.zipWithIndex.map { case (e, i) =>
-      val ev = e.gen(ctx)
+    val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx))
+    // aggregate buffer should be updated atomic
+    val updates = aggVals.zipWithIndex.map { case (ev, i) =>
       s"""
-         | ${ev.code}
          | ${bufVars(i).isNull} = ${ev.isNull};
          | ${bufVars(i).value} = ${ev.value};
        """.stripMargin
     }
 
     s"""
-       | // do aggregate and update aggregation buffer
-       | ${codes.mkString("")}
+       | // do aggregate
+       | ${aggVals.map(_.code).mkString("\n")}
+       | // update aggregation buffer
+       | ${updates.mkString("")}
      """.stripMargin
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/55561e76/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 6deb72a..e7a73d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -37,11 +37,15 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
 
   override def output: Seq[Attribute] = projectList.map(_.toAttribute)
 
-  protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+  override def upstream(): RDD[InternalRow] = {
+    child.asInstanceOf[CodegenSupport].upstream()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
= {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     val exprs = projectList.map(x =>
       ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
     ctx.currentVars = input
@@ -76,11 +80,15 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode
wit
     "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
 
-  protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+  override def upstream(): RDD[InternalRow] = {
+    child.asInstanceOf[CodegenSupport].upstream()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String
= {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     val expr = ExpressionCanonicalizer.execute(
       BindReferences.bindReference(condition, child.output))
     ctx.currentVars = input
@@ -153,17 +161,21 @@ case class Range(
     output: Seq[Attribute])
   extends LeafNode with CodegenSupport {
 
-  protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
-    val initTerm = ctx.freshName("range_initRange")
+  override def upstream(): RDD[InternalRow] = {
+    sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    val initTerm = ctx.freshName("initRange")
     ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
-    val partitionEnd = ctx.freshName("range_partitionEnd")
+    val partitionEnd = ctx.freshName("partitionEnd")
     ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
-    val number = ctx.freshName("range_number")
+    val number = ctx.freshName("number")
     ctx.addMutableState("long", number, s"$number = 0L;")
-    val overflow = ctx.freshName("range_overflow")
+    val overflow = ctx.freshName("overflow")
     ctx.addMutableState("boolean", overflow, s"$overflow = false;")
 
-    val value = ctx.freshName("range_value")
+    val value = ctx.freshName("value")
     val ev = ExprCode("", "false", value)
     val BigInt = classOf[java.math.BigInteger].getName
     val checkEnd = if (step > 0) {
@@ -172,38 +184,42 @@ case class Range(
       s"$number > $partitionEnd"
     }
 
-    val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
-      .map(i => InternalRow(i))
+    ctx.addNewFunction("initRange",
+      s"""
+        | private void initRange(int idx) {
+        |   $BigInt index = $BigInt.valueOf(idx);
+        |   $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
+        |   $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
+        |   $BigInt step = $BigInt.valueOf(${step}L);
+        |   $BigInt start = $BigInt.valueOf(${start}L);
+        |
+        |   $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
+        |   if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+        |     $number = Long.MAX_VALUE;
+        |   } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+        |     $number = Long.MIN_VALUE;
+        |   } else {
+        |     $number = st.longValue();
+        |   }
+        |
+        |   $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
+        |     .multiply(step).add(start);
+        |   if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+        |     $partitionEnd = Long.MAX_VALUE;
+        |   } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+        |     $partitionEnd = Long.MIN_VALUE;
+        |   } else {
+        |     $partitionEnd = end.longValue();
+        |   }
+        | }
+       """.stripMargin)
 
-    val code = s"""
+    s"""
       | // initialize Range
       | if (!$initTerm) {
       |   $initTerm = true;
       |   if (input.hasNext()) {
-      |     $BigInt index = $BigInt.valueOf(((InternalRow) input.next()).getInt(0));
-      |     $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
-      |     $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
-      |     $BigInt step = $BigInt.valueOf(${step}L);
-      |     $BigInt start = $BigInt.valueOf(${start}L);
-      |
-      |     $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
-      |     if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
-      |       $number = Long.MAX_VALUE;
-      |     } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
-      |       $number = Long.MIN_VALUE;
-      |     } else {
-      |       $number = st.longValue();
-      |     }
-      |
-      |     $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
-      |       .multiply(step).add(start);
-      |     if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
-      |       $partitionEnd = Long.MAX_VALUE;
-      |     } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
-      |       $partitionEnd = Long.MIN_VALUE;
-      |     } else {
-      |       $partitionEnd = end.longValue();
-      |     }
+      |     initRange(((InternalRow) input.next()).getInt(0));
       |   } else {
       |     return;
       |   }
@@ -218,12 +234,6 @@ case class Range(
       |  ${consume(ctx, Seq(ev))}
       | }
      """.stripMargin
-
-    (rdd, code)
-  }
-
-  def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
-    throw new UnsupportedOperationException
   }
 
   protected override def doExecute(): RDD[InternalRow] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/55561e76/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index b1004bc..08fb7c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -153,6 +153,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext
{
     )
   }
 
+  test("agg without groups and functions") {
+    checkAnswer(
+      testData2.agg(lit(1)),
+      Row(1)
+    )
+  }
+
   test("average") {
     checkAnswer(
       testData2.agg(avg('a), mean('a)),

http://git-wip-us.apache.org/repos/asf/spark/blob/55561e76/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 989cb29..51a50c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1939,58 +1939,61 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
   }
 
   test("Common subexpression elimination") {
-    // select from a table to prevent constant folding.
-    val df = sql("SELECT a, b from testData2 limit 1")
-    checkAnswer(df, Row(1, 1))
-
-    checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
-    checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
-
-    // This does not work because the expressions get grouped like (a + a) + 1
-    checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
-    checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
-
-    // Identity udf that tracks the number of times it is called.
-    val countAcc = sparkContext.accumulator(0, "CallCount")
-    sqlContext.udf.register("testUdf", (x: Int) => {
-      countAcc.++=(1)
-      x
-    })
+    // TODO: support subexpression elimination in whole stage codegen
+    withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+      // select from a table to prevent constant folding.
+      val df = sql("SELECT a, b from testData2 limit 1")
+      checkAnswer(df, Row(1, 1))
+
+      checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
+      checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
+
+      // This does not work because the expressions get grouped like (a + a) + 1
+      checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
+      checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
+
+      // Identity udf that tracks the number of times it is called.
+      val countAcc = sparkContext.accumulator(0, "CallCount")
+      sqlContext.udf.register("testUdf", (x: Int) => {
+        countAcc.++=(1)
+        x
+      })
+
+      // Evaluates df, verifying it is equal to the expectedResult and the accumulator's
value
+      // is correct.
+      def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit =
{
+        countAcc.setValue(0)
+        checkAnswer(df, expectedResult)
+        assert(countAcc.value == expectedCount)
+      }
 
-    // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value
-    // is correct.
-    def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = {
-      countAcc.setValue(0)
-      checkAnswer(df, expectedResult)
-      assert(countAcc.value == expectedCount)
+      verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
+      verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+      verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1)
+      verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2)
+      verifyCallCount(
+        df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1)
+
+      verifyCallCount(
+        df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
+
+      val testUdf = functions.udf((x: Int) => {
+        countAcc.++=(1)
+        x
+      })
+      verifyCallCount(
+        df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
+
+      // Would be nice if semantic equals for `+` understood commutative
+      verifyCallCount(
+        df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
+
+      // Try disabling it via configuration.
+      sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
+      verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
+      sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
+      verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
     }
-
-    verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
-    verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
-    verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1)
-    verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2)
-    verifyCallCount(
-      df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1)
-
-    verifyCallCount(
-      df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2)
-
-    val testUdf = functions.udf((x: Int) => {
-      countAcc.++=(1)
-      x
-    })
-    verifyCallCount(
-      df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1)
-
-    // Would be nice if semantic equals for `+` understood commutative
-    verifyCallCount(
-      df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)
-
-    // Try disabling it via configuration.
-    sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
-    verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
-    sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
-    verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
   }
 
   test("SPARK-10707: nullability should be correctly propagated through set operations (1)")
{

http://git-wip-us.apache.org/repos/asf/spark/blob/55561e76/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index cbae19e..82f6811 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -335,22 +335,24 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
 
   test("save metrics") {
     withTempPath { file =>
-      val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
-      // Assume the execution plan is
-      // PhysicalRDD(nodeId = 0)
-      person.select('name).write.format("json").save(file.getAbsolutePath)
-      sparkContext.listenerBus.waitUntilEmpty(10000)
-      val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
-      assert(executionIds.size === 1)
-      val executionId = executionIds.head
-      val jobs = sqlContext.listener.getExecution(executionId).get.jobs
-      // Use "<=" because there is a race condition that we may miss some jobs
-      // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted
event.
-      assert(jobs.size <= 1)
-      val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
-      // Because "save" will create a new DataFrame internally, we cannot get the real metric
id.
-      // However, we still can check the value.
-      assert(metricValues.values.toSeq === Seq("2"))
+      withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+        val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
+        // Assume the execution plan is
+        // PhysicalRDD(nodeId = 0)
+        person.select('name).write.format("json").save(file.getAbsolutePath)
+        sparkContext.listenerBus.waitUntilEmpty(10000)
+        val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
+        assert(executionIds.size === 1)
+        val executionId = executionIds.head
+        val jobs = sqlContext.listener.getExecution(executionId).get.jobs
+        // Use "<=" because there is a race condition that we may miss some jobs
+        // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted
event.
+        assert(jobs.size <= 1)
+        val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
+        // Because "save" will create a new DataFrame internally, we cannot get the real
metric id.
+        // However, we still can check the value.
+        assert(metricValues.values.toSeq === Seq("2"))
+      }
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/55561e76/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index d481437..7d6bff8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils
     val schema = df.schema
     val childRDD = df
       .queryExecution
-      .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
+      .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
       .child
       .execute()
       .map(row => Row.fromSeq(row.copy().toSeq(schema)))

http://git-wip-us.apache.org/repos/asf/spark/blob/55561e76/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 9a24a24..a3e5243 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -97,10 +97,12 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
     }
     sqlContext.listenerManager.register(listener)
 
-    val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
-    df.collect()
-    df.collect()
-    Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
+    withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+      val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
+      df.collect()
+      df.collect()
+      Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
+    }
 
     assert(metrics.length == 3)
     assert(metrics(0) == 1)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message