spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-9425] [SQL] support DecimalType in UnsafeRow
Date Fri, 31 Jul 2015 00:18:38 GMT
Repository: spark
Updated Branches:
  refs/heads/master e7a0976e9 -> 0b1a464b6


[SPARK-9425] [SQL] support DecimalType in UnsafeRow

This PR brings the support of DecimalType in UnsafeRow, for precision <= 18, it's settable,
otherwise it's not settable.

Author: Davies Liu <davies@databricks.com>

Closes #7758 from davies/unsafe_decimal and squashes the following commits:

478b1ba [Davies Liu] address comments
536314c [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_decimal
7c2e77a [Davies Liu] fix JoinedRow
76d6fa4 [Davies Liu] fix tests
99d3151 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_decimal
d49c6ae [Davies Liu] support DecimalType in UnsafeRow


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0b1a464b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0b1a464b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0b1a464b

Branch: refs/heads/master
Commit: 0b1a464b6e061580a75b99a91b042069d76bbbfd
Parents: e7a0976
Author: Davies Liu <davies@databricks.com>
Authored: Thu Jul 30 17:18:32 2015 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Thu Jul 30 17:18:32 2015 -0700

----------------------------------------------------------------------
 .../expressions/SpecializedGetters.java         |   2 +-
 .../UnsafeFixedWidthAggregationMap.java         |  22 ++--
 .../sql/catalyst/expressions/UnsafeRow.java     |  53 ++++++---
 .../catalyst/expressions/UnsafeRowWriters.java  |  42 +++++++
 .../sql/catalyst/CatalystTypeConverters.scala   |   9 +-
 .../apache/spark/sql/catalyst/InternalRow.scala |   4 +-
 .../sql/catalyst/expressions/Projection.scala   |   7 +-
 .../expressions/codegen/CodeGenerator.scala     |   9 +-
 .../codegen/GenerateUnsafeProjection.scala      | 115 ++++++++++---------
 .../spark/sql/catalyst/expressions/rows.scala   |   3 +-
 .../org/apache/spark/sql/types/Decimal.scala    |   6 +-
 .../spark/sql/types/GenericArrayData.scala      |   2 +-
 .../sql/catalyst/expressions/CastSuite.scala    |   5 +-
 .../expressions/DateExpressionsSuite.scala      |   2 +-
 .../UnsafeFixedWidthAggregationMapSuite.scala   |   8 +-
 .../expressions/UnsafeRowConverterSuite.scala   |  17 +--
 .../spark/sql/columnar/ColumnBuilder.scala      |   2 +-
 .../apache/spark/sql/columnar/ColumnStats.scala |   4 +-
 .../apache/spark/sql/columnar/ColumnType.scala  |   2 +-
 .../sql/execution/GeneratedAggregate.scala      |   2 +-
 .../sql/execution/SparkSqlSerializer2.scala     |   2 +-
 .../spark/sql/parquet/ParquetTableSupport.scala |   4 +-
 .../spark/sql/columnar/ColumnStatsSuite.scala   |  40 ++++++-
 23 files changed, 237 insertions(+), 125 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
index f7cea13..e3d3ba7 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
@@ -41,7 +41,7 @@ public interface SpecializedGetters {
 
   double getDouble(int ordinal);
 
-  Decimal getDecimal(int ordinal);
+  Decimal getDecimal(int ordinal, int precision, int scale);
 
   UTF8String getUTF8String(int ordinal);
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 03f4c3e..f3b4627 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions;
 import java.util.Iterator;
 
 import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.types.DecimalType;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.spark.unsafe.PlatformDependent;
@@ -62,25 +64,17 @@ public final class UnsafeFixedWidthAggregationMap {
   private final boolean enablePerfMetrics;
 
   /**
-   * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given
schema,
-   *         false otherwise.
-   */
-  public static boolean supportsGroupKeySchema(StructType schema) {
-    for (StructField field: schema.fields()) {
-      if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  /**
    * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the
given
    *         schema, false otherwise.
    */
   public static boolean supportsAggregationBufferSchema(StructType schema) {
     for (StructField field: schema.fields()) {
-      if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
+      if (field.dataType() instanceof DecimalType) {
+        DecimalType dt = (DecimalType) field.dataType();
+        if (dt.precision() > Decimal.MAX_LONG_DIGITS()) {
+          return false;
+        }
+      } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
         return false;
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 6d684ba..e7088ed 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions;
 
 import java.io.IOException;
 import java.io.OutputStream;
+import java.math.BigDecimal;
+import java.math.BigInteger;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
@@ -65,12 +67,7 @@ public final class UnsafeRow extends MutableRow {
    */
   public static final Set<DataType> settableFieldTypes;
 
-  /**
-   * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
-   */
-  public static final Set<DataType> readableFieldTypes;
-
-  // TODO: support DecimalType
+  // DecimalType(precision <= 18) is settable
   static {
     settableFieldTypes = Collections.unmodifiableSet(
       new HashSet<>(
@@ -86,16 +83,6 @@ public final class UnsafeRow extends MutableRow {
           DateType,
           TimestampType
         })));
-
-    // We support get() on a superset of the types for which we support set():
-    final Set<DataType> _readableFieldTypes = new HashSet<>(
-      Arrays.asList(new DataType[]{
-        StringType,
-        BinaryType,
-        CalendarIntervalType
-      }));
-    _readableFieldTypes.addAll(settableFieldTypes);
-    readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
   }
 
   //////////////////////////////////////////////////////////////////////////////
@@ -233,6 +220,21 @@ public final class UnsafeRow extends MutableRow {
   }
 
   @Override
+  public void setDecimal(int ordinal, Decimal value, int precision) {
+    assertIndexIsValid(ordinal);
+    if (value == null) {
+      setNullAt(ordinal);
+    } else {
+      if (precision <= Decimal.MAX_LONG_DIGITS()) {
+        setLong(ordinal, value.toUnscaledLong());
+      } else {
+        // TODO(davies): support update decimal (hold a bounded space even it's null)
+        throw new UnsupportedOperationException();
+      }
+    }
+  }
+
+  @Override
   public Object get(int ordinal) {
     throw new UnsupportedOperationException();
   }
@@ -256,7 +258,8 @@ public final class UnsafeRow extends MutableRow {
     } else if (dataType instanceof DoubleType) {
       return getDouble(ordinal);
     } else if (dataType instanceof DecimalType) {
-      return getDecimal(ordinal);
+      DecimalType dt = (DecimalType) dataType;
+      return getDecimal(ordinal, dt.precision(), dt.scale());
     } else if (dataType instanceof DateType) {
       return getInt(ordinal);
     } else if (dataType instanceof TimestampType) {
@@ -323,6 +326,22 @@ public final class UnsafeRow extends MutableRow {
   }
 
   @Override
+  public Decimal getDecimal(int ordinal, int precision, int scale) {
+    assertIndexIsValid(ordinal);
+    if (isNullAt(ordinal)) {
+      return null;
+    }
+    if (precision <= Decimal.MAX_LONG_DIGITS()) {
+      return Decimal.apply(getLong(ordinal), precision, scale);
+    } else {
+      byte[] bytes = getBinary(ordinal);
+      BigInteger bigInteger = new BigInteger(bytes);
+      BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+      return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale);
+    }
+  }
+
+  @Override
   public UTF8String getUTF8String(int ordinal) {
     assertIndexIsValid(ordinal);
     return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal));

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
index c3259e2..f43a285 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions;
 
 import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.Decimal;
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.types.ByteArray;
@@ -30,6 +31,47 @@ import org.apache.spark.unsafe.types.UTF8String;
  */
 public class UnsafeRowWriters {
 
+  /** Writer for Decimal with precision under 18. */
+  public static class CompactDecimalWriter {
+
+    public static int getSize(Decimal input) {
+      return 0;
+    }
+
+    public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
+      target.setLong(ordinal, input.toUnscaledLong());
+      return 0;
+    }
+  }
+
+  /** Writer for Decimal with precision larger than 18. */
+  public static class DecimalWriter {
+
+    public static int getSize(Decimal input) {
+      // bounded size
+      return 16;
+    }
+
+    public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
+      final long offset = target.getBaseOffset() + cursor;
+      final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
+      final int numBytes = bytes.length;
+      assert(numBytes <= 16);
+
+      // zero-out the bytes
+      PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L);
+      PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L);
+
+      // Write the bytes to the variable length portion.
+      PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET,
+        target.getBaseObject(), offset, numBytes);
+
+      // Set the fixed length portion.
+      target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
+      return 16;
+    }
+  }
+
   /** Writer for UTF8String. */
   public static class UTF8StringWriter {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 22452c0..7ca20fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -68,7 +68,7 @@ object CatalystTypeConverters {
       case StringType => StringConverter
       case DateType => DateConverter
       case TimestampType => TimestampConverter
-      case dt: DecimalType => BigDecimalConverter
+      case dt: DecimalType => new DecimalConverter(dt)
       case BooleanType => BooleanConverter
       case ByteType => ByteConverter
       case ShortType => ShortConverter
@@ -306,7 +306,8 @@ object CatalystTypeConverters {
       DateTimeUtils.toJavaTimestamp(row.getLong(column))
   }
 
-  private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal]
{
+  private class DecimalConverter(dataType: DecimalType)
+    extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] {
     override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match {
       case d: BigDecimal => Decimal(d)
       case d: JavaBigDecimal => Decimal(d)
@@ -314,9 +315,11 @@ object CatalystTypeConverters {
     }
     override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal
     override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal =
-      row.getDecimal(column).toJavaBigDecimal
+      row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal
   }
 
+  private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT)
+
   private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any]
{
     final override def toScala(catalystValue: Any): Any = catalystValue
     final override def toCatalystImpl(scalaValue: T): Any = scalaValue

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 486ba03..b19bf43 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -58,8 +58,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters
{
 
   override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType)
 
-  override def getDecimal(ordinal: Int): Decimal =
-    getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT)
+  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal =
+    getAs[Decimal](ordinal, DecimalType(precision, scale))
 
   override def getInterval(ordinal: Int): CalendarInterval =
     getAs[CalendarInterval](ordinal, CalendarIntervalType)

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index b3beb7e..7c7664e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection}
-import org.apache.spark.sql.types.{StructType, DataType}
+import org.apache.spark.sql.types.{Decimal, StructType, DataType}
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
@@ -225,6 +225,11 @@ class JoinedRow extends InternalRow {
   override def getFloat(i: Int): Float =
     if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)
 
+  override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
+    if (i < row1.numFields) row1.getDecimal(i, precision, scale)
+    else row2.getDecimal(i - row1.numFields, precision, scale)
+  }
+
   override def getStruct(i: Int, numFields: Int): InternalRow = {
     if (i < row1.numFields) {
       row1.getStruct(i, numFields)

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index c39e0df..60e2863 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -106,6 +106,7 @@ class CodeGenContext {
     val jt = javaType(dataType)
     dataType match {
       case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)"
+      case t: DecimalType => s"$getter.getDecimal($ordinal, ${t.precision}, ${t.scale})"
       case StringType => s"$getter.getUTF8String($ordinal)"
       case BinaryType => s"$getter.getBinary($ordinal)"
       case CalendarIntervalType => s"$getter.getInterval($ordinal)"
@@ -120,10 +121,10 @@ class CodeGenContext {
    */
   def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = {
     val jt = javaType(dataType)
-    if (isPrimitiveType(jt)) {
-      s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
-    } else {
-      s"$row.update($ordinal, $value)"
+    dataType match {
+      case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
+      case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
+      case _ => s"$row.update($ordinal, $value)"
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index a662357..1d22398 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -35,6 +35,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
   private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
   private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName
   private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName
+  private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName
+  private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName
 
   /** Returns true iff we support this data type. */
   def canSupport(dataType: DataType): Boolean = dataType match {
@@ -42,9 +44,64 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression],
UnsafePro
     case _: CalendarIntervalType => true
     case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
     case NullType => true
+    case t: DecimalType => true
     case _ => false
   }
 
+  def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match {
+    case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
+      s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))"
+    case StringType =>
+      s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
+    case BinaryType =>
+      s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
+    case CalendarIntervalType =>
+      s" + (${ev.isNull} ? 0 : 16)"
+    case _: StructType =>
+      s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
+    case _ => ""
+  }
+
+  def genFieldWriter(
+      ctx: CodeGenContext,
+      fieldType: DataType,
+      ev: GeneratedExpressionCode,
+      primitive: String,
+      index: Int,
+      cursor: String): String = fieldType match {
+    case _ if ctx.isPrimitiveType(fieldType) =>
+      s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}"
+    case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
+      s"""
+       // make sure Decimal object has the same scale as DecimalType
+       if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
+         $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
+       } else {
+         $primitive.setNullAt($index);
+       }
+       """
+    case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
+      s"""
+       // make sure Decimal object has the same scale as DecimalType
+       if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
+         $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive});
+       } else {
+         $primitive.setNullAt($index);
+       }
+       """
+    case StringType =>
+      s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+    case BinaryType =>
+      s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+    case CalendarIntervalType =>
+      s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+    case t: StructType =>
+      s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})"
+    case NullType => ""
+    case _ =>
+      throw new UnsupportedOperationException(s"Not supported DataType: $fieldType")
+  }
+
   /**
    * Generates the code to create an [[UnsafeRow]] object based on the input expressions.
    * @param ctx context for code generation
@@ -69,36 +126,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression],
UnsafePro
     val allExprs = exprs.map(_.code).mkString("\n")
 
     val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
-    val additionalSize = expressions.zipWithIndex.map { case (e, i) =>
-      e.dataType match {
-        case StringType =>
-          s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))"
-        case BinaryType =>
-          s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))"
-        case CalendarIntervalType =>
-          s" + (${exprs(i).isNull} ? 0 : 16)"
-        case _: StructType =>
-          s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))"
-        case _ => ""
-      }
+    val additionalSize = expressions.zipWithIndex.map {
+      case (e, i) => genAdditionalSize(e.dataType, exprs(i))
     }.mkString("")
 
     val writers = expressions.zipWithIndex.map { case (e, i) =>
-      val update = e.dataType match {
-        case dt if ctx.isPrimitiveType(dt) =>
-          s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}"
-        case StringType =>
-          s"$cursor += $StringWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
-        case BinaryType =>
-          s"$cursor += $BinaryWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
-        case CalendarIntervalType =>
-          s"$cursor += $IntervalWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
-        case t: StructType =>
-          s"$cursor += $StructWriter.write($ret, $i, $cursor, ${exprs(i).primitive})"
-        case NullType => ""
-        case _ =>
-          throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
-      }
+      val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor)
       s"""if (${exprs(i).isNull}) {
             $ret.setNullAt($i);
           } else {
@@ -168,35 +201,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression],
UnsafePro
 
     val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length)
     val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) =>
-      dt match {
-        case StringType =>
-          s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))"
-        case BinaryType =>
-          s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))"
-        case CalendarIntervalType =>
-          s" + (${ev.isNull} ? 0 : 16)"
-        case _: StructType =>
-          s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))"
-        case _ => ""
-      }
+      genAdditionalSize(dt, ev)
     }.mkString("")
 
     val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev),
i) =>
-      val update = dt match {
-        case _ if ctx.isPrimitiveType(dt) =>
-          s"${ctx.setColumn(primitive, dt, i, exprs(i).primitive)}"
-        case StringType =>
-          s"$cursor += $StringWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
-        case BinaryType =>
-          s"$cursor += $BinaryWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
-        case CalendarIntervalType =>
-          s"$cursor += $IntervalWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
-        case t: StructType =>
-          s"$cursor += $StructWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})"
-        case NullType => ""
-        case _ =>
-          throw new UnsupportedOperationException(s"Not supported DataType: $dt")
-      }
+      val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor)
       s"""
           if (${exprs(i).isNull}) {
             $primitive.setNullAt($i);

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index b7c4ece..df6ea58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.types.{DataType, StructType, AtomicType}
+import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType}
 import org.apache.spark.unsafe.types.UTF8String
 
 /**
@@ -39,6 +39,7 @@ abstract class MutableRow extends InternalRow {
   def setShort(i: Int, value: Short): Unit = { update(i, value) }
   def setByte(i: Int, value: Byte): Unit = { update(i, value) }
   def setFloat(i: Int, value: Float): Unit = { update(i, value) }
+  def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) }
   def setString(i: Int, value: String): Unit = {
     update(i, UTF8String.fromString(value))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index bc68981..c0155ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -188,6 +188,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
    * @return true if successful, false if overflow would occur
    */
   def changePrecision(precision: Int, scale: Int): Boolean = {
+    // fast path for UnsafeProjection
+    if (precision == this.precision && scale == this.scale) {
+      return true
+    }
     // First, update our longVal if we can, or transfer over to using a BigDecimal
     if (decimalVal.eq(null)) {
       if (scale < _scale) {
@@ -224,7 +228,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
       decimalVal = newVal
     } else {
       // We're still using Longs, but we should check whether we match the new precision
-      val p = POW_10(math.min(_precision, MAX_LONG_DIGITS))
+      val p = POW_10(math.min(precision, MAX_LONG_DIGITS))
       if (longVal <= -p || longVal >= p) {
         // Note that we shouldn't have been able to fix this by switching to BigDecimal
         return false

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
index 7992ba9..35ace67 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -43,7 +43,7 @@ class GenericArrayData(array: Array[Any]) extends ArrayData {
 
   override def getDouble(ordinal: Int): Double = getAs(ordinal)
 
-  override def getDecimal(ordinal: Int): Decimal = getAs(ordinal)
+  override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal)
 
   override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 4f35b65..1ad7073 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -242,10 +242,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
 
     checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
     checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
-    checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0))
+    checkEvaluation(cast(123L, DecimalType(3, 1)), null)
 
-    // TODO: Fix the following bug and re-enable it.
-    // checkEvaluation(cast(123L, DecimalType(2, 0)), null)
+    checkEvaluation(cast(123L, DecimalType(2, 0)), null)
   }
 
   test("cast from boolean") {

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
index fd1d6c1..887e436 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import java.sql.{Timestamp, Date}
+import java.sql.{Date, Timestamp}
 import java.text.SimpleDateFormat
 import java.util.Calendar
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
index 6a90729..c6b4c72 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
@@ -55,13 +55,13 @@ class UnsafeFixedWidthAggregationMapSuite
   }
 
   test("supported schemas") {
+    assert(supportsAggregationBufferSchema(
+      StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil)))
+    assert(!supportsAggregationBufferSchema(
+      StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil)))
     assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
-    assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil)))
-
     assert(
       !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType))
