spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-22750][SQL] Reuse mutable states when possible
Date Fri, 22 Dec 2017 02:13:31 GMT
Repository: spark
Updated Branches:
  refs/heads/master c0abb1d99 -> c6f01cade


[SPARK-22750][SQL] Reuse mutable states when possible

## What changes were proposed in this pull request?

The PR introduces a new method `addImmutableStateIfNotExists ` to `CodeGenerator` to allow
reusing and sharing the same global variable between different Expressions. This helps reducing
the number of global variables needed, which is important to limit the impact on the constant
pool.

## How was this patch tested?

added UTs

Author: Marco Gaido <marcogaido91@gmail.com>
Author: Marco Gaido <mgaido@hortonworks.com>

Closes #19940 from mgaido91/SPARK-22750.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c6f01cad
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c6f01cad
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c6f01cad

Branch: refs/heads/master
Commit: c6f01cadede490bde987d067becef14442f1e4a1
Parents: c0abb1d
Author: Marco Gaido <marcogaido91@gmail.com>
Authored: Fri Dec 22 10:13:26 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Fri Dec 22 10:13:26 2017 +0800

----------------------------------------------------------------------
 .../expressions/MonotonicallyIncreasingID.scala |  3 +-
 .../catalyst/expressions/SparkPartitionID.scala |  3 +-
 .../expressions/codegen/CodeGenerator.scala     | 40 ++++++++++++++++++++
 .../expressions/datetimeExpressions.scala       | 12 ++++--
 .../catalyst/expressions/objects/objects.scala  | 24 ++++++++----
 .../expressions/CodeGenerationSuite.scala       | 12 ++++++
 6 files changed, 80 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 784eaf8..11fb579 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -66,7 +66,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
-    val partitionMaskTerm = ctx.addMutableState(ctx.JAVA_LONG, "partitionMask")
+    val partitionMaskTerm = "partitionMask"
+    ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
     ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
     ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex)
