spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject [2/3] spark git commit: [SPARK-4233] [SPARK-4367] [SPARK-3947] [SPARK-3056] [SQL] Aggregation Improvement
Date Wed, 22 Jul 2015 06:26:17 GMT
http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
new file mode 100644
index 0000000..ce1cbdc
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
@@ -0,0 +1,749 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.types.NullType
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * An iterator used to evaluate aggregate functions. It assumes that input rows
+ * are already grouped by values of `groupingExpressions`.
+ */
+private[sql] abstract class SortAggregationIterator(
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression2],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends Iterator[InternalRow] {
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Static fields for this iterator
+  ///////////////////////////////////////////////////////////////////////////
+
+  protected val aggregateFunctions: Array[AggregateFunction2] = {
+    var bufferOffset = initialBufferOffset
+    val functions = new Array[AggregateFunction2](aggregateExpressions.length)
+    var i = 0
+    while (i < aggregateExpressions.length) {
+      val func = aggregateExpressions(i).aggregateFunction
+      val funcWithBoundReferences = aggregateExpressions(i).mode match {
+        case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
+          // We need to create BoundReferences if the function is not an
+          // AlgebraicAggregate (it does not support code-gen) and the mode of
+          // this function is Partial or Complete because we will call eval of this
+          // function's children in the update method of this aggregate function.
+          // Those eval calls require BoundReferences to work.
+          BindReferences.bindReference(func, inputAttributes)
+        case _ => func
+      }
+      // Set bufferOffset for this function. It is important that setting bufferOffset
+      // happens after all potential bindReference operations because bindReference
+      // will create a new instance of the function.
+      funcWithBoundReferences.bufferOffset = bufferOffset
+      bufferOffset += funcWithBoundReferences.bufferSchema.length
+      functions(i) = funcWithBoundReferences
+      i += 1
+    }
+    functions
+  }
+
+  // All non-algebraic aggregate functions.
+  protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+    aggregateFunctions.collect {
+      case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+    }.toArray
+  }
+
+  // Positions of those non-algebraic aggregate functions in aggregateFunctions.
+  // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
+  // func2 and func3 are non-algebraic aggregate functions.
+  // nonAlgebraicAggregateFunctionPositions will be [1, 2].
+  protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = {
+    val positions = new ArrayBuffer[Int]()
+    var i = 0
+    while (i < aggregateFunctions.length) {
+      aggregateFunctions(i) match {
+        case agg: AlgebraicAggregate =>
+        case _ => positions += i
+      }
+      i += 1
+    }
+    positions.toArray
+  }
+
+  // This is used to project expressions for the grouping expressions.
+  protected val groupGenerator =
+    newMutableProjection(groupingExpressions, inputAttributes)()
+
+  // The underlying buffer shared by all aggregate functions.
+  protected val buffer: MutableRow = {
+    // The number of elements of the underlying buffer of this operator.
+    // All aggregate functions are sharing this underlying buffer and they find their
+    // buffer values through bufferOffset.
+    var size = initialBufferOffset
+    var i = 0
+    while (i < aggregateFunctions.length) {
+      size += aggregateFunctions(i).bufferSchema.length
+      i += 1
+    }
+    new GenericMutableRow(size)
+  }
+
+  protected val joinedRow = new JoinedRow4
+
+  protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp)
+
+  // This projection is used to initialize buffer values for all AlgebraicAggregates.
+  protected val algebraicInitialProjection = {
+    val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+      case ae: AlgebraicAggregate => ae.initialValues
+      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+    }
+    newMutableProjection(initExpressions, Nil)().target(buffer)
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Mutable states
+  ///////////////////////////////////////////////////////////////////////////
+
+  // The partition key of the current partition.
+  protected var currentGroupingKey: InternalRow = _
+  // The partition key of next partition.
+  protected var nextGroupingKey: InternalRow = _
+  // The first row of next partition.
+  protected var firstRowInNextGroup: InternalRow = _
+  // Indicates if we has new group of rows to process.
+  protected var hasNewGroup: Boolean = true
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Private methods
+  ///////////////////////////////////////////////////////////////////////////
+
+  /** Initializes buffer values for all aggregate functions. */
+  protected def initializeBuffer(): Unit = {
+    algebraicInitialProjection(EmptyRow)
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      nonAlgebraicAggregateFunctions(i).initialize(buffer)
+      i += 1
+    }
+  }
+
+  protected def initialize(): Unit = {
+    if (inputIter.hasNext) {
+      initializeBuffer()
+      val currentRow = inputIter.next().copy()
+      // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+      // we are making a copy at here.
+      nextGroupingKey = groupGenerator(currentRow).copy()
+      firstRowInNextGroup = currentRow
+    } else {
+      // This iter is an empty one.
+      hasNewGroup = false
+    }
+  }
+
+  /** Processes rows in the current group. It will stop when it find a new group. */
+  private def processCurrentGroup(): Unit = {
+    currentGroupingKey = nextGroupingKey
+    // Now, we will start to find all rows belonging to this group.
+    // We create a variable to track if we see the next group.
+    var findNextPartition = false
+    // firstRowInNextGroup is the first row of this group. We first process it.
+    processRow(firstRowInNextGroup)
+    // The search will stop when we see the next group or there is no
+    // input row left in the iter.
+    while (inputIter.hasNext && !findNextPartition) {
+      val currentRow = inputIter.next()
+      // Get the grouping key based on the grouping expressions.
+      // For the below compare method, we do not need to make a copy of groupingKey.
+      val groupingKey = groupGenerator(currentRow)
+      // Check if the current row belongs the current input row.
+      currentGroupingKey.equals(groupingKey)
+
+      if (currentGroupingKey == groupingKey) {
+        processRow(currentRow)
+      } else {
+        // We find a new group.
+        findNextPartition = true
+        nextGroupingKey = groupingKey.copy()
+        firstRowInNextGroup = currentRow.copy()
+      }
+    }
+    // We have not seen a new group. It means that there is no new row in the input
+    // iter. The current group is the last group of the iter.
+    if (!findNextPartition) {
+      hasNewGroup = false
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Public methods
+  ///////////////////////////////////////////////////////////////////////////
+
+  override final def hasNext: Boolean = hasNewGroup
+
+  override final def next(): InternalRow = {
+    if (hasNext) {
+      // Process the current group.
+      processCurrentGroup()
+      // Generate output row for the current group.
+      val outputRow = generateOutput()
+      // Initilize buffer values for the next group.
+      initializeBuffer()
+
+      outputRow
+    } else {
+      // no more result
+      throw new NoSuchElementException
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Methods that need to be implemented
+  ///////////////////////////////////////////////////////////////////////////
+
+  protected def initialBufferOffset: Int
+
+  protected def processRow(row: InternalRow): Unit
+
+  protected def generateOutput(): InternalRow
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Initialize this iterator
+  ///////////////////////////////////////////////////////////////////////////
+
+  initialize()
+}
+
+/**
+ * An iterator only used to group input rows according to values of `groupingExpressions`.
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ */
+class GroupingIterator(
+    groupingExpressions: Seq[NamedExpression],
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends SortAggregationIterator(
+    groupingExpressions,
+    Nil,
+    newMutableProjection,
+    inputAttributes,
+    inputIter) {
+
+  private val resultProjection =
+    newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))()
+
+  override protected def initialBufferOffset: Int = 0
+
+  override protected def processRow(row: InternalRow): Unit = {
+    // Since we only do grouping, there is nothing to do at here.
+  }
+
+  override protected def generateOutput(): InternalRow = {
+    resultProjection(currentGroupingKey)
+  }
+}
+
+/**
+ * An iterator used to do partial aggregations (for those aggregate functions with mode Partial).
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ * The format of its output rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ */
+class PartialSortAggregationIterator(
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression2],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends SortAggregationIterator(
+    groupingExpressions,
+    aggregateExpressions,
+    newMutableProjection,
+    inputAttributes,
+    inputIter) {
+
+  // This projection is used to update buffer values for all AlgebraicAggregates.
+  private val algebraicUpdateProjection = {
+    val bufferSchema = aggregateFunctions.flatMap {
+      case ae: AlgebraicAggregate => ae.bufferAttributes
+      case agg: AggregateFunction2 => agg.bufferAttributes
+    }
+    val updateExpressions = aggregateFunctions.flatMap {
+      case ae: AlgebraicAggregate => ae.updateExpressions
+      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+    }
+    newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
+  }
+
+  override protected def initialBufferOffset: Int = 0
+
+  override protected def processRow(row: InternalRow): Unit = {
+    // Process all algebraic aggregate functions.
+    algebraicUpdateProjection(joinedRow(buffer, row))
+    // Process all non-algebraic aggregate functions.
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      nonAlgebraicAggregateFunctions(i).update(buffer, row)
+      i += 1
+    }
+  }
+
+  override protected def generateOutput(): InternalRow = {
+    // We just output the grouping expressions and the underlying buffer.
+    joinedRow(currentGroupingKey, buffer).copy()
+  }
+}
+
+/**
+ * An iterator used to do partial merge aggregations (for those aggregate functions with mode
+ * PartialMerge). It assumes that input rows are already grouped by values of
+ * `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN|
+ * Every placeholder is for a grouping expression.
+ * The actual buffers are stored after placeholderN.
+ * The reason that we have placeholders at here is to make our underlying buffer have the same
+ * length with a input row.
+ *
+ * The format of its output rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ */
+class PartialMergeSortAggregationIterator(
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression2],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends SortAggregationIterator(
+    groupingExpressions,
+    aggregateExpressions,
+    newMutableProjection,
+    inputAttributes,
+    inputIter) {
+
+  private val placeholderAttribtues =
+    Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+  // This projection is used to merge buffer values for all AlgebraicAggregates.
+  private val algebraicMergeProjection = {
+    val bufferSchemata =
+      placeholderAttribtues ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      } ++ placeholderAttribtues ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+        case agg: AggregateFunction2 => agg.cloneBufferAttributes
+      }
+    val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+      case ae: AlgebraicAggregate => ae.mergeExpressions
+      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+    }
+
+    newMutableProjection(mergeExpressions, bufferSchemata)()
+  }
+
+  // This projection is used to extract aggregation buffers from the underlying buffer.
+  // We need it because the underlying buffer has placeholders at its beginning.
+  private val extractsBufferValues = {
+    val expressions = aggregateFunctions.flatMap {
+      case agg => agg.bufferAttributes
+    }
+
+    newMutableProjection(expressions, inputAttributes)()
+  }
+
+  override protected def initialBufferOffset: Int = groupingExpressions.length
+
+  override protected def processRow(row: InternalRow): Unit = {
+    // Process all algebraic aggregate functions.
+    algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
+    // Process all non-algebraic aggregate functions.
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      nonAlgebraicAggregateFunctions(i).merge(buffer, row)
+      i += 1
+    }
+  }
+
+  override protected def generateOutput(): InternalRow = {
+    // We output grouping expressions and aggregation buffers.
+    joinedRow(currentGroupingKey, extractsBufferValues(buffer))
+  }
+}
+
+/**
+ * An iterator used to do final aggregations (for those aggregate functions with mode
+ * Final). It assumes that input rows are already grouped by values of
+ * `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN|
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN|
+ * Every placeholder is for a grouping expression.
+ * The actual buffers are stored after placeholderN.
+ * The reason that we have placeholders at here is to make our underlying buffer have the same
+ * length with a input row.
+ *
+ * The format of its output rows is represented by the schema of `resultExpressions`.
+ */
+class FinalSortAggregationIterator(
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression2],
+    aggregateAttributes: Seq[Attribute],
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends SortAggregationIterator(
+    groupingExpressions,
+    aggregateExpressions,
+    newMutableProjection,
+    inputAttributes,
+    inputIter) {
+
+  // The result of aggregate functions.
+  private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length)
+
+  // The projection used to generate the output rows of this operator.
+  // This is only used when we are generating final results of aggregate functions.
+  private val resultProjection =
+    newMutableProjection(
+      resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
+
+  private val offsetAttributes =
+    Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+  // This projection is used to merge buffer values for all AlgebraicAggregates.
+  private val algebraicMergeProjection = {
+    val bufferSchemata =
+      offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+        case agg: AggregateFunction2 => agg.cloneBufferAttributes
+      }
+    val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
+      case ae: AlgebraicAggregate => ae.mergeExpressions
+      case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+    }
+
+    newMutableProjection(mergeExpressions, bufferSchemata)()
+  }
+
+  // This projection is used to evaluate all AlgebraicAggregates.
+  private val algebraicEvalProjection = {
+    val bufferSchemata =
+      offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+        case agg: AggregateFunction2 => agg.cloneBufferAttributes
+      }
+    val evalExpressions = aggregateFunctions.map {
+      case ae: AlgebraicAggregate => ae.evaluateExpression
+      case agg: AggregateFunction2 => NoOp
+    }
+
+    newMutableProjection(evalExpressions, bufferSchemata)()
+  }
+
+  override protected def initialBufferOffset: Int = groupingExpressions.length
+
+  override def initialize(): Unit = {
+    if (inputIter.hasNext) {
+      initializeBuffer()
+      val currentRow = inputIter.next().copy()
+      // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+      // we are making a copy at here.
+      nextGroupingKey = groupGenerator(currentRow).copy()
+      firstRowInNextGroup = currentRow
+    } else {
+      if (groupingExpressions.isEmpty) {
+        // If there is no grouping expression, we need to generate a single row as the output.
+        initializeBuffer()
+        // Right now, the buffer only contains initial buffer values. Because
+        // merging two buffers with initial values will generate a row that
+        // still store initial values. We set the currentRow as the copy of the current buffer.
+        val currentRow = buffer.copy()
+        nextGroupingKey = groupGenerator(currentRow).copy()
+        firstRowInNextGroup = currentRow
+      } else {
+        // This iter is an empty one.
+        hasNewGroup = false
+      }
+    }
+  }
+
+  override protected def processRow(row: InternalRow): Unit = {
+    // Process all algebraic aggregate functions.
+    algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
+    // Process all non-algebraic aggregate functions.
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      nonAlgebraicAggregateFunctions(i).merge(buffer, row)
+      i += 1
+    }
+  }
+
+  override protected def generateOutput(): InternalRow = {
+    // Generate results for all algebraic aggregate functions.
+    algebraicEvalProjection.target(aggregateResult)(buffer)
+    // Generate results for all non-algebraic aggregate functions.
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      aggregateResult.update(
+        nonAlgebraicAggregateFunctionPositions(i),
+        nonAlgebraicAggregateFunctions(i).eval(buffer))
+      i += 1
+    }
+    resultProjection(joinedRow(currentGroupingKey, aggregateResult))
+  }
+}
+
+/**
+ * An iterator used to do both final aggregations (for those aggregate functions with mode
+ * Final) and complete aggregations (for those aggregate functions with mode Complete).
+ * It assumes that input rows are already grouped by values of `groupingExpressions`.
+ * The format of its input rows is:
+ * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN|
+ * col1 to colM are columns used by aggregate functions with Complete mode.
+ * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with
+ * Final mode.
+ *
+ * The format of its internal buffer is:
+ * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)|
+ * The first N placeholders represent slots of grouping expressions.
+ * Then, next M placeholders represent slots of col1 to colM.
+ * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with
+ * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode
+ * Complete. The reason that we have placeholders at here is to make our underlying buffer
+ * have the same length with a input row.
+ *
+ * The format of its output rows is represented by the schema of `resultExpressions`.
+ */
+class FinalAndCompleteSortAggregationIterator(
+    override protected val initialBufferOffset: Int,
+    groupingExpressions: Seq[NamedExpression],
+    finalAggregateExpressions: Seq[AggregateExpression2],
+    finalAggregateAttributes: Seq[Attribute],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection),
+    inputAttributes: Seq[Attribute],
+    inputIter: Iterator[InternalRow])
+  extends SortAggregationIterator(
+    groupingExpressions,
+    // TODO: document the ordering
+    finalAggregateExpressions ++ completeAggregateExpressions,
+    newMutableProjection,
+    inputAttributes,
+    inputIter) {
+
+  // The result of aggregate functions.
+  private val aggregateResult: MutableRow =
+    new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length)
+
+  // The projection used to generate the output rows of this operator.
+  // This is only used when we are generating final results of aggregate functions.
+  private val resultProjection = {
+    val inputSchema =
+      groupingExpressions.map(_.toAttribute) ++
+        finalAggregateAttributes ++
+        completeAggregateAttributes
+    newMutableProjection(resultExpressions, inputSchema)()
+  }
+
+  private val offsetAttributes =
+    Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
+
+  // All aggregate functions with mode Final.
+  private val finalAggregateFunctions: Array[AggregateFunction2] = {
+    val functions = new Array[AggregateFunction2](finalAggregateExpressions.length)
+    var i = 0
+    while (i < finalAggregateExpressions.length) {
+      functions(i) = aggregateFunctions(i)
+      i += 1
+    }
+    functions
+  }
+
+  // All non-algebraic aggregate functions with mode Final.
+  private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+    finalAggregateFunctions.collect {
+      case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+    }.toArray
+  }
+
+  // All aggregate functions with mode Complete.
+  private val completeAggregateFunctions: Array[AggregateFunction2] = {
+    val functions = new Array[AggregateFunction2](completeAggregateExpressions.length)
+    var i = 0
+    while (i < completeAggregateExpressions.length) {
+      functions(i) = aggregateFunctions(finalAggregateFunctions.length + i)
+      i += 1
+    }
+    functions
+  }
+
+  // All non-algebraic aggregate functions with mode Complete.
+  private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+    completeAggregateFunctions.collect {
+      case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
+    }.toArray
+  }
+
+  // This projection is used to merge buffer values for all AlgebraicAggregates with mode
+  // Final.
+  private val finalAlgebraicMergeProjection = {
+    val numCompleteOffsetAttributes =
+      completeAggregateFunctions.map(_.bufferAttributes.length).sum
+    val completeOffsetAttributes =
+      Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)())
+    val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp)
+
+    val bufferSchemata =
+      offsetAttributes ++ finalAggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+        case agg: AggregateFunction2 => agg.cloneBufferAttributes
+      } ++ completeOffsetAttributes
+    val mergeExpressions =
+      placeholderExpressions ++ finalAggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.mergeExpressions
+        case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+      } ++ completeOffsetExpressions
+
+    newMutableProjection(mergeExpressions, bufferSchemata)()
+  }
+
+  // This projection is used to update buffer values for all AlgebraicAggregates with mode
+  // Complete.
+  private val completeAlgebraicUpdateProjection = {
+    val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum
+    val finalOffsetAttributes =
+      Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)())
+    val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp)
+
+    val bufferSchema =
+      offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      }
+    val updateExpressions =
+      placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.updateExpressions
+        case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
+      }
+    newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer)
+  }
+
+  // This projection is used to evaluate all AlgebraicAggregates.
+  private val algebraicEvalProjection = {
+    val bufferSchemata =
+      offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.bufferAttributes
+        case agg: AggregateFunction2 => agg.bufferAttributes
+      } ++ offsetAttributes ++ aggregateFunctions.flatMap {
+        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
+        case agg: AggregateFunction2 => agg.cloneBufferAttributes
+      }
+    val evalExpressions = aggregateFunctions.map {
+      case ae: AlgebraicAggregate => ae.evaluateExpression
+      case agg: AggregateFunction2 => NoOp
+    }
+
+    newMutableProjection(evalExpressions, bufferSchemata)()
+  }
+
+  override def initialize(): Unit = {
+    if (inputIter.hasNext) {
+      initializeBuffer()
+      val currentRow = inputIter.next().copy()
+      // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
+      // we are making a copy at here.
+      nextGroupingKey = groupGenerator(currentRow).copy()
+      firstRowInNextGroup = currentRow
+    } else {
+      if (groupingExpressions.isEmpty) {
+        // If there is no grouping expression, we need to generate a single row as the output.
+        initializeBuffer()
+        // Right now, the buffer only contains initial buffer values. Because
+        // merging two buffers with initial values will generate a row that
+        // still store initial values. We set the currentRow as the copy of the current buffer.
+        val currentRow = buffer.copy()
+        nextGroupingKey = groupGenerator(currentRow).copy()
+        firstRowInNextGroup = currentRow
+      } else {
+        // This iter is an empty one.
+        hasNewGroup = false
+      }
+    }
+  }
+
+  override protected def processRow(row: InternalRow): Unit = {
+    val input = joinedRow(buffer, row)
+    // For all aggregate functions with mode Complete, update buffers.
+    completeAlgebraicUpdateProjection(input)
+    var i = 0
+    while (i < completeNonAlgebraicAggregateFunctions.length) {
+      completeNonAlgebraicAggregateFunctions(i).update(buffer, row)
+      i += 1
+    }
+
+    // For all aggregate functions with mode Final, merge buffers.
+    finalAlgebraicMergeProjection.target(buffer)(input)
+    i = 0
+    while (i < finalNonAlgebraicAggregateFunctions.length) {
+      finalNonAlgebraicAggregateFunctions(i).merge(buffer, row)
+      i += 1
+    }
+  }
+
+  override protected def generateOutput(): InternalRow = {
+    // Generate results for all algebraic aggregate functions.
+    algebraicEvalProjection.target(aggregateResult)(buffer)
+    // Generate results for all non-algebraic aggregate functions.
+    var i = 0
+    while (i < nonAlgebraicAggregateFunctions.length) {
+      aggregateResult.update(
+        nonAlgebraicAggregateFunctionPositions(i),
+        nonAlgebraicAggregateFunctions(i).eval(buffer))
+      i += 1
+    }
+
+    resultProjection(joinedRow(currentGroupingKey, aggregateResult))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
new file mode 100644
index 0000000..1cb2771
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+
+/**
+ * Utility functions used by the query planner to convert our plan to new aggregation code path.
+ */
+object Utils {
+  // Right now, we do not support complex types in the grouping key schema.
+  private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
+    val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
+      case array: ArrayType => true
+      case map: MapType => true
+      case struct: StructType => true
+      case _ => false
+    }
+
+    !hasComplexTypes
+  }
+
+  private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+    case p: Aggregate if supportsGroupingKeySchema(p) =>
+      val converted = p.transformExpressionsDown {
+        case expressions.Average(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Average(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Count(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Count(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        // We do not support multiple COUNT DISTINCT columns for now.
+        case expressions.CountDistinct(children) if children.length == 1 =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Count(children.head),
+            mode = aggregate.Complete,
+            isDistinct = true)
+
+        case expressions.First(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.First(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Last(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Last(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Max(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Max(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Min(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Min(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Sum(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Sum(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.SumDistinct(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Sum(child),
+            mode = aggregate.Complete,
+            isDistinct = true)
+      }
+      // Check if there is any expressions.AggregateExpression1 left.
+      // If so, we cannot convert this plan.
+      val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
+        // For every expressions, check if it contains AggregateExpression1.
+        expr.find {
+          case agg: expressions.AggregateExpression1 => true
+          case other => false
+        }.isDefined
+      }
+
+      // Check if there are multiple distinct columns.
+      val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
+        expr.collect {
+          case agg: AggregateExpression2 => agg
+        }
+      }.toSet.toSeq
+      val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
+      val hasMultipleDistinctColumnSets =
+        if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
+          true
+        } else {
+          false
+        }
+
+      if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
+
+    case other => None
+  }
+
+  private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
+    // If the plan cannot be converted, we will do a final round check to if the original
+    // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
+    // we need to throw an exception.
+    val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
+      expr.collect {
+        case agg: AggregateExpression2 => agg.aggregateFunction
+      }
+    }.distinct
+    if (aggregateFunction2s.nonEmpty) {
+      // For functions implemented based on the new interface, prepare a list of function names.
+      val invalidFunctions = {
+        if (aggregateFunction2s.length > 1) {
+          s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
+            s"and ${aggregateFunction2s.head.nodeName} are"
+        } else {
+          s"${aggregateFunction2s.head.nodeName} is"
+        }
+      }
+      val errorMessage =
+        s"${invalidFunctions} implemented based on the new Aggregate Function " +
+          s"interface and it cannot be used with functions implemented based on " +
+          s"the old Aggregate Function interface."
+      throw new AnalysisException(errorMessage)
+    }
+  }
+
+  def tryConvert(
+      plan: LogicalPlan,
+      useNewAggregation: Boolean,
+      codeGenEnabled: Boolean): Option[Aggregate] = plan match {
+    case p: Aggregate if useNewAggregation && codeGenEnabled =>
+      val converted = tryConvert(p)
+      if (converted.isDefined) {
+        converted
+      } else {
+        checkInvalidAggregateFunction2(p)
+        None
+      }
+    case p: Aggregate =>
+      checkInvalidAggregateFunction2(p)
+      None
+    case other => None
+  }
+
+  def planAggregateWithoutDistinct(
+      groupingExpressions: Seq[Expression],
+      aggregateExpressions: Seq[AggregateExpression2],
+      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+      resultExpressions: Seq[NamedExpression],
+      child: SparkPlan): Seq[SparkPlan] = {
+    // 1. Create an Aggregate Operator for partial aggregations.
+    val namedGroupingExpressions = groupingExpressions.map {
+      case ne: NamedExpression => ne -> ne
+      // If the expression is not a NamedExpressions, we add an alias.
+      // So, when we generate the result of the operator, the Aggregate Operator
+      // can directly get the Seq of attributes representing the grouping expressions.
+      case other =>
+        val withAlias = Alias(other, other.toString)()
+        other -> withAlias
+    }
+    val groupExpressionMap = namedGroupingExpressions.toMap
+    val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+    val partialAggregateExpressions = aggregateExpressions.map {
+      case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+        AggregateExpression2(aggregateFunction, Partial, isDistinct)
+    }
+    val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
+      agg.aggregateFunction.bufferAttributes
+    }
+    val partialAggregate =
+      Aggregate2Sort(
+        None: Option[Seq[Expression]],
+        namedGroupingExpressions.map(_._2),
+        partialAggregateExpressions,
+        partialAggregateAttributes,
+        namedGroupingAttributes ++ partialAggregateAttributes,
+        child)
+
+    // 2. Create an Aggregate Operator for final aggregations.
+    val finalAggregateExpressions = aggregateExpressions.map {
+      case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+        AggregateExpression2(aggregateFunction, Final, isDistinct)
+    }
+    val finalAggregateAttributes =
+      finalAggregateExpressions.map {
+        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+      }
+    val rewrittenResultExpressions = resultExpressions.map { expr =>
+      expr.transformDown {
+        case agg: AggregateExpression2 =>
+          aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
+        case expression =>
+          // We do not rely on the equality check at here since attributes may
+          // different cosmetically. Instead, we use semanticEquals.
+          groupExpressionMap.collectFirst {
+            case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+          }.getOrElse(expression)
+      }.asInstanceOf[NamedExpression]
+    }
+    val finalAggregate = Aggregate2Sort(
+      Some(namedGroupingAttributes),
+      namedGroupingAttributes,
+      finalAggregateExpressions,
+      finalAggregateAttributes,
+      rewrittenResultExpressions,
+      partialAggregate)
+
+    finalAggregate :: Nil
+  }
+
+  def planAggregateWithOneDistinct(
+      groupingExpressions: Seq[Expression],
+      functionsWithDistinct: Seq[AggregateExpression2],
+      functionsWithoutDistinct: Seq[AggregateExpression2],
+      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+      resultExpressions: Seq[NamedExpression],
+      child: SparkPlan): Seq[SparkPlan] = {
+
+    // 1. Create an Aggregate Operator for partial aggregations.
+    // The grouping expressions are original groupingExpressions and
+    // distinct columns. For example, for avg(distinct value) ... group by key
+    // the grouping expressions of this Aggregate Operator will be [key, value].
+    val namedGroupingExpressions = groupingExpressions.map {
+      case ne: NamedExpression => ne -> ne
+      // If the expression is not a NamedExpressions, we add an alias.
+      // So, when we generate the result of the operator, the Aggregate Operator
+      // can directly get the Seq of attributes representing the grouping expressions.
+      case other =>
+        val withAlias = Alias(other, other.toString)()
+        other -> withAlias
+    }
+    val groupExpressionMap = namedGroupingExpressions.toMap
+    val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
+
+    // It is safe to call head at here since functionsWithDistinct has at least one
+    // AggregateExpression2.
+    val distinctColumnExpressions =
+      functionsWithDistinct.head.aggregateFunction.children
+    val namedDistinctColumnExpressions = distinctColumnExpressions.map {
+      case ne: NamedExpression => ne -> ne
+      case other =>
+        val withAlias = Alias(other, other.toString)()
+        other -> withAlias
+    }
+    val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
+    val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
+
+    val partialAggregateExpressions = functionsWithoutDistinct.map {
+      case AggregateExpression2(aggregateFunction, mode, _) =>
+        AggregateExpression2(aggregateFunction, Partial, false)
+    }
+    val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
+      agg.aggregateFunction.bufferAttributes
+    }
+    val partialAggregate =
+      Aggregate2Sort(
+        None: Option[Seq[Expression]],
+        (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2),
+        partialAggregateExpressions,
+        partialAggregateAttributes,
+        namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes,
+        child)
+
+    // 2. Create an Aggregate Operator for partial merge aggregations.
+    val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
+      case AggregateExpression2(aggregateFunction, mode, _) =>
+        AggregateExpression2(aggregateFunction, PartialMerge, false)
+    }
+    val partialMergeAggregateAttributes =
+      partialMergeAggregateExpressions.map {
+        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+      }
+    val partialMergeAggregate =
+      Aggregate2Sort(
+        Some(namedGroupingAttributes),
+        namedGroupingAttributes ++ distinctColumnAttributes,
+        partialMergeAggregateExpressions,
+        partialMergeAggregateAttributes,
+        namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes,
+        partialAggregate)
+
+    // 3. Create an Aggregate Operator for partial merge aggregations.
+    val finalAggregateExpressions = functionsWithoutDistinct.map {
+      case AggregateExpression2(aggregateFunction, mode, _) =>
+        AggregateExpression2(aggregateFunction, Final, false)
+    }
+    val finalAggregateAttributes =
+      finalAggregateExpressions.map {
+        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+      }
+    val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map {
+      // Children of an AggregateFunction with DISTINCT keyword has already
+      // been evaluated. At here, we need to replace original children
+      // to AttributeReferences.
+      case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+        val rewrittenAggregateFunction = aggregateFunction.transformDown {
+          case expr if distinctColumnExpressionMap.contains(expr) =>
+            distinctColumnExpressionMap(expr).toAttribute
+        }.asInstanceOf[AggregateFunction2]
+        // We rewrite the aggregate function to a non-distinct aggregation because
+        // its input will have distinct arguments.
+        val rewrittenAggregateExpression =
+          AggregateExpression2(rewrittenAggregateFunction, Complete, false)
+
+        val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct)
+        (rewrittenAggregateExpression -> aggregateFunctionAttribute)
+    }.unzip
+
+    val rewrittenResultExpressions = resultExpressions.map { expr =>
+      expr.transform {
+        case agg: AggregateExpression2 =>
+          aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
+        case expression =>
+          // We do not rely on the equality check at here since attributes may
+          // different cosmetically. Instead, we use semanticEquals.
+          groupExpressionMap.collectFirst {
+            case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+          }.getOrElse(expression)
+      }.asInstanceOf[NamedExpression]
+    }
+    val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort(
+      namedGroupingAttributes ++ distinctColumnAttributes,
+      namedGroupingAttributes,
+      finalAggregateExpressions,
+      finalAggregateAttributes,
+      completeAggregateExpressions,
+      completeAggregateAttributes,
+      rewrittenResultExpressions,
+      partialMergeAggregate)
+
+    finalAndCompleteAggregate :: Nil
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
new file mode 100644
index 0000000..6c49a90
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.aggregate
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
+import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.Row
+
+/**
+ * The abstract class for implementing user-defined aggregate function.
+ */
+abstract class UserDefinedAggregateFunction extends Serializable {
+
+  /**
+   * A [[StructType]] represents data types of input arguments of this aggregate function.
+   * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
+   * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
+   *
+   * ```
+   *   StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
+   * ```
+   *
+   * The name of a field of this [[StructType]] is only used to identify the corresponding
+   * input argument. Users can choose names to identify the input arguments.
+   */
+  def inputSchema: StructType
+
+  /**
+   * A [[StructType]] represents data types of values in the aggregation buffer.
+   * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
+   * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
+   * the returned [[StructType]] will look like
+   *
+   * ```
+   *   StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType)))
+   * ```
+   *
+   * The name of a field of this [[StructType]] is only used to identify the corresponding
+   * buffer value. Users can choose names to identify the input arguments.
+   */
+  def bufferSchema: StructType
+
+  /**
+   * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
+   */
+  def returnDataType: DataType
+
+  /** Indicates if this function is deterministic. */
+  def deterministic: Boolean
+
+  /**
+   *  Initializes the given aggregation buffer. Initial values set by this method should satisfy
+   *  the condition that when merging two buffers with initial values, the new buffer should
+   *  still store initial values.
+   */
+  def initialize(buffer: MutableAggregationBuffer): Unit
+
+  /** Updates the given aggregation buffer `buffer` with new input data from `input`. */
+  def update(buffer: MutableAggregationBuffer, input: Row): Unit
+
+  /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */
+  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
+
+  /**
+   * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
+   * aggregation buffer.
+   */
+  def evaluate(buffer: Row): Any
+}
+
+private[sql] abstract class AggregationBuffer(
+    toCatalystConverters: Array[Any => Any],
+    toScalaConverters: Array[Any => Any],
+    bufferOffset: Int)
+  extends Row {
+
+  override def length: Int = toCatalystConverters.length
+
+  protected val offsets: Array[Int] = {
+    val newOffsets = new Array[Int](length)
+    var i = 0
+    while (i < newOffsets.length) {
+      newOffsets(i) = bufferOffset + i
+      i += 1
+    }
+    newOffsets
+  }
+}
+
+/**
+ * A Mutable [[Row]] representing an mutable aggregation buffer.
+ */
+class MutableAggregationBuffer private[sql] (
+    toCatalystConverters: Array[Any => Any],
+    toScalaConverters: Array[Any => Any],
+    bufferOffset: Int,
+    var underlyingBuffer: MutableRow)
+  extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
+
+  override def get(i: Int): Any = {
+    if (i >= length || i < 0) {
+      throw new IllegalArgumentException(
+        s"Could not access ${i}th value in this buffer because it only has $length values.")
+    }
+    toScalaConverters(i)(underlyingBuffer(offsets(i)))
+  }
+
+  def update(i: Int, value: Any): Unit = {
+    if (i >= length || i < 0) {
+      throw new IllegalArgumentException(
+        s"Could not update ${i}th value in this buffer because it only has $length values.")
+    }
+    underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value))
+  }
+
+  override def copy(): MutableAggregationBuffer = {
+    new MutableAggregationBuffer(
+      toCatalystConverters,
+      toScalaConverters,
+      bufferOffset,
+      underlyingBuffer)
+  }
+}
+
+/**
+ * A [[Row]] representing an immutable aggregation buffer.
+ */
+class InputAggregationBuffer private[sql] (
+    toCatalystConverters: Array[Any => Any],
+    toScalaConverters: Array[Any => Any],
+    bufferOffset: Int,
+    var underlyingInputBuffer: Row)
+  extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) {
+
+  override def get(i: Int): Any = {
+    if (i >= length || i < 0) {
+      throw new IllegalArgumentException(
+        s"Could not access ${i}th value in this buffer because it only has $length values.")
+    }
+    toScalaConverters(i)(underlyingInputBuffer(offsets(i)))
+  }
+
+  override def copy(): InputAggregationBuffer = {
+    new InputAggregationBuffer(
+      toCatalystConverters,
+      toScalaConverters,
+      bufferOffset,
+      underlyingInputBuffer)
+  }
+}
+
+/**
+ * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the
+ * internal aggregation code path.
+ * @param children
+ * @param udaf
+ */
+case class ScalaUDAF(
+    children: Seq[Expression],
+    udaf: UserDefinedAggregateFunction)
+  extends AggregateFunction2 with Logging {
+
+  require(
+    children.length == udaf.inputSchema.length,
+    s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
+      s"but ${children.length} are provided.")
+
+  override def nullable: Boolean = true
+
+  override def dataType: DataType = udaf.returnDataType
+
+  override def deterministic: Boolean = udaf.deterministic
+
+  override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType)
+
+  override val bufferSchema: StructType = udaf.bufferSchema
+
+  override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes
+
+  override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
+
+  val childrenSchema: StructType = {
+    val inputFields = children.zipWithIndex.map {
+      case (child, index) =>
+        StructField(s"input$index", child.dataType, child.nullable, Metadata.empty)
+    }
+    StructType(inputFields)
+  }
+
+  lazy val inputProjection = {
+    val inputAttributes = childrenSchema.toAttributes
+    log.debug(
+      s"Creating MutableProj: $children, inputSchema: $inputAttributes.")
+    try {
+      GenerateMutableProjection.generate(children, inputAttributes)()
+    } catch {
+      case e: Exception =>
+        log.error("Failed to generate mutable projection, fallback to interpreted", e)
+        new InterpretedMutableProjection(children, inputAttributes)
+    }
+  }
+
+  val inputToScalaConverters: Any => Any =
+    CatalystTypeConverters.createToScalaConverter(childrenSchema)
+
+  val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
+    CatalystTypeConverters.createToCatalystConverter(field.dataType)
+  }
+
+  val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field =>
+    CatalystTypeConverters.createToScalaConverter(field.dataType)
+  }
+
+  lazy val inputAggregateBuffer: InputAggregationBuffer =
+    new InputAggregationBuffer(
+      bufferValuesToCatalystConverters,
+      bufferValuesToScalaConverters,
+      bufferOffset,
+      null)
+
+  lazy val mutableAggregateBuffer: MutableAggregationBuffer =
+    new MutableAggregationBuffer(
+      bufferValuesToCatalystConverters,
+      bufferValuesToScalaConverters,
+      bufferOffset,
+      null)
+
+
+  override def initialize(buffer: MutableRow): Unit = {
+    mutableAggregateBuffer.underlyingBuffer = buffer
+
+    udaf.initialize(mutableAggregateBuffer)
+  }
+
+  override def update(buffer: MutableRow, input: InternalRow): Unit = {
+    mutableAggregateBuffer.underlyingBuffer = buffer
+
+    udaf.update(
+      mutableAggregateBuffer,
+      inputToScalaConverters(inputProjection(input)).asInstanceOf[Row])
+  }
+
+  override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
+    mutableAggregateBuffer.underlyingBuffer = buffer1
+    inputAggregateBuffer.underlyingInputBuffer = buffer2
+
+    udaf.merge(mutableAggregateBuffer, inputAggregateBuffer)
+  }
+
+  override def eval(buffer: InternalRow = null): Any = {
+    inputAggregateBuffer.underlyingInputBuffer = buffer
+
+    udaf.evaluate(inputAggregateBuffer)
+  }
+
+  override def toString: String = {
+    s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})"""
+  }
+
+  override def nodeName: String = udaf.getClass.getSimpleName
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 28159cb..bfeecbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2420,7 +2420,7 @@ object functions {
    * @since 1.5.0
    */
   def callUDF(udfName: String, cols: Column*): Column = {
-    UnresolvedFunction(udfName, cols.map(_.expr))
+    UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
   }
 
   /**
@@ -2449,7 +2449,7 @@ object functions {
       exprs(i) = cols(i).expr
       i += 1
     }
-    UnresolvedFunction(udfName, exprs)
+    UnresolvedFunction(udfName, exprs, isDistinct = false)
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index beee101..ab8dce6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -23,6 +23,7 @@ import java.sql.Timestamp
 
 import org.apache.spark.sql.catalyst.DefaultParserDialect
 import org.apache.spark.sql.catalyst.errors.DialectException
+import org.apache.spark.sql.execution.aggregate.Aggregate2Sort
 import org.apache.spark.sql.execution.GeneratedAggregate
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.TestData._
@@ -204,6 +205,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
       var hasGeneratedAgg = false
       df.queryExecution.executedPlan.foreach {
         case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
+        case newAggregate: Aggregate2Sort => hasGeneratedAgg = true
         case _ =>
       }
       if (!hasGeneratedAgg) {
@@ -285,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
       // Aggregate with Code generation handling all null values
       testCodeGen(
         "SELECT  sum('a'), avg('a'), count(null) FROM testData",
-        Row(0, null, 0) :: Nil)
+        Row(null, null, 0) :: Nil)
     } finally {
       sqlContext.dropTempTable("testData3x")
       sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue)

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 3dd2413..3d71deb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.TestSQLContext._
@@ -30,6 +31,20 @@ import org.apache.spark.sql.{Row, SQLConf, execution}
 
 
 class PlannerSuite extends SparkFunSuite {
+  private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
+    val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
+    val planned =
+      plannedOption.getOrElse(
+        fail(s"Could query play aggregation query $query. Is it an aggregation query?"))
+    val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
+
+    // For the new aggregation code path, there will be three aggregate operator for
+    // distinct aggregations.
+    assert(
+      aggregations.size == 2 || aggregations.size == 3,
+      s"The plan of query $query does not have partial aggregations.")
+  }
+
   test("unions are collapsed") {
     val query = testData.unionAll(testData).unionAll(testData).logicalPlan
     val planned = BasicOperators(query).head
@@ -42,23 +57,18 @@ class PlannerSuite extends SparkFunSuite {
 
   test("count is partially aggregated") {
     val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed
-    val planned = HashAggregation(query).head
-    val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n }
-
-    assert(aggregations.size === 2)
+    testPartialAggregationPlan(query)
   }
 
   test("count distinct is partially aggregated") {
     val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed
-    val planned = HashAggregation(query)
-    assert(planned.nonEmpty)
+    testPartialAggregationPlan(query)
   }
 
   test("mixed aggregates are partially aggregated") {
     val query =
       testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed
-    val planned = HashAggregation(query)
-    assert(planned.nonEmpty)
+    testPartialAggregationPlan(query)
   }
 
   test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
index 31a49a3..24a758f 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
@@ -833,6 +833,7 @@ abstract class HiveWindowFunctionQueryFileBaseSuite
     "windowing_adjust_rowcontainer_sz"
   )
 
+  // Only run those query tests in the realWhileList (do not try other ignored query files).
   override def testCases: Seq[(String, File)] = super.testCases.filter {
     case (name, _) => realWhiteList.contains(name)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
index f458567..1fe4fe9 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.hive.execution
 
+import java.io.File
+
 import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.hive.test.TestHive
 
@@ -159,4 +161,9 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite {
     "join_reorder4",
     "join_star"
   )
+
+  // Only run those query tests in the realWhileList (do not try other ignored query files).
+  override def testCases: Seq[(String, File)] = super.testCases.filter {
+    case (name, _) => realWhiteList.contains(name)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index cec7685..4cdb83c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -451,6 +451,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging {
       DataSinks,
       Scripts,
       HashAggregation,
+      Aggregation,
       LeftSemiJoin,
       HashJoin,
       BasicOperators,

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index f557450..8518e33 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1464,9 +1464,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
 
     /* UDFs - Must be last otherwise will preempt built in functions */
     case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
-      UnresolvedFunction(name, args.map(nodeToExpr))
+      UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false)
+    // Aggregate function with DISTINCT keyword.
+    case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) =>
+      UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true)
     case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
-      UnresolvedFunction(name, UnresolvedStar(None) :: Nil)
+      UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false)
 
     /* Literals */
     case Token("TOK_NULL", Nil) => Literal.create(null, NullType)

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 4d23c70..3259b50 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -409,7 +409,7 @@ private[hive] case class HiveWindowFunction(
 
 private[hive] case class HiveGenericUDAF(
     funcWrapper: HiveFunctionWrapper,
-    children: Seq[Expression]) extends AggregateExpression
+    children: Seq[Expression]) extends AggregateExpression1
   with HiveInspectors {
 
   type UDFType = AbstractGenericUDAFResolver
@@ -441,7 +441,7 @@ private[hive] case class HiveGenericUDAF(
 /** It is used as a wrapper for the hive functions which uses UDAF interface */
 private[hive] case class HiveUDAF(
     funcWrapper: HiveFunctionWrapper,
-    children: Seq[Expression]) extends AggregateExpression
+    children: Seq[Expression]) extends AggregateExpression1
   with HiveInspectors {
 
   type UDFType = UDAF
@@ -550,9 +550,9 @@ private[hive] case class HiveGenericUDTF(
 private[hive] case class HiveUDAFFunction(
     funcWrapper: HiveFunctionWrapper,
     exprs: Seq[Expression],
-    base: AggregateExpression,
+    base: AggregateExpression1,
     isUDAFBridgeRequired: Boolean = false)
-  extends AggregateFunction
+  extends AggregateFunction1
   with HiveInspectors {
 
   def this() = this(null, null, null)

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
new file mode 100644
index 0000000..5c9d0e9
--- /dev/null
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.hive.aggregate;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class MyDoubleAvg extends UserDefinedAggregateFunction {
+
+  private StructType _inputDataType;
+
+  private StructType _bufferSchema;
+
+  private DataType _returnDataType;
+
+  public MyDoubleAvg() {
+    List<StructField> inputfields = new ArrayList<StructField>();
+    inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+    _inputDataType = DataTypes.createStructType(inputfields);
+
+    List<StructField> bufferFields = new ArrayList<StructField>();
+    bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
+    bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true));
+    _bufferSchema = DataTypes.createStructType(bufferFields);
+
+    _returnDataType = DataTypes.DoubleType;
+  }
+
+  @Override public StructType inputSchema() {
+    return _inputDataType;
+  }
+
+  @Override public StructType bufferSchema() {
+    return _bufferSchema;
+  }
+
+  @Override public DataType returnDataType() {
+    return _returnDataType;
+  }
+
+  @Override public boolean deterministic() {
+    return true;
+  }
+
+  @Override public void initialize(MutableAggregationBuffer buffer) {
+    buffer.update(0, null);
+    buffer.update(1, 0L);
+  }
+
+  @Override public void update(MutableAggregationBuffer buffer, Row input) {
+    if (!input.isNullAt(0)) {
+      if (buffer.isNullAt(0)) {
+        buffer.update(0, input.getDouble(0));
+        buffer.update(1, 1L);
+      } else {
+        Double newValue = input.getDouble(0) + buffer.getDouble(0);
+        buffer.update(0, newValue);
+        buffer.update(1, buffer.getLong(1) + 1L);
+      }
+    }
+  }
+
+  @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+    if (!buffer2.isNullAt(0)) {
+      if (buffer1.isNullAt(0)) {
+        buffer1.update(0, buffer2.getDouble(0));
+        buffer1.update(1, buffer2.getLong(1));
+      } else {
+        Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
+        buffer1.update(0, newValue);
+        buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
+      }
+    }
+  }
+
+  @Override public Object evaluate(Row buffer) {
+    if (buffer.isNullAt(0)) {
+      return null;
+    } else {
+      return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
+    }
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
new file mode 100644
index 0000000..1d4587a
--- /dev/null
+++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.hive.aggregate;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.Row;
+
+public class MyDoubleSum extends UserDefinedAggregateFunction {
+
+  private StructType _inputDataType;
+
+  private StructType _bufferSchema;
+
+  private DataType _returnDataType;
+
+  public MyDoubleSum() {
+    List<StructField> inputfields = new ArrayList<StructField>();
+    inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+    _inputDataType = DataTypes.createStructType(inputfields);
+
+    List<StructField> bufferFields = new ArrayList<StructField>();
+    bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
+    _bufferSchema = DataTypes.createStructType(bufferFields);
+
+    _returnDataType = DataTypes.DoubleType;
+  }
+
+  @Override public StructType inputSchema() {
+    return _inputDataType;
+  }
+
+  @Override public StructType bufferSchema() {
+    return _bufferSchema;
+  }
+
+  @Override public DataType returnDataType() {
+    return _returnDataType;
+  }
+
+  @Override public boolean deterministic() {
+    return true;
+  }
+
+  @Override public void initialize(MutableAggregationBuffer buffer) {
+    buffer.update(0, null);
+  }
+
+  @Override public void update(MutableAggregationBuffer buffer, Row input) {
+    if (!input.isNullAt(0)) {
+      if (buffer.isNullAt(0)) {
+        buffer.update(0, input.getDouble(0));
+      } else {
+        Double newValue = input.getDouble(0) + buffer.getDouble(0);
+        buffer.update(0, newValue);
+      }
+    }
+  }
+
+  @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+    if (!buffer2.isNullAt(0)) {
+      if (buffer1.isNullAt(0)) {
+        buffer1.update(0, buffer2.getDouble(0));
+      } else {
+        Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
+        buffer1.update(0, newValue);
+      }
+    }
+  }
+
+  @Override public Object evaluate(Row buffer) {
+    if (buffer.isNullAt(0)) {
+      return null;
+    } else {
+      return buffer.getDouble(0);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada
new file mode 100644
index 0000000..573541a
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada
@@ -0,0 +1 @@
+0

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47
new file mode 100644
index 0000000..44b2a42
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47
@@ -0,0 +1 @@
+unhex(str) - Converts hexadecimal argument to binary

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4
new file mode 100644
index 0000000..97af3b8
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4
@@ -0,0 +1,14 @@
+unhex(str) - Converts hexadecimal argument to binary
+Performs the inverse operation of HEX(str). That is, it interprets
+each pair of hexadecimal digits in the argument as a number and
+converts it to the byte representation of the number. The
+resulting characters are returned as a binary string.
+
+Example:
+> SELECT DECODE(UNHEX('4D7953514C'), 'UTF-8') from src limit 1;
+'MySQL'
+
+The characters in the argument string must be legal hexadecimal
+digits: '0' .. '9', 'A' .. 'F', 'a' .. 'f'. If UNHEX() encounters
+any nonhexadecimal digits in the argument, it returns NULL. Also,
+if there are an odd number of characters a leading 0 is appended.

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e
new file mode 100644
index 0000000..b4a6f2b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e
@@ -0,0 +1 @@
+MySQL	1267	a	-4	

http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3
new file mode 100644
index 0000000..3a67ada
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3
@@ -0,0 +1 @@
+NULL	NULL	NULL


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message