spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-25519][SQL] ArrayRemove function may return incorrect result when right expression is implicitly downcasted.
Date Tue, 25 Sep 2018 04:05:13 GMT
Repository: spark
Updated Branches:
  refs/heads/master 615792da4 -> 7d8f5b62c


[SPARK-25519][SQL] ArrayRemove function may return incorrect result when right expression
is implicitly downcasted.

## What changes were proposed in this pull request?
In ArrayRemove, we currently cast the right hand side expression to match the element type
of the left hand side Array. This may result in down casting and may return wrong result or
questionable result.

Example :
```SQL
spark-sql> select array_remove(array(1,2,3), 1.23D);
       [2,3]
```
```SQL
spark-sql> select array_remove(array(1,2,3), 'foo');
        NULL
```
We should safely coerce both left and right hand side expressions.
## How was this patch tested?
Added tests in DataFrameFunctionsSuite

Closes #22542 from dilipbiswal/SPARK-25519.

Authored-by: Dilip Biswal <dbiswal@us.ibm.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7d8f5b62
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7d8f5b62
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7d8f5b62

Branch: refs/heads/master
Commit: 7d8f5b62c57c9e2903edd305e8b9c5400652fdb0
Parents: 615792d
Author: Dilip Biswal <dbiswal@us.ibm.com>
Authored: Tue Sep 25 12:05:04 2018 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Tue Sep 25 12:05:04 2018 +0800

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 29 +++++++-----
 .../spark/sql/DataFrameFunctionsSuite.scala     | 48 +++++++++++++++++++-
 2 files changed, 63 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7d8f5b62/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 85bc1cd..9cc7dba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -3088,11 +3088,24 @@ case class ArrayRemove(left: Expression, right: Expression)
   override def dataType: DataType = left.dataType
 
   override def inputTypes: Seq[AbstractDataType] = {
-    val elementType = left.dataType match {
-      case t: ArrayType => t.elementType
-      case _ => AnyDataType
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, hasNull), e2) =>
+        TypeCoercion.findTightestCommonType(e1, e2) match {
+          case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
+          case _ => Seq.empty
+        }
+      case _ => Seq.empty
+    }
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    (left.dataType, right.dataType) match {
+      case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
+        TypeUtils.checkForOrderingExpr(e2, s"function $prettyName")
+      case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should
have " +
+        s"been ${ArrayType.simpleString} followed by a value with same element type, but
it's " +
+        s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
     }
-    Seq(ArrayType, elementType)
   }
 
   private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
@@ -3100,14 +3113,6 @@ case class ArrayRemove(left: Expression, right: Expression)
   @transient private lazy val ordering: Ordering[Any] =
     TypeUtils.getInterpretedOrdering(right.dataType)
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    super.checkInputDataTypes() match {
-      case f: TypeCheckResult.TypeCheckFailure => f
-      case TypeCheckResult.TypeCheckSuccess =>
-        TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
-    }
-  }
-
   override def nullSafeEval(arr: Any, value: Any): Any = {
     val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements())
     var pos = 0

http://git-wip-us.apache.org/repos/asf/spark/blob/7d8f5b62/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index fd71f24..88dbae8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -1575,6 +1575,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
     )
 
     checkAnswer(
+      OneRowRelation().selectExpr("array_remove(array(1, 2), 1.23D)"),
+      Seq(
+        Row(Seq(1.0, 2.0))
+      )
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("array_remove(array(1, 2), 1.0D)"),
+      Seq(
+        Row(Seq(2.0))
+      )
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("array_remove(array(1.0D, 2.0D), 2)"),
+      Seq(
+        Row(Seq(1.0))
+      )
+    )
+
+    checkAnswer(
+      OneRowRelation().selectExpr("array_remove(array(1.1D, 1.2D), 1)"),
+      Seq(
+        Row(Seq(1.1, 1.2))
+      )
+    )
+
+    checkAnswer(
       df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")",
         "array_remove(c, \"\")"),
       Seq(
@@ -1583,10 +1611,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
         Row(null, null, null))
     )
 
-    val e = intercept[AnalysisException] {
+    val e1 = intercept[AnalysisException] {
       Seq(("a string element", "a")).toDF().selectExpr("array_remove(_1, _2)")
     }
-    assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string
type"))
+    val errorMsg1 =
+      s"""
+         |Input to function array_remove should have been array followed by a
+         |value with same element type, but it's [string, string].
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e1.message.contains(errorMsg1))
+
+    val e2 = intercept[AnalysisException] {
+      OneRowRelation().selectExpr("array_remove(array(1, 2), '1')")
+    }
+
+    val errorMsg2 =
+      s"""
+         |Input to function array_remove should have been array followed by a
+         |value with same element type, but it's [array<int>, string].
+       """.stripMargin.replace("\n", " ").trim()
+    assert(e2.message.contains(errorMsg2))
   }
 
   test("array_distinct functions") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message