spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l...@apache.org
Subject spark git commit: [SPARK-16097][SQL] Encoders.tuple should handle null object correctly
Date Wed, 22 Jun 2016 10:37:55 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.0 60bd704b5 -> 503eb882c


[SPARK-16097][SQL] Encoders.tuple should handle null object correctly

## What changes were proposed in this pull request?

Although the top level input object can not be null, but when we use `Encoders.tuple` to combine
2 encoders, their input objects are not top level anymore and can be null. We should handle
this case.

## How was this patch tested?

new test in DatasetSuite

Author: Wenchen Fan <wenchen@databricks.com>

Closes #13807 from cloud-fan/bug.

(cherry picked from commit 01277d4b259dcf9cad25eece1377162b7a8c946d)
Signed-off-by: Cheng Lian <lian@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/503eb882
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/503eb882
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/503eb882

Branch: refs/heads/branch-2.0
Commit: 503eb882c14eac9681981199ccf8f699cab23bf0
Parents: 60bd704
Author: Wenchen Fan <wenchen@databricks.com>
Authored: Wed Jun 22 18:32:14 2016 +0800
Committer: Cheng Lian <lian@databricks.com>
Committed: Wed Jun 22 18:37:36 2016 +0800

----------------------------------------------------------------------
 .../catalyst/encoders/ExpressionEncoder.scala   | 48 ++++++++++++++------
 .../org/apache/spark/sql/DatasetSuite.scala     |  7 +++
 2 files changed, 42 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/503eb882/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 0023ce6..1fac26c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
 import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
 import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
 import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
-import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
+import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType}
 import org.apache.spark.util.Utils
 
 /**
@@ -110,16 +110,34 @@ object ExpressionEncoder {
 
     val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
 
-    val serializer = encoders.map {
-      case e if e.flat => e.serializer.head
-      case other => CreateStruct(other.serializer)
-    }.zipWithIndex.map { case (expr, index) =>
-      expr.transformUp {
-        case BoundReference(0, t, _) =>
-          Invoke(
-            BoundReference(0, ObjectType(cls), nullable = true),
-            s"_${index + 1}",
-            t)
+    val serializer = encoders.zipWithIndex.map { case (enc, index) =>
+      val originalInputObject = enc.serializer.head.collect { case b: BoundReference =>
b }.head
+      val newInputObject = Invoke(
+        BoundReference(0, ObjectType(cls), nullable = true),
+        s"_${index + 1}",
+        originalInputObject.dataType)
+
+      val newSerializer = enc.serializer.map(_.transformUp {
+        case b: BoundReference if b == originalInputObject => newInputObject
+      })
+
+      if (enc.flat) {
+        newSerializer.head
+      } else {
+        // For non-flat encoder, the input object is not top level anymore after being combined
to
+        // a tuple encoder, thus it can be null and we should wrap the `CreateStruct` with
`If` and
+        // null check to handle null case correctly.
+        // e.g. for Encoder[(Int, String)], the serializer expressions will create 2 columns,
and is
+        // not able to handle the case when the input tuple is null. This is not a problem
as there
+        // is a check to make sure the input object won't be null. However, if this encoder
is used
+        // to create a bigger tuple encoder, the original input object becomes a filed of
the new
+        // input tuple and can be null. So instead of creating a struct directly here, we
should add
+        // a null/None check and return a null struct if the null/None check fails.
+        val struct = CreateStruct(newSerializer)
+        val nullCheck = Or(
+          IsNull(newInputObject),
+          Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil))
+        If(nullCheck, Literal.create(null, struct.dataType), struct)
       }
     }
 
@@ -203,8 +221,12 @@ case class ExpressionEncoder[T](
   // (intermediate value is not an attribute). We assume that all serializer expressions
use a same
   // `BoundReference` to refer to the object, and throw exception if they don't.
   assert(serializer.forall(_.references.isEmpty), "serializer cannot reference to any attributes.")
-  assert(serializer.flatMap(_.collect { case b: BoundReference => b}).distinct.length
<= 1,
-    "all serializer expressions must use the same BoundReference.")
+  assert(serializer.flatMap { ser =>
+    val boundRefs = ser.collect { case b: BoundReference => b }
+    assert(boundRefs.nonEmpty,
+      "each serializer expression should contains at least one `BoundReference`")
+    boundRefs
+  }.distinct.length <= 1, "all serializer expressions must use the same BoundReference.")
 
   /**
    * Returns a new copy of this encoder, where the `deserializer` is resolved and bound to
the

http://git-wip-us.apache.org/repos/asf/spark/blob/503eb882/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index f02a314..bd8479b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -830,6 +830,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       ds.dropDuplicates("_1", "_2"),
       ("a", 1), ("a", 2), ("b", 1))
   }
+
+  test("SPARK-16097: Encoders.tuple should handle null object correctly") {
+    val enc = Encoders.tuple(Encoders.tuple(Encoders.STRING, Encoders.STRING), Encoders.STRING)
+    val data = Seq((("a", "b"), "c"), (null, "d"))
+    val ds = spark.createDataset(data)(enc)
+    checkDataset(ds, (("a", "b"), "c"), (null, "d"))
+  }
 }
 
 case class Generic[T](id: T, value: Double)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message