:: Nil)))
-    assert(
-      !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
   }
 
   test("empty map") {

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index b7bc17f..a0e1701 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -46,7 +46,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(unsafeRow.getLong(1) === 1)
     assert(unsafeRow.getInt(2) === 2)
 
-    // We can copy UnsafeRows as long as they don't reference ObjectPools
     val unsafeRowCopy = unsafeRow.copy()
     assert(unsafeRowCopy.getLong(0) === 0)
     assert(unsafeRowCopy.getLong(1) === 1)
@@ -122,8 +121,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       FloatType,
       DoubleType,
       StringType,
-      BinaryType
-      // DecimalType.Default,
+      BinaryType,
+      DecimalType.USER_DEFAULT
       // ArrayType(IntegerType)
     )
     val converter = UnsafeProjection.create(fieldTypes)
@@ -150,7 +149,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(createdFromNull.getDouble(7) === 0.0d)
     assert(createdFromNull.getUTF8String(8) === null)
     assert(createdFromNull.getBinary(9) === null)
-    // assert(createdFromNull.get(10) === null)
+    assert(createdFromNull.getDecimal(10, 10, 0) === null)
     // assert(createdFromNull.get(11) === null)
 
     // If we have an UnsafeRow with columns that are initially non-null and we null out those
@@ -168,7 +167,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
       r.setDouble(7, 700)
       r.update(8, UTF8String.fromString("hello"))
       r.update(9, "world".getBytes)
