spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject [3/4] spark git commit: [SPARK-9830][SQL] Remove AggregateExpression1 and Aggregate Operator used to evaluate AggregateExpression1s
Date Tue, 10 Nov 2015 19:06:37 GMT
http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
deleted file mode 100644
index 9b22ce2..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala
+++ /dev/null
@@ -1,467 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions.aggregate
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.types._
-
-/**
- * Utility functions used by the query planner to convert our plan to new aggregation code path.
- */
-object Utils {
-
-  // Check if the DataType given cannot be part of a group by clause.
-  private def isUnGroupable(dt: DataType): Boolean = dt match {
-    case _: ArrayType | _: MapType => true
-    case s: StructType => s.fields.exists(f => isUnGroupable(f.dataType))
-    case _ => false
-  }
-
-  // Right now, we do not support complex types in the grouping key schema.
-  private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean =
-    !aggregate.groupingExpressions.exists(e => isUnGroupable(e.dataType))
-
-  private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
-    case p: Aggregate if supportsGroupingKeySchema(p) =>
-
-      val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown {
-        case expressions.Average(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Average(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Count(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Count(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.CountDistinct(children) =>
-          val child = if (children.size > 1) {
-            DropAnyNull(CreateStruct(children))
-          } else {
-            children.head
-          }
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Count(child),
-            mode = aggregate.Complete,
-            isDistinct = true)
-
-        case expressions.First(child, ignoreNulls) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.First(child, ignoreNulls),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Kurtosis(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Kurtosis(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Last(child, ignoreNulls) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Last(child, ignoreNulls),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Max(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Max(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Min(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Min(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Skewness(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Skewness(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.StddevPop(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.StddevPop(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.StddevSamp(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.StddevSamp(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Sum(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Sum(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.SumDistinct(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Sum(child),
-            mode = aggregate.Complete,
-            isDistinct = true)
-
-        case expressions.Corr(left, right) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Corr(left, right),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.ApproxCountDistinct(child, rsd) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.VariancePop(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.VariancePop(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.VarianceSamp(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.VarianceSamp(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-      })
-
-      // Check if there is any expressions.AggregateExpression1 left.
-      // If so, we cannot convert this plan.
-      val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
-        // For every expressions, check if it contains AggregateExpression1.
-        expr.find {
-          case agg: expressions.AggregateExpression1 => true
-          case other => false
-        }.isDefined
-      }
-
-      // Check if there are multiple distinct columns.
-      // TODO remove this.
-      val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
-        expr.collect {
-          case agg: AggregateExpression2 => agg
-        }
-      }.toSet.toSeq
-      val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
-      val hasMultipleDistinctColumnSets =
-        if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
-          true
-        } else {
-          false
-        }
-
-      if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None
-
-    case other => None
-  }
-
-  def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
-    // If the plan cannot be converted, we will do a final round check to see if the original
-    // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so,
-    // we need to throw an exception.
-    val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
-      expr.collect {
-        case agg: AggregateExpression2 => agg.aggregateFunction
-      }
-    }.distinct
-    if (aggregateFunction2s.nonEmpty) {
-      // For functions implemented based on the new interface, prepare a list of function names.
-      val invalidFunctions = {
-        if (aggregateFunction2s.length > 1) {
-          s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
-            s"and ${aggregateFunction2s.head.nodeName} are"
-        } else {
-          s"${aggregateFunction2s.head.nodeName} is"
-        }
-      }
-      val errorMessage =
-        s"${invalidFunctions} implemented based on the new Aggregate Function " +
-          s"interface and it cannot be used with functions implemented based on " +
-          s"the old Aggregate Function interface."
-      throw new AnalysisException(errorMessage)
-    }
-  }
-
-  def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
-    case p: Aggregate =>
-      val converted = doConvert(p)
-      if (converted.isDefined) {
-        converted
-      } else {
-        checkInvalidAggregateFunction2(p)
-        None
-      }
-    case other => None
-  }
-}
-
-/**
- * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double
- * aggregation in which the regular aggregation expressions and every distinct clause is aggregated
- * in a separate group. The results are then combined in a second aggregate.
- *
- * For example (in scala):
- * {{{
- *   val data = Seq(
- *     ("a", "ca1", "cb1", 10),
- *     ("a", "ca1", "cb2", 5),
- *     ("b", "ca1", "cb1", 13))
- *     .toDF("key", "cat1", "cat2", "value")
- *   data.registerTempTable("data")
- *
- *   val agg = data.groupBy($"key")
- *     .agg(
- *       countDistinct($"cat1").as("cat1_cnt"),
- *       countDistinct($"cat2").as("cat2_cnt"),
- *       sum($"value").as("total"))
- * }}}
- *
- * This translates to the following (pseudo) logical plan:
- * {{{
- * Aggregate(
- *    key = ['key]
- *    functions = [COUNT(DISTINCT 'cat1),
- *                 COUNT(DISTINCT 'cat2),
- *                 sum('value)]
- *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
- *   LocalTableScan [...]
- * }}}
- *
- * This rule rewrites this logical plan to the following (pseudo) logical plan:
- * {{{
- * Aggregate(
- *    key = ['key]
- *    functions = [count(if (('gid = 1)) 'cat1 else null),
- *                 count(if (('gid = 2)) 'cat2 else null),
- *                 first(if (('gid = 0)) 'total else null) ignore nulls]
- *    output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
- *   Aggregate(
- *      key = ['key, 'cat1, 'cat2, 'gid]
- *      functions = [sum('value)]
- *      output = ['key, 'cat1, 'cat2, 'gid, 'total])
- *     Expand(
- *        projections = [('key, null, null, 0, cast('value as bigint)),
- *                       ('key, 'cat1, null, 1, null),
- *                       ('key, null, 'cat2, 2, null)]
- *        output = ['key, 'cat1, 'cat2, 'gid, 'value])
- *       LocalTableScan [...]
- * }}}
- *
- * The rule does the following things here:
- * 1. Expand the data. There are three aggregation groups in this query:
- *    i. the non-distinct group;
- *    ii. the distinct 'cat1 group;
- *    iii. the distinct 'cat2 group.
- *    An expand operator is inserted to expand the child data for each group. The expand will null
- *    out all unused columns for the given group; this must be done in order to ensure correctness
- *    later on. Groups can by identified by a group id (gid) column added by the expand operator.
- * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of
- *    this aggregate consists of the original group by clause, all the requested distinct columns
- *    and the group id. Both de-duplication of distinct column and the aggregation of the
- *    non-distinct group take advantage of the fact that we group by the group id (gid) and that we
- *    have nulled out all non-relevant columns for the the given group.
- * 3. Aggregating the distinct groups and combining this with the results of the non-distinct
- *    aggregation. In this step we use the group id to filter the inputs for the aggregate
- *    functions. The result of the non-distinct group are 'aggregated' by using the first operator,
- *    it might be more elegant to use the native UDAF merge mechanism for this in the future.
- *
- * This rule duplicates the input data by two or more times (# distinct groups + an optional
- * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and
- * exchange operators. Keeping the number of distinct groups as low a possible should be priority,
- * we could improve this in the current rule by applying more advanced expression cannocalization
- * techniques.
- */
-object MultipleDistinctRewriter extends Rule[LogicalPlan] {
-
-  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-    case a: Aggregate => rewrite(a)
-    case p => p
-  }
-
-  def rewrite(a: Aggregate): Aggregate = {
-
-    // Collect all aggregate expressions.
-    val aggExpressions = a.aggregateExpressions.flatMap { e =>
-      e.collect {
-        case ae: AggregateExpression2 => ae
-      }
-    }
-
-    // Extract distinct aggregate expressions.
-    val distinctAggGroups = aggExpressions
-      .filter(_.isDistinct)
-      .groupBy(_.aggregateFunction.children.toSet)
-
-    // Only continue to rewrite if there is more than one distinct group.
-    if (distinctAggGroups.size > 1) {
-      // Create the attributes for the grouping id and the group by clause.
-      val gid = new AttributeReference("gid", IntegerType, false)()
-      val groupByMap = a.groupingExpressions.collect {
-        case ne: NamedExpression => ne -> ne.toAttribute
-        case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)()
-      }
-      val groupByAttrs = groupByMap.map(_._2)
-
-      // Functions used to modify aggregate functions and their inputs.
-      def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e))
-      def patchAggregateFunctionChildren(
-          af: AggregateFunction2)(
-          attrs: Expression => Expression): AggregateFunction2 = {
-        af.withNewChildren(af.children.map {
-          case afc => attrs(afc)
-        }).asInstanceOf[AggregateFunction2]
-      }
-
-      // Setup unique distinct aggregate children.
-      val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
-      val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap
-      val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
-
-      // Setup expand & aggregate operators for distinct aggregate expressions.
-      val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
-        case ((group, expressions), i) =>
-          val id = Literal(i + 1)
-
-          // Expand projection
-          val projection = distinctAggChildren.map {
-            case e if group.contains(e) => e
-            case e => nullify(e)
-          } :+ id
-
-          // Final aggregate
-          val operators = expressions.map { e =>
-            val af = e.aggregateFunction
-            val naf = patchAggregateFunctionChildren(af) { x =>
-              evalWithinGroup(id, distinctAggChildAttrMap(x))
-            }
-            (e, e.copy(aggregateFunction = naf, isDistinct = false))
-          }
-
-          (projection, operators)
-      }
-
-      // Setup expand for the 'regular' aggregate expressions.
-      val regularAggExprs = aggExpressions.filter(!_.isDistinct)
-      val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct
-      val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair)
-
-      // Setup aggregates for 'regular' aggregate expressions.
-      val regularGroupId = Literal(0)
-      val regularAggChildAttrLookup = regularAggChildAttrMap.toMap
-      val regularAggOperatorMap = regularAggExprs.map { e =>
-        // Perform the actual aggregation in the initial aggregate.
-        val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup)
-        val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)()
-
-        // Select the result of the first aggregate in the last aggregate.
-        val result = AggregateExpression2(
-          aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)),
-          mode = Complete,
-          isDistinct = false)
-
-        // Some aggregate functions (COUNT) have the special property that they can return a
-        // non-null result without any input. We need to make sure we return a result in this case.
-        val resultWithDefault = af.defaultResult match {
-          case Some(lit) => Coalesce(Seq(result, lit))
-          case None => result
-        }
-
-        // Return a Tuple3 containing:
-        // i. The original aggregate expression (used for look ups).
-        // ii. The actual aggregation operator (used in the first aggregate).
-        // iii. The operator that selects and returns the result (used in the second aggregate).
-        (e, operator, resultWithDefault)
-      }
-
-      // Construct the regular aggregate input projection only if we need one.
-      val regularAggProjection = if (regularAggExprs.nonEmpty) {
-        Seq(a.groupingExpressions ++
-          distinctAggChildren.map(nullify) ++
-          Seq(regularGroupId) ++
-          regularAggChildren)
-      } else {
-        Seq.empty[Seq[Expression]]
-      }
-
-      // Construct the distinct aggregate input projections.
-      val regularAggNulls = regularAggChildren.map(nullify)
-      val distinctAggProjections = distinctAggOperatorMap.map {
-        case (projection, _) =>
-          a.groupingExpressions ++
-            projection ++
-            regularAggNulls
-      }
-
-      // Construct the expand operator.
-      val expand = Expand(
-        regularAggProjection ++ distinctAggProjections,
-        groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2),
-        a.child)
-
-      // Construct the first aggregate operator. This de-duplicates the all the children of
-      // distinct operators, and applies the regular aggregate operators.
-      val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid
-      val firstAggregate = Aggregate(
-        firstAggregateGroupBy,
-        firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2),
-        expand)
-
-      // Construct the second aggregate
-      val transformations: Map[Expression, Expression] =
-        (distinctAggOperatorMap.flatMap(_._2) ++
-          regularAggOperatorMap.map(e => (e._1, e._3))).toMap
-
-      val patchedAggExpressions = a.aggregateExpressions.map { e =>
-        e.transformDown {
-          case e: Expression =>
-            // The same GROUP BY clauses can have different forms (different names for instance) in
-            // the groupBy and aggregate expressions of an aggregate. This makes a map lookup
-            // tricky. So we do a linear search for a semantically equal group by expression.
-            groupByMap
-              .find(ge => e.semanticEquals(ge._1))
-              .map(_._2)
-              .getOrElse(transformations.getOrElse(e, e))
-        }.asInstanceOf[NamedExpression]
-      }
-      Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate)
-    } else {
-      a
-    }
-  }
-
-  private def nullify(e: Expression) = Literal.create(null, e.dataType)
-
-  private def expressionAttributePair(e: Expression) =
-    // We are creating a new reference here instead of reusing the attribute in case of a
-    // NamedExpression. This is done to prevent collisions between distinct and regular aggregate
-    // children, in this case attribute reuse causes the input of the regular aggregate to bound to
-    // the (nulled out) input of the distinct aggregate.
-    e -> new AttributeReference(e.prettyString, e.dataType, true)()
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
index ec63534..ede2da2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala
@@ -24,6 +24,8 @@ case class VarianceSamp(child: Expression,
     inputAggBufferOffset: Int = 0)
   extends CentralMomentAgg(child) {
 
+  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 
@@ -42,11 +44,14 @@ case class VarianceSamp(child: Expression,
   }
 }
 
-case class VariancePop(child: Expression,
+case class VariancePop(
+    child: Expression,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
   extends CentralMomentAgg(child) {
 
+  def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0)
+
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 5c5b3d1..3b441de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -17,23 +17,24 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 
-/** The mode of an [[AggregateFunction2]]. */
+/** The mode of an [[AggregateFunction]]. */
 private[sql] sealed trait AggregateMode
 
 /**
- * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation.
+ * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation.
  * This function updates the given aggregation buffer with the original input of this
  * function. When it has processed all input rows, the aggregation buffer is returned.
  */
 private[sql] case object Partial extends AggregateMode
 
 /**
- * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers
  * containing intermediate results for this function.
  * This function updates the given aggregation buffer by merging multiple aggregation buffers.
  * When it has processed all input rows, the aggregation buffer is returned.
@@ -41,7 +42,7 @@ private[sql] case object Partial extends AggregateMode
 private[sql] case object PartialMerge extends AggregateMode
 
 /**
- * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers
  * containing intermediate results for this function and then generate final result.
  * This function updates the given aggregation buffer by merging multiple aggregation buffers.
  * When it has processed all input rows, the final result of this function is returned.
@@ -49,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode
 private[sql] case object Final extends AggregateMode
 
 /**
- * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly
+ * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly
  * from original input rows without any partial aggregation.
  * This function updates the given aggregation buffer with the original input of this
  * function. When it has processed all input rows, the final result of this function is returned.
@@ -67,13 +68,15 @@ private[sql] case object NoOp extends Expression with Unevaluable {
 }
 
 /**
- * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
+ * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field
  * (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
  */
-private[sql] case class AggregateExpression2(
-    aggregateFunction: AggregateFunction2,
+private[sql] case class AggregateExpression(
+    aggregateFunction: AggregateFunction,
     mode: AggregateMode,
-    isDistinct: Boolean) extends AggregateExpression {
+    isDistinct: Boolean)
+  extends Expression
+  with Unevaluable {
 
   override def children: Seq[Expression] = aggregateFunction :: Nil
   override def dataType: DataType = aggregateFunction.dataType
@@ -89,6 +92,8 @@ private[sql] case class AggregateExpression2(
     AttributeSet(childReferences)
   }
 
+  override def prettyString: String = aggregateFunction.prettyString
+
   override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)"
 }
 
@@ -106,10 +111,10 @@ private[sql] case class AggregateExpression2(
  * combined aggregation buffer which concatenates the aggregation buffers of the individual
  * aggregate functions.
  *
- * Code which accepts [[AggregateFunction2]] instances should be prepared to handle both types of
+ * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of
  * aggregate functions.
  */
-sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes {
+sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes {
 
   /** An aggregate function is not foldable. */
   final override def foldable: Boolean = false
@@ -141,6 +146,27 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
 
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
     throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+  /**
+   * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because
+   * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode,
+   * and the flag indicating if this aggregation is distinct aggregation or not.
+   * An [[AggregateFunction]] should not be used without being wrapped in
+   * an [[AggregateExpression]].
+   */
+  def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false)
+
+  /**
+   * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct
+   * field of the [[AggregateExpression]] to the given value because
+   * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode,
+   * and the flag indicating if this aggregation is distinct aggregation or not.
+   * An [[AggregateFunction]] should not be used without being wrapped in
+   * an [[AggregateExpression]].
+   */
+  def toAggregateExpression(isDistinct: Boolean): AggregateExpression = {
+    AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct)
+  }
 }
 
 /**
@@ -161,7 +187,7 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp
  * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes`
  * and `inputAggBufferAttributes`.
  */
-abstract class ImperativeAggregate extends AggregateFunction2 {
+abstract class ImperativeAggregate extends AggregateFunction {
 
   /**
    * The offset of this function's first buffer value in the underlying shared mutable aggregation
@@ -258,9 +284,14 @@ abstract class ImperativeAggregate extends AggregateFunction2 {
  * `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You
  * can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and
  * `evaluateExpressions`.
+ *
+ * Please note that children of an aggregate function can be unresolved (it will happen when
+ * we create this function in DataFrame API). So, if there is any fields in
+ * the implemented class that need to access fields of its children, please make
+ * those fields `lazy val`s.
  */
 abstract class DeclarativeAggregate
-  extends AggregateFunction2
+  extends AggregateFunction
   with Serializable
   with Unevaluable {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
deleted file mode 100644
index 3dcf791..0000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ /dev/null
@@ -1,1073 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import com.clearspring.analytics.stream.cardinality.HyperLogLog
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData, TypeUtils}
-import org.apache.spark.sql.types._
-import org.apache.spark.util.collection.OpenHashSet
-
-
-trait AggregateExpression extends Expression with Unevaluable
-
-trait AggregateExpression1 extends AggregateExpression {
-
-  /**
-   * Aggregate expressions should not be foldable.
-   */
-  override def foldable: Boolean = false
-
-  /**
-   * Creates a new instance that can be used to compute this aggregate expression for a group
-   * of input rows/
-   */
-  def newInstance(): AggregateFunction1
-}
-
-/**
- * Represents an aggregation that has been rewritten to be performed in two steps.
- *
- * @param finalEvaluation an aggregate expression that evaluates to same final result as the
- *                        original aggregation.
- * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial
- *                           data sets and are required to compute the `finalEvaluation`.
- */
-case class SplitEvaluation(
-    finalEvaluation: Expression,
-    partialEvaluations: Seq[NamedExpression])
-
-/**
- * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples.
- * These partial evaluations can then be combined to compute the actual answer.
- */
-trait PartialAggregate1 extends AggregateExpression1 {
-
-  /**
-   * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation.
-   */
-  def asPartial: SplitEvaluation
-}
-
-/**
- * A specific implementation of an aggregate function. Used to wrap a generic
- * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result.
- */
-abstract class AggregateFunction1 extends LeafExpression with Serializable {
-
-  /** Base should return the generic aggregate expression that this function is computing */
-  val base: AggregateExpression1
-
-  override def nullable: Boolean = base.nullable
-  override def dataType: DataType = base.dataType
-
-  def update(input: InternalRow): Unit
-
-  override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    throw new UnsupportedOperationException(
-      "AggregateFunction1 should not be used for generated aggregates")
-  }
-}
-
-case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
-  override def nullable: Boolean = true
-  override def dataType: DataType = child.dataType
-
-  override def asPartial: SplitEvaluation = {
-    val partialMin = Alias(Min(child), "PartialMin")()
-    SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
-  }
-
-  override def newInstance(): MinFunction = new MinFunction(child, this)
-
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForOrderingExpr(child.dataType, "function min")
-}
-
-case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
-  def this() = this(null, null) // Required for serialization.
-
-  val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType)
-  val cmp = GreaterThan(currentMin, expr)
-
-  override def update(input: InternalRow): Unit = {
-    if (currentMin.value == null) {
-      currentMin.value = expr.eval(input)
-    } else if (cmp.eval(input) == true) {
-      currentMin.value = expr.eval(input)
-    }
-  }
-
-  override def eval(input: InternalRow): Any = currentMin.value
-}
-
-case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
-  override def nullable: Boolean = true
-  override def dataType: DataType = child.dataType
-
-  override def asPartial: SplitEvaluation = {
-    val partialMax = Alias(Max(child), "PartialMax")()
-    SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
-  }
-
-  override def newInstance(): MaxFunction = new MaxFunction(child, this)
-
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForOrderingExpr(child.dataType, "function max")
-}
-
-case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
-  def this() = this(null, null) // Required for serialization.
-
-  val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType)
-  val cmp = LessThan(currentMax, expr)
-
-  override def update(input: InternalRow): Unit = {
-    if (currentMax.value == null) {
-      currentMax.value = expr.eval(input)
-    } else if (cmp.eval(input) == true) {
-      currentMax.value = expr.eval(input)
-    }
-  }
-
-  override def eval(input: InternalRow): Any = currentMax.value
-}
-
-case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
-  override def nullable: Boolean = false
-  override def dataType: LongType.type = LongType
-
-  override def asPartial: SplitEvaluation = {
-    val partialCount = Alias(Count(child), "PartialCount")()
-    SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil)
-  }
-
-  override def newInstance(): CountFunction = new CountFunction(child, this)
-}
-
-case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
-  def this() = this(null, null) // Required for serialization.
-
-  var count: Long = _
-
-  override def update(input: InternalRow): Unit = {
-    val evaluatedExpr = expr.eval(input)
-    if (evaluatedExpr != null) {
-      count += 1L
-    }
-  }
-
-  override def eval(input: InternalRow): Any = count
-}
-
-case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 {
-  def this() = this(null)
-
-  override def children: Seq[Expression] = expressions
-
-  override def nullable: Boolean = false
-  override def dataType: DataType = LongType
-  override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})"
-  override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this)
-
-  override def asPartial: SplitEvaluation = {
-    val partialSet = Alias(CollectHashSet(expressions), "partialSets")()
-    SplitEvaluation(
-      CombineSetsAndCount(partialSet.toAttribute),
-      partialSet :: Nil)
-  }
-}
-
-case class CountDistinctFunction(
-    @transient expr: Seq[Expression],
-    @transient base: AggregateExpression1)
-  extends AggregateFunction1 {
-
-  def this() = this(null, null) // Required for serialization.
-
-  val seen = new OpenHashSet[Any]()
-
-  @transient
-  val distinctValue = new InterpretedProjection(expr)
-
-  override def update(input: InternalRow): Unit = {
-    val evaluatedExpr = distinctValue(input)
-    if (!evaluatedExpr.anyNull) {
-      seen.add(evaluatedExpr)
-    }
-  }
-
-  override def eval(input: InternalRow): Any = seen.size.toLong
-}
-
-case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 {
-  def this() = this(null)
-
-  override def children: Seq[Expression] = expressions
-  override def nullable: Boolean = false
-  override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType)
-  override def toString: String = s"AddToHashSet(${expressions.mkString(",")})"
-  override def newInstance(): CollectHashSetFunction =
-    new CollectHashSetFunction(expressions, this)
-}
-
-case class CollectHashSetFunction(
-    @transient expr: Seq[Expression],
-    @transient base: AggregateExpression1)
-  extends AggregateFunction1 {
-
-  def this() = this(null, null) // Required for serialization.
-
-  val seen = new OpenHashSet[Any]()
-
-  @transient
-  val distinctValue = new InterpretedProjection(expr)
-
-  override def update(input: InternalRow): Unit = {
-    val evaluatedExpr = distinctValue(input)
-    if (!evaluatedExpr.anyNull) {
-      seen.add(evaluatedExpr)
-    }
-  }
-
-  override def eval(input: InternalRow): Any = {
-    seen
-  }
-}
-
-case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 {
-  def this() = this(null)
-
-  override def children: Seq[Expression] = inputSet :: Nil
-  override def nullable: Boolean = false
-  override def dataType: DataType = LongType
-  override def toString: String = s"CombineAndCount($inputSet)"
-  override def newInstance(): CombineSetsAndCountFunction = {
-    new CombineSetsAndCountFunction(inputSet, this)
-  }
-}
-
-case class CombineSetsAndCountFunction(
-    @transient inputSet: Expression,
-    @transient base: AggregateExpression1)
-  extends AggregateFunction1 {
-
-  def this() = this(null, null) // Required for serialization.
-
-  val seen = new OpenHashSet[Any]()
-
-  override def update(input: InternalRow): Unit = {
-    val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
-    val inputIterator = inputSetEval.iterator
-    while (inputIterator.hasNext) {
-      seen.add(inputIterator.next)
-    }
-  }
-
-  override def eval(input: InternalRow): Any = seen.size.toLong
-}
-
-/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */
-private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
-
-  override def sqlType: DataType = BinaryType
-
-  /** Since we are using HyperLogLog internally, usually it will not be called. */
-  override def serialize(obj: Any): Array[Byte] =
-    obj.asInstanceOf[HyperLogLog].getBytes
-
-
-  /** Since we are using HyperLogLog internally, usually it will not be called. */
-  override def deserialize(datum: Any): HyperLogLog =
-    HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]])
-
-  override def userClass: Class[HyperLogLog] = classOf[HyperLogLog]
-}
-
-case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
-  extends UnaryExpression with AggregateExpression1 {
-
-  override def nullable: Boolean = false
-  override def dataType: DataType = HyperLogLogUDT
-  override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
-  override def newInstance(): ApproxCountDistinctPartitionFunction = {
-    new ApproxCountDistinctPartitionFunction(child, this, relativeSD)
-  }
-}
-
-case class ApproxCountDistinctPartitionFunction(
-    expr: Expression,
-    base: AggregateExpression1,
-    relativeSD: Double)
-  extends AggregateFunction1 {
-  def this() = this(null, null, 0) // Required for serialization.
-
-  private val hyperLogLog = new HyperLogLog(relativeSD)
-
-  override def update(input: InternalRow): Unit = {
-    val evaluatedExpr = expr.eval(input)
-    if (evaluatedExpr != null) {
-      hyperLogLog.offer(evaluatedExpr)
-    }
-  }
-
-  override def eval(input: InternalRow): Any = hyperLogLog
-}
-
-case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
-  extends UnaryExpression with AggregateExpression1 {
-
-  override def nullable: Boolean = false
-  override def dataType: LongType.type = LongType
-  override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
-  override def newInstance(): ApproxCountDistinctMergeFunction = {
-    new ApproxCountDistinctMergeFunction(child, this, relativeSD)
-  }
-}
-
-case class ApproxCountDistinctMergeFunction(
-    expr: Expression,
-    base: AggregateExpression1,
-    relativeSD: Double)
-  extends AggregateFunction1 {
-  def this() = this(null, null, 0) // Required for serialization.
-
-  private val hyperLogLog = new HyperLogLog(relativeSD)
-
-  override def update(input: InternalRow): Unit = {
-    val evaluatedExpr = expr.eval(input)
-    hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
-  }
-
-  override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
-}
-
-case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
-  extends UnaryExpression with PartialAggregate1 {
-
-  override def nullable: Boolean = false
-  override def dataType: LongType.type = LongType
-  override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)"
-
-  override def asPartial: SplitEvaluation = {
-    val partialCount =
-      Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")()
-
-    SplitEvaluation(
-      ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD),
-      partialCount :: Nil)
-  }
-
-  override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this)
-}
-
-case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
-  override def prettyName: String = "avg"
-
-  override def nullable: Boolean = true
-
-  override def dataType: DataType = child.dataType match {
-    case DecimalType.Fixed(precision, scale) =>
-      // Add 4 digits after decimal point, like Hive
-      DecimalType.bounded(precision + 4, scale + 4)
-    case _ =>
-      DoubleType
-  }
-
-  override def asPartial: SplitEvaluation = {
-    child.dataType match {
-      case DecimalType.Fixed(precision, scale) =>
-        val partialSum = Alias(Sum(child), "PartialSum")()
-        val partialCount = Alias(Count(child), "PartialCount")()
-
-        // partialSum already increase the precision by 10
-        val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType)
-        val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType)
-        SplitEvaluation(
-          Cast(Divide(castedSum, castedCount), dataType),
-          partialCount :: partialSum :: Nil)
-
-      case _ =>
-        val partialSum = Alias(Sum(child), "PartialSum")()
-        val partialCount = Alias(Count(child), "PartialCount")()
-
-        val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
-        val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
-        SplitEvaluation(
-          Divide(castedSum, castedCount),
-          partialCount :: partialSum :: Nil)
-    }
-  }
-
-  override def newInstance(): AverageFunction = new AverageFunction(child, this)
-
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForNumericExpr(child.dataType, "function average")
-}
-
-case class AverageFunction(expr: Expression, base: AggregateExpression1)
-  extends AggregateFunction1 {
-
-  def this() = this(null, null) // Required for serialization.
-
-  private val calcType =
-    expr.dataType match {
-      case DecimalType.Fixed(precision, scale) =>
-        DecimalType.bounded(precision + 10, scale)
-      case _ =>
-        expr.dataType
-    }
-
-  private val zero = Cast(Literal(0), calcType)
-
-  private var count: Long = _
-  private val sum = MutableLiteral(zero.eval(null), calcType)
-
-  private def addFunction(value: Any) = Add(sum,
-    Cast(Literal.create(value, expr.dataType), calcType))
-
-  override def eval(input: InternalRow): Any = {
-    if (count == 0L) {
-      null
-    } else {
-      expr.dataType match {
-        case DecimalType.Fixed(precision, scale) =>
-          val dt = DecimalType.bounded(precision + 14, scale + 4)
-          Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null)
-        case _ =>
-          Divide(
-            Cast(sum, dataType),
-            Cast(Literal(count), dataType)).eval(null)
-      }
-    }
-  }
-
-  override def update(input: InternalRow): Unit = {
-    val evaluatedExpr = expr.eval(input)
-    if (evaluatedExpr != null) {
-      count += 1
-      sum.update(addFunction(evaluatedExpr), input)
-    }
-  }
-}
-
-case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
-  override def nullable: Boolean = true
-
-  override def dataType: DataType = child.dataType match {
-    case DecimalType.Fixed(precision, scale) =>
-      // Add 10 digits left of decimal point, like Hive
-      DecimalType.bounded(precision + 10, scale)
-    case _ =>
-      child.dataType
-  }
-
-  override def asPartial: SplitEvaluation = {
-    child.dataType match {
-      case DecimalType.Fixed(_, _) =>
-        val partialSum = Alias(Sum(child), "PartialSum")()
-        SplitEvaluation(
-          Cast(Sum(partialSum.toAttribute), dataType),
-          partialSum :: Nil)
-
-      case _ =>
-        val partialSum = Alias(Sum(child), "PartialSum")()
-        SplitEvaluation(
-          Sum(partialSum.toAttribute),
-          partialSum :: Nil)
-    }
-  }
-
-  override def newInstance(): SumFunction = new SumFunction(child, this)
-
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForNumericExpr(child.dataType, "function sum")
-}
-
-case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
-  def this() = this(null, null) // Required for serialization.
-
-  private val calcType =
-    expr.dataType match {
-      case DecimalType.Fixed(precision, scale) =>
-        DecimalType.bounded(precision + 10, scale)
-      case _ =>
-        expr.dataType
-    }
-
-  private val zero = Cast(Literal(0), calcType)
-
-  private val sum = MutableLiteral(null, calcType)
-
-  private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum))
-
-  override def update(input: InternalRow): Unit = {
-    sum.update(addFunction, input)
-  }
-
-  override def eval(input: InternalRow): Any = {
-    expr.dataType match {
-      case DecimalType.Fixed(_, _) =>
-        Cast(sum, dataType).eval(null)
-      case _ => sum.eval(null)
-    }
-  }
-}
-
-case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 {
-
-  def this() = this(null)
-  override def nullable: Boolean = true
-  override def dataType: DataType = child.dataType match {
-    case DecimalType.Fixed(precision, scale) =>
-      // Add 10 digits left of decimal point, like Hive
-      DecimalType.bounded(precision + 10, scale)
-    case _ =>
-      child.dataType
-  }
-  override def toString: String = s"sum(distinct $child)"
-  override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this)
-
-  override def asPartial: SplitEvaluation = {
-    val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
-    SplitEvaluation(
-      CombineSetsAndSum(partialSet.toAttribute, this),
-      partialSet :: Nil)
-  }
-
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
-}
-
-case class SumDistinctFunction(expr: Expression, base: AggregateExpression1)
-  extends AggregateFunction1 {
-
-  def this() = this(null, null) // Required for serialization.
-
-  private val seen = new scala.collection.mutable.HashSet[Any]()
-
-  override def update(input: InternalRow): Unit = {
-    val evaluatedExpr = expr.eval(input)
-    if (evaluatedExpr != null) {
-      seen += evaluatedExpr
-    }
-  }
-
-  override def eval(input: InternalRow): Any = {
-    if (seen.size == 0) {
-      null
-    } else {
-      Cast(Literal(
-        seen.reduceLeft(
-          dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
-        dataType).eval(null)
-    }
-  }
-}
-
-case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 {
-  def this() = this(null, null)
-
-  override def children: Seq[Expression] = inputSet :: Nil
-  override def nullable: Boolean = true
-  override def dataType: DataType = base.dataType
-  override def toString: String = s"CombineAndSum($inputSet)"
-  override def newInstance(): CombineSetsAndSumFunction = {
-    new CombineSetsAndSumFunction(inputSet, this)
-  }
-}
-
-case class CombineSetsAndSumFunction(
-    @transient inputSet: Expression,
-    @transient base: AggregateExpression1)
-  extends AggregateFunction1 {
-
-  def this() = this(null, null) // Required for serialization.
-
-  val seen = new OpenHashSet[Any]()
-
-  override def update(input: InternalRow): Unit = {
-    val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
-    val inputIterator = inputSetEval.iterator
-    while (inputIterator.hasNext) {
-      seen.add(inputIterator.next())
-    }
-  }
-
-  override def eval(input: InternalRow): Any = {
-    val casted = seen.asInstanceOf[OpenHashSet[InternalRow]]
-    if (casted.size == 0) {
-      null
-    } else {
-      Cast(Literal(
-        casted.iterator.map(f => f.get(0, null)).reduceLeft(
-          base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
-        base.dataType).eval(null)
-    }
-  }
-}
-
-case class First(
-    child: Expression,
-    ignoreNullsExpr: Expression)
-  extends UnaryExpression with PartialAggregate1 {
-
-  def this(child: Expression) = this(child, Literal.create(false, BooleanType))
-
-  private val ignoreNulls: Boolean = ignoreNullsExpr match {
-    case Literal(b: Boolean, BooleanType) => b
-    case _ =>
-      throw new AnalysisException("The second argument of First should be a boolean literal.")
-  }
-
-  override def nullable: Boolean = true
-  override def dataType: DataType = child.dataType
-  override def toString: String = s"first(${child}${if (ignoreNulls) " ignore nulls"})"
-
-  override def asPartial: SplitEvaluation = {
-    val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")()
-    SplitEvaluation(
-      First(partialFirst.toAttribute, ignoreNulls),
-      partialFirst :: Nil)
-  }
-  override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this)
-}
-
-object First {
-  def apply(child: Expression): First = First(child, ignoreNulls = false)
-
-  def apply(child: Expression, ignoreNulls: Boolean): First =
-    First(child, Literal.create(ignoreNulls, BooleanType))
-}
-
-case class FirstFunction(
-    expr: Expression,
-    ignoreNulls: Boolean,
-    base: AggregateExpression1)
-  extends AggregateFunction1 {
-
-  def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.
-
-  private[this] var result: Any = null
-
-  private[this] var valueSet: Boolean = false
-
-  override def update(input: InternalRow): Unit = {
-    if (!valueSet) {
-      val value = expr.eval(input)
-      // When we have not set the result, we will set the result if we respect nulls
-      // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null.
-      if (!ignoreNulls || (ignoreNulls && value != null)) {
-        result = value
-        valueSet = true
-      }
-    }
-  }
-
-  override def eval(input: InternalRow): Any = result
-}
-
-case class Last(
-    child: Expression,
-    ignoreNullsExpr: Expression)
-  extends UnaryExpression with PartialAggregate1 {
-
-  def this(child: Expression) = this(child, Literal.create(false, BooleanType))
-
-  private val ignoreNulls: Boolean = ignoreNullsExpr match {
-    case Literal(b: Boolean, BooleanType) => b
-    case _ =>
-      throw new AnalysisException("The second argument of First should be a boolean literal.")
-  }
-
-  override def references: AttributeSet = child.references
-  override def nullable: Boolean = true
-  override def dataType: DataType = child.dataType
-  override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}"
-
-  override def asPartial: SplitEvaluation = {
-    val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")()
-    SplitEvaluation(
-      Last(partialLast.toAttribute, ignoreNulls),
-      partialLast :: Nil)
-  }
-  override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this)
-}
-
-object Last {
-  def apply(child: Expression): Last = Last(child, ignoreNulls = false)
-
-  def apply(child: Expression, ignoreNulls: Boolean): Last =
-    Last(child, Literal.create(ignoreNulls, BooleanType))
-}
-
-case class LastFunction(
-    expr: Expression,
-    ignoreNulls: Boolean,
-    base: AggregateExpression1)
-  extends AggregateFunction1 {
-
-  def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.
-
-  var result: Any = null
-
-  override def update(input: InternalRow): Unit = {
-    val value = expr.eval(input)
-    if (!ignoreNulls || (ignoreNulls && value != null)) {
-      result = value
-    }
-  }
-
-  override def eval(input: InternalRow): Any = {
-    result
-  }
-}
-
-/**
- * Calculate Pearson Correlation Coefficient for the given columns.
- * Only support AggregateExpression2.
- *
- */
-case class Corr(left: Expression, right: Expression)
-    extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes {
-  override def nullable: Boolean = false
-  override def dataType: DoubleType.type = DoubleType
-  override def toString: String = s"corr($left, $right)"
-  override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
-  override def newInstance(): AggregateFunction1 = {
-    throw new UnsupportedOperationException(
-      "Corr only supports the new AggregateExpression2 and can only be used " +
-        "when spark.sql.useAggregate2 = true")
-  }
-}
-
-// Compute standard deviation based on online algorithm specified here:
-// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 {
-  override def nullable: Boolean = true
-  override def dataType: DataType = DoubleType
-
-  def isSample: Boolean
-
-  override def asPartial: SplitEvaluation = {
-    val partialStd = Alias(ComputePartialStd(child), "PartialStddev")()
-    SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil)
-  }
-
-  override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample)
-
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForNumericExpr(child.dataType, "function stddev")
-
-}
-
-// Compute the population standard deviation of a column
-case class StddevPop(child: Expression) extends StddevAgg1(child) {
-
-  override def toString: String = s"stddev_pop($child)"
-  override def isSample: Boolean = false
-}
-
-// Compute the sample standard deviation of a column
-case class StddevSamp(child: Expression) extends StddevAgg1(child) {
-
-  override def toString: String = s"stddev_samp($child)"
-  override def isSample: Boolean = true
-}
-
-case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 {
-  def this() = this(null)
-
-  override def children: Seq[Expression] = child :: Nil
-  override def nullable: Boolean = false
-  override def dataType: DataType = ArrayType(DoubleType)
-  override def toString: String = s"computePartialStddev($child)"
-  override def newInstance(): ComputePartialStdFunction =
-    new ComputePartialStdFunction(child, this)
-}
-
-case class ComputePartialStdFunction (
-    expr: Expression,
-    base: AggregateExpression1
-  ) extends AggregateFunction1 {
-
-  def this() = this(null, null)  // Required for serialization
-
-  private val computeType = DoubleType
-  private val zero = Cast(Literal(0), computeType)
-  private var partialCount: Long = 0L
-
-  // the mean of data processed so far
-  private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType)
-
-  // update average based on this formula:
-  // avg = avg + (value - avg)/count
-  private def avgAddFunction (value: Literal): Expression = {
-    val delta = Subtract(Cast(value, computeType), partialAvg)
-    Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType)))
-  }
-
-  // the sum of squares of difference from mean
-  private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType)
-
-  // update sum of square of difference from mean based on following formula:
-  // Mk = Mk + (value - preAvg) * (value - updatedAvg)
-  private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = {
-    val delta1 = Subtract(Cast(value, computeType), prePartialAvg)
-    val delta2 = Subtract(Cast(value, computeType), partialAvg)
-    Add(partialMk, Multiply(delta1, delta2))
-  }
-
-  override def update(input: InternalRow): Unit = {
-    val evaluatedExpr = expr.eval(input)
-    if (evaluatedExpr != null) {
-      val exprValue = Literal.create(evaluatedExpr, expr.dataType)
-      val prePartialAvg = partialAvg.copy()
-      partialCount += 1
-      partialAvg.update(avgAddFunction(exprValue), input)
-      partialMk.update(mkAddFunction(exprValue, prePartialAvg), input)
-    }
-  }
-
-  override def eval(input: InternalRow): Any = {
-    new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null),
-        partialAvg.eval(null),
-        partialMk.eval(null)))
-  }
-}
-
-case class MergePartialStd(
-    child: Expression,
-    isSample: Boolean
-) extends UnaryExpression with AggregateExpression1 {
-  def this() = this(null, false) // required for serialization
-
-  override def children: Seq[Expression] = child:: Nil
-  override def nullable: Boolean = false
-  override def dataType: DataType = DoubleType
-  override def toString: String = s"MergePartialStd($child)"
-  override def newInstance(): MergePartialStdFunction = {
-    new MergePartialStdFunction(child, this, isSample)
-  }
-}
-
-case class MergePartialStdFunction(
-    expr: Expression,
-    base: AggregateExpression1,
-    isSample: Boolean
-) extends AggregateFunction1 {
-  def this() = this (null, null, false) // Required for serialization
-
-  private val computeType = DoubleType
-  private val zero = Cast(Literal(0), computeType)
-  private val combineCount = MutableLiteral(zero.eval(null), computeType)
-  private val combineAvg = MutableLiteral(zero.eval(null), computeType)
-  private val combineMk = MutableLiteral(zero.eval(null), computeType)
-
-  private def avgUpdateFunction(preCount: Expression,
-                                partialCount: Expression,
-                                partialAvg: Expression): Expression = {
-    Divide(Add(Multiply(combineAvg, preCount),
-               Multiply(partialAvg, partialCount)),
-           Add(preCount, partialCount))
-  }
-
-  override def update(input: InternalRow): Unit = {
-    val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData]
-
-    if (evaluatedExpr != null) {
-      val exprValue = evaluatedExpr.toArray(computeType)
-      val (partialCount, partialAvg, partialMk) =
-        (Literal.create(exprValue(0), computeType),
-         Literal.create(exprValue(1), computeType),
-         Literal.create(exprValue(2), computeType))
-
-      if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) {
-        val preCount = combineCount.copy()
-        combineCount.update(Add(combineCount, partialCount), input)
-
-        val preAvg = combineAvg.copy()
-        val avgDelta = Subtract(partialAvg, preAvg)
-        val mkDelta = Multiply(Multiply(avgDelta, avgDelta),
-                               Divide(Multiply(preCount, partialCount),
-                                      combineCount))
-
-        // update average based on following formula
-        // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount)
-        combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input)
-
-        // update sum of square differences from mean based on following formula
-        // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount)
-        combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input)
-      }
-    }
-  }
-
-  override def eval(input: InternalRow): Any = {
-    val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long]
-
-    if (count == 0) null
-    else if (count < 2) zero.eval(null)
-    else {
-      // when total count > 2
-      // stddev_samp = sqrt (combineMk/(combineCount -1))
-      // stddev_pop = sqrt (combineMk/combineCount)
-      val varCol = {
-        if (isSample) {
-          Divide(combineMk, Cast(Literal(count - 1), computeType))
-        }
-        else {
-          Divide(combineMk, Cast(Literal(count), computeType))
-        }
-      }
-      Sqrt(varCol).eval(null)
-    }
-  }
-}
-
-case class StddevFunction(
-    expr: Expression,
-    base: AggregateExpression1,
-    isSample: Boolean
-) extends AggregateFunction1 {
-
-  def this() = this(null, null, false) // Required for serialization
-
-  private val computeType = DoubleType
-  private var curCount: Long = 0L
-  private val zero = Cast(Literal(0), computeType)
-  private val curAvg = MutableLiteral(zero.eval(null), computeType)
-  private val curMk = MutableLiteral(zero.eval(null), computeType)
-
-  private def curAvgAddFunction(value: Literal): Expression = {
-    val delta = Subtract(Cast(value, computeType), curAvg)
-    Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType)))
-  }
-  private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = {
-    val delta1 = Subtract(Cast(value, computeType), preAvg)
-    val delta2 = Subtract(Cast(value, computeType), curAvg)
-    Add(curMk, Multiply(delta1, delta2))
-  }
-
-  override def update(input: InternalRow): Unit = {
-    val evaluatedExpr = expr.eval(input)
-    if (evaluatedExpr != null) {
-      val preAvg: MutableLiteral = curAvg.copy()
-      val exprValue = Literal.create(evaluatedExpr, expr.dataType)
-      curCount += 1L
-      curAvg.update(curAvgAddFunction(exprValue), input)
-      curMk.update(curMkAddFunction(exprValue, preAvg), input)
-    }
-  }
-
-  override def eval(input: InternalRow): Any = {
-    if (curCount == 0) null
-    else if (curCount < 2) zero.eval(null)
-    else {
-      // when total count > 2,
-      // stddev_samp = sqrt(curMk/(curCount - 1))
-      // stddev_pop = sqrt(curMk/curCount)
-      val varCol = {
-        if (isSample) {
-          Divide(curMk, Cast(Literal(curCount - 1), computeType))
-        }
-        else {
-          Divide(curMk, Cast(Literal(curCount), computeType))
-        }
-      }
-      Sqrt(varCol).eval(null)
-    }
-  }
-}
-
-// placeholder
-case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
-  override def newInstance(): AggregateFunction1 = {
-    throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
-      "please set spark.sql.useAggregate2 = true")
-  }
-
-  override def nullable: Boolean = false
-
-  override def dataType: DoubleType.type = DoubleType
-
-  override def foldable: Boolean = false
-
-  override def prettyName: String = "kurtosis"
-}
-
-// placeholder
-case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
-  override def newInstance(): AggregateFunction1 = {
-    throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
-      "please set spark.sql.useAggregate2 = true")
-  }
-
-  override def nullable: Boolean = false
-
-  override def dataType: DoubleType.type = DoubleType
-
-  override def foldable: Boolean = false
-
-  override def prettyName: String = "skewness"
-}
-
-// placeholder
-case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
-  override def newInstance(): AggregateFunction1 = {
-    throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
-      "please set spark.sql.useAggregate2 = true")
-  }
-
-  override def nullable: Boolean = false
-
-  override def dataType: DoubleType.type = DoubleType
-
-  override def foldable: Boolean = false
-
-  override def prettyName: String = "var_pop"
-}
-
-// placeholder
-case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression1 {
-
-  override def newInstance(): AggregateFunction1 = {
-    throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " +
-      "please set spark.sql.useAggregate2 = true")
-  }
-
-  override def nullable: Boolean = false
-
-  override def dataType: DoubleType.type = DoubleType
-
-  override def foldable: Boolean = false
-
-  override def prettyName: String = "var_samp"
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index d222dfa..f4dba67 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
 import scala.collection.immutable.HashSet
 import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.FullOuter
 import org.apache.spark.sql.catalyst.plans.LeftOuter
