spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-10639] [SQL] Need to convert UDAF's result from scala to sql type
Date Thu, 17 Sep 2015 18:33:34 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 fd58ed48d -> 464d6e7d1


[SPARK-10639] [SQL] Need to convert UDAF's result from scala to sql type

https://issues.apache.org/jira/browse/SPARK-10639

Author: Yin Huai <yhuai@databricks.com>

Closes #8788 from yhuai/udafConversion.

(cherry picked from commit aad644fbe29151aec9004817d42e4928bdb326f3)
Signed-off-by: Michael Armbrust <michael@databricks.com>

Conflicts:
	sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
	sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/464d6e7d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/464d6e7d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/464d6e7d

Branch: refs/heads/branch-1.5
Commit: 464d6e7d1be33f4bb11b38c0bcabb8d90aed1246
Parents: fd58ed4
Author: Yin Huai <yhuai@databricks.com>
Authored: Thu Sep 17 11:14:52 2015 -0700
Committer: Michael Armbrust <michael@databricks.com>
Committed: Thu Sep 17 11:33:10 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/CatalystTypeConverters.scala   |   7 +-
 .../apache/spark/sql/RandomDataGenerator.scala  |  16 ++-
 .../spark/sql/execution/aggregate/udaf.scala    |  37 ++++++-
 .../scala/org/apache/spark/sql/QueryTest.scala  |  21 ++--
 .../apache/spark/sql/UserDefinedTypeSuite.scala |  11 ++
 .../hive/execution/AggregationQuerySuite.scala  | 107 ++++++++++++++++++-
 6 files changed, 187 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/464d6e7d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 966623e..f255917 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -138,8 +138,13 @@ object CatalystTypeConverters {
 
   private case class UDTConverter(
       udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
+    // toCatalyst (it calls toCatalystImpl) will do null check.
     override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
-    override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue)
+
+    override def toScala(catalystValue: Any): Any = {
+      if (catalystValue == null) null else udt.deserialize(catalystValue)
+    }
+
     override def toScalaImpl(row: InternalRow, column: Int): Any =
       toScala(row.get(column, udt.sqlType))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/464d6e7d/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 4025cbc..e483950 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -108,7 +108,21 @@ object RandomDataGenerator {
         arr
       })
       case BooleanType => Some(() => rand.nextBoolean())