-      // r.update(10, Decimal(10))
+      r.setDecimal(10, Decimal(10), 10)
       // r.update(11, Array(11))
       r
     }
@@ -184,7 +183,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
     assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
     assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9))
-    // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
+    assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
+      rowWithNoNullColumns.getDecimal(10, 10, 0))
     // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
 
     for (i <- fieldTypes.indices) {
@@ -203,7 +203,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     setToNullAfterCreation.setDouble(7, 700)
     // setToNullAfterCreation.update(8, UTF8String.fromString("hello"))
     // setToNullAfterCreation.update(9, "world".getBytes)
-    // setToNullAfterCreation.update(10, Decimal(10))
+    setToNullAfterCreation.setDecimal(10, Decimal(10), 10)
     // setToNullAfterCreation.update(11, Array(11))
 
     assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
@@ -216,7 +216,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
     // assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
     // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9))
-    // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
+    assert(setToNullAfterCreation.getDecimal(10, 10, 0) ===
+      rowWithNoNullColumns.getDecimal(10, 10, 0))
     // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index 454b7b9..1620fc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -114,7 +114,7 @@ private[sql] class FixedDecimalColumnBuilder(
     precision: Int,
     scale: Int)
   extends NativeColumnBuilder(
-    new FixedDecimalColumnStats,
+    new FixedDecimalColumnStats(precision, scale),
     FIXED_DECIMAL(precision, scale))
 
 // TODO (lian) Add support for array, struct and map

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index 32a84b2..af1a8ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -234,14 +234,14 @@ private[sql] class BinaryColumnStats extends ColumnStats {
     InternalRow(null, null, nullCount, count, sizeInBytes)
 }
 
