spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-12770][SQL] Implement rules for branch elimination for CaseWhen
Date Wed, 20 Jan 2016 00:14:45 GMT
Repository: spark
Updated Branches:
  refs/heads/master f6f7ca9d2 -> 3e84ef0a5


[SPARK-12770][SQL] Implement rules for branch elimination for CaseWhen

The three optimization cases are:

1. If the first branch's condition is a true literal, remove the CaseWhen and use the value
from that branch.
2. If a branch's condition is a false or null literal, remove that branch.
3. If only the else branch is left, remove the CaseWhen and use the value from the else branch.

Author: Reynold Xin <rxin@databricks.com>

Closes #10827 from rxin/SPARK-12770.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3e84ef0a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3e84ef0a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3e84ef0a

Branch: refs/heads/master
Commit: 3e84ef0a54c53c45d7802cd2fecfa1c223580aee
Parents: f6f7ca9
Author: Reynold Xin <rxin@databricks.com>
Authored: Tue Jan 19 16:14:41 2016 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Tue Jan 19 16:14:41 2016 -0800

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/Optimizer.scala      | 18 ++++++++++
 .../optimizer/SimplifyConditionalSuite.scala    | 37 ++++++++++++++++++++
 2 files changed, 55 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3e84ef0a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index cc3371c..b7caa49 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -635,6 +635,24 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper
{
     case q: LogicalPlan => q transformExpressionsUp {
       case If(TrueLiteral, trueValue, _) => trueValue
       case If(FalseLiteral, _, falseValue) => falseValue
+
+      case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
+        // If there are branches that are always false, remove them.
+        // If there are no more branches left, just use the else value.
+        // Note that these two are handled together here in a single case statement because
+        // otherwise we cannot determine the data type for the elseValue if it is None (i.e.
null).
+        val newBranches = branches.filter(_._1 != FalseLiteral)
+        if (newBranches.isEmpty) {
+          elseValue.getOrElse(Literal.create(null, e.dataType))
+        } else {
+          e.copy(branches = newBranches)
+        }
+
+      case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral)
=>
+        // If the first branch is a true literal, remove the entire CaseWhen and use the
value
+        // from that. Note that CaseWhen.branches should never be empty, and as a result
the
+        // headOption (rather than head) added above is just a extra (and unnecessary) safeguard.
+        branches.head._2
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3e84ef0a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
index 8e5d7ef..d436b62 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral,
TrueLite
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.IntegerType
 
 
 class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
@@ -37,6 +38,10 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
     comparePlans(actual, correctAnswer)
   }
 
+  private val trueBranch = (TrueLiteral, Literal(5))
+  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
+  private val unreachableBranch = (FalseLiteral, Literal(20))
+
   test("simplify if") {
     assertEquivalent(
       If(TrueLiteral, Literal(10), Literal(20)),
@@ -47,4 +52,36 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
       Literal(20))
   }
 
+  test("remove unreachable branches") {
+    // i.e. removing branches whose conditions are always false
+    assertEquivalent(
+      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
+      CaseWhen(normalBranch :: Nil, None))
+  }
+
+  test("remove entire CaseWhen if only the else branch is reachable") {
+    assertEquivalent(
+      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
+      Literal(30))
+
+    assertEquivalent(
+      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
+      Literal.create(null, IntegerType))
+  }
+
+  test("remove entire CaseWhen if the first branch is always true") {
+    assertEquivalent(
+      CaseWhen(trueBranch :: normalBranch :: Nil, None),
+      Literal(5))
+
+    // Test branch elimination and simplification in combination
+    assertEquivalent(
+      CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil,
None),
+      Literal(5))
+
+    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
+    assertEquivalent(
+      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
+      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message