-      case DateType => Some(() => new java.sql.Date(rand.nextInt()))
+      case DateType =>
+        val generator =
+          () => {
+            var milliseconds = rand.nextLong() % 253402329599999L
+            // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00
GMT
+            // for "0001-01-01 00:00:00.000000". We need to find a
+            // number that is greater or equals to this number as a valid timestamp value.
+            while (milliseconds < -62135740800000L) {
+              // 253402329599999L is the the number of milliseconds since
+              // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999".
+              milliseconds = rand.nextLong() % 253402329599999L
+            }
+            DateTimeUtils.toJavaDate((milliseconds / DateTimeUtils.MILLIS_PER_DAY).toInt)
+          }
+        Some(generator)
       case TimestampType =>
         val generator =
           () => {

http://git-wip-us.apache.org/repos/asf/spark/blob/464d6e7d/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index d43d3dd..1114fe6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -40,6 +40,9 @@ sealed trait BufferSetterGetterUtils {
     var i = 0
     while (i < getters.length) {
       getters(i) = dataTypes(i) match {
+        case NullType =>
+          (row: InternalRow, ordinal: Int) => null
+
         case BooleanType =>
           (row: InternalRow, ordinal: Int) =>
             if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal)
@@ -74,6 +77,14 @@ sealed trait BufferSetterGetterUtils {
           (row: InternalRow, ordinal: Int) =>
             if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale)
 
+        case DateType =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getInt(ordinal)
+
+        case TimestampType =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getLong(ordinal)
+
         case other =>
           (row: InternalRow, ordinal: Int) =>
             if (row.isNullAt(ordinal)) null else row.get(ordinal, other)
@@ -92,6 +103,9 @@ sealed trait BufferSetterGetterUtils {
     var i = 0
     while (i < setters.length) {
       setters(i) = dataTypes(i) match {
+        case NullType =>
+          (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal)
+
         case b: BooleanType =>
           (row: MutableRow, ordinal: Int, value: Any) =>
             if (value != null) {
@@ -151,8 +165,22 @@ sealed trait BufferSetterGetterUtils {
         case dt: DecimalType =>
           val precision = dt.precision
           (row: MutableRow, ordinal: Int, value: Any) =>
+            // To make it work with UnsafeRow, we cannot use setNullAt.
+            // Please see the comment of UnsafeRow's setDecimal.
+            row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
+
+        case DateType =>
+          (row: MutableRow, ordinal: Int, value: Any) =>
             if (value != null) {
-              row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
+              row.setInt(ordinal, value.asInstanceOf[Int])
+            } else {
+              row.setNullAt(ordinal)
+            }
+
+        case TimestampType =>
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            if (value != null) {
+              row.setLong(ordinal, value.asInstanceOf[Long])
             } else {
               row.setNullAt(ordinal)
             }
@@ -205,6 +233,7 @@ private[sql] class MutableAggregationBufferImpl (
       throw new IllegalArgumentException(
         s"Could not access ${i}th value in this buffer because it only has $length values.")
     }
+
     toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i)))
   }
 
@@ -352,6 +381,10 @@ private[sql] case class ScalaUDAF(
     }
   }
 
+  private[this] lazy val outputToCatalystConverter: Any => Any = {
+    CatalystTypeConverters.createToCatalystConverter(dataType)
+  }
+
   // This buffer is only used at executor side.
   private[this] var inputAggregateBuffer: InputAggregationBuffer = null
 
@@ -424,7 +457,7 @@ private[sql] case class ScalaUDAF(
   override def eval(buffer: InternalRow): Any = {
     evalAggregateBuffer.underlyingInputBuffer = buffer
 
-    udaf.evaluate(evalAggregateBuffer)
+    outputToCatalystConverter(udaf.evaluate(evalAggregateBuffer))
   }
 
   override def toString: String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/464d6e7d/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 4adcefb..c9f40ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -98,19 +98,26 @@ object QueryTest {
    */
   def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
     val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
+
+    // We need to call prepareRow recursively to handle schemas with struct types.
+    def prepareRow(row: Row): Row = {
+      Row.fromSeq(row.toSeq.map {
+        case null => null
+        case d: java.math.BigDecimal => BigDecimal(d)
+        // Convert array to Seq for easy equality check.
+        case b: Array[_] => b.toSeq
+        case r: Row => prepareRow(r)
+        case o => o
+      })
+    }
+
     def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
       // Converts data to types that we can do equality comparison using Scala collections.
       // For BigDecimal type, the Scala type has a better definition of equality test (similar
to
       // Java's java.math.BigDecimal.compareTo).
       // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals
for
       // equality test.
-      val converted: Seq[Row] = answer.map { s =>
-        Row.fromSeq(s.toSeq.map {
-          case d: java.math.BigDecimal => BigDecimal(d)
-          case b: Array[Byte] => b.toSeq
-          case o => o
-        })
-      }
+      val converted: Seq[Row] = answer.map(prepareRow)
       if (!isSorted) converted.sortBy(_.toString()) else converted
     }
     val sparkAnswer = try df.collect().toSeq catch {

http://git-wip-us.apache.org/repos/asf/spark/blob/464d6e7d/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index b6d279a..5e5a24b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty}
 import com.clearspring.analytics.stream.cardinality.HyperLogLog
 
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
@@ -157,4 +158,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
         Nil
     )
   }
+
+  test("Catalyst type converter null handling for UDTs") {
+    val udt = new MyDenseVectorUDT()
+    val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt)
+    assert(toScalaConverter(null) === null)
+
+    val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt)
+    assert(toCatalystConverter(null) === null)
+
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/464d6e7d/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 4886a85..9571100 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -18,14 +18,55 @@
 package org.apache.spark.sql.hive.execution
 
 import org.scalatest.BeforeAndAfterAll
+import scala.collection.JavaConverters._
 
 import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
 
+class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
+
+  def inputSchema: StructType = schema
+
+  def bufferSchema: StructType = schema
+
+  def dataType: DataType = schema
+
+  def deterministic: Boolean = true
+
+  def initialize(buffer: MutableAggregationBuffer): Unit = {
+    (0 until schema.length).foreach { i =>
+      buffer.update(i, null)
+    }
+  }
+
+  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+    if (!input.isNullAt(0) && input.getInt(0) == 50) {
+      (0 until schema.length).foreach { i =>
+        buffer.update(i, input.get(i))
+      }
+    }
+  }
+
+  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+    if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) {
+      (0 until schema.length).foreach { i =>
+        buffer1.update(i, buffer2.get(i))
+      }
+    }
+  }
+
+  def evaluate(buffer: Row): Any = {
+    Row.fromSeq(buffer.toSeq)
+  }
+}
+
 abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll
{
   override def _sqlContext: SQLContext = TestHive
   protected val sqlContext = _sqlContext
@@ -547,6 +588,70 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils
with Be
       assert(newAggregateOperators.isEmpty, message)
     }
   }
