spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject git commit: [SPARK-4192][SQL] Internal API for Python UDT
Date Tue, 04 Nov 2014 03:30:29 GMT
Repository: spark
Updated Branches:
  refs/heads/master c5912ecc7 -> 04450d115


[SPARK-4192][SQL] Internal API for Python UDT

Following #2919, this PR adds Python UDT (for internal use only) with tests under "pyspark.tests".
Before `SQLContext.applySchema`, we check whether we need to convert user-type instances into
SQL recognizable data. In the current implementation, a Python UDT must be paired with a Scala
UDT for serialization on the JVM side. A following PR will add VectorUDT in MLlib for both
Scala and Python.

marmbrus jkbradley davies

Author: Xiangrui Meng <meng@databricks.com>

Closes #3068 from mengxr/SPARK-4192-sql and squashes the following commits:

acff637 [Xiangrui Meng] merge master
dba5ea7 [Xiangrui Meng] only use pyClass for Python UDT output sqlType as well
2c9d7e4 [Xiangrui Meng] move import to global setup; update needsConversion
7c4a6a9 [Xiangrui Meng] address comments
75223db [Xiangrui Meng] minor update
f740379 [Xiangrui Meng] remove UDT from default imports
e98d9d0 [Xiangrui Meng] fix py style
4e84fce [Xiangrui Meng] remove local hive tests and add more tests
39f19e0 [Xiangrui Meng] add tests
b7f666d [Xiangrui Meng] add Python UDT


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/04450d11
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/04450d11
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/04450d11

Branch: refs/heads/master
Commit: 04450d11548cfb25d4fb77d4a33e3a7cd4254183
Parents: c5912ec
Author: Xiangrui Meng <meng@databricks.com>
Authored: Mon Nov 3 19:29:11 2014 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Mon Nov 3 19:29:11 2014 -0800

----------------------------------------------------------------------
 python/pyspark/sql.py                           | 206 ++++++++++++++++++-
 python/pyspark/tests.py                         |  93 ++++++++-
 .../spark/sql/catalyst/types/dataTypes.scala    |   9 +-
 .../scala/org/apache/spark/sql/SQLContext.scala |   2 +
 .../apache/spark/sql/execution/pythonUdfs.scala |   5 +
 .../apache/spark/sql/test/ExamplePointUDT.scala |  64 ++++++
 .../sql/types/util/DataTypeConversions.scala    |   1 -
 7 files changed, 375 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/04450d11/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 675df08..d16c18b 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -417,6 +417,75 @@ class StructType(DataType):
         return StructType([StructField.fromJson(f) for f in json["fields"]])
 
 
+class UserDefinedType(DataType):
+    """
+    :: WARN: Spark Internal Use Only ::
+    SQL User-Defined Type (UDT).
+    """
+
+    @classmethod
+    def typeName(cls):
+        return cls.__name__.lower()
+
+    @classmethod
+    def sqlType(cls):
+        """
+        Underlying SQL storage type for this UDT.
+        """
+        raise NotImplementedError("UDT must implement sqlType().")
+
+    @classmethod
+    def module(cls):
+        """
+        The Python module of the UDT.
+        """
+        raise NotImplementedError("UDT must implement module().")
+
+    @classmethod
+    def scalaUDT(cls):
+        """
+        The class name of the paired Scala UDT.
+        """
+        raise NotImplementedError("UDT must have a paired Scala UDT.")
+
+    def serialize(self, obj):
+        """
+        Converts the a user-type object into a SQL datum.
+        """
+        raise NotImplementedError("UDT must implement serialize().")
+
+    def deserialize(self, datum):
+        """
+        Converts a SQL datum into a user-type object.
+        """
+        raise NotImplementedError("UDT must implement deserialize().")
+
+    def json(self):
+        return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
+
+    def jsonValue(self):
+        schema = {
+            "type": "udt",
+            "class": self.scalaUDT(),
+            "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+            "sqlType": self.sqlType().jsonValue()
+        }
+        return schema
+
+    @classmethod
+    def fromJson(cls, json):
+        pyUDT = json["pyClass"]
+        split = pyUDT.rfind(".")
+        pyModule = pyUDT[:split]
+        pyClass = pyUDT[split+1:]
+        m = __import__(pyModule, globals(), locals(), [pyClass], -1)
+        UDT = getattr(m, pyClass)
+        return UDT()
+
+    def __eq__(self, other):
+        return type(self) == type(other)
+
+
 _all_primitive_types = dict((v.typeName(), v)
                             for v in globals().itervalues()
                             if type(v) is PrimitiveTypeSingleton and
@@ -469,6 +538,12 @@ def _parse_datatype_json_string(json_string):
     ...                           complex_arraytype, False)
     >>> check_datatype(complex_maptype)
     True
