spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-17641][SQL] Collect_list/Collect_set should not collect null values.
Date Wed, 28 Sep 2016 23:25:19 GMT
Repository: spark
Updated Branches:
  refs/heads/master 557d6e322 -> 7d0923202


[SPARK-17641][SQL] Collect_list/Collect_set should not collect null values.

## What changes were proposed in this pull request?
We added native versions of `collect_set` and `collect_list` in Spark 2.0. These currently
also (try to) collect null values, this is different from the original Hive implementation.
This PR fixes this by adding a null check to the `Collect.update` method.

## How was this patch tested?
Added a regression test to `DataFrameAggregateSuite`.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #15208 from hvanhovell/SPARK-17641.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7d092320
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7d092320
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7d092320

Branch: refs/heads/master
Commit: 7d09232028967978d9db314ec041a762599f636b
Parents: 557d6e3
Author: Herman van Hovell <hvanhovell@databricks.com>
Authored: Wed Sep 28 16:25:10 2016 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Wed Sep 28 16:25:10 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/aggregate/collect.scala    |  7 ++++++-
 .../org/apache/spark/sql/DataFrameAggregateSuite.scala  | 12 ++++++++++++
 2 files changed, 18 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7d092320/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index 896ff61..78a388d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -65,7 +65,12 @@ abstract class Collect extends ImperativeAggregate {
   }
 
   override def update(b: MutableRow, input: InternalRow): Unit = {
-    buffer += child.eval(input)
+    // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set
here.
+    // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
+    val value = child.eval(input)
+    if (value != null) {
+      buffer += value
+    }
   }
 
   override def merge(buffer: MutableRow, input: InternalRow): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/7d092320/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 0e172be..7aa4f00 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -477,6 +477,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext
{
     assert(error.message.contains("collect_set() cannot have map type data"))
   }
 
+  test("SPARK-17641: collect functions should not collect null values") {
+    val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b")
+    checkAnswer(
+      df.select(collect_list($"a"), collect_list($"b")),
+      Seq(Row(Seq("1", "1"), Seq(2, 2, 4)))
+    )
+    checkAnswer(
+      df.select(collect_set($"a"), collect_set($"b")),
+      Seq(Row(Seq("1"), Seq(2, 4)))
+    )
+  }
+
   test("SPARK-14664: Decimal sum/avg over window should work.") {
     checkAnswer(
       spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message