spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yh...@apache.org
Subject spark git commit: [SPARK-14541][SQL] Support IFNULL, NULLIF, NVL and NVL2
Date Fri, 13 May 2016 05:18:46 GMT
Repository: spark
Updated Branches:
  refs/heads/master ba169c323 -> eda2800d4


[SPARK-14541][SQL] Support IFNULL, NULLIF, NVL and NVL2

## What changes were proposed in this pull request?
This patch adds support for a few SQL functions to improve compatibility with other databases:
IFNULL, NULLIF, NVL and NVL2. In order to do this, this patch introduced a RuntimeReplaceable
expression trait that allows replacing an unevaluable expression in the optimizer before evaluation.

Note that the semantics are not completely identical to other databases in esoteric cases.

## How was this patch tested?
Added a new test suite SQLCompatibilityFunctionSuite.

Closes #12373.

Author: Reynold Xin <rxin@databricks.com>

Closes #13084 from rxin/SPARK-14541.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/eda2800d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/eda2800d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/eda2800d

Branch: refs/heads/master
Commit: eda2800d44843b6478e22d2c99bca4af7e9c9613
Parents: ba169c3
Author: Reynold Xin <rxin@databricks.com>
Authored: Thu May 12 22:18:39 2016 -0700
Committer: Yin Huai <yhuai@databricks.com>
Committed: Thu May 12 22:18:39 2016 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  5 +-
 .../catalyst/analysis/HiveTypeCoercion.scala    |  2 +
 .../sql/catalyst/expressions/Expression.scala   | 27 +++++++
 .../catalyst/expressions/nullExpressions.scala  | 78 +++++++++++++++++++-
 .../sql/catalyst/optimizer/Optimizer.scala      | 12 +++
 .../spark/sql/DataFrameFunctionsSuite.scala     |  6 --
 .../sql/SQLCompatibilityFunctionSuite.scala     | 72 ++++++++++++++++++
 .../sql/catalyst/ExpressionToSQLSuite.scala     |  1 -
 8 files changed, 194 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/eda2800d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index c459fe5..eca837c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -165,13 +165,16 @@ object FunctionRegistry {
     expression[Greatest]("greatest"),
     expression[If]("if"),
     expression[IsNaN]("isnan"),
+    expression[IfNull]("ifnull"),
     expression[IsNull]("isnull"),
     expression[IsNotNull]("isnotnull"),
     expression[Least]("least"),
     expression[CreateMap]("map"),
     expression[CreateNamedStruct]("named_struct"),
     expression[NaNvl]("nanvl"),
-    expression[Coalesce]("nvl"),
+    expression[NullIf]("nullif"),
+    expression[Nvl]("nvl"),
+    expression[Nvl2]("nvl2"),
     expression[Rand]("rand"),
     expression[Randn]("randn"),
     expression[CreateStruct]("struct"),

http://git-wip-us.apache.org/repos/asf/spark/blob/eda2800d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 8319ec0..537dda6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -521,6 +521,8 @@ object HiveTypeCoercion {
         NaNvl(l, Cast(r, DoubleType))
       case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
         NaNvl(Cast(l, DoubleType), r)
+
+      case e: RuntimeReplaceable => e.replaceForTypeCoercion()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/eda2800d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index c26faee..fab1634 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -222,6 +222,33 @@ trait Unevaluable extends Expression {
 
 
 /**
+ * An expression that gets replaced at runtime (currently by the optimizer) into a different
+ * expression for evaluation. This is mainly used to provide compatibility with other databases.
+ * For example, we use this to support "nvl" by replacing it with "coalesce".
+ */
+trait RuntimeReplaceable extends Unevaluable {
+  /**
+   * Method for concrete implementations to override that specifies how to construct the
expression
+   * that should replace the current one.
+   */
+  def replaceForEvaluation(): Expression
+
+  /**
+   * Method for concrete implementations to override that specifies how to coerce the input
types.
+   */
+  def replaceForTypeCoercion(): Expression
+
+  /** The expression that should be used during evaluation. */
+  lazy val replaced: Expression = replaceForEvaluation()
+
+  override def nullable: Boolean = replaced.nullable
+  override def foldable: Boolean = replaced.foldable
+  override def dataType: DataType = replaced.dataType
+  override def checkInputDataTypes(): TypeCheckResult = replaced.checkInputDataTypes()
+}
+
+
+/**
  * Expressions that don't have SQL representation should extend this trait.  Examples are
  * `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/eda2800d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 421200e..641c81b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{HiveTypeCoercion, TypeCheckResult}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
@@ -88,6 +88,82 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
 }
 
 
+@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.")
+case class IfNull(left: Expression, right: Expression) extends RuntimeReplaceable {
+  override def children: Seq[Expression] = Seq(left, right)
+
+  override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right))
+
+  override def replaceForTypeCoercion(): Expression = {
+    if (left.dataType != right.dataType) {
+      HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype
=>
+        copy(left = Cast(left, dtype), right = Cast(right, dtype))
+      }.getOrElse(this)
+    } else {
+      this
+    }
+  }
+}
+
+
+@ExpressionDescription(usage = "_FUNC_(a,b) - Returns null if a equals to b, or a otherwise.")
+case class NullIf(left: Expression, right: Expression) extends RuntimeReplaceable {
+  override def children: Seq[Expression] = Seq(left, right)
+
+  override def replaceForEvaluation(): Expression = {
+    If(EqualTo(left, right), Literal.create(null, left.dataType), left)
+  }
+
+  override def replaceForTypeCoercion(): Expression = {
+    if (left.dataType != right.dataType) {
+      HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype
=>
+        copy(left = Cast(left, dtype), right = Cast(right, dtype))
+      }.getOrElse(this)
+    } else {
+      this
+    }
+  }
+}
+
+
+@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.")
+case class Nvl(left: Expression, right: Expression) extends RuntimeReplaceable {
+  override def children: Seq[Expression] = Seq(left, right)
+
+  override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right))
+
+  override def replaceForTypeCoercion(): Expression = {
+    if (left.dataType != right.dataType) {
+      HiveTypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype
=>
+        copy(left = Cast(left, dtype), right = Cast(right, dtype))
+      }.getOrElse(this)
+    } else {
+      this
+    }
+  }
+}
+
+
+@ExpressionDescription(usage = "_FUNC_(a,b,c) - Returns b if a is not null, or c otherwise.")
+case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression)
+  extends RuntimeReplaceable {
+
+  override def replaceForEvaluation(): Expression = If(IsNotNull(expr1), expr2, expr3)
+
+  override def children: Seq[Expression] = Seq(expr1, expr2, expr3)
+
+  override def replaceForTypeCoercion(): Expression = {
+    if (expr2.dataType != expr3.dataType) {
+      HiveTypeCoercion.findTightestCommonTypeOfTwo(expr2.dataType, expr3.dataType).map {
dtype =>
+        copy(expr2 = Cast(expr2, dtype), expr3 = Cast(expr3, dtype))
+      }.getOrElse(this)
+    } else {
+      this
+    }
+  }
+}
+
+
 /**
  * Evaluates to `true` iff it's NaN.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/eda2800d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 928ba21..af7532e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -49,6 +49,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
     // we do not eliminate subqueries or compute current time in the analyzer.
     Batch("Finish Analysis", Once,
       EliminateSubqueryAliases,
+      ReplaceExpressions,
       ComputeCurrentTime,
       GetCurrentDatabase(sessionCatalog),
       DistinctAggregationRewriter) ::
@@ -1512,6 +1513,17 @@ object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan]
{
 }
 
 /**
+ * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that
can
+ * be evaluated. This is mainly used to provide compatibility with other databases.
+ * For example, we use this to support "nvl" by replacing it with "coalesce".
+ */
+object ReplaceExpressions extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+    case e: RuntimeReplaceable => e.replaced
+  }
+}
+
+/**
  * Computes the current date and time to make sure we return the same result in a single
query.
  */
 object ComputeCurrentTime extends Rule[LogicalPlan] {

http://git-wip-us.apache.org/repos/asf/spark/blob/eda2800d/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 746e25a..73d7765 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -152,12 +152,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
       Row("one", "not_one"))
   }
 
