spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-23835][SQL] Add not-null check to Tuples' arguments deserialization
Date Tue, 17 Apr 2018 13:46:13 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.3 d4f204c53 -> 9857e249c


[SPARK-23835][SQL] Add not-null check to Tuples' arguments deserialization

## What changes were proposed in this pull request?

There was no check on nullability for arguments of `Tuple`s. This could lead to have weird
behavior when a null value had to be deserialized into a non-nullable Scala object: in those
cases, the `null` got silently transformed in a valid value (like `-1` for `Int`), corresponding
to the default value we are using in the SQL codebase. This situation was very likely to happen
when deserializing to a Tuple of primitive Scala types (like Double, Int, ...).

The PR adds the `AssertNotNull` to arguments of tuples which have been asked to be converted
to non-nullable types.

## How was this patch tested?

added UT

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #20976 from mgaido91/SPARK-23835.

(cherry picked from commit 0a9172a05e604a4a94adbb9208c8c02362afca00)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9857e249
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9857e249
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9857e249

Branch: refs/heads/branch-2.3
Commit: 9857e249c20f842868cb4681ea374b8e316c3ead
Parents: d4f204c
Author: Marco Gaido <marcogaido91@gmail.com>
Authored: Tue Apr 17 21:45:20 2018 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Tue Apr 17 21:46:00 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/kafka010/KafkaContinuousSinkSuite.scala |  6 +++---
 .../apache/spark/sql/kafka010/KafkaSinkSuite.scala    |  2 +-
 .../apache/spark/sql/catalyst/ScalaReflection.scala   | 14 +++++++-------
 .../spark/sql/catalyst/ScalaReflectionSuite.scala     | 12 +++++++++++-
 .../scala/org/apache/spark/sql/DatasetSuite.scala     |  5 +++++
 5 files changed, 27 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9857e249/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
index fc890a0..ddfc0c1 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
@@ -79,7 +79,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
     val reader = createKafkaReader(topic)
       .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
       .selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
-      .as[(Int, Int)]
+      .as[(Option[Int], Int)]
       .map(_._2)
 
     try {
@@ -119,7 +119,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
     val reader = createKafkaReader(topic)
       .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
       .selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
-      .as[(Int, Int)]
+      .as[(Option[Int], Int)]
       .map(_._2)
 
     try {
@@ -167,7 +167,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
     val reader = createKafkaReader(topic)
       .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
       .selectExpr("CAST(key AS INT)", "CAST(value AS INT)")
-      .as[(Int, Int)]
+      .as[(Option[Int], Int)]
       .map(_._2)
 
     try {

http://git-wip-us.apache.org/repos/asf/spark/blob/9857e249/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
index 42f8b4c..7079ac6 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
@@ -138,7 +138,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext {
     val reader = createKafkaReader(topic)
       .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
       .selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
-      .as[(Int, Int)]
+      .as[(Option[Int], Int)]
       .map(_._2)
 
     try {

http://git-wip-us.apache.org/repos/asf/spark/blob/9857e249/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 9a4bf00..fabf895 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -382,22 +382,22 @@ object ScalaReflection extends ScalaReflection {
           val clsName = getClassNameFromType(fieldType)
           val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
           // For tuples, we based grab the inner fields by ordinal instead of name.
-          if (cls.getName startsWith "scala.Tuple") {
+          val constructor = if (cls.getName startsWith "scala.Tuple") {
             deserializerFor(
               fieldType,
               Some(addToPathOrdinal(i, dataType, newTypePath)),
               newTypePath)
           } else {
-            val constructor = deserializerFor(
+            deserializerFor(
               fieldType,
               Some(addToPath(fieldName, dataType, newTypePath)),
               newTypePath)
+          }
 
-            if (!nullable) {
-              AssertNotNull(constructor, newTypePath)
-            } else {
-              constructor
-            }
+          if (!nullable) {
+            AssertNotNull(constructor, newTypePath)
+          } else {
+            constructor
           }
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9857e249/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 8c3db48..353b834 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow,
UpCast}
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow,
UpCast}
 import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -365,4 +365,14 @@ class ScalaReflectionSuite extends SparkFunSuite {
         StructField("_2", NullType, nullable = true))),
       nullable = true))
   }
+
+  test("SPARK-23835: add null check to non-nullable types in Tuples") {
+    def numberOfCheckedArguments(deserializer: Expression): Int = {
+      assert(deserializer.isInstanceOf[NewInstance])
+      deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull])
+    }
+    assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2)
+    assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1)
+    assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)])
== 0)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9857e249/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 9b745be..e0f4d2b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1453,6 +1453,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     val group2 = cached.groupBy("x").agg(min(col("z")) as "value")
     checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) ::
Nil)
   }
+
+  test("SPARK-23835: null primitive data type should throw NullPointerException") {
+    val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS()
+    intercept[NullPointerException](ds.as[(Int, Int)].collect())
+  }
 }
 
 case class TestDataUnion(x: Int, y: Int, z: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message