spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-22520][SQL] Support code generation for large CaseWhen
Date Mon, 27 Nov 2017 23:46:24 GMT
Repository: spark
Updated Branches:
  refs/heads/master 1ff4a77be -> 087879a77


[SPARK-22520][SQL] Support code generation for large CaseWhen

## What changes were proposed in this pull request?

Code generation is disabled for CaseWhen when the number of branches is higher than `spark.sql.codegen.maxCaseBranches`
(which defaults to 20). This was done to prevent the well known 64KB method limit exception.
This PR proposes to support code generation also in those cases (without causing exceptions
of course). As a side effect, we could get rid of the `spark.sql.codegen.maxCaseBranches`
configuration.

## How was this patch tested?

existing UTs

Author: Marco Gaido <mgaido@hortonworks.com>
Author: Marco Gaido <marcogaido91@gmail.com>

Closes #19752 from mgaido91/SPARK-22520.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/087879a7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/087879a7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/087879a7

Branch: refs/heads/master
Commit: 087879a77acb37b790c36f8da67355b90719c2dc
Parents: 1ff4a77
Author: Marco Gaido <mgaido@hortonworks.com>
Authored: Tue Nov 28 07:46:18 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Tue Nov 28 07:46:18 2017 +0800

----------------------------------------------------------------------
 .../expressions/EquivalentExpressions.scala     |   3 +-
 .../expressions/conditionalExpressions.scala    | 192 ++++++++++---------
 .../sql/catalyst/optimizer/Optimizer.scala      |   2 -
 .../sql/catalyst/optimizer/expressions.scala    |  15 --
 .../org/apache/spark/sql/internal/SQLConf.scala |   8 -
 .../expressions/CodeGenerationSuite.scala       |   8 +-
 .../optimizer/OptimizeCodegenSuite.scala        | 101 ----------
 .../FlatMapGroupsWithState_StateManager.scala   |   2 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   |  15 +-
 .../spark/sql/internal/SQLConfSuite.scala       |  29 ---
 10 files changed, 122 insertions(+), 253 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/087879a7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index f8644c2..8d06804 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -87,8 +87,7 @@ class EquivalentExpressions {
     def childrenToRecurse: Seq[Expression] = expr match {
       case _: CodegenFallback => Nil
       case i: If => i.predicate :: Nil
-      // `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen`
here.
-      case c: CaseWhenCodegen => c.children.head :: Nil
+      case c: CaseWhen => c.children.head :: Nil
       case c: Coalesce => c.children.head :: Nil
       case other => other.children
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/087879a7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 6195be3..a8629c1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -88,14 +88,34 @@ case class If(predicate: Expression, trueValue: Expression, falseValue:
Expressi
 }
 
 /**
- * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
+ * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
+ * When a = true, returns b; when c = true, returns d; else returns e.
  *
  * @param branches seq of (branch condition, branch value)
  * @param elseValue optional value for the else branch
  */
-abstract class CaseWhenBase(
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1`
= true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.",
+  arguments = """
+    Arguments:
+      * expr1, expr3 - the branch condition expressions should all be boolean type.
+      * expr2, expr4, expr5 - the branch value expressions and else value expression should
all be
+          same type or coercible to a common type.
+  """,
+  examples = """
+    Examples:
+      > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
+       1
+      > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
+       2
+      > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 END;
+       NULL
+  """)
+// scalastyle:on line.size.limit
+case class CaseWhen(
     branches: Seq[(Expression, Expression)],
-    elseValue: Option[Expression])
+    elseValue: Option[Expression] = None)
   extends Expression with Serializable {
 
   override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil)
++ elseValue
@@ -158,111 +178,103 @@ abstract class CaseWhenBase(
     val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
     "CASE" + cases + elseCase + " END"
   }
-}
-
-
-/**
- * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
- * When a = true, returns b; when c = true, returns d; else returns e.
- *
- * @param branches seq of (branch condition, branch value)
- * @param elseValue optional value for the else branch
- */
-// scalastyle:off line.size.limit
-@ExpressionDescription(
-  usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1`
= true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.",
-  arguments = """
-    Arguments:
-      * expr1, expr3 - the branch condition expressions should all be boolean type.
-      * expr2, expr4, expr5 - the branch value expressions and else value expression should
all be
-          same type or coercible to a common type.
-  """,
-  examples = """
-    Examples:
-      > SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
-       1
-      > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
-       2
-      > SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END;
-       NULL
-  """)
-// scalastyle:on line.size.limit
-case class CaseWhen(
-    val branches: Seq[(Expression, Expression)],
-    val elseValue: Option[Expression] = None)
-  extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {
-
-  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    super[CodegenFallback].doGenCode(ctx, ev)
-  }
-
-  def toCodegen(): CaseWhenCodegen = {
-    CaseWhenCodegen(branches, elseValue)
-  }
-}
-
-/**
- * CaseWhen expression used when code generation condition is satisfied.
- * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
- *
- * @param branches seq of (branch condition, branch value)
- * @param elseValue optional value for the else branch
- */
-case class CaseWhenCodegen(
-    val branches: Seq[(Expression, Expression)],
-    val elseValue: Option[Expression] = None)
-  extends CaseWhenBase(branches, elseValue) with Serializable {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    // Generate code that looks like:
-    //
-    // condA = ...
-    // if (condA) {
-    //   valueA
-    // } else {
-    //   condB = ...
-    //   if (condB) {
-    //     valueB
-    //   } else {
-    //     condC = ...
-    //     if (condC) {
-    //       valueC
-    //     } else {
-    //       elseValue
-    //     }
-    //   }
-    // }
+    // This variable represents whether the first successful condition is met or not.
+    // It is initialized to `false` and it is set to `true` when the first condition which
+    // evaluates to `true` is met and therefore is not needed to go on anymore on the computation
+    // of the following conditions.
+    val conditionMet = ctx.freshName("caseWhenConditionMet")
+    ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
+    ctx.addMutableState(ctx.javaType(dataType), ev.value)
+
+    // these blocks are meant to be inside a
+    // do {
+    //   ...
+    // } while (false);
+    // loop
     val cases = branches.map { case (condExpr, valueExpr) =>
       val cond = condExpr.genCode(ctx)
       val res = valueExpr.genCode(ctx)
       s"""
-        ${cond.code}
-        if (!${cond.isNull} && ${cond.value}) {
-          ${res.code}
-          ${ev.isNull} = ${res.isNull};
-          ${ev.value} = ${res.value};
+        if(!$conditionMet) {
+          ${cond.code}
+          if (!${cond.isNull} && ${cond.value}) {
+            ${res.code}
+            ${ev.isNull} = ${res.isNull};
+            ${ev.value} = ${res.value};
+            $conditionMet = true;
+            continue;
+          }
         }
       """
     }
 
-    var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")
-
-    elseValue.foreach { elseExpr =>
+    val elseCode = elseValue.map { elseExpr =>
       val res = elseExpr.genCode(ctx)
-      generatedCode +=
-        s"""
+      s"""
+        if(!$conditionMet) {
           ${res.code}
           ${ev.isNull} = ${res.isNull};
           ${ev.value} = ${res.value};
-        """
+        }
+      """
     }
 
-    generatedCode += "}\n" * cases.size
+    val allConditions = cases ++ elseCode
+
+    val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
+        allConditions.mkString("\n")
+      } else {
+        // This generates code like:
+        //   conditionMet = caseWhen_1(i);
+        //   if(conditionMet) {
+        //     continue;
+        //   }
+        //   conditionMet = caseWhen_2(i);
+        //   if(conditionMet) {
+        //     continue;
+        //   }
+        //   ...
+        // and the declared methods are:
+        //   private boolean caseWhen_1234() {
+        //     boolean conditionMet = false;
+        //     do {
+        //       // here the evaluation of the conditions
+        //     } while (false);
+        //     return conditionMet;
+        //   }
+        ctx.splitExpressions(allConditions, "caseWhen",
+          ("InternalRow", ctx.INPUT_ROW) :: Nil,
+          returnType = ctx.JAVA_BOOLEAN,
+          makeSplitFunction = {
+            func =>
+              s"""
+                ${ctx.JAVA_BOOLEAN} $conditionMet = false;
+                do {
+                  $func
+                } while (false);
+                return $conditionMet;
+              """
+          },
+          foldFunctions = { funcCalls =>
+            funcCalls.map { funcCall =>
+              s"""
+                $conditionMet = $funcCall;
+                if ($conditionMet) {
+                  continue;
+                }"""
+            }.mkString
+          })
+      }
 
     ev.copy(code = s"""
-      boolean ${ev.isNull} = true;
-      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-      $generatedCode""")
+      ${ev.isNull} = true;
+      ${ev.value} = ${ctx.defaultValue(dataType)};
+      ${ctx.JAVA_BOOLEAN} $conditionMet = false;
+      do {
+        $code
+      } while (false);""")
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/087879a7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 3a3ccd5..0d961bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -138,8 +138,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
     // The following batch should be executed after batch "Join Reorder" and "LocalRelation".
     Batch("Check Cartesian Products", Once,
       CheckCartesianProducts) ::
-    Batch("OptimizeCodegen", Once,
-      OptimizeCodegen) ::
     Batch("RewriteSubquery", Once,
       RewritePredicateSubquery,
       CollapseProject) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/087879a7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 523b53b..785e815 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -553,21 +553,6 @@ object FoldablePropagation extends Rule[LogicalPlan] {
 
 
 /**
- * Optimizes expressions by replacing according to CodeGen configuration.
- */
-object OptimizeCodegen extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
-    case e: CaseWhen if canCodegen(e) => e.toCodegen()
-  }
-
-  private def canCodegen(e: CaseWhen): Boolean = {
-    val numBranches = e.branches.size + e.elseValue.size
-    numBranches <= SQLConf.get.maxCaseBranchesForCodegen
-  }
-}
-
-
-/**
  * Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
  */
 object SimplifyCasts extends Rule[LogicalPlan] {

http://git-wip-us.apache.org/repos/asf/spark/blob/087879a7/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 4eda9f3..ce68dbb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -599,12 +599,6 @@ object SQLConf {
     .booleanConf
     .createWithDefault(true)
 
-  val MAX_CASES_BRANCHES = buildConf("spark.sql.codegen.maxCaseBranches")
-    .internal()
-    .doc("The maximum number of switches supported with codegen.")
-    .intConf
-    .createWithDefault(20)
-
   val CODEGEN_LOGGING_MAX_LINES = buildConf("spark.sql.codegen.logging.maxLines")
     .internal()
     .doc("The maximum number of codegen lines to log when errors occur. Use -1 for unlimited.")
@@ -1140,8 +1134,6 @@ class SQLConf extends Serializable with Logging {
 
   def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK)
 
-  def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)
-
   def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES)
 
   def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT)

http://git-wip-us.apache.org/repos/asf/spark/blob/087879a7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 6e33087..a4198f8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -77,7 +77,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper
{
   }
 
   test("SPARK-13242: case-when expression with large number of branches (or cases)") {
-    val cases = 50
+    val cases = 500
     val clauses = 20
 
     // Generate an individual case
@@ -88,13 +88,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper
{
       (condition, Literal(n))
     }
 
-    val expression = CaseWhen((1 to cases).map(generateCase(_)))
+    val expression = CaseWhen((1 to cases).map(generateCase))
 
     val plan = GenerateMutableProjection.generate(Seq(expression))
-    val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}")))
+    val input = new GenericInternalRow(Array[Any](UTF8String.fromString(s"$clauses:$cases")))
     val actual = plan(input).toSeq(Seq(expression.dataType))
 
-    assert(actual(0) == cases)
+    assert(actual.head == cases)
   }
 
   test("SPARK-22543: split large if expressions into blocks due to JVM code size limit")
{

http://git-wip-us.apache.org/repos/asf/spark/blob/087879a7/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
deleted file mode 100644
index b1157f3..0000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.optimizer
-
-import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.Literal._
-import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules._
-
-
-class OptimizeCodegenSuite extends PlanTest {
-
-  object Optimize extends RuleExecutor[LogicalPlan] {
-    val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen) :: Nil
-  }
-
-  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
-    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
-    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
-    comparePlans(actual, correctAnswer)
-  }
-
-  test("Codegen only when the number of branches is small.") {
-    assertEquivalent(
-      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
-      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen())
-
-    assertEquivalent(
-      CaseWhen(List.fill(100)((TrueLiteral, Literal(1))), Literal(2)),
-      CaseWhen(List.fill(100)((TrueLiteral, Literal(1))), Literal(2)))
-  }
-
-  test("Nested CaseWhen Codegen.") {
-    assertEquivalent(
-      CaseWhen(
-        Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral), Literal(3))),
-        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))),
-      CaseWhen(
-        Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral).toCodegen(), Literal(3))),
-        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen())
-  }
-
-  test("Multiple CaseWhen in one operator.") {
-    val plan = OneRowRelation()
-      .select(
-        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
-        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
-        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
-        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze
-    val correctAnswer = OneRowRelation()
-      .select(
-        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
-        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
-        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
-        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze
-    val optimized = Optimize.execute(plan)
-    comparePlans(optimized, correctAnswer)
-  }
-
-  test("Multiple CaseWhen in different operators") {
-    val plan = OneRowRelation()
-      .select(
-        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
-        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
-        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
-      .where(
-        LessThan(
-          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)),
-          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
-      ).analyze
-    val correctAnswer = OneRowRelation()
-      .select(
-        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
-        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
-        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
-      .where(
-        LessThan(
-          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(),
-          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
-      ).analyze
-    val optimized = Optimize.execute(plan)
-    comparePlans(optimized, correctAnswer)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/087879a7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
index d077836..e495468 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala
@@ -90,7 +90,7 @@ class FlatMapGroupsWithState_StateManager(
     val deser = stateEncoder.resolveAndBind().deserializer.transformUp {
       case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal)
     }
-    CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser).toCodegen()
+    CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser)
   }
 
   // Converters for translating state between rows and Java objects

http://git-wip-us.apache.org/repos/asf/spark/blob/087879a7/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 644e72c..72a5cc9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -29,7 +29,7 @@ import org.scalatest.Matchers._
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
-import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
+import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec,
ShuffleExchangeExec}
 import org.apache.spark.sql.functions._
@@ -2158,4 +2158,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val mean = result.select("DecimalCol").where($"summary" === "mean")
     assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000")))
   }
+
+  test("SPARK-22520: support code generation for large CaseWhen") {
+    val N = 30
+    var expr1 = when($"id" === lit(0), 0)
+    var expr2 = when($"id" === lit(0), 10)
+    (1 to N).foreach { i =>
+      expr1 = expr1.when($"id" === lit(i), -i)
+      expr2 = expr2.when($"id" === lit(i + 10), i)
+    }
+    val df = spark.range(1).select(expr1, expr2.otherwise(0))
+    checkAnswer(df, Row(0, 10) :: Nil)
+    assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/087879a7/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index f9d75fc..8b1521b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -221,35 +221,6 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
       .sessionState.conf.warehousePath.stripSuffix("/"))
   }
 
-  test("MAX_CASES_BRANCHES") {
-    withTable("tab1") {
-      spark.range(10).write.saveAsTable("tab1")
-      val sql_one_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 END FROM tab1"
-      val sql_two_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 ELSE 0 END FROM tab1"
-
-      withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "0") {
-        assert(!sql(sql_one_branch_caseWhen)
-          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
-        assert(!sql(sql_two_branch_caseWhen)
-          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
-      }
-
-      withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "1") {
-        assert(sql(sql_one_branch_caseWhen)
-          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
-        assert(!sql(sql_two_branch_caseWhen)
-          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
-      }
-
-      withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "2") {
-        assert(sql(sql_one_branch_caseWhen)
-          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
-        assert(sql(sql_two_branch_caseWhen)
-          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
-      }
-    }
-  }
-
   test("static SQL conf comes from SparkConf") {
     val previousValue = sparkContext.conf.get(SCHEMA_STRING_LENGTH_THRESHOLD)
     try {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message