spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-11727][SQL] Split ExpressionEncoder into FlatEncoder and ProductEncoder
Date Fri, 13 Nov 2015 19:25:40 GMT
Repository: spark
Updated Branches:
  refs/heads/master 23b8188f7 -> d7b2b97ad


[SPARK-11727][SQL] Split ExpressionEncoder into FlatEncoder and ProductEncoder

also add more tests for encoders, and fix bugs that I found:

* when convert array to catalyst array, we can only skip element conversion for native types(e.g. int, long, boolean), not `AtomicType`(String is AtomicType but we need to convert it)
* we should also handle scala `BigDecimal` when convert from catalyst `Decimal`.
* complex map type should be supported

other issues that still in investigation:

* encode java `BigDecimal` and decode it back, seems we will loss precision info.
* when encode case class that defined inside a object, `ClassNotFound` exception will be thrown.

I'll remove unused code in a follow-up PR.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9693 from cloud-fan/split.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d7b2b97a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d7b2b97a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d7b2b97a

Branch: refs/heads/master
Commit: d7b2b97ad67f9700fb8c13422c399f2edb72f770
Parents: 23b8188
Author: Wenchen Fan <wenchen@databricks.com>
Authored: Fri Nov 13 11:25:33 2015 -0800
Committer: Michael Armbrust <michael@databricks.com>
Committed: Fri Nov 13 11:25:33 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |   2 +-
 .../sql/catalyst/encoders/FlatEncoder.scala     |  50 ++
 .../sql/catalyst/encoders/ProductEncoder.scala  | 452 +++++++++++++++++++
 .../sql/catalyst/encoders/RowEncoder.scala      |  58 +--
 .../spark/sql/catalyst/util/DateTimeUtils.scala |   2 +-
 .../sql/catalyst/util/GenericArrayData.scala    |   2 +-
 .../encoders/ExpressionEncoderSuite.scala       | 259 ++---------
 .../catalyst/encoders/FlatEncoderSuite.scala    |  74 +++
 .../catalyst/encoders/ProductEncoderSuite.scala | 123 +++++
 .../org/apache/spark/sql/GroupedDataset.scala   |   7 +-
 .../org/apache/spark/sql/SQLImplicits.scala     |  22 +-
 .../scala/org/apache/spark/sql/functions.scala  |   4 +-
 12 files changed, 766 insertions(+), 289 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 6d82226..0b3dd35 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -75,7 +75,7 @@ trait ScalaReflection {
    *
    * @see SPARK-5281
    */
-  private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
+  def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
 
   /**
    * Returns the Spark SQL DataType for a given scala type.  Where this is not an exact mapping

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala
new file mode 100644
index 0000000..6d307ab
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference}
+import org.apache.spark.sql.catalyst.ScalaReflection
+
+object FlatEncoder {
+  import ScalaReflection.schemaFor
+  import ScalaReflection.dataTypeFor
+
+  def apply[T : TypeTag]: ExpressionEncoder[T] = {
+    // We convert the not-serializable TypeTag into StructType and ClassTag.
+    val tpe = typeTag[T].tpe
+    val mirror = typeTag[T].mirror
+    val cls = mirror.runtimeClass(tpe)
+    assert(!schemaFor(tpe).dataType.isInstanceOf[StructType])
+
+    val input = BoundReference(0, dataTypeFor(tpe), nullable = true)
+    val toRowExpression = CreateNamedStruct(
+      Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil)
+    val fromRowExpression = ProductEncoder.constructorFor(tpe)
+
+    new ExpressionEncoder[T](
+      toRowExpression.dataType,
+      flat = true,
+      toRowExpression.flatten,
+      fromRowExpression,
+      ClassTag[T](cls))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
new file mode 100644
index 0000000..414adb2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.util.Utils
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData}
+
+import scala.reflect.ClassTag
+
+object ProductEncoder {
+  import ScalaReflection.universe._
+  import ScalaReflection.localTypeOf
+  import ScalaReflection.dataTypeFor
+  import ScalaReflection.Schema
+  import ScalaReflection.schemaFor
+  import ScalaReflection.arrayClassFor
+
+  def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = {
+    // We convert the not-serializable TypeTag into StructType and ClassTag.
+    val tpe = typeTag[T].tpe
+    val mirror = typeTag[T].mirror
+    val cls = mirror.runtimeClass(tpe)
+
+    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+    val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct]
+    val fromRowExpression = constructorFor(tpe)
+
+    new ExpressionEncoder[T](
+      toRowExpression.dataType,
+      flat = false,
+      toRowExpression.flatten,
+      fromRowExpression,
+      ClassTag[T](cls))
+  }
+
+  // The Predef.Map is scala.collection.immutable.Map.
+  // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
+  import scala.collection.Map
+
+  def extractorFor(
+      inputObject: Expression,
+      tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
+    if (!inputObject.dataType.isInstanceOf[ObjectType]) {
+      inputObject
+    } else {
+      tpe match {
+        case t if t <:< localTypeOf[Option[_]] =>
+          val TypeRef(_, _, Seq(optType)) = t
+          optType match {
+            // For primitive types we must manually unbox the value of the object.
+            case t if t <:< definitions.IntTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject),
+                "intValue",
+                IntegerType)
+            case t if t <:< definitions.LongTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject),
+                "longValue",
+                LongType)
+            case t if t <:< definitions.DoubleTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject),
+                "doubleValue",
+                DoubleType)
+            case t if t <:< definitions.FloatTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject),
+                "floatValue",
+                FloatType)
+            case t if t <:< definitions.ShortTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject),
+                "shortValue",
+                ShortType)
+            case t if t <:< definitions.ByteTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject),
+                "byteValue",
+                ByteType)
+            case t if t <:< definitions.BooleanTpe =>
+              Invoke(
+                UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject),
+                "booleanValue",
+                BooleanType)
+
+            // For non-primitives, we can just extract the object from the Option and then recurse.
+            case other =>
+              val className: String = optType.erasure.typeSymbol.asClass.fullName
+              val classObj = Utils.classForName(className)
+              val optionObjectType = ObjectType(classObj)
+
+              val unwrapped = UnwrapOption(optionObjectType, inputObject)
+              expressions.If(
+                IsNull(unwrapped),
+                expressions.Literal.create(null, schemaFor(optType).dataType),
+                extractorFor(unwrapped, optType))
+          }
+
+        case t if t <:< localTypeOf[Product] =>
+          val formalTypeArgs = t.typeSymbol.asClass.typeParams
+          val TypeRef(_, _, actualTypeArgs) = t
+          val constructorSymbol = t.member(nme.CONSTRUCTOR)
+          val params = if (constructorSymbol.isMethod) {
+            constructorSymbol.asMethod.paramss
+          } else {
+            // Find the primary constructor, and use its parameter ordering.
+            val primaryConstructorSymbol: Option[Symbol] =
+              constructorSymbol.asTerm.alternatives.find(s =>
+                s.isMethod && s.asMethod.isPrimaryConstructor)
+
+            if (primaryConstructorSymbol.isEmpty) {
+              sys.error("Internal SQL error: Product object did not have a primary constructor.")
+            } else {
+              primaryConstructorSymbol.get.asMethod.paramss
+            }
+          }
+
+          CreateNamedStruct(params.head.flatMap { p =>
+            val fieldName = p.name.toString
+            val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+            val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
+            expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
+          })
+
+        case t if t <:< localTypeOf[Array[_]] =>
+          val TypeRef(_, _, Seq(elementType)) = t
+          toCatalystArray(inputObject, elementType)
+
+        case t if t <:< localTypeOf[Seq[_]] =>
+          val TypeRef(_, _, Seq(elementType)) = t
+          toCatalystArray(inputObject, elementType)
+
+        case t if t <:< localTypeOf[Map[_, _]] =>
+          val TypeRef(_, _, Seq(keyType, valueType)) = t
+
+          val keys =
+            Invoke(
+              Invoke(inputObject, "keysIterator",
+                ObjectType(classOf[scala.collection.Iterator[_]])),
+              "toSeq",
+              ObjectType(classOf[scala.collection.Seq[_]]))
+          val convertedKeys = toCatalystArray(keys, keyType)
+
+          val values =
+            Invoke(
+              Invoke(inputObject, "valuesIterator",
+                ObjectType(classOf[scala.collection.Iterator[_]])),
+              "toSeq",
+              ObjectType(classOf[scala.collection.Seq[_]]))
+          val convertedValues = toCatalystArray(values, valueType)
+
+          val Schema(keyDataType, _) = schemaFor(keyType)
+          val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+          NewInstance(
+            classOf[ArrayBasedMapData],
+            convertedKeys :: convertedValues :: Nil,
+            dataType = MapType(keyDataType, valueDataType, valueNullable))
+
+        case t if t <:< localTypeOf[String] =>
+          StaticInvoke(
+            classOf[UTF8String],
+            StringType,
+            "fromString",
+            inputObject :: Nil)
+
+        case t if t <:< localTypeOf[java.sql.Timestamp] =>
+          StaticInvoke(
+            DateTimeUtils,
+            TimestampType,
+            "fromJavaTimestamp",
+            inputObject :: Nil)
+
+        case t if t <:< localTypeOf[java.sql.Date] =>
+          StaticInvoke(
+            DateTimeUtils,
+            DateType,
+            "fromJavaDate",
+            inputObject :: Nil)
+
+        case t if t <:< localTypeOf[BigDecimal] =>
+          StaticInvoke(
+            Decimal,
+            DecimalType.SYSTEM_DEFAULT,
+            "apply",
+            inputObject :: Nil)
+
+        case t if t <:< localTypeOf[java.math.BigDecimal] =>
+          StaticInvoke(
+            Decimal,
+            DecimalType.SYSTEM_DEFAULT,
+            "apply",
+            inputObject :: Nil)
+
+        case t if t <:< localTypeOf[java.lang.Integer] =>
+          Invoke(inputObject, "intValue", IntegerType)
+        case t if t <:< localTypeOf[java.lang.Long] =>
+          Invoke(inputObject, "longValue", LongType)
+        case t if t <:< localTypeOf[java.lang.Double] =>
+          Invoke(inputObject, "doubleValue", DoubleType)
+        case t if t <:< localTypeOf[java.lang.Float] =>
+          Invoke(inputObject, "floatValue", FloatType)
+        case t if t <:< localTypeOf[java.lang.Short] =>
+          Invoke(inputObject, "shortValue", ShortType)
+        case t if t <:< localTypeOf[java.lang.Byte] =>
+          Invoke(inputObject, "byteValue", ByteType)
+        case t if t <:< localTypeOf[java.lang.Boolean] =>
+          Invoke(inputObject, "booleanValue", BooleanType)
+
+        case other =>
+          throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
+      }
+    }
+  }
+
+  private def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
+    val externalDataType = dataTypeFor(elementType)
+    val Schema(catalystType, nullable) = schemaFor(elementType)
+    if (RowEncoder.isNativeType(catalystType)) {
+      NewInstance(
+        classOf[GenericArrayData],
+        input :: Nil,
+        dataType = ArrayType(catalystType, nullable))
+    } else {
+      MapObjects(extractorFor(_, elementType), input, externalDataType)
+    }
+  }
+
+  def constructorFor(
+      tpe: `Type`,
+      path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized {
+
+    /** Returns the current path with a sub-field extracted. */
+    def addToPath(part: String): Expression = path
+      .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
+      .getOrElse(UnresolvedAttribute(part))
+
+    /** Returns the current path with a field at ordinal extracted. */
+    def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
+      .map(p => GetInternalRowField(p, ordinal, dataType))
+      .getOrElse(BoundReference(ordinal, dataType, false))
+
+    /** Returns the current path or `BoundReference`. */
+    def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true))
+
+    tpe match {
+      case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
+
+      case t if t <:< localTypeOf[Option[_]] =>
+        val TypeRef(_, _, Seq(optType)) = t
+        WrapOption(null, constructorFor(optType, path))
+
+      case t if t <:< localTypeOf[java.lang.Integer] =>
+        val boxedType = classOf[java.lang.Integer]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Long] =>
+        val boxedType = classOf[java.lang.Long]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Double] =>
+        val boxedType = classOf[java.lang.Double]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Float] =>
+        val boxedType = classOf[java.lang.Float]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Short] =>
+        val boxedType = classOf[java.lang.Short]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Byte] =>
+        val boxedType = classOf[java.lang.Byte]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.lang.Boolean] =>
+        val boxedType = classOf[java.lang.Boolean]
+        val objectType = ObjectType(boxedType)
+        NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+      case t if t <:< localTypeOf[java.sql.Date] =>
+        StaticInvoke(
+          DateTimeUtils,
+          ObjectType(classOf[java.sql.Date]),
+          "toJavaDate",
+          getPath :: Nil,
+          propagateNull = true)
+
+      case t if t <:< localTypeOf[java.sql.Timestamp] =>
+        StaticInvoke(
+          DateTimeUtils,
+          ObjectType(classOf[java.sql.Timestamp]),
+          "toJavaTimestamp",
+          getPath :: Nil,
+          propagateNull = true)
+
+      case t if t <:< localTypeOf[java.lang.String] =>
+        Invoke(getPath, "toString", ObjectType(classOf[String]))
+
+      case t if t <:< localTypeOf[java.math.BigDecimal] =>
+        Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+
+      case t if t <:< localTypeOf[BigDecimal] =>
+        Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
+
+      case t if t <:< localTypeOf[Array[_]] =>
+        val TypeRef(_, _, Seq(elementType)) = t
+        val primitiveMethod = elementType match {
+          case t if t <:< definitions.IntTpe => Some("toIntArray")
+          case t if t <:< definitions.LongTpe => Some("toLongArray")
+          case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
+          case t if t <:< definitions.FloatTpe => Some("toFloatArray")
+          case t if t <:< definitions.ShortTpe => Some("toShortArray")
+          case t if t <:< definitions.ByteTpe => Some("toByteArray")
+          case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
+          case _ => None
+        }
+
+        primitiveMethod.map { method =>
+          Invoke(getPath, method, arrayClassFor(elementType))
+        }.getOrElse {
+          Invoke(
+            MapObjects(
+              p => constructorFor(elementType, Some(p)),
+              getPath,
+              schemaFor(elementType).dataType),
+            "array",
+            arrayClassFor(elementType))
+        }
+
+      case t if t <:< localTypeOf[Seq[_]] =>
+        val TypeRef(_, _, Seq(elementType)) = t
+        val arrayData =
+          Invoke(
+            MapObjects(
+              p => constructorFor(elementType, Some(p)),
+              getPath,
+              schemaFor(elementType).dataType),
+            "array",
+            ObjectType(classOf[Array[Any]]))
+
+        StaticInvoke(
+          scala.collection.mutable.WrappedArray,
+          ObjectType(classOf[Seq[_]]),
+          "make",
+          arrayData :: Nil)
+
+      case t if t <:< localTypeOf[Map[_, _]] =>
+        val TypeRef(_, _, Seq(keyType, valueType)) = t
+
+        val keyData =
+          Invoke(
+            MapObjects(
+              p => constructorFor(keyType, Some(p)),
+              Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
+              schemaFor(keyType).dataType),
+            "array",
+            ObjectType(classOf[Array[Any]]))
+
+        val valueData =
+          Invoke(
+            MapObjects(
+              p => constructorFor(valueType, Some(p)),
+              Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
+              schemaFor(valueType).dataType),
+            "array",
+            ObjectType(classOf[Array[Any]]))
+
+        StaticInvoke(
+          ArrayBasedMapData,
+          ObjectType(classOf[Map[_, _]]),
+          "toScalaMap",
+          keyData :: valueData :: Nil)
+
+      case t if t <:< localTypeOf[Product] =>
+        val formalTypeArgs = t.typeSymbol.asClass.typeParams
+        val TypeRef(_, _, actualTypeArgs) = t
+        val constructorSymbol = t.member(nme.CONSTRUCTOR)
+        val params = if (constructorSymbol.isMethod) {
+          constructorSymbol.asMethod.paramss
+        } else {
+          // Find the primary constructor, and use its parameter ordering.
+          val primaryConstructorSymbol: Option[Symbol] =
+            constructorSymbol.asTerm.alternatives.find(s =>
+              s.isMethod && s.asMethod.isPrimaryConstructor)
+
+          if (primaryConstructorSymbol.isEmpty) {
+            sys.error("Internal SQL error: Product object did not have a primary constructor.")
+          } else {
+            primaryConstructorSymbol.get.asMethod.paramss
+          }
+        }
+
+        val className: String = t.erasure.typeSymbol.asClass.fullName
+        val cls = Utils.classForName(className)
+
+        val arguments = params.head.zipWithIndex.map { case (p, i) =>
+          val fieldName = p.name.toString
+          val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+          val dataType = schemaFor(fieldType).dataType
+
+          // For tuples, we based grab the inner fields by ordinal instead of name.
+          if (className startsWith "scala.Tuple") {
+            constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
+          } else {
+            constructorFor(fieldType, Some(addToPath(fieldName)))
+          }
+        }
+
+        val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
+
+        if (path.nonEmpty) {
+          expressions.If(
+            IsNull(getPath),
+            expressions.Literal.create(null, ObjectType(cls)),
+            newInstance
+          )
+        } else {
+          newInstance
+        }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 0b42130..e0be896 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -119,9 +119,17 @@ object RowEncoder {
       CreateStruct(convertedFields)
   }
 