-  test("nvl function") {
-    checkAnswer(
-      sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
-      Row("x", "y", null))
-  }
-
   test("misc md5 function") {
     val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b")
     checkAnswer(

http://git-wip-us.apache.org/repos/asf/spark/blob/eda2800d/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
new file mode 100644
index 0000000..1e32395
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+/**
+ * A test suite for functions added for compatibility with other databases such as Oracle,
MSSQL.
+ * These functions are typically implemented using the trait
+ * [[org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable]].
+ */
+class SQLCompatibilityFunctionSuite extends QueryTest with SharedSQLContext {
+
+  test("ifnull") {
+    checkAnswer(
+      sql("SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null)"),
+      Row("x", "y", null))
+
+    // Type coercion
+    checkAnswer(
+      sql("SELECT ifnull(1, 2.1d), ifnull(null, 2.1d)"),
+      Row(1.0, 2.1))
+  }
+
+  test("nullif") {
+    checkAnswer(
+      sql("SELECT nullif('x', 'x'), nullif('x', 'y')"),
+      Row(null, "x"))
+
+    // Type coercion
+    checkAnswer(
+      sql("SELECT nullif(1, 2.1d), nullif(1, 1.0d)"),
+      Row(1.0, null))
+  }
+
+  test("nvl") {
+    checkAnswer(
+      sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
+      Row("x", "y", null))
+
+    // Type coercion
+    checkAnswer(
+      sql("SELECT nvl(1, 2.1d), nvl(null, 2.1d)"),
+      Row(1.0, 2.1))
+  }
+
+  test("nvl2") {
+    checkAnswer(
+      sql("SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null)"),
+      Row("y", "x", null))
+
+    // Type coercion
+    checkAnswer(
+      sql("SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d)"),
+      Row(2.1, 1.0))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/eda2800d/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
index 72736ee..b4eb50e 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
@@ -102,7 +102,6 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
     checkSqlGeneration("SELECT map(1, 'a', 2, 'b')")
     checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
     checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
-    checkSqlGeneration("SELECT nvl(null, 1, 2)")
     checkSqlGeneration("SELECT rand(1)")
     checkSqlGeneration("SELECT randn(3)")
     checkSqlGeneration("SELECT struct(1,2,3)")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message