spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l...@apache.org
Subject spark git commit: [SPARK-15657][SQL] RowEncoder should validate the data type of input object
Date Sun, 05 Jun 2016 22:59:55 GMT
Repository: spark
Updated Branches:
  refs/heads/master 8a9110510 -> 30c4774f3


[SPARK-15657][SQL] RowEncoder should validate the data type of input object

## What changes were proposed in this pull request?

This PR improves the error handling of `RowEncoder`. When we create a `RowEncoder` with a
given schema, we should validate the data type of input object. e.g. we should throw an exception
when a field is boolean but is declared as a string column.

This PR also removes the support to use `Product` as a valid external type of struct type.
 This support is added at https://github.com/apache/spark/pull/9712, but is incomplete, e.g.
nested product, product in array are both not working.  However, we never officially support
this feature and I think it's ok to ban it.

## How was this patch tested?

new tests in `RowEncoderSuite`.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #13401 from cloud-fan/bug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/30c4774f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/30c4774f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/30c4774f

Branch: refs/heads/master
Commit: 30c4774f33fed63b7d400d220d710fb432f599a8
Parents: 8a91105
Author: Wenchen Fan <wenchen@databricks.com>
Authored: Sun Jun 5 15:59:52 2016 -0700
Committer: Cheng Lian <lian@databricks.com>
Committed: Sun Jun 5 15:59:52 2016 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/sql/Row.scala   | 10 +---
 .../sql/catalyst/encoders/RowEncoder.scala      | 17 ++++--
 .../catalyst/expressions/objects/objects.scala  | 61 +++++++++++++++++---
 .../sql/catalyst/encoders/RowEncoderSuite.scala | 47 ++++++++++-----
 4 files changed, 95 insertions(+), 40 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/30c4774f/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index a257b83..391001d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -304,15 +304,7 @@ trait Row extends Serializable {
    *
    * @throws ClassCastException when data type does not match.
    */
-  def getStruct(i: Int): Row = {
-    // Product and Row both are recognized as StructType in a Row
-    val t = get(i)
-    if (t.isInstanceOf[Product]) {
-      Row.fromTuple(t.asInstanceOf[Product])
-    } else {
-      t.asInstanceOf[Row]
-    }
-  }
+  def getStruct(i: Int): Row = getAs[Row](i)
 
   /**
    * Returns the value at position i.

http://git-wip-us.apache.org/repos/asf/spark/blob/30c4774f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 6cd7b34..67fca15 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -51,7 +51,7 @@ import org.apache.spark.unsafe.types.UTF8String
  *   BinaryType -> byte array
  *   ArrayType -> scala.collection.Seq or Array
  *   MapType -> scala.collection.Map
- *   StructType -> org.apache.spark.sql.Row or Product
+ *   StructType -> org.apache.spark.sql.Row
  * }}}
  */
 object RowEncoder {
@@ -121,11 +121,15 @@ object RowEncoder {
 
     case t @ ArrayType(et, _) => et match {
       case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType
=>
+        // TODO: validate input type for primitive array.
         NewInstance(
           classOf[GenericArrayData],
           inputObject :: Nil,
           dataType = t)
-      case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et))
+      case _ => MapObjects(
+        element => serializerFor(ValidateExternalType(element, et), et),
+        inputObject,
+        ObjectType(classOf[Object]))
     }
 
     case t @ MapType(kt, vt, valueNullable) =>
@@ -151,8 +155,9 @@ object RowEncoder {
     case StructType(fields) =>
       val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index)
=>
         val fieldValue = serializerFor(
-          GetExternalRowField(
-            inputObject, index, field.name, externalDataTypeForInput(field.dataType)),
+          ValidateExternalType(
+            GetExternalRowField(inputObject, index, field.name),
+            field.dataType),
           field.dataType)
         val convertedField = if (field.nullable) {
           If(
@@ -183,7 +188,7 @@ object RowEncoder {
    * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
    * `org.apache.spark.sql.types.Decimal`.
    */
-  private def externalDataTypeForInput(dt: DataType): DataType = dt match {
+  def externalDataTypeForInput(dt: DataType): DataType = dt match {
     // In order to support both Decimal and java/scala BigDecimal in external row, we make
this
     // as java.lang.Object.
     case _: DecimalType => ObjectType(classOf[java.lang.Object])
@@ -192,7 +197,7 @@ object RowEncoder {
     case _ => externalDataTypeFor(dt)
   }
 
-  private def externalDataTypeFor(dt: DataType): DataType = dt match {
+  def externalDataTypeFor(dt: DataType): DataType = dt match {
     case _ if ScalaReflection.isNativeType(dt) => dt
     case TimestampType => ObjectType(classOf[java.sql.Timestamp])
     case DateType => ObjectType(classOf[java.sql.Date])

http://git-wip-us.apache.org/repos/asf/spark/blob/30c4774f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index d4c71bf..87c8a2e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -26,6 +26,7 @@ import org.apache.spark.SparkConf
 import org.apache.spark.serializer._
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.util.GenericArrayData
@@ -692,22 +693,17 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
 case class GetExternalRowField(
     child: Expression,
     index: Int,
-    fieldName: String,
-    dataType: DataType) extends UnaryExpression with NonSQLExpression {
+    fieldName: String) extends UnaryExpression with NonSQLExpression {
 
   override def nullable: Boolean = false
 
+  override def dataType: DataType = ObjectType(classOf[Object])
+
   override def eval(input: InternalRow): Any =
     throw new UnsupportedOperationException("Only code-generated evaluation is supported")
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val row = child.genCode(ctx)
-
-    val getField = dataType match {
-      case ObjectType(x) if x == classOf[Row] => s"""${row.value}.getStruct($index)"""
-      case _ => s"""(${ctx.boxedType(dataType)}) ${row.value}.get($index)"""
-    }
-
     val code = s"""
       ${row.code}
 
@@ -720,8 +716,55 @@ case class GetExternalRowField(
           "cannot be null.");
       }
 
-      final ${ctx.javaType(dataType)} ${ev.value} = $getField;
+      final Object ${ev.value} = ${row.value}.get($index);
      """
     ev.copy(code = code, isNull = "false")
   }
 }
+
+/**
+ * Validates the actual data type of input expression at runtime.  If it doesn't match the
+ * expectation, throw an exception.
+ */
+case class ValidateExternalType(child: Expression, expected: DataType)
+  extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(ObjectType(classOf[Object]))
+
+  override def nullable: Boolean = child.nullable
+
+  override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected)
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val input = child.genCode(ctx)
+    val obj = input.value
+
+    val typeCheck = expected match {
+      case _: DecimalType =>
+        Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
+          .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
+      case _: ArrayType =>
+        s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
+      case _ =>
+        s"$obj instanceof ${ctx.boxedType(dataType)}"
+    }
+
+    val code = s"""
+      ${input.code}
+      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+      if (!${input.isNull}) {
+        if ($typeCheck) {
+          ${ev.value} = (${ctx.boxedType(dataType)}) $obj;
+        } else {
+          throw new RuntimeException($obj.getClass().getName() + " is not a valid " +
+            "external type for schema of ${expected.simpleString}");
+        }
+      }
+
+    """
+    ev.copy(code = code, isNull = input.isNull)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/30c4774f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 16abde0..2e513ea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -127,22 +127,6 @@ class RowEncoderSuite extends SparkFunSuite {
         new StructType().add("array", arrayOfString).add("map", mapOfString))
       .add("structOfUDT", structOfUDT))
 
-  test(s"encode/decode: Product") {
-    val schema = new StructType()
-      .add("structAsProduct",
-        new StructType()
-          .add("int", IntegerType)
-          .add("string", StringType)
-          .add("double", DoubleType))
-
-    val encoder = RowEncoder(schema).resolveAndBind()
-
-    val input: Row = Row((100, "test", 0.123))
-    val row = encoder.toRow(input)
-    val convertedBack = encoder.fromRow(row)
-    assert(input.getStruct(0) == convertedBack.getStruct(0))
-  }
-
   test("encode/decode decimal type") {
     val schema = new StructType()
       .add("int", IntegerType)
@@ -232,6 +216,37 @@ class RowEncoderSuite extends SparkFunSuite {
     assert(e.getMessage.contains("top level row object"))
   }
 
+  test("RowEncoder should validate external type") {
+    val e1 = intercept[RuntimeException] {
+      val schema = new StructType().add("a", IntegerType)
+      val encoder = RowEncoder(schema)
+      encoder.toRow(Row(1.toShort))
+    }
+    assert(e1.getMessage.contains("java.lang.Short is not a valid external type"))
+
+    val e2 = intercept[RuntimeException] {
+      val schema = new StructType().add("a", StringType)
+      val encoder = RowEncoder(schema)
+      encoder.toRow(Row(1))
+    }
+    assert(e2.getMessage.contains("java.lang.Integer is not a valid external type"))
+
+    val e3 = intercept[RuntimeException] {
+      val schema = new StructType().add("a",
+        new StructType().add("b", IntegerType).add("c", StringType))
+      val encoder = RowEncoder(schema)
+      encoder.toRow(Row(1 -> "a"))
+    }
+    assert(e3.getMessage.contains("scala.Tuple2 is not a valid external type"))
+
+    val e4 = intercept[RuntimeException] {
+      val schema = new StructType().add("a", ArrayType(TimestampType))
+      val encoder = RowEncoder(schema)
+      encoder.toRow(Row(Array("a")))
+    }
+    assert(e4.getMessage.contains("java.lang.String is not a valid external type"))
+  }
+
   private def encodeDecodeTest(schema: StructType): Unit = {
     test(s"encode/decode: ${schema.simpleString}") {
       val encoder = RowEncoder(schema).resolveAndBind()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message