Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 32734200B6A for ; Mon, 22 Aug 2016 18:12:30 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 2F885160AB3; Mon, 22 Aug 2016 16:12:30 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 77BC7160A87 for ; Mon, 22 Aug 2016 18:12:29 +0200 (CEST) Received: (qmail 61930 invoked by uid 500); 22 Aug 2016 16:12:28 -0000 Mailing-List: contact reviews-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list reviews@spark.apache.org Received: (qmail 61915 invoked by uid 99); 22 Aug 2016 16:12:28 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 22 Aug 2016 16:12:28 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id E0BC4E055D; Mon, 22 Aug 2016 16:12:27 +0000 (UTC) From: hvanhovell To: reviews@spark.apache.org Reply-To: reviews@spark.apache.org References: In-Reply-To: Subject: [GitHub] spark pull request #10896: [SPARK-12978][SQL] Skip unnecessary final group-b... Content-Type: text/plain Message-Id: <20160822161227.E0BC4E055D@git1-us-west.apache.org> Date: Mon, 22 Aug 2016 16:12:27 +0000 (UTC) archived-at: Mon, 22 Aug 2016 16:12:30 -0000 Github user hvanhovell commented on a diff in the pull request: https://github.com/apache/spark/pull/10896#discussion_r75707785 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala --- @@ -27,26 +27,87 @@ import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateSto */ object AggUtils { - def planAggregateWithoutPartial( + private[execution] def isAggregate(operator: SparkPlan): Boolean = { + operator.isInstanceOf[HashAggregateExec] || operator.isInstanceOf[SortAggregateExec] + } + + private[execution] def supportPartialAggregate(operator: SparkPlan): Boolean = { + assert(isAggregate(operator)) + def supportPartial(exprs: Seq[AggregateExpression]) = + exprs.map(_.aggregateFunction).forall(_.supportsPartial) + operator match { + case agg @ HashAggregateExec(_, _, aggregateExpressions, _, _, _, _) => + supportPartial(aggregateExpressions) + case agg @ SortAggregateExec(_, _, aggregateExpressions, _, _, _, _) => + supportPartial(aggregateExpressions) + } + } + + private def createPartialAggregateExec( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - resultExpressions: Seq[NamedExpression], - child: SparkPlan): Seq[SparkPlan] = { + child: SparkPlan): SparkPlan = { + val groupingAttributes = groupingExpressions.map(_.toAttribute) + val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) + val partialAggregateExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, _, false, _) if functionsWithDistinct.length > 0 => + agg.copy(mode = PartialMerge) + case agg => + agg.copy(mode = Partial) + } + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + val partialResultExpressions = + groupingAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) - SortAggregateExec( - requiredChildDistributionExpressions = Some(groupingExpressions), + createAggregateExec( + requiredChildDistributionExpressions = None, groupingExpressions = groupingExpressions, - aggregateExpressions = completeAggregateExpressions, - aggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = 0, - resultExpressions = resultExpressions, - child = child - ) :: Nil + aggregateExpressions = partialAggregateExpressions, + aggregateAttributes = partialAggregateAttributes, + initialInputBufferOffset = if (functionsWithDistinct.length > 0) { + groupingExpressions.length + functionsWithDistinct.head.aggregateFunction.children.length + } else { + 0 + }, + resultExpressions = partialResultExpressions, + child = child) + } + + private def updateMergeAggregateMode(aggregateExpressions: Seq[AggregateExpression]) = { + def updateMode(mode: AggregateMode) = mode match { + case Partial => PartialMerge + case Complete => Final + case mode => mode + } + aggregateExpressions.map(e => e.copy(mode = updateMode(e.mode))) + } + + private[execution] def createPartialAggregate(operator: SparkPlan) --- End diff -- Much better --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastructure@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org For additional commands, e-mail: reviews-help@spark.apache.org