@@ -201,8 +202,8 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
 object ColumnPruning extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case a @ Aggregate(_, _, e @ Expand(_, _, child))
-      if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty =>
-      a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references)))
+      if (child.outputSet -- e.references -- a.references).nonEmpty =>
+      a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
 
     // Eliminate attributes that are not needed to calculate the specified aggregates.
     case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
@@ -363,7 +364,8 @@ object LikeSimplification extends Rule[LogicalPlan] {
 object NullPropagation extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case q: LogicalPlan => q transformExpressionsUp {
-      case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
+      case e @ AggregateExpression(Count(Literal(null, _)), _, _) =>
+        Cast(Literal(0L), e.dataType)
       case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
       case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
       case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType)
@@ -375,7 +377,9 @@ object NullPropagation extends Rule[LogicalPlan] {
         Literal.create(null, e.dataType)
       case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
       case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
-      case e @ Count(expr) if !expr.nullable => Count(Literal(1))
+      case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable =>
+        // This rule should be only triggered when isDistinct field is false.
+        AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
 
       // For Coalesce, remove null literals.
       case e @ Coalesce(children) =>
@@ -857,12 +861,15 @@ object DecimalAggregates extends Rule[LogicalPlan] {
   private val MAX_DOUBLE_DIGITS = 15
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
-    case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
-      MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale)
+    case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+      if prec + 10 <= MAX_LONG_DIGITS =>
+      MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale)
 
-    case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
+    case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+      if prec + 4 <= MAX_DOUBLE_DIGITS =>
+      val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct)
       Cast(
-        Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)),
+        Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
         DecimalType(prec + 4, scale + 4))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 3b975b9..6f4f114 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -85,80 +85,6 @@ object PhysicalOperation extends PredicateHelper {
 }
 
 /**
- * Matches a logical aggregation that can be performed on distributed data in two steps.  The first
- * operates on the data in each partition performing partial aggregation for each group.  The second
- * occurs after the shuffle and completes the aggregation.
- *
- * This pattern will only match if all aggregate expressions can be computed partially and will
- * return the rewritten aggregation expressions for both phases.
- *
- * The returned values for this match are as follows:
- *  - Grouping attributes for the final aggregation.
- *  - Aggregates for the final aggregation.
- *  - Grouping expressions for the partial aggregation.
- *  - Partial aggregate expressions.
- *  - Input to the aggregation.
- */
-object PartialAggregation {
-  type ReturnType =
-    (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
-
-  def unapply(plan: LogicalPlan): Option[ReturnType] = plan match {
-    case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
-      // Collect all aggregate expressions.
-      val allAggregates =
-        aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a})
-      // Collect all aggregate expressions that can be computed partially.
-      val partialAggregates =
-        aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p})
-
-      // Only do partial aggregation if supported by all aggregate expressions.
-      if (allAggregates.size == partialAggregates.size) {
-        // Create a map of expressions to their partial evaluations for all aggregate expressions.
-        val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] =
-          partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap
-
-        // We need to pass all grouping expressions though so the grouping can happen a second
-        // time. However some of them might be unnamed so we alias them allowing them to be
-        // referenced in the second aggregation.
-        val namedGroupingExpressions: Seq[(Expression, NamedExpression)] =
-          groupingExpressions.map {
-            case n: NamedExpression => (n, n)
-            case other => (other, Alias(other, "PartialGroup")())
-          }
-
-        // Replace aggregations with a new expression that computes the result from the already
-        // computed partial evaluations and grouping values.
-        val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown {
-          case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
-            partialEvaluations(new TreeNodeRef(e)).finalEvaluation
-
-          case e: Expression =>
-            namedGroupingExpressions.collectFirst {
-              case (expr, ne) if expr semanticEquals e => ne.toAttribute
-            }.getOrElse(e)
-        }).asInstanceOf[Seq[NamedExpression]]
-
-        val partialComputation = namedGroupingExpressions.map(_._2) ++
-          partialEvaluations.values.flatMap(_.partialEvaluations)
-
-        val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
-
-        Some(
-          (namedGroupingAttributes,
-           rewrittenAggregateExpressions,
-           groupingExpressions,
-           partialComputation,
-           child))
-      } else {
-        None
-      }
-    case _ => None
-  }
-}
-
-
-/**
  * A pattern that finds joins with equality conditions that can be evaluated using equi-join.
  *
  * Null-safe equality will be transformed into equality as joining key (replace null with default

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0ec9f08..b9db783 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -137,13 +137,17 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
 
   /** Returns all of the expressions present in this query plan operator. */
   def expressions: Seq[Expression] = {
+    // Recursively find all expressions from a traversable.
+    def seqToExpressions(seq: Traversable[Any]): Traversable[Expression] = seq.flatMap {
+      case e: Expression => e :: Nil
+      case s: Traversable[_] => seqToExpressions(s)
+      case other => Nil
+    }
+
     productIterator.flatMap {
       case e: Expression => e :: Nil
       case Some(e: Expression) => e :: Nil
-      case seq: Traversable[_] => seq.flatMap {
-        case e: Expression => e :: Nil
-        case other => Nil
-      }
+      case seq: Traversable[_] => seqToExpressions(seq)
       case other => Nil
     }.toSeq
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d771088..764f8aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.OpenHashSet
@@ -219,8 +219,6 @@ case class Aggregate(
     !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions
   }
 
-  lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this)
-
   override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index fbdd3a7..5a2368e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -171,16 +171,18 @@ class AnalysisErrorSuite extends AnalysisTest {
 
   test("SPARK-6452 regression test") {
     // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
+    // Since we manually construct the logical plan at here and Sum only accetp
+    // LongType, DoubleType, and DecimalType. We use LongType as the type of a.
     val plan =
       Aggregate(
         Nil,
-        Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil,
+        Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil,
         LocalRelation(
-          AttributeReference("a", IntegerType)(exprId = ExprId(2))))
+          AttributeReference("a", LongType)(exprId = ExprId(2))))
 
     assert(plan.resolved)
 
