Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id EE5E52009F8 for ; Fri, 3 Jun 2016 23:28:31 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id ECFFE160A50; Fri, 3 Jun 2016 21:28:31 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 183CD160A3B for ; Fri, 3 Jun 2016 23:28:30 +0200 (CEST) Received: (qmail 50104 invoked by uid 500); 3 Jun 2016 21:28:30 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 50095 invoked by uid 99); 3 Jun 2016 21:28:30 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 03 Jun 2016 21:28:30 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 31C59DFDEF; Fri, 3 Jun 2016 21:28:30 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: lian@apache.org To: commits@spark.apache.org Message-Id: <0e64251a0af345efbe59ee4cc56a1b3b@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-15140][SQL] make the semantics of null input object for encoder clear Date: Fri, 3 Jun 2016 21:28:30 +0000 (UTC) archived-at: Fri, 03 Jun 2016 21:28:32 -0000 Repository: spark Updated Branches: refs/heads/branch-2.0 52376e067 -> 7315acf89 [SPARK-15140][SQL] make the semantics of null input object for encoder clear ## What changes were proposed in this pull request? For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL doesn't allow row to be null, only its columns can be null. This PR explicitly add this constraint and throw exception if users break it. ## How was this patch tested? several new tests Author: Wenchen Fan Closes #13469 from cloud-fan/null-object. (cherry picked from commit 11c83f83d5172167cb64513d5311b4178797d40e) Signed-off-by: Cheng Lian Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7315acf8 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7315acf8 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7315acf8 Branch: refs/heads/branch-2.0 Commit: 7315acf896b2474a4b7513434f5ba2faf468abd9 Parents: 52376e0 Author: Wenchen Fan Authored: Fri Jun 3 14:28:19 2016 -0700 Committer: Cheng Lian Committed: Fri Jun 3 14:28:26 2016 -0700 ---------------------------------------------------------------------- .../sql/catalyst/encoders/ExpressionEncoder.scala | 13 ++++++++++--- .../spark/sql/catalyst/encoders/RowEncoder.scala | 7 +++---- .../sql/catalyst/expressions/objects/objects.scala | 4 ++-- .../spark/sql/catalyst/encoders/RowEncoderSuite.scala | 8 ++++++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 10 ++++++++++ 5 files changed, 33 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7315acf8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index cc59d06..688082d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaRefle import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} import org.apache.spark.sql.types.{ObjectType, StructField, StructType} @@ -50,8 +50,15 @@ object ExpressionEncoder { val cls = mirror.runtimeClass(tpe) val flat = !ScalaReflection.definedByConstructorParams(tpe) - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) - val serializer = ScalaReflection.serializerFor[T](inputObject) + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val nullSafeInput = if (flat) { + inputObject + } else { + // For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL + // doesn't allow top-level row to be null, only its columns can be null. + AssertNotNull(inputObject, Seq("top level non-flat input object")) + } + val serializer = ScalaReflection.serializerFor[T](nullSafeInput) val deserializer = ScalaReflection.deserializerFor[T] val schema = ScalaReflection.schemaFor[T] match { http://git-wip-us.apache.org/repos/asf/spark/blob/7315acf8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 3c6ae1c..6cd7b34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -57,8 +57,8 @@ import org.apache.spark.unsafe.types.UTF8String object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] - val inputObject = BoundReference(0, ObjectType(cls), nullable = false) - val serializer = serializerFor(inputObject, schema) + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema) val deserializer = deserializerFor(schema) new ExpressionEncoder[Row]( schema, @@ -153,8 +153,7 @@ object RowEncoder { val fieldValue = serializerFor( GetExternalRowField( inputObject, index, field.name, externalDataTypeForInput(field.dataType)), - field.dataType - ) + field.dataType) val convertedField = if (field.nullable) { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), http://git-wip-us.apache.org/repos/asf/spark/blob/7315acf8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c2e3ab8..d4c71bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -519,7 +519,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val code = s""" $values = new Object[${children.size}]; $childrenCode - final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField); + final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); """ ev.copy(code = code, isNull = "false") } @@ -675,7 +675,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) ${childGen.code} if (${childGen.isNull}) { - throw new RuntimeException(this.$errMsgField); + throw new RuntimeException($errMsgField); } """ ev.copy(code = code, isNull = "false", value = childGen.value) http://git-wip-us.apache.org/repos/asf/spark/blob/7315acf8/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 6f1bc80..16abde0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -224,6 +224,14 @@ class RowEncoderSuite extends SparkFunSuite { assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null)) } + test("RowEncoder should throw RuntimeException if input row object is null") { + val schema = new StructType().add("int", IntegerType) + val encoder = RowEncoder(schema) + val e = intercept[RuntimeException](encoder.toRow(null)) + assert(e.getMessage.contains("Null value appeared in non-nullable field")) + assert(e.getMessage.contains("top level row object")) + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema).resolveAndBind() http://git-wip-us.apache.org/repos/asf/spark/blob/7315acf8/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d1c2329..bf2b0a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -790,6 +790,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getMessage.contains( "`abstract` is a reserved keyword and cannot be used as field name")) } + + test("Dataset should support flat input object to be null") { + checkDataset(Seq("a", null).toDS(), "a", null) + } + + test("Dataset should throw RuntimeException if non-flat input object is null") { + val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS()) + assert(e.getMessage.contains("Null value appeared in non-nullable field")) + assert(e.getMessage.contains("top level non-flat input object")) + } } case class Generic[T](id: T, value: Double) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org