+    >>> check_datatype(ExamplePointUDT())
+    True
+    >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+    ...                                   StructField("point", ExamplePointUDT(), False)])
+    >>> check_datatype(structtype_with_udt)
+    True
     """
     return _parse_datatype_json_value(json.loads(json_string))
 
@@ -488,7 +563,13 @@ def _parse_datatype_json_value(json_value):
         else:
             raise ValueError("Could not parse datatype: %s" % json_value)
     else:
-        return _all_complex_types[json_value["type"]].fromJson(json_value)
+        tpe = json_value["type"]
+        if tpe in _all_complex_types:
+            return _all_complex_types[tpe].fromJson(json_value)
+        elif tpe == 'udt':
+            return UserDefinedType.fromJson(json_value)
+        else:
+            raise ValueError("not supported type: %s" % tpe)
 
 
 # Mapping Python types to Spark SQL DataType
@@ -509,7 +590,18 @@ _type_mappings = {
 
 
 def _infer_type(obj):
-    """Infer the DataType from obj"""
+    """Infer the DataType from obj
+
+    >>> p = ExamplePoint(1.0, 2.0)
+    >>> _infer_type(p)
+    ExamplePointUDT
+    """
+    if obj is None:
+        raise ValueError("Can not infer type for None")
+
+    if hasattr(obj, '__UDT__'):
+        return obj.__UDT__
+
     dataType = _type_mappings.get(type(obj))
     if dataType is not None:
         return dataType()
@@ -558,6 +650,93 @@ def _infer_schema(row):
     return StructType(fields)
 
 
+def _need_python_to_sql_conversion(dataType):
+    """
+    Checks whether we need python to sql conversion for the given type.
+    For now, only UDTs need this conversion.
+
+    >>> _need_python_to_sql_conversion(DoubleType())
+    False
+    >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False),
False),
+    ...                       StructField("values", ArrayType(DoubleType(), False), False)])
+    >>> _need_python_to_sql_conversion(schema0)
+    False
+    >>> _need_python_to_sql_conversion(ExamplePointUDT())
+    True
+    >>> schema1 = ArrayType(ExamplePointUDT(), False)
+    >>> _need_python_to_sql_conversion(schema1)
+    True
+    >>> schema2 = StructType([StructField("label", DoubleType(), False),
+    ...                       StructField("point", ExamplePointUDT(), False)])
+    >>> _need_python_to_sql_conversion(schema2)
+    True
+    """
+    if isinstance(dataType, StructType):
+        return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
+    elif isinstance(dataType, ArrayType):
+        return _need_python_to_sql_conversion(dataType.elementType)
+    elif isinstance(dataType, MapType):
+        return _need_python_to_sql_conversion(dataType.keyType) or \
+            _need_python_to_sql_conversion(dataType.valueType)
+    elif isinstance(dataType, UserDefinedType):
+        return True
+    else:
+        return False
+
+
+def _python_to_sql_converter(dataType):
+    """
+    Returns a converter that converts a Python object into a SQL datum for the given type.
+
+    >>> conv = _python_to_sql_converter(DoubleType())
+    >>> conv(1.0)
+    1.0
+    >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
+    >>> conv([1.0, 2.0])
+    [1.0, 2.0]
+    >>> conv = _python_to_sql_converter(ExamplePointUDT())
+    >>> conv(ExamplePoint(1.0, 2.0))
+    [1.0, 2.0]
+    >>> schema = StructType([StructField("label", DoubleType(), False),
+    ...                      StructField("point", ExamplePointUDT(), False)])
+    >>> conv = _python_to_sql_converter(schema)
+    >>> conv((1.0, ExamplePoint(1.0, 2.0)))
+    (1.0, [1.0, 2.0])
+    """
+    if not _need_python_to_sql_conversion(dataType):
+        return lambda x: x
+
+    if isinstance(dataType, StructType):
+        names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
+        converters = map(_python_to_sql_converter, types)
+
+        def converter(obj):
+            if isinstance(obj, dict):
+                return tuple(c(obj.get(n)) for n, c in zip(names, converters))
+            elif isinstance(obj, tuple):
+                if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
+                    return tuple(c(v) for c, v in zip(converters, obj))
+                elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):  # k-v pairs
+                    d = dict(obj)
+                    return tuple(c(d.get(n)) for n, c in zip(names, converters))
+                else:
+                    return tuple(c(v) for c, v in zip(converters, obj))
+            else:
+                raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+        return converter
+    elif isinstance(dataType, ArrayType):
+        element_converter = _python_to_sql_converter(dataType.elementType)
+        return lambda a: [element_converter(v) for v in a]
+    elif isinstance(dataType, MapType):
+        key_converter = _python_to_sql_converter(dataType.keyType)
+        value_converter = _python_to_sql_converter(dataType.valueType)
+        return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+    elif isinstance(dataType, UserDefinedType):
+        return lambda obj: dataType.serialize(obj)
+    else:
+        raise ValueError("Unexpected type %r" % dataType)
+
+
 def _has_nulltype(dt):
     """ Return whether there is NullType in `dt` or not """
     if isinstance(dt, StructType):
@@ -818,11 +997,22 @@ def _verify_type(obj, dataType):
     Traceback (most recent call last):
         ...
     ValueError:...
+    >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
+    >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
+    Traceback (most recent call last):
+        ...
+    ValueError:...
     """
     # all objects are nullable
     if obj is None:
         return
 
