spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-22668][SQL] Ensure no global variables in arguments of method split by CodegenContext.splitExpressions()
Date Thu, 21 Dec 2017 16:21:34 GMT
Repository: spark
Updated Branches:
  refs/heads/master 4c2efde93 -> 8a0ed5a5e


[SPARK-22668][SQL] Ensure no global variables in arguments of method split by CodegenContext.splitExpressions()

## What changes were proposed in this pull request?

Passing global variables to the split method is dangerous, as any mutating to it is ignored
and may lead to unexpected behavior.

To prevent this, one approach is to make sure no expression would output global variables:
Localizing lifetime of mutable states in expressions.

Another approach is, when calling `ctx.splitExpression`, make sure we don't use children's
output as parameter names.

Approach 1 is actually hard to do, as we need to check all expressions and operators that
support whole-stage codegen. Approach 2 is easier as the callers of `ctx.splitExpressions`
are not too many.

Besides, approach 2 is more flexible, as children's output may be other stuff that can't be
parameter name: literal, inlined statement(a + 1), etc.

close https://github.com/apache/spark/pull/19865
close https://github.com/apache/spark/pull/19938

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #20021 from cloud-fan/codegen.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8a0ed5a5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8a0ed5a5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8a0ed5a5

Branch: refs/heads/master
Commit: 8a0ed5a5ee64a6e854c516f80df5a9729435479b
Parents: 4c2efde
Author: Wenchen Fan <wenchen@databricks.com>
Authored: Fri Dec 22 00:21:27 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Fri Dec 22 00:21:27 2017 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   | 18 +++++------
 .../expressions/codegen/CodeGenerator.scala     | 32 +++++++++++++++++---
 .../expressions/conditionalExpressions.scala    |  8 ++---
 .../catalyst/expressions/nullExpressions.scala  |  9 +++---
 .../sql/catalyst/expressions/predicates.scala   |  2 +-
 5 files changed, 43 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8a0ed5a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index d3a8cb5..8bb1459 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -602,13 +602,13 @@ case class Least(children: Seq[Expression]) extends Expression {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val evalChildren = children.map(_.genCode(ctx))
-    val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "leastTmpIsNull")
+    ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
     val evals = evalChildren.map(eval =>
       s"""
          |${eval.code}
-         |if (!${eval.isNull} && ($tmpIsNull ||
+         |if (!${eval.isNull} && (${ev.isNull} ||
          |  ${ctx.genGreater(dataType, ev.value, eval.value)})) {
-         |  $tmpIsNull = false;
+         |  ${ev.isNull} = false;
          |  ${ev.value} = ${eval.value};
          |}
       """.stripMargin
@@ -628,10 +628,9 @@ case class Least(children: Seq[Expression]) extends Expression {
       foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
     ev.copy(code =
       s"""
-         |$tmpIsNull = true;
+         |${ev.isNull} = true;
          |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
          |$codes
-         |final boolean ${ev.isNull} = $tmpIsNull;
       """.stripMargin)
   }
 }
@@ -682,13 +681,13 @@ case class Greatest(children: Seq[Expression]) extends Expression {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val evalChildren = children.map(_.genCode(ctx))
-    val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "greatestTmpIsNull")
+    ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
     val evals = evalChildren.map(eval =>
       s"""
          |${eval.code}
-         |if (!${eval.isNull} && ($tmpIsNull ||
+         |if (!${eval.isNull} && (${ev.isNull} ||
          |  ${ctx.genGreater(dataType, eval.value, ev.value)})) {
-         |  $tmpIsNull = false;
+         |  ${ev.isNull} = false;
          |  ${ev.value} = ${eval.value};
          |}
       """.stripMargin
@@ -708,10 +707,9 @@ case class Greatest(children: Seq[Expression]) extends Expression {
       foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
     ev.copy(code =
       s"""
-         |$tmpIsNull = true;
+         |${ev.isNull} = true;
          |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
          |$codes
-         |final boolean ${ev.isNull} = $tmpIsNull;
       """.stripMargin)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8a0ed5a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 41a920b..9adf632 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -128,7 +128,7 @@ class CodegenContext {
    * `currentVars` to null, or set `currentVars(i)` to null for certain columns, before calling
    * `Expression.genCode`.
    */
-  final var INPUT_ROW = "i"
+  var INPUT_ROW = "i"
 
   /**
    * Holding a list of generated columns as input of current operator, will be used by
@@ -146,22 +146,30 @@ class CodegenContext {
    * as a member variable
    *
    * They will be kept as member variables in generated classes like `SpecificProjection`.
+   *
+   * Exposed for tests only.
    */
-  val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] =
+  private[catalyst] val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] =
     mutable.ArrayBuffer.empty[(String, String)]
 
   /**
    * The mapping between mutable state types and corrseponding compacted arrays.
    * The keys are java type string. The values are [[MutableStateArrays]] which encapsulates
    * the compacted arrays for the mutable states with the same java type.
+   *
+   * Exposed for tests only.
    */
-  val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] =
+  private[catalyst] val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays]
=
     mutable.Map.empty[String, MutableStateArrays]
 
   // An array holds the code that will initialize each state