-  private def externalDataTypeFor(dt: DataType): DataType = dt match {
+  /**
+   * Returns true if the value of this data type is same between internal and external.
+   */
+  def isNativeType(dt: DataType): Boolean = dt match {
     case BooleanType | ByteType | ShortType | IntegerType | LongType |
-         FloatType | DoubleType | BinaryType => dt
+         FloatType | DoubleType | BinaryType => true
+    case _ => false
+  }
+
+  private def externalDataTypeFor(dt: DataType): DataType = dt match {
+    case _ if isNativeType(dt) => dt
     case TimestampType => ObjectType(classOf[java.sql.Timestamp])
     case DateType => ObjectType(classOf[java.sql.Date])
     case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
@@ -137,13 +145,13 @@ object RowEncoder {
       If(
         IsNull(field),
         Literal.create(null, externalDataTypeFor(f.dataType)),
-        constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType)
+        constructorFor(BoundReference(i, f.dataType, f.nullable))
       )
     }
     CreateExternalRow(fields)
   }
 
-  private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match {
+  private def constructorFor(input: Expression): Expression = input.dataType match {
     case BooleanType | ByteType | ShortType | IntegerType | LongType |
          FloatType | DoubleType | BinaryType => input
 
@@ -170,7 +178,7 @@ object RowEncoder {
     case ArrayType(et, nullable) =>
       val arrayData =
         Invoke(
-          MapObjects(constructorFor(_, et), input, et),
+          MapObjects(constructorFor, input, et),
           "array",
           ObjectType(classOf[Array[_]]))
       StaticInvoke(
@@ -181,10 +189,10 @@ object RowEncoder {
 
     case MapType(kt, vt, valueNullable) =>
       val keyArrayType = ArrayType(kt, false)
-      val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType)
+      val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType))
 
       val valueArrayType = ArrayType(vt, valueNullable)
-      val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType)
+      val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType))
 
       StaticInvoke(
         ArrayBasedMapData,
@@ -197,42 +205,8 @@ object RowEncoder {
         If(
           Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
           Literal.create(null, externalDataTypeFor(f.dataType)),
-          constructorFor(getField(input, i, f.dataType), f.dataType))
+          constructorFor(GetInternalRowField(input, i, f.dataType)))
       }
       CreateExternalRow(convertedFields)
   }