-    assertAnalysisError(plan, "resolved attribute(s) a#1 missing from a#2" :: Nil)
+    assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil)
   }
 
   test("error test for self-join") {
@@ -196,7 +198,7 @@ class AnalysisErrorSuite extends AnalysisTest {
     val plan =
       Aggregate(
         AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil,
-        Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+        Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
         LocalRelation(
           AttributeReference("a", BinaryType)(exprId = ExprId(2)),
           AttributeReference("b", IntegerType)(exprId = ExprId(1))))
@@ -207,13 +209,24 @@ class AnalysisErrorSuite extends AnalysisTest {
     val plan2 =
       Aggregate(
         AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil,
-        Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+        Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
         LocalRelation(
           AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
           AttributeReference("b", IntegerType)(exprId = ExprId(1))))
 
     assertAnalysisError(plan2,
       "map type expression a cannot be used in grouping expression" :: Nil)
+
+    val plan3 =
+      Aggregate(
+        AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)) :: Nil,
+        Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+        LocalRelation(
+          AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)),
+          AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+
+    assertAnalysisError(plan3,
+      "array type expression a cannot be used in grouping expression" :: Nil)
   }
 
   test("Join can't work on binary and map types") {

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 71d2939..65f09b4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -45,7 +45,7 @@ class AnalysisSuite extends AnalysisTest {
     val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
     assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)
 
-    assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)
+    assert(!Project(Seq(Alias(count(Literal(1)), "count")()), testRelation).resolved)
   }
 
   test("analyze project") {

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 40c4ae7..fed591f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf}

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index c9bcc68..b902982 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.types.{TypeCollection, StringType}
 