-private[sql] class FixedDecimalColumnStats extends ColumnStats {
+private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats
{
   protected var upper: Decimal = null
   protected var lower: Decimal = null
 
   override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
     super.gatherStats(row, ordinal)
     if (!row.isNullAt(ordinal)) {
-      val value = row.getDecimal(ordinal)
+      val value = row.getDecimal(ordinal, precision, scale)
       if (upper == null || value.compareTo(upper) > 0) upper = value
       if (lower == null || value.compareTo(lower) < 0) lower = value
       sizeInBytes += FIXED_DECIMAL.defaultSize

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 2863f6c..30f8fe3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -392,7 +392,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
   }
 
   override def getField(row: InternalRow, ordinal: Int): Decimal = {
-    row.getDecimal(ordinal)
+    row.getDecimal(ordinal, precision, scale)
   }
 
   override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index b85aada..d851eae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -202,7 +202,7 @@ case class GeneratedAggregate(
 
     val schemaSupportsUnsafe: Boolean = {
       UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema)
&&
-        UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema)
+        UnsafeProjection.canSupport(groupKeySchema)
     }
 
     child.execute().mapPartitions { iter =>

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
index c808442..e5bbd0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 {
                 out.writeByte(NULL)
               } else {
                 out.writeByte(NOT_NULL)
-                val value = row.getDecimal(i)
+                val value = row.getDecimal(i, decimal.precision, decimal.scale)
                 val javaBigDecimal = value.toJavaBigDecimal
                 // First, write out the unscaled value.
                 val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 79dd16b..ec8da38 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -293,8 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport
{
         writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes))
       case BinaryType =>
         writer.addBinary(Binary.fromByteArray(record.getBinary(index)))
-      case DecimalType.Fixed(precision, _) =>
-        writeDecimal(record.getDecimal(index), precision)
+      case DecimalType.Fixed(precision, scale) =>
+        writeDecimal(record.getDecimal(index, precision, scale), precision)
       case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0b1a464b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index 4499a72..66014dd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -34,8 +34,7 @@ class ColumnStatsSuite extends SparkFunSuite {
   testColumnStats(classOf[DoubleColumnStats], DOUBLE,
     InternalRow(Double.MaxValue, Double.MinValue, 0))
   testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0))
-  testColumnStats(classOf[FixedDecimalColumnStats],
-    FIXED_DECIMAL(15, 10), InternalRow(null, null, 0))
+  testDecimalColumnStats(InternalRow(null, null, 0))
 
   def testColumnStats[T <: AtomicType, U <: ColumnStats](
       columnStatsClass: Class[U],
@@ -52,7 +51,7 @@ class ColumnStatsSuite extends SparkFunSuite {
     }
 
     test(s"$columnStatsName: non-empty") {
-      import ColumnarTestUtils._
+      import org.apache.spark.sql.columnar.ColumnarTestUtils._
 
       val columnStats = columnStatsClass.newInstance()
       val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
@@ -73,4 +72,39 @@ class ColumnStatsSuite extends SparkFunSuite {
       }
     }
   }
+
+  def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics:
InternalRow) {
+
+    val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName
+    val columnType = FIXED_DECIMAL(15, 10)
+
+    test(s"$columnStatsName: empty") {
+      val columnStats = new FixedDecimalColumnStats(15, 10)
+      columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach {
+        case (actual, expected) => assert(actual === expected)
+      }
+    }
+
+    test(s"$columnStatsName: non-empty") {
+      import org.apache.spark.sql.columnar.ColumnarTestUtils._
+
+      val columnStats = new FixedDecimalColumnStats(15, 10)
+      val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
+      rows.foreach(columnStats.gatherStats(_, 0))
+
+      val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType])
+      val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
+      val stats = columnStats.collectedStatistics
+
+      assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0))
+      assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1))
+      assertResult(10, "Wrong null count")(stats.genericGet(2))
+      assertResult(20, "Wrong row count")(stats.genericGet(3))
+      assertResult(stats.genericGet(4), "Wrong size in bytes") {
+        rows.map { row =>
+          if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
+        }.sum
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message