-
-  private def getField(
-     row: Expression,
-     ordinal: Int,
-     dataType: DataType): Expression = dataType match {
-    case BooleanType =>
-      Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil)
-    case ByteType =>
-      Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil)
-    case ShortType =>
-      Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil)
-    case IntegerType | DateType =>
-      Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil)
-    case LongType | TimestampType =>
-      Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil)
-    case FloatType =>
-      Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil)
-    case DoubleType =>
-      Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil)
-    case t: DecimalType =>
-      Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_)))
-    case StringType =>
-      Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil)
-    case BinaryType =>
-      Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil)
-    case CalendarIntervalType =>
-      Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil)
-    case t: StructType =>
-      Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil)
-    case _: ArrayType =>
-      Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil)
-    case _: MapType =>
-      Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil)
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index f5fff90..deff8a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -110,7 +110,7 @@ object DateTimeUtils {
   }
 
   def stringToTime(s: String): java.util.Date = {
-    var indexOfGMT = s.indexOf("GMT");
+    val indexOfGMT = s.indexOf("GMT")
     if (indexOfGMT != -1) {
       // ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00)
       val s0 = s.substring(0, indexOfGMT)

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
index e9bf7b3..96588bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
@@ -23,7 +23,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
 class GenericArrayData(val array: Array[Any]) extends ArrayData {
 
-  def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray)
+  def this(seq: Seq[Any]) = this(seq.toArray)
 
   // TODO: This is boxing.  We should specialize.
   def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq)

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index b0dacf7..9fe64b4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -17,232 +17,27 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.runtime.universe._
+import java.util.Arrays
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{StructField, ArrayType}
-
-case class RepeatedStruct(s: Seq[PrimitiveData])
-
-case class NestedArray(a: Array[Array[Int]])
-
-case class BoxedData(
-    intField: java.lang.Integer,
-    longField: java.lang.Long,
-    doubleField: java.lang.Double,
-    floatField: java.lang.Float,
-    shortField: java.lang.Short,
-    byteField: java.lang.Byte,
-    booleanField: java.lang.Boolean)
-
-case class RepeatedData(
-    arrayField: Seq[Int],
-    arrayFieldContainsNull: Seq[java.lang.Integer],
-    mapField: scala.collection.Map[Int, Long],
-    mapFieldNull: scala.collection.Map[Int, java.lang.Long],
-    structField: PrimitiveData)
-
-case class SpecificCollection(l: List[Int])
-
-class ExpressionEncoderSuite extends SparkFunSuite {
-
-  encodeDecodeTest(1)
-  encodeDecodeTest(1L)
-  encodeDecodeTest(1.toDouble)
-  encodeDecodeTest(1.toFloat)
-  encodeDecodeTest(true)
-  encodeDecodeTest(false)
-  encodeDecodeTest(1.toShort)
-  encodeDecodeTest(1.toByte)
-  encodeDecodeTest("hello")
-
-  encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
-
-  // TODO: Support creating specific subclasses of Seq.
-  ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) }
-
-  encodeDecodeTest(
-    OptionalData(
-      Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
-      Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
-
-  encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None))
-
-  encodeDecodeTest(
-    BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
-
-  encodeDecodeTest(
-    BoxedData(null, null, null, null, null, null, null))
-
-  encodeDecodeTest(
-    RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
-
-  encodeDecodeTest(
-    RepeatedData(
-      Seq(1, 2),
-      Seq(new Integer(1), null, new Integer(2)),
-      Map(1 -> 2L),
-      Map(1 -> null),
-      PrimitiveData(1, 1, 1, 1, 1, 1, true)))
-
-  encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null)))
-
-  encodeDecodeTest(("Seq[(String, String)]",
-    Seq(("a", "b"))))
-  encodeDecodeTest(("Seq[(Int, Int)]",
-    Seq((1, 2))))
-  encodeDecodeTest(("Seq[(Long, Long)]",
-    Seq((1L, 2L))))
-  encodeDecodeTest(("Seq[(Float, Float)]",
-    Seq((1.toFloat, 2.toFloat))))
-  encodeDecodeTest(("Seq[(Double, Double)]",
-    Seq((1.toDouble, 2.toDouble))))
-  encodeDecodeTest(("Seq[(Short, Short)]",
-    Seq((1.toShort, 2.toShort))))
-  encodeDecodeTest(("Seq[(Byte, Byte)]",
-    Seq((1.toByte, 2.toByte))))
-  encodeDecodeTest(("Seq[(Boolean, Boolean)]",
-    Seq((true, false))))
-
-  // TODO: Decoding/encoding of complex maps.
-  ignore("complex maps") {
-    encodeDecodeTest(("Map[Int, (String, String)]",
-      Map(1 ->("a", "b"))))
-  }
-
-  encodeDecodeTest(("ArrayBuffer[(String, String)]",
-    ArrayBuffer(("a", "b"))))
-  encodeDecodeTest(("ArrayBuffer[(Int, Int)]",
-    ArrayBuffer((1, 2))))
-  encodeDecodeTest(("ArrayBuffer[(Long, Long)]",
-    ArrayBuffer((1L, 2L))))
-  encodeDecodeTest(("ArrayBuffer[(Float, Float)]",
-    ArrayBuffer((1.toFloat, 2.toFloat))))
-  encodeDecodeTest(("ArrayBuffer[(Double, Double)]",
-    ArrayBuffer((1.toDouble, 2.toDouble))))
-  encodeDecodeTest(("ArrayBuffer[(Short, Short)]",
-    ArrayBuffer((1.toShort, 2.toShort))))
-  encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]",
-    ArrayBuffer((1.toByte, 2.toByte))))
-  encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]",
-    ArrayBuffer((true, false))))
-
-  encodeDecodeTest(("Seq[Seq[(Int, Int)]]",
-    Seq(Seq((1, 2)))))
-
-  encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
-    Array(Array((1, 2)))))
-  { (l, r) => l._2(0)(0) == r._2(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
-    Array(Array(Array((1, 2))))))
-  { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]",
-    Array(Array(Array(Array((1, 2)))))))
-  { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]",
-    Array(Array(Array(Array(Array((1, 2))))))))
-  { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
-
-
-  encodeDecodeTestCustom(("Array[Array[Integer]]",
-    Array(Array[Integer](1))))
-  { (l, r) => l._2(0)(0) == r._2(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Int]]",
-    Array(Array(1))))
-  { (l, r) => l._2(0)(0) == r._2(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Int]]",
-    Array(Array(Array(1)))))
-  { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Array[Int]]]",
-    Array(Array(Array(Array(1))))))
-  { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
-
-  encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]",
-    Array(Array(Array(Array(Array(1)))))))
-  { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
-
-  encodeDecodeTest(("Array[Byte] null",
-    null: Array[Byte]))
-  encodeDecodeTestCustom(("Array[Byte]",
-    Array[Byte](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Int] null",
-    null: Array[Int]))
-  encodeDecodeTestCustom(("Array[Int]",
-    Array[Int](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Long] null",
-    null: Array[Long]))
-  encodeDecodeTestCustom(("Array[Long]",
-    Array[Long](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Double] null",
-    null: Array[Double]))
-  encodeDecodeTestCustom(("Array[Double]",
-    Array[Double](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Float] null",
-    null: Array[Float]))
-  encodeDecodeTestCustom(("Array[Float]",
-    Array[Float](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Boolean] null",
-    null: Array[Boolean]))
-  encodeDecodeTestCustom(("Array[Boolean]",
-    Array[Boolean](true, false)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTest(("Array[Short] null",
-    null: Array[Short]))
-  encodeDecodeTestCustom(("Array[Short]",
-    Array[Short](1, 2, 3)))
-    { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
-  encodeDecodeTestCustom(("java.sql.Timestamp",
-    new java.sql.Timestamp(1)))
-    { (l, r) => l._2.toString == r._2.toString }
-
-  encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1)))
-    { (l, r) => l._2.toString == r._2.toString }
-
-  /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */
-  protected def encodeDecodeTest[T : TypeTag](inputData: T) =
-    encodeDecodeTestCustom[T](inputData)((l, r) => l == r)
-
-  /**
-   * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it
-   * matches the original.
-   */
-  protected def encodeDecodeTestCustom[T : TypeTag](
-      inputData: T)(
-      c: (T, T) => Boolean) = {
-    test(s"encode/decode: $inputData - ${inputData.getClass.getName}") {
-      val encoder = try ExpressionEncoder[T]() catch {
-        case e: Exception =>
-          fail(s"Exception thrown generating encoder", e)
-      }
-      val convertedData = encoder.toRow(inputData)
+import org.apache.spark.sql.types.ArrayType
+
+abstract class ExpressionEncoderSuite extends SparkFunSuite {
+  protected def encodeDecodeTest[T](
+      input: T,
+      encoder: ExpressionEncoder[T],
+      testName: String): Unit = {
+    test(s"encode/decode for $testName: $input") {
+      val row = encoder.toRow(input)
       val schema = encoder.schema.toAttributes
       val boundEncoder = encoder.resolve(schema).bind(schema)
-      val convertedBack = try boundEncoder.fromRow(convertedData) catch {
+      val convertedBack = try boundEncoder.fromRow(row) catch {
         case e: Exception =>
           fail(
            s"""Exception thrown while decoding
-              |Converted: $convertedData
+              |Converted: $row
               |Schema: ${schema.mkString(",")}
               |${encoder.schema.treeString}
               |
@@ -252,18 +47,27 @@ class ExpressionEncoderSuite extends SparkFunSuite {
             """.stripMargin, e)
       }
 
-      if (!c(inputData, convertedBack)) {
+      val isCorrect = (input, convertedBack) match {
+        case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2)
+        case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2)
+        case (b1: Array[Array[_]], b2: Array[Array[_]]) =>
+          Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
+        case (b1: Array[_], b2: Array[_]) =>
+          Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
+        case _ => input == convertedBack
+      }
+
+      if (!isCorrect) {
         val types = convertedBack match {
           case c: Product =>
             c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
           case other => other.getClass.getName
         }
 
-
         val encodedData = try {
-          convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
-            case (a: ArrayData, StructField(_, at: ArrayType, _, _)) =>
-              a.toArray[Any](at.elementType).toSeq
+          row.toSeq(encoder.schema).zip(schema).map {
+            case (a: ArrayData, AttributeReference(_, ArrayType(et, _), _, _)) =>
+              a.toArray[Any](et).toSeq
             case (other, _) =>
               other
           }.mkString("[", ",", "]")
@@ -274,7 +78,7 @@ class ExpressionEncoderSuite extends SparkFunSuite {
         fail(
           s"""Encoded/Decoded data does not match input data
              |
-             |in:  $inputData
+             |in:  $input
              |out: $convertedBack
              |types: $types
              |
@@ -282,11 +86,10 @@ class ExpressionEncoderSuite extends SparkFunSuite {
              |Schema: ${schema.mkString(",")}
              |${encoder.schema.treeString}
              |
-             |Extract Expressions:
-             |$boundEncoder
+             |fromRow Expressions:
+             |${boundEncoder.fromRowExpression.treeString}
          """.stripMargin)
-        }
       }
-
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
new file mode 100644
index 0000000..55821c4
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import java.sql.{Date, Timestamp}
+
+class FlatEncoderSuite extends ExpressionEncoderSuite {
+  encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean")
+  encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte")
+  encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short")
+  encodeDecodeTest(-3, FlatEncoder[Int], "primitive int")
+  encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long")
+  encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float")
+  encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double")
+
+  encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean")
+  encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte")
+  encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short")
+  encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int")
+  encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long")
+  encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float")
+  encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double")
+
+  encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal")
+  type JDecimal = java.math.BigDecimal
+  // encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal")
+
+  encodeDecodeTest("hello", FlatEncoder[String], "string")
+  encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date")
+  encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp")
+  encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary")
+
+  encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int")
+  encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string")
+  encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null")
+  encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int")
+  encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string")
+
+  encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)),
+    FlatEncoder[Seq[Seq[Int]]], "seq of seq of int")
+  encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")),
+    FlatEncoder[Seq[Seq[String]]], "seq of seq of string")
+
+  encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int")
+  encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string")
+  encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null")
+  encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int")
+  encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string")
+
+  encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)),
+    FlatEncoder[Array[Array[Int]]], "array of array of int")
+  encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")),
+    FlatEncoder[Array[Array[String]]], "array of array of string")
+
+  encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map")
+  encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null")
+  encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)),
+    FlatEncoder[Map[Int, Map[String, Int]]], "map of map")
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
new file mode 100644
index 0000000..fda978e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
+
+case class RepeatedStruct(s: Seq[PrimitiveData])
+
+case class NestedArray(a: Array[Array[Int]]) {
+  override def equals(other: Any): Boolean = other match {
+    case NestedArray(otherArray) =>
+      java.util.Arrays.deepEquals(
+        a.asInstanceOf[Array[AnyRef]],
+        otherArray.asInstanceOf[Array[AnyRef]])
+    case _ => false
+  }
+}
+
+case class BoxedData(
+    intField: java.lang.Integer,
+    longField: java.lang.Long,
+    doubleField: java.lang.Double,
+    floatField: java.lang.Float,
+    shortField: java.lang.Short,
+    byteField: java.lang.Byte,
+    booleanField: java.lang.Boolean)
+
+case class RepeatedData(
+    arrayField: Seq[Int],
+    arrayFieldContainsNull: Seq[java.lang.Integer],
+    mapField: scala.collection.Map[Int, Long],
+    mapFieldNull: scala.collection.Map[Int, java.lang.Long],
+    structField: PrimitiveData)
+
+case class SpecificCollection(l: List[Int])
+
+class ProductEncoderSuite extends ExpressionEncoderSuite {
+
+  productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
+
+  productTest(
+    OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
+      Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
+
+  productTest(OptionalData(None, None, None, None, None, None, None, None))
+
+  productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
+
+  productTest(BoxedData(null, null, null, null, null, null, null))
+
+  productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
+
+  productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true)))
+
+  productTest(
+    RepeatedData(
+      Seq(1, 2),
+      Seq(new Integer(1), null, new Integer(2)),
+      Map(1 -> 2L),
+      Map(1 -> null),
+      PrimitiveData(1, 1, 1, 1, 1, 1, true)))
+
+  productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6))))
+
+  productTest(("Seq[(String, String)]",
+    Seq(("a", "b"))))
+  productTest(("Seq[(Int, Int)]",
+    Seq((1, 2))))
+  productTest(("Seq[(Long, Long)]",
+    Seq((1L, 2L))))
+  productTest(("Seq[(Float, Float)]",
+    Seq((1.toFloat, 2.toFloat))))
+  productTest(("Seq[(Double, Double)]",
+    Seq((1.toDouble, 2.toDouble))))
+  productTest(("Seq[(Short, Short)]",
+    Seq((1.toShort, 2.toShort))))
+  productTest(("Seq[(Byte, Byte)]",
+    Seq((1.toByte, 2.toByte))))
+  productTest(("Seq[(Boolean, Boolean)]",
+    Seq((true, false))))
+
+  productTest(("ArrayBuffer[(String, String)]",
+    ArrayBuffer(("a", "b"))))
+  productTest(("ArrayBuffer[(Int, Int)]",
+    ArrayBuffer((1, 2))))
+  productTest(("ArrayBuffer[(Long, Long)]",
+    ArrayBuffer((1L, 2L))))
+  productTest(("ArrayBuffer[(Float, Float)]",
+    ArrayBuffer((1.toFloat, 2.toFloat))))
+  productTest(("ArrayBuffer[(Double, Double)]",
+    ArrayBuffer((1.toDouble, 2.toDouble))))
+  productTest(("ArrayBuffer[(Short, Short)]",
+    ArrayBuffer((1.toShort, 2.toShort))))
+  productTest(("ArrayBuffer[(Byte, Byte)]",
+    ArrayBuffer((1.toByte, 2.toByte))))
+  productTest(("ArrayBuffer[(Boolean, Boolean)]",
+    ArrayBuffer((true, false))))
+
+  productTest(("Seq[Seq[(Int, Int)]]",
+    Seq(Seq((1, 2)))))
+
+  private def productTest[T <: Product : TypeTag](input: T): Unit = {
+    encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 9c16940..ebcf4c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
+import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor}
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
@@ -56,9 +56,6 @@ class GroupedDataset[K, T] private[sql](
   private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes)
   private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes)
 