+    if isinstance(dataType, UserDefinedType):
+        if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
+            raise ValueError("%r is not an instance of type %r" % (obj, dataType))
+        _verify_type(dataType.serialize(obj), dataType.sqlType())
+        return
+
     _type = type(dataType)
     assert _type in _acceptable_types, "unkown datatype: %s" % dataType
 
@@ -897,6 +1087,8 @@ def _has_struct_or_date(dt):
         return _has_struct_or_date(dt.valueType)
     elif isinstance(dt, DateType):
         return True
+    elif isinstance(dt, UserDefinedType):
+        return True
     return False
 
 
@@ -967,6 +1159,9 @@ def _create_cls(dataType):
     elif isinstance(dataType, DateType):
         return datetime.date
 
+    elif isinstance(dataType, UserDefinedType):
+        return lambda datum: dataType.deserialize(datum)
+
     elif not isinstance(dataType, StructType):
         raise Exception("unexpected data type: %s" % dataType)
 
@@ -1244,6 +1439,10 @@ class SQLContext(object):
         for row in rows:
             _verify_type(row, schema)
 
+        # convert python objects to sql data
+        converter = _python_to_sql_converter(schema)
+        rdd = rdd.map(converter)
+
         batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
         jrdd = self._pythonToJava(rdd._jrdd, batched)
         srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
@@ -1877,6 +2076,7 @@ def _test():
     # let doctest run in pyspark.sql, so DataTypes can be picklable
     import pyspark.sql
     from pyspark.sql import Row, SQLContext