@@ -140,15 +141,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
   }
 
   test("check types for aggregates") {
+    // We use AggregateFunction directly at here because the error will be thrown from it
+    // instead of from AggregateExpression, which is the wrapper of an AggregateFunction.
+
     // We will cast String to Double for sum and average
     assertSuccess(Sum('stringField))
-    assertSuccess(SumDistinct('stringField))
     assertSuccess(Average('stringField))
 
     assertError(Min('complexField), "min does not support ordering on type")
     assertError(Max('complexField), "max does not support ordering on type")
     assertError(Sum('booleanField), "function sum requires numeric type")
-    assertError(SumDistinct('booleanField), "function sumDistinct requires numeric type")
     assertError(Average('booleanField), "function average requires numeric type")
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e0701c75/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index e676062..8aaefa8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -162,7 +162,7 @@ class ConstantFoldingSuite extends PlanTest {
       testRelation
         .select(
           Rand(5L) + Literal(1) as Symbol("c1"),
-          Sum('a) as Symbol("c2"))
+          sum('a) as Symbol("c2"))
 
     val optimized = Optimize.execute(originalQuery.analyze)
 
@@ -170,7 +170,7 @@ class ConstantFoldingSuite extends PlanTest {
       testRelation
         .select(
           Rand(5L) + Literal(1.0) as Symbol("c1"),
-          Sum('a) as Symbol("c2"))
+          sum('a) as Symbol("c2"))
         .analyze
 
     comparePlans(optimized, correctAnswer)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message