spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From lix...@apache.org
Subject spark git commit: [SPARK-19425][SQL] Make ExtractEquiJoinKeys support UDT columns
Date Sat, 04 Feb 2017 23:58:05 GMT
Repository: spark
Updated Branches:
  refs/heads/master 2f3c20bbd -> 0674e7eb8


[SPARK-19425][SQL] Make ExtractEquiJoinKeys support UDT columns

## What changes were proposed in this pull request?

DataFrame.except doesn't work for UDT columns. It is because `ExtractEquiJoinKeys` will run
`Literal.default` against UDT. However, we don't handle UDT in `Literal.default` and an exception
will throw like:

    java.lang.RuntimeException: no default for type
    org.apache.spark.ml.linalg.VectorUDT3bfc3ba7
      at org.apache.spark.sql.catalyst.expressions.Literal$.default(literals.scala:179)
      at org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys$$anonfun$4.apply(patterns.scala:117)
      at org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys$$anonfun$4.apply(patterns.scala:110)

More simple fix is just let `Literal.default` handle UDT by its sql type. So we can use more
efficient join type on UDT.

Besides `except`, this also fixes other similar scenarios, so in summary this fixes:

* `except` on two Datasets with UDT
* `intersect` on two Datasets with UDT
* `Join` with the join conditions using `<=>` on UDT columns

## How was this patch tested?

Jenkins tests.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #16765 from viirya/df-except-for-udt.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0674e7eb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0674e7eb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0674e7eb

Branch: refs/heads/master
Commit: 0674e7eb85160e3f8da333b5243d76063824d58c
Parents: 2f3c20b
Author: Liang-Chi Hsieh <viirya@gmail.com>
Authored: Sat Feb 4 15:57:56 2017 -0800
Committer: gatorsmile <gatorsmile@gmail.com>
Committed: Sat Feb 4 15:57:56 2017 -0800

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/expressions/literals.scala    | 1 +
 .../sql/catalyst/expressions/LiteralExpressionSuite.scala   | 3 +++
 .../scala/org/apache/spark/sql/UserDefinedTypeSuite.scala   | 9 +++++++++
 3 files changed, 13 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0674e7eb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index cb0c4d3..e66fb89 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -175,6 +175,7 @@ object Literal {
     case map: MapType => create(Map(), map)
     case struct: StructType =>
       create(InternalRow.fromSeq(struct.fields.map(f => default(f.dataType).value)), struct)
+    case udt: UserDefinedType[_] => default(udt.sqlType)
     case other =>
       throw new RuntimeException(s"no default for type $dataType")
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0674e7eb/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index 4af4da8..15e8e6c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -22,6 +22,7 @@ import java.nio.charset.StandardCharsets
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
@@ -67,6 +68,8 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
{
     checkEvaluation(Literal.default(ArrayType(StringType)), Array())
     checkEvaluation(Literal.default(MapType(IntegerType, StringType)), Map())
     checkEvaluation(Literal.default(StructType(StructField("a", StringType) :: Nil)), Row(""))
+    // ExamplePointUDT.sqlType is ArrayType(DoubleType, false).
+    checkEvaluation(Literal.default(new ExamplePointUDT), Array())
   }
 
   test("boolean literals") {

http://git-wip-us.apache.org/repos/asf/spark/blob/0674e7eb/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index ea4a8ee..c7a77da 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -150,6 +150,10 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with
ParquetT
     MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))),
     MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))).toDF()
 
+  private lazy val pointsRDD2 = Seq(
+    MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))),
+    MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.3, 3.0)))).toDF()
+
   test("register user type: MyDenseVector for MyLabeledPoint") {
     val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) =>
v }
     val labelsArrays: Array[Double] = labels.collect()
@@ -297,4 +301,9 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with
ParquetT
     sql("SELECT doOtherUDF(doSubTypeUDF(42))")
   }
 
+  test("except on UDT") {
+    checkAnswer(
+      pointsRDD.except(pointsRDD2),
+      Seq(Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message