-  /** Encoders for built in aggregations. */
-  private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
-
   private def logicalPlan = queryExecution.analyzed
   private def sqlContext = queryExecution.sqlContext
 
@@ -211,7 +208,7 @@ class GroupedDataset[K, T] private[sql](
    * Returns a [[Dataset]] that contains a tuple with each key and the number of items present
    * for that key.
    */
-  def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long])
+  def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long]))
 
   /**
    * Applies the given function to each cogrouped data.  For each unique group, the function will

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 6da46a5..8471eea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -37,17 +37,21 @@ import org.apache.spark.unsafe.types.UTF8String
 abstract class SQLImplicits {
   protected def _sqlContext: SQLContext
 
-  implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]()
+  implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
 
-  implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true)
-  implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
-  implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true)
-  implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true)
-  implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true)
-  implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true)
-  implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true)
-  implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true)
+  implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int]
+  implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long]
+  implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double]
+  implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float]
+  implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte]
+  implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short]
+  implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean]
+  implicit def newStringEncoder: Encoder[String] = FlatEncoder[String]
 
+  /**
+   * Creates a [[Dataset]] from an RDD.
+   * @since 1.6.0
+   */
   implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = {
     DatasetHolder(_sqlContext.createDataset(rdd))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b2b97a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 53cc6e0..95158de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -26,7 +26,7 @@ import scala.util.Try
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.FlatEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
@@ -267,7 +267,7 @@ object functions extends LegacyFunctions {
    * @since 1.3.0
    */
   def count(columnName: String): TypedColumn[Any, Long] =
-    count(Column(columnName)).as(ExpressionEncoder[Long](flat = true))
+    count(Column(columnName)).as(FlatEncoder[Long])
 
   /**
    * Aggregate function: returns the number of distinct items in a group.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message