spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-15190][SQL] Support using SQLUserDefinedType for case classes
Date Fri, 20 May 2016 19:38:58 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.0 e99b22080 -> 42e63c35a


[SPARK-15190][SQL] Support using SQLUserDefinedType for case classes

## What changes were proposed in this pull request?

Right now inferring the schema for case classes happens before searching the SQLUserDefinedType
annotation, so the SQLUserDefinedType annotation for case classes doesn't work.

This PR simply changes the inferring order to resolve it. I also reenabled the java.math.BigDecimal
test and added two tests for `List`.

## How was this patch tested?

`encodeDecodeTest(UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case
class")`

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #12965 from zsxwing/SPARK-15190.

(cherry picked from commit dfa61f7b136ae060bbe04e3c0da1148da41018c7)
Signed-off-by: Michael Armbrust <michael@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/42e63c35
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/42e63c35
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/42e63c35

Branch: refs/heads/branch-2.0
Commit: 42e63c35a60dc256759cb42260ba1113df05c74b
Parents: e99b220
Author: Shixiong Zhu <shixiong@databricks.com>
Authored: Fri May 20 12:38:46 2016 -0700
Committer: Michael Armbrust <michael@databricks.com>
Committed: Fri May 20 12:38:54 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 70 ++++++++++----------
 .../encoders/ExpressionEncoderSuite.scala       | 28 +++++++-
 2 files changed, 62 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/42e63c35/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 58df651..36989a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -348,6 +348,23 @@ object ScalaReflection extends ScalaReflection {
           "toScalaMap",
           keyData :: valueData :: Nil)
 
+      case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
+        val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+        val obj = NewInstance(
+          udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+          Nil,
+          dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
+
+      case t if UDTRegistration.exists(getClassNameFromType(t)) =>
+        val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
+          .asInstanceOf[UserDefinedType[_]]
+        val obj = NewInstance(
+          udt.getClass,
+          Nil,
+          dataType = ObjectType(udt.getClass))
+        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
+
       case t if definedByConstructorParams(t) =>
         val params = getConstructorParameters(t)
 
@@ -388,23 +405,6 @@ object ScalaReflection extends ScalaReflection {
         } else {
           newInstance
         }
-
-      case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
-        val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
-        val obj = NewInstance(
-          udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
-          Nil,
-          dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
-        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
-
-      case t if UDTRegistration.exists(getClassNameFromType(t)) =>
-        val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
-          .asInstanceOf[UserDefinedType[_]]
-        val obj = NewInstance(
-          udt.getClass,
-          Nil,
-          dataType = ObjectType(udt.getClass))
-        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
     }
   }
 
@@ -522,17 +522,6 @@ object ScalaReflection extends ScalaReflection {
           val TypeRef(_, _, Seq(elementType)) = t
           toCatalystArray(inputObject, elementType)
 
-        case t if definedByConstructorParams(t) =>
-          val params = getConstructorParameters(t)
-          val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType)
=>
-            val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
-            val clsName = getClassNameFromType(fieldType)
-            val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
-            expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath)
:: Nil
-          })
-          val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
-          expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
-
         case t if t <:< localTypeOf[Array[_]] =>
           val TypeRef(_, _, Seq(elementType)) = t
           toCatalystArray(inputObject, elementType)
@@ -645,6 +634,17 @@ object ScalaReflection extends ScalaReflection {
             dataType = ObjectType(udt.getClass))
           Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
 
+        case t if definedByConstructorParams(t) =>
+          val params = getConstructorParameters(t)
+          val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType)
=>
+            val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
+            val clsName = getClassNameFromType(fieldType)
+            val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
+            expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath)
:: Nil
+          })
+          val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+          expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
+
         case other =>
           throw new UnsupportedOperationException(
             s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
@@ -743,13 +743,6 @@ object ScalaReflection extends ScalaReflection {
         val Schema(valueDataType, valueNullable) = schemaFor(valueType)
         Schema(MapType(schemaFor(keyType).dataType,
           valueDataType, valueContainsNull = valueNullable), nullable = true)
-      case t if definedByConstructorParams(t) =>
-        val params = getConstructorParameters(t)
-        Schema(StructType(
-          params.map { case (fieldName, fieldType) =>
-            val Schema(dataType, nullable) = schemaFor(fieldType)
-            StructField(fieldName, dataType, nullable)
-          }), nullable = true)
       case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
       case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable
= true)
       case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable =
true)
@@ -775,6 +768,13 @@ object ScalaReflection extends ScalaReflection {
       case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
       case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
       case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
+      case t if definedByConstructorParams(t) =>
+        val params = getConstructorParameters(t)
+        Schema(StructType(
+          params.map { case (fieldName, fieldType) =>
+            val Schema(dataType, nullable) = schemaFor(fieldType)
+            StructField(fieldName, dataType, nullable)
+          }), nullable = true)
       case other =>
         throw new UnsupportedOperationException(s"Schema for type $other is not supported")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/42e63c35/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index d438789..3d97113 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
 import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 case class RepeatedStruct(s: Seq[PrimitiveData])
 
@@ -86,6 +87,25 @@ class JavaSerializable(val value: Int) extends Serializable {
   }
 }
 
+/** For testing UDT for a case class */
+@SQLUserDefinedType(udt = classOf[UDTForCaseClass])
+case class UDTCaseClass(uri: java.net.URI)
+
+class UDTForCaseClass extends UserDefinedType[UDTCaseClass] {
+
+  override def sqlType: DataType = StringType
+
+  override def serialize(obj: UDTCaseClass): UTF8String = {
+    UTF8String.fromString(obj.uri.toString)
+  }
+
+  override def userClass: Class[UDTCaseClass] = classOf[UDTCaseClass]
+
+  override def deserialize(datum: Any): UDTCaseClass = datum match {
+    case uri: UTF8String => UDTCaseClass(new java.net.URI(uri.toString))
+  }
+}
+
 class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
   OuterScopes.addOuterScope(this)
 
@@ -147,6 +167,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
   encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple")
   encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple")
 
+  encodeDecodeTest(List(1, 2), "list of int")
+  encodeDecodeTest(List("a", null), "list with String and null")
+
+  encodeDecodeTest(
+    UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class")
+
   // Kryo encoders
   encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String]))
   encodeDecodeTest(new KryoSerializable(15), "kryo object")(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message