<< 33;")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index 736ca37..a160b9b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -43,7 +43,8 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic
{
   override protected def evalInternal(input: InternalRow): Int = partitionId
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val idTerm = ctx.addMutableState(ctx.JAVA_INT, "partitionId")
+    val idTerm = "partitionId"
+    ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
     ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
     ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 9adf632..d6eccad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -208,6 +208,14 @@ class CodegenContext {
   }
 
   /**
+   * A map containing the mutable states which have been defined so far using
+   * `addImmutableStateIfNotExists`. Each entry contains the name of the mutable state as
key and
+   * its Java type and init code as value.
+   */
+  private val immutableStates: mutable.Map[String, (String, String)] =
+    mutable.Map.empty[String, (String, String)]
+
+  /**
    * Add a mutable state as a field to the generated class. c.f. the comments above.
    *
    * @param javaType Java type of the field. Note that short names can be used for some types,
@@ -266,6 +274,38 @@ class CodegenContext {
   }
 
   /**
+   * Add an immutable state as a field to the generated class only if it does not exist yet
a field
+   * with that name. This helps reducing the number of the generated class' fields, since
the same
+   * variable can be reused by many functions.
+   *
+   * Even though the added variables are not declared as final, they should never be reassigned
in
+   * the generated code to prevent errors and unexpected behaviors.
+   *
+   * Internally, this method calls `addMutableState`.
+   *
+   * @param javaType Java type of the field.
+   * @param variableName Name of the field.
+   * @param initFunc Function includes statement(s) to put into the init() method to initialize
+   *                 this field. The argument is the name of the mutable state variable.
+   */
+  def addImmutableStateIfNotExists(
+      javaType: String,
+      variableName: String,
+      initFunc: String => String = _ => ""): Unit = {
+    val existingImmutableState = immutableStates.get(variableName)
+    if (existingImmutableState.isEmpty) {
+      addMutableState(javaType, variableName, initFunc, useFreshName = false, forceInline
= true)
+      immutableStates(variableName) = (javaType, initFunc(variableName))
+    } else {
+      val (prevJavaType, prevInitCode) = existingImmutableState.get
+      assert(prevJavaType == javaType, s"$variableName has already been defined with type
" +
+        s"$prevJavaType and now it is tried to define again with type $javaType.")
+      assert(prevInitCode == initFunc(variableName), s"$variableName has already been defined
" +
+        s"with different initialization statements.")
+    }
+  }
+
+  /**
    * Add buffer variable which stores data coming from an [[InternalRow]]. This methods guarantees
    * that the variable is safely stored, which is important for (potentially) byte array
backed
    * data types like: UTF8String, ArrayData, MapData & InternalRow.

http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 59c3e3d..7a674ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -443,7 +443,8 @@ case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCas
     nullSafeCodeGen(ctx, ev, time => {
       val cal = classOf[Calendar].getName
       val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
-      val c = ctx.addMutableState(cal, "cal",
+      val c = "calDayOfWeek"
+      ctx.addImmutableStateIfNotExists(cal, c,
         v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""")
       s"""
         $c.setTimeInMillis($time * 1000L * 3600L * 24L);
@@ -484,8 +485,9 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with
ImplicitCa
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     nullSafeCodeGen(ctx, ev, time => {
       val cal = classOf[Calendar].getName
+      val c = "calWeekOfYear"
       val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
-      val c = ctx.addMutableState(cal, "cal", v =>
+      ctx.addImmutableStateIfNotExists(cal, c, v =>
         s"""
            |$v = $cal.getInstance($dtu.getTimeZone("UTC"));
            |$v.setFirstDayOfWeek($cal.MONDAY);
@@ -1017,7 +1019,8 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
         val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
         val tzTerm = ctx.addMutableState(tzClass, "tz",
           v => s"""$v = $dtu.getTimeZone("$tz");""")
-        val utcTerm = ctx.addMutableState(tzClass, "utc",
+        val utcTerm = "tzUTC"
+        ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
           v => s"""$v = $dtu.getTimeZone("UTC");""")
         val eval = left.genCode(ctx)
         ev.copy(code = s"""
@@ -1193,7 +1196,8 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
         val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
         val tzTerm = ctx.addMutableState(tzClass, "tz",
           v => s"""$v = $dtu.getTimeZone("$tz");""")
-        val utcTerm = ctx.addMutableState(tzClass, "utc",
+        val utcTerm = "tzUTC"
+        ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
           v => s"""$v = $dtu.getTimeZone("UTC");""")
         val eval = left.genCode(ctx)
         ev.copy(code = s"""

http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index a59aad5..4af8134 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -1148,17 +1148,21 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     // Code to initialize the serializer.
-    val (serializerClass, serializerInstanceClass) = {
+    val (serializer, serializerClass, serializerInstanceClass) = {
       if (kryo) {
-        (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
+        ("kryoSerializer",
+          classOf[KryoSerializer].getName,
+          classOf[KryoSerializerInstance].getName)
       } else {
-        (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
+        ("javaSerializer",
+          classOf[JavaSerializer].getName,
+          classOf[JavaSerializerInstance].getName)
       }
     }
     // try conf from env, otherwise create a new one
     val env = s"${classOf[SparkEnv].getName}.get()"
     val sparkConf = s"new ${classOf[SparkConf].getName}()"
-    val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForEncode",
v =>
+    ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
       s"""
          |if ($env == null) {
          |  $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();
@@ -1193,17 +1197,21 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T],
kryo: B
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     // Code to initialize the serializer.
-    val (serializerClass, serializerInstanceClass) = {
+    val (serializer, serializerClass, serializerInstanceClass) = {
       if (kryo) {
-        (classOf[KryoSerializer].getName, classOf[KryoSerializerInstance].getName)
+        ("kryoSerializer",
+          classOf[KryoSerializer].getName,
+          classOf[KryoSerializerInstance].getName)
       } else {
-        (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName)
+        ("javaSerializer",
+          classOf[JavaSerializer].getName,
+          classOf[JavaSerializerInstance].getName)
       }
     }
     // try conf from env, otherwise create a new one
     val env = s"${classOf[SparkEnv].getName}.get()"
     val sparkConf = s"new ${classOf[SparkConf].getName}()"
-    val serializer = ctx.addMutableState(serializerInstanceClass, "serializerForDecode",
v =>
+    ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
       s"""
          |if ($env == null) {
          |  $v = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();

http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index b1a4452..676ba39 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -424,4 +424,16 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper
{
     assert(ctx2.arrayCompactedMutableStates("InternalRow[]").getCurrentIndex == 10)
     assert(ctx2.mutableStateInitCode.size == CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT +
10)
   }
+
+  test("SPARK-22750: addImmutableStateIfNotExists") {
+    val ctx = new CodegenContext
+    val mutableState1 = "field1"
+    val mutableState2 = "field2"
+    ctx.addImmutableStateIfNotExists("int", mutableState1)
+    ctx.addImmutableStateIfNotExists("int", mutableState1)
+    ctx.addImmutableStateIfNotExists("String", mutableState2)
+    ctx.addImmutableStateIfNotExists("int", mutableState1)
+    ctx.addImmutableStateIfNotExists("String", mutableState2)
+    assert(ctx.inlinedMutableStates.length == 2)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message