spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-22695][SQL] ScalaUDF should not use global variables
Date Wed, 06 Dec 2017 16:51:11 GMT
Repository: spark
Updated Branches:
  refs/heads/master 813c0f945 -> e98f9647f


[SPARK-22695][SQL] ScalaUDF should not use global variables

## What changes were proposed in this pull request?

ScalaUDF is using global variables which are not needed. This can generate some unneeded entries
in the constant pool.

The PR replaces the unneeded global variables with local variables.

## How was this patch tested?

added UT

Author: Marco Gaido <mgaido@hortonworks.com>
Author: Marco Gaido <marcogaido91@gmail.com>

Closes #19900 from mgaido91/SPARK-22695.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e98f9647
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e98f9647
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e98f9647

Branch: refs/heads/master
Commit: e98f9647f44d1071a6b070db070841b8cda6bd7a
Parents: 813c0f9
Author: Marco Gaido <mgaido@hortonworks.com>
Authored: Thu Dec 7 00:50:49 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Thu Dec 7 00:50:49 2017 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/ScalaUDF.scala     | 88 ++++++++++----------
 .../catalyst/expressions/ScalaUDFSuite.scala    |  6 ++
 2 files changed, 51 insertions(+), 43 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e98f9647/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 1798530..4d26d98 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -982,35 +982,28 @@ case class ScalaUDF(
 
   // scalastyle:on line.size.limit
 
-  // Generate codes used to convert the arguments to Scala type for user-defined functions
-  private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = {
-    val converterClassName = classOf[Any => Any].getName
-    val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
-    val expressionClassName = classOf[Expression].getName
-    val scalaUDFClassName = classOf[ScalaUDF].getName
+  private val converterClassName = classOf[Any => Any].getName
+  private val scalaUDFClassName = classOf[ScalaUDF].getName
+  private val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
 
+  // Generate codes used to convert the arguments to Scala type for user-defined functions
+  private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): (String, String)
= {
     val converterTerm = ctx.freshName("converter")
     val expressionIdx = ctx.references.size - 1
-    ctx.addMutableState(converterClassName, converterTerm,
-      s"$converterTerm = ($converterClassName)$typeConvertersClassName" +
-        s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" +
-          s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
-    converterTerm
+    (converterTerm,
+      s"$converterClassName $converterTerm = ($converterClassName)$typeConvertersClassName"
+
+        s".createToScalaConverter(((Expression)((($scalaUDFClassName)" +
+        s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
   }
 
   override def doGenCode(
       ctx: CodegenContext,
       ev: ExprCode): ExprCode = {
+    val scalaUDF = ctx.freshName("scalaUDF")
+    val scalaUDFRef = ctx.addReferenceMinorObj(this, scalaUDFClassName)
 
-    val scalaUDF = ctx.addReferenceObj("scalaUDF", this)
-    val converterClassName = classOf[Any => Any].getName
-    val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
-
-    // Generate codes used to convert the returned value of user-defined functions to Catalyst
type
+    // Object to convert the returned value of user-defined functions to Catalyst type
     val catalystConverterTerm = ctx.freshName("catalystConverter")
-    ctx.addMutableState(converterClassName, catalystConverterTerm,
-      s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
-        s".createToCatalystConverter($scalaUDF.dataType());")
 
     val resultTerm = ctx.freshName("result")
 
@@ -1022,8 +1015,6 @@ case class ScalaUDF(
     val funcClassName = s"scala.Function${children.size}"
 
     val funcTerm = ctx.freshName("udf")
-    ctx.addMutableState(funcClassName, funcTerm,
-      s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")
 
     // codegen for children expressions
     val evals = children.map(_.genCode(ctx))
@@ -1033,34 +1024,45 @@ case class ScalaUDF(
     // such as IntegerType, its javaType is `int` and the returned type of user-defined
     // function is Object. Trying to convert an Object to `int` will cause casting exception.
     val evalCode = evals.map(_.code).mkString
-    val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter,
i) =>
-      val eval = evals(i)
-      val argTerm = ctx.freshName("arg")
-      val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});"
-      (convert, argTerm)
+    val (converters, funcArguments) = converterTerms.zipWithIndex.map {
+      case ((convName, convInit), i) =>
+        val eval = evals(i)
+        val argTerm = ctx.freshName("arg")
+        val convert =
+          s"""
+             |$convInit
+             |Object $argTerm = ${eval.isNull} ? null : $convName.apply(${eval.value});
+           """.stripMargin
+        (convert, argTerm)
     }.unzip
 
     val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})"
     val callFunc =
       s"""
-         ${ctx.boxedType(dataType)} $resultTerm = null;
-         try {
-           $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
-         } catch (Exception e) {
-           throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
-         }
-       """
+         |${ctx.boxedType(dataType)} $resultTerm = null;
+         |$scalaUDFClassName $scalaUDF = $scalaUDFRef;
+         |try {
+         |  $funcClassName $funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();
+         |  $converterClassName $catalystConverterTerm = ($converterClassName)
+         |    $typeConvertersClassName.createToCatalystConverter($scalaUDF.dataType());
+         |  $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
+         |} catch (Exception e) {
+         |  throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
+         |}
+       """.stripMargin
 
-    ev.copy(code = s"""
-      $evalCode
-      ${converters.mkString("\n")}
-      $callFunc
-
-      boolean ${ev.isNull} = $resultTerm == null;
-      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${ev.value} = $resultTerm;
-      }""")
+    ev.copy(code =
+      s"""
+         |$evalCode
+         |${converters.mkString("\n")}
+         |$callFunc
+         |
+         |boolean ${ev.isNull} = $resultTerm == null;
+         |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+         |if (!${ev.isNull}) {
+         |  ${ev.value} = $resultTerm;
+         |}
+       """.stripMargin)
   }
 
   private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/e98f9647/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
index 13bd363..70dea4b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
 import java.util.Locale
 
 import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
 import org.apache.spark.sql.types.{IntegerType, StringType}
 
 class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -47,4 +48,9 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
     assert(e2.getMessage.contains("Failed to execute user defined function"))
   }
 
+  test("SPARK-22695: ScalaUDF should not use global variables") {
+    val ctx = new CodegenContext
+    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
+    assert(ctx.mutableStates.isEmpty)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message