spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From lix...@apache.org
Subject spark git commit: [SPARK-19311][SQL] fix UDT hierarchy issue
Date Wed, 25 Jan 2017 16:17:31 GMT
Repository: spark
Updated Branches:
  refs/heads/master f1ddca5fc -> f6480b146


[SPARK-19311][SQL] fix UDT hierarchy issue

## What changes were proposed in this pull request?
acceptType() in UDT will no only accept the same type but also all base types

## How was this patch tested?
Manual test using a set of generated UDTs fixing acceptType() in my user defined types

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: gmoehler <moehler@de.ibm.com>

Closes #16660 from gmoehler/master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f6480b14
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f6480b14
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f6480b14

Branch: refs/heads/master
Commit: f6480b1467d0432fb2aa48c7a3a8a6e6679fd481
Parents: f1ddca5
Author: gmoehler <moehler@de.ibm.com>
Authored: Wed Jan 25 08:17:24 2017 -0800
Committer: gatorsmile <gatorsmile@gmail.com>
Committed: Wed Jan 25 08:17:24 2017 -0800

----------------------------------------------------------------------
 .../spark/sql/types/UserDefinedType.scala       |   8 +-
 .../apache/spark/sql/UserDefinedTypeSuite.scala | 105 ++++++++++++++++++-
 2 files changed, 110 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f6480b14/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index c33219c..5a944e7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -78,8 +78,12 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with
Serializa
    */
   override private[spark] def asNullable: UserDefinedType[UserType] = this
 
-  override private[sql] def acceptsType(dataType: DataType) =
-    this.getClass == dataType.getClass
+  override private[sql] def acceptsType(dataType: DataType) = dataType match {
+    case other: UserDefinedType[_] =>
+      this.getClass == other.getClass ||
+        this.userClass.isAssignableFrom(other.userClass)
+    case _ => false
+  }
 
   override def sql: String = sqlType.sql
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f6480b14/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 474f17f..ea4a8ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql
 import scala.beans.{BeanInfo, BeanProperty}
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
 import org.apache.spark.sql.functions._
@@ -71,6 +72,77 @@ object UDT {
 
 }
 
+// object and classes to test SPARK-19311
+
+// Trait/Interface for base type
+sealed trait IExampleBaseType extends Serializable {
+  def field: Int
+}
+
+// Trait/Interface for derived type
+sealed trait IExampleSubType extends IExampleBaseType
+
+// a base class
+class ExampleBaseClass(override val field: Int) extends IExampleBaseType
+
+// a derived class
+class ExampleSubClass(override val field: Int)
+  extends ExampleBaseClass(field) with IExampleSubType
+
+// UDT for base class
+class ExampleBaseTypeUDT extends UserDefinedType[IExampleBaseType] {
+
+  override def sqlType: StructType = {
+    StructType(Seq(
+      StructField("intfield", IntegerType, nullable = false)))
+  }
+
+  override def serialize(obj: IExampleBaseType): InternalRow = {
+    val row = new GenericInternalRow(1)
+    row.setInt(0, obj.field)
+    row
+  }
+
+  override def deserialize(datum: Any): IExampleBaseType = {
+    datum match {
+      case row: InternalRow =>
+        require(row.numFields == 1,
+          "ExampleBaseTypeUDT requires row with length == 1")
+        val field = row.getInt(0)
+        new ExampleBaseClass(field)
+    }
+  }
+
+  override def userClass: Class[IExampleBaseType] = classOf[IExampleBaseType]
+}
+
+// UDT for derived class
+private[spark] class ExampleSubTypeUDT extends UserDefinedType[IExampleSubType] {
+
+  override def sqlType: StructType = {
+    StructType(Seq(
+      StructField("intfield", IntegerType, nullable = false)))
+  }
+
+  override def serialize(obj: IExampleSubType): InternalRow = {
+    val row = new GenericInternalRow(1)
+    row.setInt(0, obj.field)
+    row
+  }
+
+  override def deserialize(datum: Any): IExampleSubType = {
+    datum match {
+      case row: InternalRow =>
+        require(row.numFields == 1,
+          "ExampleSubTypeUDT requires row with length == 1")
+        val field = row.getInt(0)
+        new ExampleSubClass(field)
+    }
+  }
+
+  override def userClass: Class[IExampleSubType] = classOf[IExampleSubType]
+}
+
 class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest {
   import testImplicits._
 
@@ -194,4 +266,35 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with
ParquetT
     // call `collect` to make sure this query can pass analysis.
     pointsRDD.as[MyLabeledPoint].map(_.copy(label = 2.0)).collect()
   }
+
+  test("SPARK-19311: UDFs disregard UDT type hierarchy") {
+    UDTRegistration.register(classOf[IExampleBaseType].getName,
+      classOf[ExampleBaseTypeUDT].getName)
+    UDTRegistration.register(classOf[IExampleSubType].getName,
+      classOf[ExampleSubTypeUDT].getName)
+
+    // UDF that returns a base class object
+    sqlContext.udf.register("doUDF", (param: Int) => {
+      new ExampleBaseClass(param)
+    }: IExampleBaseType)
+
+    // UDF that returns a derived class object
+    sqlContext.udf.register("doSubTypeUDF", (param: Int) => {
+      new ExampleSubClass(param)
+    }: IExampleSubType)
+
+    // UDF that takes a base class object as parameter
+    sqlContext.udf.register("doOtherUDF", (obj: IExampleBaseType) => {
+      obj.field
+    }: Int)
+
+    // this worked already before the fix SPARK-19311:
+    // return type of doUDF equals parameter type of doOtherUDF
+    sql("SELECT doOtherUDF(doUDF(41))")
+
+    // this one passes only with the fix SPARK-19311:
+    // return type of doSubUDF is a subtype of the parameter type of doOtherUDF
+    sql("SELECT doOtherUDF(doSubTypeUDF(42))")
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message