Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id A1E9418454 for ; Thu, 6 Aug 2015 22:04:56 +0000 (UTC) Received: (qmail 65495 invoked by uid 500); 6 Aug 2015 22:04:56 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 65458 invoked by uid 500); 6 Aug 2015 22:04:56 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 65444 invoked by uid 99); 6 Aug 2015 22:04:56 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 06 Aug 2015 22:04:56 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 698FDE6B7F; Thu, 6 Aug 2015 22:04:56 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: rxin@apache.org To: commits@spark.apache.org Date: Thu, 06 Aug 2015 22:04:56 -0000 Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: [1/2] spark git commit: [SPARK-9630] [SQL] Clean up new aggregate operators (SPARK-9240 follow up) Repository: spark Updated Branches: refs/heads/branch-1.5 980687206 -> 272e88342 http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 960be08..80816a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -17,20 +17,41 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan} +import org.apache.spark.sql.types.StructType /** * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { + def supportsTungstenAggregate( + groupingExpressions: Seq[Expression], + aggregateBufferAttributes: Seq[Attribute]): Boolean = { + val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeProjection.canSupport(groupingExpressions) + } + def planAggregateWithoutDistinct( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + // Check if we can use TungstenAggregate. + val usesTungstenAggregate = + child.sqlContext.conf.unsafeEnabled && + aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) && + supportsTungstenAggregate( + groupingExpressions, + aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + + // 1. Create an Aggregate Operator for partial aggregations. val namedGroupingExpressions = groupingExpressions.map { case ne: NamedExpression => ne -> ne @@ -44,11 +65,23 @@ object Utils { val groupExpressionMap = namedGroupingExpressions.toMap val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) - val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } - val partialAggregate = - Aggregate( + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) + val partialResultExpressions = + namedGroupingAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + + val partialAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = namedGroupingExpressions.map(_._2), + nonCompleteAggregateExpressions = partialAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialResultExpressions, + child = child) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], groupingExpressions = namedGroupingExpressions.map(_._2), nonCompleteAggregateExpressions = partialAggregateExpressions, @@ -56,29 +89,57 @@ object Utils { completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, initialInputBufferOffset = 0, - resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes, + resultExpressions = partialResultExpressions, child = child) + } // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 } - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - val finalAggregate = - Aggregate( + + val finalAggregate = if (usesTungstenAggregate) { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + // aggregateFunctionMap contains unique aggregate functions. + val aggregateFunction = + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1 + aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = namedGroupingAttributes.length, + resultExpressions = rewrittenResultExpressions, + child = partialAggregate) + } else { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, @@ -88,6 +149,7 @@ object Utils { initialInputBufferOffset = namedGroupingAttributes.length, resultExpressions = rewrittenResultExpressions, child = partialAggregate) + } finalAggregate :: Nil } @@ -96,10 +158,18 @@ object Utils { groupingExpressions: Seq[Expression], functionsWithDistinct: Seq[AggregateExpression2], functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct + val usesTungstenAggregate = + child.sqlContext.conf.unsafeEnabled && + aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) && + supportsTungstenAggregate( + groupingExpressions, + aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)) + // 1. Create an Aggregate Operator for partial aggregations. // The grouping expressions are original groupingExpressions and // distinct columns. For example, for avg(distinct value) ... group by key @@ -129,19 +199,26 @@ object Utils { val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) - val partialAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, Partial, false) - } - val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } + val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val partialAggregateAttributes = + partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) val partialAggregateGroupingExpressions = (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2) val partialAggregateResult = - namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes - val partialAggregate = - Aggregate( + namedGroupingAttributes ++ + distinctColumnAttributes ++ + partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + val partialAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = None: Option[Seq[Expression]], + groupingExpressions = partialAggregateGroupingExpressions, + nonCompleteAggregateExpressions = partialAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = 0, + resultExpressions = partialAggregateResult, + child = child) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], groupingExpressions = partialAggregateGroupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, @@ -151,20 +228,27 @@ object Utils { initialInputBufferOffset = 0, resultExpressions = partialAggregateResult, child = child) + } // 2. Create an Aggregate Operator for partial merge aggregations. - val partialMergeAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, PartialMerge, false) - } + val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.flatMap { agg => - agg.aggregateFunction.bufferAttributes - } + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes) val partialMergeAggregateResult = - namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes - val partialMergeAggregate = - Aggregate( + namedGroupingAttributes ++ + distinctColumnAttributes ++ + partialMergeAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes) + val partialMergeAggregate = if (usesTungstenAggregate) { + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + completeAggregateExpressions = Nil, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = partialMergeAggregateResult, + child = partialAggregate) + } else { + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, @@ -174,48 +258,91 @@ object Utils { initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) + } // 3. Create an Aggregate Operator for partial merge aggregations. - val finalAggregateExpressions = functionsWithoutDistinct.map { - case AggregateExpression2(aggregateFunction, mode, _) => - AggregateExpression2(aggregateFunction, Final, false) - } + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 } + // Create a map to store those rewritten aggregate functions. We always need to use + // both function and its corresponding isDistinct flag as the key because function itself + // does not knows if it is has distinct keyword or now. + val rewrittenAggregateFunctions = + mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2] val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) => + case agg @ AggregateExpression2(aggregateFunction, mode, true) => val rewrittenAggregateFunction = aggregateFunction.transformDown { case expr if distinctColumnExpressionMap.contains(expr) => distinctColumnExpressionMap(expr).toAttribute }.asInstanceOf[AggregateFunction2] + // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions + // to track the old version and the new version of this function. + rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. + // We just keep the isDistinct setting to true, so when users look at the query plan, + // they still can see distinct aggregations. val rewrittenAggregateExpression = - AggregateExpression2(rewrittenAggregateFunction, Complete, false) + AggregateExpression2(rewrittenAggregateFunction, Complete, true) - val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct) + val aggregateFunctionAttribute = + aggregateFunctionMap(agg.aggregateFunction, true)._2 (rewrittenAggregateExpression -> aggregateFunctionAttribute) }.unzip - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - val finalAndCompleteAggregate = - Aggregate( + val finalAndCompleteAggregate = if (usesTungstenAggregate) { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + val function = agg.aggregateFunction + val isDistinct = agg.isDistinct + val aggregateFunction = + if (rewrittenAggregateFunctions.contains(function, isDistinct)) { + // If this function has been rewritten, we get the rewritten version from + // rewrittenAggregateFunctions. + rewrittenAggregateFunctions(function, isDistinct) + } else { + // Oterwise, we get it from aggregateFunctionMap, which contains unique + // aggregate functions that have not been rewritten. + aggregateFunctionMap(function, isDistinct)._1 + } + aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + TungstenAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingAttributes, + nonCompleteAggregateExpressions = finalAggregateExpressions, + completeAggregateExpressions = completeAggregateExpressions, + initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = rewrittenResultExpressions, + child = partialMergeAggregate) + } else { + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + SortBasedAggregate( requiredChildDistributionExpressions = Some(namedGroupingAttributes), groupingExpressions = namedGroupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, @@ -225,6 +352,7 @@ object Utils { initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, resultExpressions = rewrittenResultExpressions, child = partialMergeAggregate) + } finalAndCompleteAggregate :: Nil } http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cef40dd..c64aa7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -262,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = sql(sqlText) // First, check if we have GeneratedAggregate. val hasGeneratedAgg = df.queryExecution.executedPlan - .collect { case _: aggregate.Aggregate => true } + .collect { case _: aggregate.TungstenAggregate => true } .nonEmpty if (!hasGeneratedAgg) { fail( http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4b35c8f..7b5aa47 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -21,9 +21,9 @@ import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row} +import org.apache.spark.sql._ import org.scalatest.BeforeAndAfterAll -import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} +import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { @@ -141,6 +141,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Nil) } + test("null literal") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(null), + | COUNT(null), + | FIRST(null), + | LAST(null), + | MAX(null), + | MIN(null), + | SUM(null) + """.stripMargin), + Row(null, 0, null, null, null, null, null) :: Nil) + } + test("only do grouping") { checkAnswer( sqlContext.sql( @@ -266,13 +282,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |SELECT avg(value) FROM agg1 """.stripMargin), Row(11.125) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT avg(null) - """.stripMargin), - Row(null) :: Nil) } test("udaf") { @@ -364,7 +373,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be | max(distinct value1) |FROM agg2 """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100.0)) + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) checkAnswer( sqlContext.sql( @@ -402,6 +411,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: Row(3, null, 3.0, null, null, null) :: Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) } test("test count") { @@ -496,7 +522,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |FROM agg1 |GROUP BY key """.stripMargin).queryExecution.executedPlan.collect { - case agg: aggregate.Aggregate => agg + case agg: aggregate.SortBasedAggregate => agg + case agg: aggregate.TungstenAggregate => agg } val message = "We should fallback to the old aggregation code path if " + @@ -537,3 +564,58 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite { sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) } } + +class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { + + var originalUnsafeEnabled: Boolean = _ + + override def beforeAll(): Unit = { + originalUnsafeEnabled = sqlContext.conf.unsafeEnabled + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true") + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString) + sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt") + } + + override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit = { + (0 to 2).foreach { fallbackStartsAt => + sqlContext.setConf( + "spark.sql.TungstenAggregate.testFallbackStartsAt", + fallbackStartsAt.toString) + + // Create a new df to make sure its physical operator picks up + // spark.sql.TungstenAggregate.testFallbackStartsAt. + val newActual = DataFrame(sqlContext, actual.logicalPlan) + + QueryTest.checkAnswer(newActual, expectedAnswer) match { + case Some(errorMessage) => + val newErrorMessage = + s""" + |The following aggregation query failed when using TungstenAggregate with + |controlled fallback (it falls back to sort-based aggregation once it has processed + |$fallbackStartsAt input rows). The query is + |${actual.queryExecution} + | + |$errorMessage + """.stripMargin + + fail(newErrorMessage) + case None => + } + } + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = { + checkAnswer(df, Seq(expectedAnswer)) + } + + // Override it to make sure we call the actually overridden checkAnswer. + override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org