spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-23760][SQL] CodegenContext.withSubExprEliminationExprs should save/restore CSE state correctly
Date Thu, 22 Mar 2018 04:21:50 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.3 c9acd46be -> 4da8c22f7


[SPARK-23760][SQL] CodegenContext.withSubExprEliminationExprs should save/restore CSE state
correctly

## What changes were proposed in this pull request?

Fixed `CodegenContext.withSubExprEliminationExprs()` so that it saves/restores CSE state correctly.

## How was this patch tested?

Added new unit test to verify that the old CSE state is indeed saved and restored around the
`withSubExprEliminationExprs()` call. Manually verified that this test fails without this
patch.

Author: Kris Mok <kris.mok@databricks.com>

Closes #20870 from rednaxelafx/codegen-subexpr-fix.

(cherry picked from commit 95e51ff849a4c46cae463636b1ee393042469e7b)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4da8c22f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4da8c22f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4da8c22f

Branch: refs/heads/branch-2.3
Commit: 4da8c22f77475d1b328375e97e2825e1dea78fdd
Parents: c9acd46
Author: Kris Mok <kris.mok@databricks.com>
Authored: Wed Mar 21 21:21:36 2018 -0700
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Wed Mar 21 21:21:47 2018 -0700

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     | 16 ++++---
 .../expressions/CodeGenerationSuite.scala       | 44 ++++++++++++++++++++
 2 files changed, 51 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4da8c22f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index a54af03..2631e7e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -389,7 +389,7 @@ class CodegenContext {
   val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
 
   // Foreach expression that is participating in subexpression elimination, the state to
use.
-  val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
+  var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState]
 
   // The collection of sub-expression result resetting methods that need to be called on
each row.
   val subexprFunctions = mutable.ArrayBuffer.empty[String]
@@ -1118,14 +1118,12 @@ class CodegenContext {
       newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])(
       f: => Seq[ExprCode]): Seq[ExprCode] = {
     val oldsubExprEliminationExprs = subExprEliminationExprs
-    subExprEliminationExprs.clear
-    newSubExprEliminationExprs.foreach(subExprEliminationExprs += _)
+    subExprEliminationExprs = newSubExprEliminationExprs
 
     val genCodes = f
 
     // Restore previous subExprEliminationExprs
-    subExprEliminationExprs.clear
-    oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _)
+    subExprEliminationExprs = oldsubExprEliminationExprs
     genCodes
   }
 
@@ -1139,7 +1137,7 @@ class CodegenContext {
   def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes
= {
     // Create a clear EquivalentExpressions and SubExprEliminationState mapping
     val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
-    val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
+    val localSubExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState]
 
     // Add each expression tree and compute the common subexpressions.
     expressions.foreach(equivalentExpressions.addExprTree)
@@ -1152,10 +1150,10 @@ class CodegenContext {
       // Generate the code for this expression tree.
       val eval = expr.genCode(this)
       val state = SubExprEliminationState(eval.isNull, eval.value)
-      e.foreach(subExprEliminationExprs.put(_, state))
+      e.foreach(localSubExprEliminationExprs.put(_, state))
       eval.code.trim
     }
-    SubExprCodes(codes, subExprEliminationExprs.toMap)
+    SubExprCodes(codes, localSubExprEliminationExprs.toMap)
   }
 
   /**
@@ -1203,7 +1201,7 @@ class CodegenContext {
 
       subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
       val state = SubExprEliminationState(isNull, value)
-      e.foreach(subExprEliminationExprs.put(_, state))
+      subExprEliminationExprs ++= e.map(_ -> state).toMap
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4da8c22f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 3958bff..e3ce5af 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -442,4 +442,48 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper
{
     assert(ctx.calculateParamLength(Seq.range(0, 100).map(Literal(_))) == 101)
     assert(ctx.calculateParamLength(Seq.range(0, 100).map(x => Literal(x.toLong))) ==
201)
   }
+
+  test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly")
{
+
+    val ref = BoundReference(0, IntegerType, true)
+    val add1 = Add(ref, ref)
+    val add2 = Add(add1, add1)
+
+    // raw testing of basic functionality
+    {
+      val ctx = new CodegenContext
+      val e = ref.genCode(ctx)
+      // before
+      ctx.subExprEliminationExprs += ref -> SubExprEliminationState(e.isNull, e.value)
+      assert(ctx.subExprEliminationExprs.contains(ref))
+      // call withSubExprEliminationExprs
+      ctx.withSubExprEliminationExprs(Map(add1 -> SubExprEliminationState("dummy", "dummy")))
{
+        assert(ctx.subExprEliminationExprs.contains(add1))
+        assert(!ctx.subExprEliminationExprs.contains(ref))
+        Seq.empty
+      }
+      // after
+      assert(ctx.subExprEliminationExprs.nonEmpty)
+      assert(ctx.subExprEliminationExprs.contains(ref))
+      assert(!ctx.subExprEliminationExprs.contains(add1))
+    }
+
+    // emulate an actual codegen workload
+    {
+      val ctx = new CodegenContext
+      // before
+      ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger
CSE
+      assert(ctx.subExprEliminationExprs.contains(add1))
+      // call withSubExprEliminationExprs
+      ctx.withSubExprEliminationExprs(Map(ref -> SubExprEliminationState("dummy", "dummy")))
{
+        assert(ctx.subExprEliminationExprs.contains(ref))
+        assert(!ctx.subExprEliminationExprs.contains(add1))
+        Seq.empty
+      }
+      // after
+      assert(ctx.subExprEliminationExprs.nonEmpty)
+      assert(ctx.subExprEliminationExprs.contains(add1))
+      assert(!ctx.subExprEliminationExprs.contains(ref))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message