spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From gurwls...@apache.org
Subject spark git commit: [SPARK-24574][SQL] array_contains, array_position, array_remove and element_at functions deal with Column type
Date Thu, 21 Jun 2018 06:59:03 GMT
Repository: spark
Updated Branches:
  refs/heads/master 54fcaafb0 -> 7236e759c


[SPARK-24574][SQL] array_contains, array_position, array_remove and element_at functions deal
with Column type

## What changes were proposed in this pull request?

For the function ```def array_contains(column: Column, value: Any): Column ``` , if we pass
the `value` parameter as a Column type, it will yield a runtime exception.

This PR proposes a pattern matching to detect if `value` is of type Column. If yes, it will
use the .expr of the column, otherwise it will work as it used to.

Same thing for ```array_position, array_remove and element_at``` functions

## How was this patch tested?

Unit test modified to cover this code change.

Ping ueshin

Author: Chongguang LIU <chong@Chongguangs-MacBook-Pro.local>

Closes #21581 from chongguang/SPARK-24574.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7236e759
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7236e759
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7236e759

Branch: refs/heads/master
Commit: 7236e759c9856970116bf4dd20813dbf14440462
Parents: 54fcaaf
Author: Chongguang LIU <lcg31439@gmail.com>
Authored: Thu Jun 21 14:58:57 2018 +0800
Committer: hyukjinkwon <gurwls223@apache.org>
Committed: Thu Jun 21 14:58:57 2018 +0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/functions.scala  |  8 +--
 .../spark/sql/DataFrameFunctionsSuite.scala     | 69 +++++++++++++++-----
 2 files changed, 58 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7236e759/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 965dbb6..c296a1b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3093,7 +3093,7 @@ object functions {
    * @since 1.5.0
    */
   def array_contains(column: Column, value: Any): Column = withExpr {
-    ArrayContains(column.expr, Literal(value))
+    ArrayContains(column.expr, lit(value).expr)
   }
 
   /**
@@ -3157,7 +3157,7 @@ object functions {
    * @since 2.4.0
    */
   def array_position(column: Column, value: Any): Column = withExpr {
-    ArrayPosition(column.expr, Literal(value))
+    ArrayPosition(column.expr, lit(value).expr)
   }
 
   /**
@@ -3168,7 +3168,7 @@ object functions {
    * @since 2.4.0
    */
   def element_at(column: Column, value: Any): Column = withExpr {
-    ElementAt(column.expr, Literal(value))
+    ElementAt(column.expr, lit(value).expr)
   }
 
   /**
@@ -3186,7 +3186,7 @@ object functions {
    * @since 2.4.0
    */
   def array_remove(column: Column, element: Any): Column = withExpr {
-    ArrayRemove(column.expr, Literal(element))
+    ArrayRemove(column.expr, lit(element).expr)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/7236e759/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 3dc696b..fcdd33f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -635,9 +635,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
 
   test("array contains function") {
     val df = Seq(
-      (Seq[Int](1, 2), "x"),
-      (Seq[Int](), "x")
-    ).toDF("a", "b")
+      (Seq[Int](1, 2), "x", 1),
+      (Seq[Int](), "x", 1)
+    ).toDF("a", "b", "c")
 
     // Simple test cases
     checkAnswer(
@@ -648,6 +648,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
       df.selectExpr("array_contains(a, 1)"),
       Seq(Row(true), Row(false))
     )
+    checkAnswer(
+      df.select(array_contains(df("a"), df("c"))),
+      Seq(Row(true), Row(false))
+    )
+    checkAnswer(
+      df.selectExpr("array_contains(a, c)"),
+      Seq(Row(true), Row(false))
+    )
 
     // In hive, this errors because null has no type information
     intercept[AnalysisException] {
@@ -862,9 +870,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
 
   test("array position function") {
     val df = Seq(
-      (Seq[Int](1, 2), "x"),
-      (Seq[Int](), "x")
-    ).toDF("a", "b")
+      (Seq[Int](1, 2), "x", 1),
+      (Seq[Int](), "x", 1)
+    ).toDF("a", "b", "c")
 
     checkAnswer(
       df.select(array_position(df("a"), 1)),
@@ -874,7 +882,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
       df.selectExpr("array_position(a, 1)"),
       Seq(Row(1L), Row(0L))
     )
-
+    checkAnswer(
+      df.selectExpr("array_position(a, c)"),
+      Seq(Row(1L), Row(0L))
+    )
+    checkAnswer(
+      df.select(array_position(df("a"), df("c"))),
+      Seq(Row(1L), Row(0L))
+    )
     checkAnswer(
       df.select(array_position(df("a"), null)),
       Seq(Row(null), Row(null))
@@ -901,10 +916,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
 
   test("element_at function") {
     val df = Seq(
-      (Seq[String]("1", "2", "3")),
-      (Seq[String](null, "")),
-      (Seq[String]())
-    ).toDF("a")
+      (Seq[String]("1", "2", "3"), 1),
+      (Seq[String](null, ""), -1),
+      (Seq[String](), 2)
+    ).toDF("a", "b")
 
     intercept[Exception] {
       checkAnswer(
@@ -922,6 +937,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
       df.select(element_at(df("a"), 4)),
       Seq(Row(null), Row(null), Row(null))
     )
+    checkAnswer(
+      df.select(element_at(df("a"), df("b"))),
+      Seq(Row("1"), Row(""), Row(null))
+    )
+    checkAnswer(
+      df.selectExpr("element_at(a, b)"),
+      Seq(Row("1"), Row(""), Row(null))
+    )
 
     checkAnswer(
       df.select(element_at(df("a"), 1)),
@@ -1189,10 +1212,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
 
   test("array remove") {
     val df = Seq(
-      (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", "")),
-      (Array.empty[Int], Array.empty[String], Array.empty[String]),
-      (null, null, null)
-    ).toDF("a", "b", "c")
+      (Array[Int](2, 1, 2, 3), Array("a", "b", "c", "a"), Array("", ""), 2),
+      (Array.empty[Int], Array.empty[String], Array.empty[String], 2),
+      (null, null, null, 2)
+    ).toDF("a", "b", "c", "d")
     checkAnswer(
       df.select(array_remove($"a", 2), array_remove($"b", "a"), array_remove($"c", "")),
       Seq(
@@ -1202,6 +1225,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
     )
 
     checkAnswer(
+      df.select(array_remove($"a", $"d")),
+      Seq(
+        Row(Seq(1, 3)),
+        Row(Seq.empty[Int]),
+        Row(null))
+    )
+
+    checkAnswer(
+      df.selectExpr("array_remove(a, d)"),
+      Seq(
+        Row(Seq(1, 3)),
+        Row(Seq.empty[Int]),
+        Row(null))
+    )
+
+    checkAnswer(
       df.selectExpr("array_remove(a, 2)", "array_remove(b, \"a\")",
         "array_remove(c, \"\")"),
       Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message