spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-23628][SQL] calculateParamLength should not return 1 + num of epressions
Date Thu, 08 Mar 2018 19:09:44 GMT
Repository: spark
Updated Branches:
  refs/heads/master 3be4adf64 -> ea480990e


[SPARK-23628][SQL] calculateParamLength should not return 1 + num of epressions

## What changes were proposed in this pull request?

There was a bug in `calculateParamLength` which caused it to return always 1 + the number
of expressions. This could lead to Exceptions especially with expressions of type long.

## How was this patch tested?

added UT + fixed previous UT

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #20772 from mgaido91/SPARK-23628.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ea480990
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ea480990
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ea480990

Branch: refs/heads/master
Commit: ea480990e726aed59750f1cea8d40adba56d991a
Parents: 3be4adf
Author: Marco Gaido <marcogaido91@gmail.com>
Authored: Thu Mar 8 11:09:15 2018 -0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Thu Mar 8 11:09:15 2018 -0800

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     | 51 ++++++++++----------
 .../expressions/CodeGenerationSuite.scala       |  6 +++
 .../sql/execution/WholeStageCodegenExec.scala   |  5 +-
 .../sql/execution/WholeStageCodegenSuite.scala  | 16 +++---
 4 files changed, 43 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ea480990/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 793824b..fe5e63e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1063,31 +1063,6 @@ class CodegenContext {
       ""
     }
   }
-
-  /**
-   * Returns the length of parameters for a Java method descriptor. `this` contributes one
unit
-   * and a parameter of type long or double contributes two units. Besides, for nullable
parameter,
-   * we also need to pass a boolean parameter for the null status.
-   */
-  def calculateParamLength(params: Seq[Expression]): Int = {
-    def paramLengthForExpr(input: Expression): Int = {
-      // For a nullable expression, we need to pass in an extra boolean parameter.
-      (if (input.nullable) 1 else 0) + javaType(input.dataType) match {
-        case JAVA_LONG | JAVA_DOUBLE => 2
-        case _ => 1
-      }
-    }
-    // Initial value is 1 for `this`.
-    1 + params.map(paramLengthForExpr(_)).sum
-  }
-
-  /**
-   * In Java, a method descriptor is valid only if it represents method parameters with a
total
-   * length less than a pre-defined constant.
-   */
-  def isValidParamLength(paramLength: Int): Boolean = {
-    paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH
-  }
 }
 
 /**
@@ -1538,4 +1513,30 @@ object CodeGenerator extends Logging {
 
   def defaultValue(dt: DataType, typedNull: Boolean = false): String =
     defaultValue(javaType(dt), typedNull)
+
+  /**
+   * Returns the length of parameters for a Java method descriptor. `this` contributes one
unit
+   * and a parameter of type long or double contributes two units. Besides, for nullable
parameter,
+   * we also need to pass a boolean parameter for the null status.
+   */
+  def calculateParamLength(params: Seq[Expression]): Int = {
+    def paramLengthForExpr(input: Expression): Int = {
+      val javaParamLength = javaType(input.dataType) match {
+        case JAVA_LONG | JAVA_DOUBLE => 2
+        case _ => 1
+      }
+      // For a nullable expression, we need to pass in an extra boolean parameter.
+      (if (input.nullable) 1 else 0) + javaParamLength
+    }
+    // Initial value is 1 for `this`.
+    1 + params.map(paramLengthForExpr).sum
+  }
+
+  /**
+   * In Java, a method descriptor is valid only if it represents method parameters with a
total
+   * length less than a pre-defined constant.
+   */
+  def isValidParamLength(paramLength: Int): Boolean = {
+    paramLength <= MAX_JVM_METHOD_PARAMS_LENGTH
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ea480990/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 1e48c7b..64c13e8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -436,4 +436,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper
{
     ctx.addImmutableStateIfNotExists("String", mutableState2)
     assert(ctx.inlinedMutableStates.length == 2)
   }
+
+  test("SPARK-23628: calculateParamLength should compute properly the param length") {
+    assert(CodeGenerator.calculateParamLength(Seq.range(0, 100).map(Literal(_))) == 101)
+    assert(CodeGenerator.calculateParamLength(
+      Seq.range(0, 100).map(x => Literal(x.toLong))) == 201)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ea480990/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index f89e3fb..6ddaacf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -174,8 +174,9 @@ trait CodegenSupport extends SparkPlan {
     //    declaration.
     val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator
     val requireAllOutput = output.forall(parent.usedInputs.contains(_))
-    val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0)
-    val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength))
{
+    val paramLength = CodeGenerator.calculateParamLength(output) + (if (row != null) 1 else
0)
+    val consumeFunc = if (confEnabled && requireAllOutput
+        && CodeGenerator.isValidParamLength(paramLength)) {
       constructDoConsumeFunction(ctx, inputVars, row)
     } else {
       parent.doConsume(ctx, inputVars, rowVar)

http://git-wip-us.apache.org/repos/asf/spark/blob/ea480990/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index ef16292..0fb9dd2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
 import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
 import org.apache.spark.sql.execution.joins.SortMergeJoinExec
 import org.apache.spark.sql.expressions.scalalang.typed
-import org.apache.spark.sql.functions.{avg, broadcast, col, max}
+import org.apache.spark.sql.functions.{avg, broadcast, col, lit, max}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -249,12 +249,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext
{
   }
 
   test("Skip splitting consume function when parameter number exceeds JVM limit") {
-    import testImplicits._
-
-    Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) =>
+    // since every field is nullable we have 2 params for each input column (one for the
value
+    // and one for the isNull variable)
+    Seq((128, false), (127, true)).foreach { case (columnNum, hasSplit) =>
       withTempPath { dir =>
         val path = dir.getCanonicalPath
-        spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*)
+        spark.range(10).select(Seq.tabulate(columnNum) {i => lit(i).as(s"c$i")} : _*)
           .write.mode(SaveMode.Overwrite).parquet(path)
 
         withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
@@ -263,10 +263,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext
{
           val df = spark.read.parquet(path).selectExpr(projection: _*)
 
           val plan = df.queryExecution.executedPlan
-          val wholeStageCodeGenExec = plan.find(p => p match {
-            case wp: WholeStageCodegenExec => true
+          val wholeStageCodeGenExec = plan.find {
+            case _: WholeStageCodegenExec => true
             case _ => false
-          })
+          }
           assert(wholeStageCodeGenExec.isDefined)
           val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
           assert(code.body.contains("project_doConsume") == hasSplit)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message