+
+  test("udaf with all data types") {
+    val struct =
+      StructType(
+        StructField("f1", FloatType, true) ::
+          StructField("f2", ArrayType(BooleanType), true) :: Nil)
+    val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType,
+      ByteType, ShortType, IntegerType, LongType,
+      FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
+      DateType, TimestampType,
+      ArrayType(IntegerType), MapType(StringType, LongType), struct,
+      new MyDenseVectorUDT())
+    // Right now, we will use SortBasedAggregate to handle UDAFs.
+    // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use
+    // UnsafeRow as the aggregation buffer. While, dataTypes will trigger
+    // SortBasedAggregate to use a safe row as the aggregation buffer.
+    Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes =>
+      val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
+        StructField(s"col$index", dataType, nullable = true)
+      }
+      // The schema used for data generator.
+      val schemaForGenerator = StructType(fields)
+      // The schema used for the DataFrame df.
+      val schema = StructType(StructField("id", IntegerType) +: fields)
+
+      logInfo(s"Testing schema: ${schema.treeString}")
+
+      val udaf = new ScalaAggregateFunction(schema)
+      // Generate data at the driver side. We need to materialize the data first and then
+      // create RDD.
+      val maybeDataGenerator =
+        RandomDataGenerator.forType(
+          dataType = schemaForGenerator,
+          nullable = true,
+          seed = Some(System.nanoTime()))
+      val dataGenerator =
+        maybeDataGenerator
+          .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator"))
+      val data = (1 to 50).map { i =>
+        dataGenerator.apply() match {
+          case row: Row => Row.fromSeq(i +: row.toSeq)
+          case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null))
+          case other =>
+            fail(s"Row or null is expected to be generated, " +
+              s"but a ${other.getClass.getCanonicalName} is generated.")
+        }
+      }
+
+      // Create a DF for the schema with random data.
+      val rdd = sqlContext.sparkContext.parallelize(data, 1)
+      val df = sqlContext.createDataFrame(rdd, schema)
+
+      val allColumns = df.schema.fields.map(f => col(f.name))
+      val expectedAnaswer =
+        data
+          .find(r => r.getInt(0) == 50)
+          .getOrElse(fail("A row with id 50 should be the expected answer."))
+      checkAnswer(
+        df.groupBy().agg(udaf(allColumns: _*)),
+        // udaf returns a Row as the output value.
+        Row(expectedAnaswer)
+      )
+    }
+  }
 }
 
 class SortBasedAggregationQuerySuite extends AggregationQuerySuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message