-  val mutableStateInitCode: mutable.ArrayBuffer[String] =
+  // Exposed for tests only.
+  private[catalyst] val mutableStateInitCode: mutable.ArrayBuffer[String] =
     mutable.ArrayBuffer.empty[String]
 
+  // Tracks the names of all the mutable states.
+  private val mutableStateNames: mutable.HashSet[String] = mutable.HashSet.empty
+
   /**
    * This class holds a set of names of mutableStateArrays that is used for compacting mutable
    * states for a certain type, and holds the next available slot of the current compacted
array.
@@ -172,7 +180,11 @@ class CodegenContext {
 
     private[this] var currentIndex = 0
 
-    private def createNewArray() = arrayNames.append(freshName("mutableStateArray"))
+    private def createNewArray() = {
+      val newArrayName = freshName("mutableStateArray")
+      mutableStateNames += newArrayName
+      arrayNames.append(newArrayName)
+    }
 
     def getCurrentIndex: Int = currentIndex
 
@@ -241,6 +253,7 @@ class CodegenContext {
       val initCode = initFunc(varName)
       inlinedMutableStates += ((javaType, varName))
       mutableStateInitCode += initCode
+      mutableStateNames += varName
       varName
     } else {
       val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new MutableStateArrays)
@@ -930,6 +943,15 @@ class CodegenContext {
       // inline execution if only one block
       blocks.head
     } else {
+      if (Utils.isTesting) {
+        // Passing global variables to the split method is dangerous, as any mutating to
it is
+        // ignored and may lead to unexpected behavior.
+        arguments.foreach { case (_, name) =>
+          assert(!mutableStateNames.contains(name),
+            s"split function argument $name cannot be a global variable.")
+        }
+      }
+
       val func = freshName(funcName)
       val argString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ")
       val functions = blocks.zipWithIndex.map { case (body, i) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/8a0ed5a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 1a9b682..142dfb0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -190,7 +190,7 @@ case class CaseWhen(
     // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
     // We won't go on anymore on the computation.
     val resultState = ctx.freshName("caseWhenResultState")
-    val tmpResult = ctx.addMutableState(ctx.javaType(dataType), "caseWhenTmpResult")
+    ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value)
 
     // these blocks are meant to be inside a
     // do {
@@ -205,7 +205,7 @@ case class CaseWhen(
          |if (!${cond.isNull} && ${cond.value}) {
          |  ${res.code}
          |  $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
-         |  $tmpResult = ${res.value};
+         |  ${ev.value} = ${res.value};
          |  continue;
          |}
        """.stripMargin
@@ -216,7 +216,7 @@ case class CaseWhen(
       s"""
          |${res.code}
          |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
-         |$tmpResult = ${res.value};
+         |${ev.value} = ${res.value};
        """.stripMargin
     }
 
@@ -264,13 +264,11 @@ case class CaseWhen(
     ev.copy(code =
       s"""
          |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
-         |$tmpResult = ${ctx.defaultValue(dataType)};
          |do {
          |  $codes
          |} while (false);
          |// TRUE if any condition is met and the result is null, or no any condition is
met.
          |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL);
-         |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult;
        """.stripMargin)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8a0ed5a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index b4f895f..470d5da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "coalesceTmpIsNull")
+    ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
 
     // all the evals are meant to be in a do { ... } while (false); loop
     val evals = children.map { e =>
@@ -80,7 +80,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
       s"""
          |${eval.code}
          |if (!${eval.isNull}) {
-         |  $tmpIsNull = false;
+         |  ${ev.isNull} = false;
          |  ${ev.value} = ${eval.value};
          |  continue;
          |}
@@ -103,7 +103,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
       foldFunctions = _.map { funcCall =>
         s"""
            |${ev.value} = $funcCall;
-           |if (!$tmpIsNull) {
+           |if (!${ev.isNull}) {
            |  continue;
            |}
          """.stripMargin
@@ -112,12 +112,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
 
     ev.copy(code =
       s"""
-         |$tmpIsNull = true;
+         |${ev.isNull} = true;
          |$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
          |do {
          |  $codes
          |} while (false);
-         |final boolean ${ev.isNull} = $tmpIsNull;
        """.stripMargin)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8a0ed5a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index ac9f56f..f4ee3d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -285,7 +285,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate
{
          |${valueGen.code}
          |byte $tmpResult = $HAS_NULL;
          |if (!${valueGen.isNull}) {
-         |  $tmpResult = 0;
+         |  $tmpResult = $NOT_MATCHED;
          |  $javaDataType $valueArg = ${valueGen.value};
          |  do {
          |    $codes


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message