+    from pyspark.tests import ExamplePoint, ExamplePointUDT
     globs = pyspark.sql.__dict__.copy()
     # The small batch size here ensures that we see multiple batches,
     # even in these small test examples:
@@ -1888,6 +2088,8 @@ def _test():
          Row(field1=2, field2="row2"),
          Row(field1=3, field2="row3")]
     )
+    globs['ExamplePoint'] = ExamplePoint
+    globs['ExamplePointUDT'] = ExamplePointUDT
     jsonStrings = [
         '{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
         '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'

http://git-wip-us.apache.org/repos/asf/spark/blob/04450d11/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 68fd756..e947b09 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -49,7 +49,8 @@ from pyspark.files import SparkFiles
 from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer,
\
     CloudPickleSerializer
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType, Row, ArrayType
+from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField,
\
+    UserDefinedType, DoubleType
 from pyspark import shuffle
 
 _have_scipy = False
@@ -694,8 +695,65 @@ class ProfilerTests(PySparkTestCase):
         self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
 
 
+class ExamplePointUDT(UserDefinedType):
+    """
+    User-defined type (UDT) for ExamplePoint.
+    """
+
+    @classmethod
+    def sqlType(self):
+        return ArrayType(DoubleType(), False)
+
+    @classmethod
+    def module(cls):
+        return 'pyspark.tests'
+
+    @classmethod
+    def scalaUDT(cls):
+        return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+    def serialize(self, obj):
+        return [obj.x, obj.y]
+
+    def deserialize(self, datum):
+        return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+    """
+    An example class to demonstrate UDT in Scala, Java, and Python.
+    """
+
+    __UDT__ = ExamplePointUDT()
+
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+
+    def __repr__(self):
+        return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+    def __str__(self):
+        return "(%s,%s)" % (self.x, self.y)
+
+    def __eq__(self, other):
+        return isinstance(other, ExamplePoint) and \
+            other.x == self.x and other.y == self.y
+
+
 class SQLTests(ReusedPySparkTestCase):
 
+    @classmethod
+    def setUpClass(cls):
+        ReusedPySparkTestCase.setUpClass()
+        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(cls.tempdir.name)
+
+    @classmethod
+    def tearDownClass(cls):
+        ReusedPySparkTestCase.tearDownClass()
+        shutil.rmtree(cls.tempdir.name)
+
     def setUp(self):
         self.sqlCtx = SQLContext(self.sc)
 
@@ -824,6 +882,39 @@ class SQLTests(ReusedPySparkTestCase):
         row = self.sqlCtx.sql("select l[0].a AS la from test").first()
         self.assertEqual(1, row.asDict()["la"])
 
+    def test_infer_schema_with_udt(self):
+        from pyspark.tests import ExamplePoint, ExamplePointUDT
+        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+        rdd = self.sc.parallelize([row])
+        srdd = self.sqlCtx.inferSchema(rdd)
+        schema = srdd.schema()
+        field = [f for f in schema.fields if f.name == "point"][0]
+        self.assertEqual(type(field.dataType), ExamplePointUDT)
+        srdd.registerTempTable("labeled_point")
+        point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point
+        self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+    def test_apply_schema_with_udt(self):
+        from pyspark.tests import ExamplePoint, ExamplePointUDT
+        row = (1.0, ExamplePoint(1.0, 2.0))
+        rdd = self.sc.parallelize([row])
+        schema = StructType([StructField("label", DoubleType(), False),
+                             StructField("point", ExamplePointUDT(), False)])
+        srdd = self.sqlCtx.applySchema(rdd, schema)
+        point = srdd.first().point
+        self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+    def test_parquet_with_udt(self):
+        from pyspark.tests import ExamplePoint
+        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+        rdd = self.sc.parallelize([row])
+        srdd0 = self.sqlCtx.inferSchema(rdd)
+        output_dir = os.path.join(self.tempdir.name, "labeled_point")
+        srdd0.saveAsParquetFile(output_dir)
+        srdd1 = self.sqlCtx.parquetFile(output_dir)
+        point = srdd1.first().point
+        self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
 
 class InputFormatTests(ReusedPySparkTestCase):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/04450d11/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index e1b5992..5dd19dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -71,6 +71,8 @@ object DataType {
 
     case JSortedObject(
         ("class", JString(udtClass)),
+        ("pyClass", _),
+        ("sqlType", _),
         ("type", JString("udt"))) =>
       Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
   }
@@ -593,6 +595,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable
{
   /** Underlying storage type for this UDT */
   def sqlType: DataType
 
+  /** Paired Python UDT class, if exists. */
+  def pyUDT: String = null
+
   /**
    * Convert the user type to a SQL datum
    *
@@ -606,7 +611,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable
{
 
   override private[sql] def jsonValue: JValue = {
     ("type" -> "udt") ~
-      ("class" -> this.getClass.getName)
+      ("class" -> this.getClass.getName) ~
+      ("pyClass" -> pyUDT) ~
+      ("sqlType" -> sqlType.jsonValue)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/04450d11/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 9e61d18..84eaf40 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.types.UserDefinedType
 import org.apache.spark.sql.execution.{SparkStrategies, _}
 import org.apache.spark.sql.json._
 import org.apache.spark.sql.parquet.ParquetRelation
@@ -483,6 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
       case ArrayType(_, _) => true
       case MapType(_, _, _) => true
       case StructType(_) => true
+      case udt: UserDefinedType[_] => needsConversion(udt.sqlType)
       case other => false
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/04450d11/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 9976690..a83cf5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -135,6 +135,8 @@ object EvaluatePython {
       case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
     }.asJava
 
+    case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
+
     case (dec: BigDecimal, dt: DecimalType) => dec.underlying()  // Pyrolite can handle
BigDecimal
 
     // Pyrolite can handle Timestamp
@@ -177,6 +179,9 @@ object EvaluatePython {
     case (c: java.util.Calendar, TimestampType) =>
       new java.sql.Timestamp(c.getTime().getTime())
 
+    case (_, udt: UserDefinedType[_]) =>
+      fromJava(obj, udt.sqlType)
+
     case (c: Int, ByteType) => c.toByte
     case (c: Long, ByteType) => c.toByte
     case (c: Int, ShortType) => c.toShort

http://git-wip-us.apache.org/repos/asf/spark/blob/04450d11/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
new file mode 100644
index 0000000..b9569e9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
+import org.apache.spark.sql.catalyst.types._
+
+/**
+ * An example class to demonstrate UDT in Scala, Java, and Python.
+ * @param x x coordinate
+ * @param y y coordinate
+ */
+@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
+private[sql] class ExamplePoint(val x: Double, val y: Double)
+
+/**
+ * User-defined type for [[ExamplePoint]].
+ */
+private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
+
+  override def sqlType: DataType = ArrayType(DoubleType, false)
+
+  override def pyUDT: String = "pyspark.tests.ExamplePointUDT"
+
+  override def serialize(obj: Any): Seq[Double] = {
+    obj match {
+      case p: ExamplePoint =>
+        Seq(p.x, p.y)
+    }
+  }
+
+  override def deserialize(datum: Any): ExamplePoint = {
+    datum match {
+      case values: Seq[_] =>
+        val xy = values.asInstanceOf[Seq[Double]]
+        assert(xy.length == 2)
+        new ExamplePoint(xy(0), xy(1))
+      case values: util.ArrayList[_] =>
+        val xy = values.asInstanceOf[util.ArrayList[Double]].asScala
+        new ExamplePoint(xy(0), xy(1))
+    }
+  }
+
+  override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/04450d11/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
index 1bc1514..3fa4a7c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.types.UserDefinedType
 
-
 protected[sql] object DataTypeConversions {
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message