spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject [1/2] spark git commit: [SPARK-9630] [SQL] Clean up new aggregate operators (SPARK-9240 follow up)
Date Thu, 06 Aug 2015 22:04:56 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 980687206 -> 272e88342


http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 960be08..80816a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -17,20 +17,41 @@
 
 package org.apache.spark.sql.execution.aggregate
 
+import scala.collection.mutable
+
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan}
+import org.apache.spark.sql.types.StructType
 
 /**
  * Utility functions used by the query planner to convert our plan to new aggregation code
path.
  */
 object Utils {
+  def supportsTungstenAggregate(
+      groupingExpressions: Seq[Expression],
+      aggregateBufferAttributes: Seq[Attribute]): Boolean = {
+    val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
+
+    UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema)
&&
+      UnsafeProjection.canSupport(groupingExpressions)
+  }
+
   def planAggregateWithoutDistinct(
       groupingExpressions: Seq[Expression],
       aggregateExpressions: Seq[AggregateExpression2],
-      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
+    // Check if we can use TungstenAggregate.
+    val usesTungstenAggregate =
+      child.sqlContext.conf.unsafeEnabled &&
+      aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]) &&
+      supportsTungstenAggregate(
+        groupingExpressions,
+        aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+
+
     // 1. Create an Aggregate Operator for partial aggregations.
     val namedGroupingExpressions = groupingExpressions.map {
       case ne: NamedExpression => ne -> ne
@@ -44,11 +65,23 @@ object Utils {
     val groupExpressionMap = namedGroupingExpressions.toMap
     val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
     val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
-    val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
-      agg.aggregateFunction.bufferAttributes
-    }
-    val partialAggregate =
-      Aggregate(
+    val partialAggregateAttributes =
+      partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
+    val partialResultExpressions =
+      namedGroupingAttributes ++
+        partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+
+    val partialAggregate = if (usesTungstenAggregate) {
+      TungstenAggregate(
+        requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+        groupingExpressions = namedGroupingExpressions.map(_._2),
+        nonCompleteAggregateExpressions = partialAggregateExpressions,
+        completeAggregateExpressions = Nil,
+        initialInputBufferOffset = 0,
+        resultExpressions = partialResultExpressions,
+        child = child)
+    } else {
+      SortBasedAggregate(
         requiredChildDistributionExpressions = None: Option[Seq[Expression]],
         groupingExpressions = namedGroupingExpressions.map(_._2),
         nonCompleteAggregateExpressions = partialAggregateExpressions,
@@ -56,29 +89,57 @@ object Utils {
         completeAggregateExpressions = Nil,
         completeAggregateAttributes = Nil,
         initialInputBufferOffset = 0,
-        resultExpressions = namedGroupingAttributes ++ partialAggregateAttributes,
+        resultExpressions = partialResultExpressions,
         child = child)
+    }
 
     // 2. Create an Aggregate Operator for final aggregations.
     val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
     val finalAggregateAttributes =
       finalAggregateExpressions.map {
-        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
       }
-    val rewrittenResultExpressions = resultExpressions.map { expr =>
-      expr.transformDown {
-        case agg: AggregateExpression2 =>
-          aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
-        case expression =>
-          // We do not rely on the equality check at here since attributes may
-          // different cosmetically. Instead, we use semanticEquals.
-          groupExpressionMap.collectFirst {
-            case (expr, ne) if expr semanticEquals expression => ne.toAttribute
-          }.getOrElse(expression)
-      }.asInstanceOf[NamedExpression]
-    }
-    val finalAggregate =
-      Aggregate(
+
+    val finalAggregate = if (usesTungstenAggregate) {
+      val rewrittenResultExpressions = resultExpressions.map { expr =>
+        expr.transformDown {
+          case agg: AggregateExpression2 =>
+            // aggregateFunctionMap contains unique aggregate functions.
+            val aggregateFunction =
+              aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1
+            aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression
+          case expression =>
+            // We do not rely on the equality check at here since attributes may
+            // different cosmetically. Instead, we use semanticEquals.
+            groupExpressionMap.collectFirst {
+              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+            }.getOrElse(expression)
+        }.asInstanceOf[NamedExpression]
+      }
+
+      TungstenAggregate(
+        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+        groupingExpressions = namedGroupingAttributes,
+        nonCompleteAggregateExpressions = finalAggregateExpressions,
+        completeAggregateExpressions = Nil,
+        initialInputBufferOffset = namedGroupingAttributes.length,
+        resultExpressions = rewrittenResultExpressions,
+        child = partialAggregate)
+    } else {
+      val rewrittenResultExpressions = resultExpressions.map { expr =>
+        expr.transformDown {
+          case agg: AggregateExpression2 =>
+            aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
+          case expression =>
+            // We do not rely on the equality check at here since attributes may
+            // different cosmetically. Instead, we use semanticEquals.
+            groupExpressionMap.collectFirst {
+              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+            }.getOrElse(expression)
+        }.asInstanceOf[NamedExpression]
+      }
+
+      SortBasedAggregate(
         requiredChildDistributionExpressions = Some(namedGroupingAttributes),
         groupingExpressions = namedGroupingAttributes,
         nonCompleteAggregateExpressions = finalAggregateExpressions,
@@ -88,6 +149,7 @@ object Utils {
         initialInputBufferOffset = namedGroupingAttributes.length,
         resultExpressions = rewrittenResultExpressions,
         child = partialAggregate)
+    }
 
     finalAggregate :: Nil
   }
@@ -96,10 +158,18 @@ object Utils {
       groupingExpressions: Seq[Expression],
       functionsWithDistinct: Seq[AggregateExpression2],
       functionsWithoutDistinct: Seq[AggregateExpression2],
-      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute],
+      aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
 
+    val aggregateExpressions = functionsWithDistinct ++ functionsWithoutDistinct
+    val usesTungstenAggregate =
+      child.sqlContext.conf.unsafeEnabled &&
+        aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])
&&
+        supportsTungstenAggregate(
+          groupingExpressions,
+          aggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
+
     // 1. Create an Aggregate Operator for partial aggregations.
     // The grouping expressions are original groupingExpressions and
     // distinct columns. For example, for avg(distinct value) ... group by key
@@ -129,19 +199,26 @@ object Utils {
     val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap
     val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute)
 
-    val partialAggregateExpressions = functionsWithoutDistinct.map {
-      case AggregateExpression2(aggregateFunction, mode, _) =>
-        AggregateExpression2(aggregateFunction, Partial, false)
-    }
-    val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
-      agg.aggregateFunction.bufferAttributes
-    }
+    val partialAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
+    val partialAggregateAttributes =
+      partialAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
     val partialAggregateGroupingExpressions =
       (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2)
     val partialAggregateResult =
-      namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes
-    val partialAggregate =
-      Aggregate(
+      namedGroupingAttributes ++
+        distinctColumnAttributes ++
+        partialAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+    val partialAggregate = if (usesTungstenAggregate) {
+      TungstenAggregate(
+        requiredChildDistributionExpressions = None: Option[Seq[Expression]],
+        groupingExpressions = partialAggregateGroupingExpressions,
+        nonCompleteAggregateExpressions = partialAggregateExpressions,
+        completeAggregateExpressions = Nil,
+        initialInputBufferOffset = 0,
+        resultExpressions = partialAggregateResult,
+        child = child)
+    } else {
+      SortBasedAggregate(
         requiredChildDistributionExpressions = None: Option[Seq[Expression]],
         groupingExpressions = partialAggregateGroupingExpressions,
         nonCompleteAggregateExpressions = partialAggregateExpressions,
@@ -151,20 +228,27 @@ object Utils {
         initialInputBufferOffset = 0,
         resultExpressions = partialAggregateResult,
         child = child)
+    }
 
     // 2. Create an Aggregate Operator for partial merge aggregations.
-    val partialMergeAggregateExpressions = functionsWithoutDistinct.map {
-      case AggregateExpression2(aggregateFunction, mode, _) =>
-        AggregateExpression2(aggregateFunction, PartialMerge, false)
-    }
+    val partialMergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
     val partialMergeAggregateAttributes =
-      partialMergeAggregateExpressions.flatMap { agg =>
-        agg.aggregateFunction.bufferAttributes
-      }
+      partialMergeAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes)
     val partialMergeAggregateResult =
-      namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes
-    val partialMergeAggregate =
-      Aggregate(
+      namedGroupingAttributes ++
+        distinctColumnAttributes ++
+        partialMergeAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
+    val partialMergeAggregate = if (usesTungstenAggregate) {
+      TungstenAggregate(
+        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+        groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
+        nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
+        completeAggregateExpressions = Nil,
+        initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+        resultExpressions = partialMergeAggregateResult,
+        child = partialAggregate)
+    } else {
+      SortBasedAggregate(
         requiredChildDistributionExpressions = Some(namedGroupingAttributes),
         groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes,
         nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
@@ -174,48 +258,91 @@ object Utils {
         initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
         resultExpressions = partialMergeAggregateResult,
         child = partialAggregate)
+    }
 
     // 3. Create an Aggregate Operator for partial merge aggregations.
-    val finalAggregateExpressions = functionsWithoutDistinct.map {
-      case AggregateExpression2(aggregateFunction, mode, _) =>
-        AggregateExpression2(aggregateFunction, Final, false)
-    }
+    val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
     val finalAggregateAttributes =
       finalAggregateExpressions.map {
-        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
+        expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2
       }
+    // Create a map to store those rewritten aggregate functions. We always need to use
+    // both function and its corresponding isDistinct flag as the key because function itself
+    // does not knows if it is has distinct keyword or now.
+    val rewrittenAggregateFunctions =
+      mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2]
     val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map
{
       // Children of an AggregateFunction with DISTINCT keyword has already
       // been evaluated. At here, we need to replace original children
       // to AttributeReferences.
-      case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) =>
+      case agg @ AggregateExpression2(aggregateFunction, mode, true) =>
         val rewrittenAggregateFunction = aggregateFunction.transformDown {
           case expr if distinctColumnExpressionMap.contains(expr) =>
             distinctColumnExpressionMap(expr).toAttribute
         }.asInstanceOf[AggregateFunction2]
+        // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions
+        // to track the old version and the new version of this function.
+        rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction
         // We rewrite the aggregate function to a non-distinct aggregation because
         // its input will have distinct arguments.
+        // We just keep the isDistinct setting to true, so when users look at the query plan,
+        // they still can see distinct aggregations.
         val rewrittenAggregateExpression =
-          AggregateExpression2(rewrittenAggregateFunction, Complete, false)
+          AggregateExpression2(rewrittenAggregateFunction, Complete, true)
 
-        val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct)
+        val aggregateFunctionAttribute =
+          aggregateFunctionMap(agg.aggregateFunction, true)._2
         (rewrittenAggregateExpression -> aggregateFunctionAttribute)
     }.unzip
 
-    val rewrittenResultExpressions = resultExpressions.map { expr =>
-      expr.transform {
-        case agg: AggregateExpression2 =>
-          aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute
-        case expression =>
-          // We do not rely on the equality check at here since attributes may
-          // different cosmetically. Instead, we use semanticEquals.
-          groupExpressionMap.collectFirst {
-            case (expr, ne) if expr semanticEquals expression => ne.toAttribute
-          }.getOrElse(expression)
-      }.asInstanceOf[NamedExpression]
-    }
-    val finalAndCompleteAggregate =
-      Aggregate(
+    val finalAndCompleteAggregate = if (usesTungstenAggregate) {
+      val rewrittenResultExpressions = resultExpressions.map { expr =>
+        expr.transform {
+          case agg: AggregateExpression2 =>
+            val function = agg.aggregateFunction
+            val isDistinct = agg.isDistinct
+            val aggregateFunction =
+              if (rewrittenAggregateFunctions.contains(function, isDistinct)) {
+                // If this function has been rewritten, we get the rewritten version from
+                // rewrittenAggregateFunctions.
+                rewrittenAggregateFunctions(function, isDistinct)
+              } else {
+                // Oterwise, we get it from aggregateFunctionMap, which contains unique
+                // aggregate functions that have not been rewritten.
+                aggregateFunctionMap(function, isDistinct)._1
+              }
+            aggregateFunction.asInstanceOf[AlgebraicAggregate].evaluateExpression
+          case expression =>
+            // We do not rely on the equality check at here since attributes may
+            // different cosmetically. Instead, we use semanticEquals.
+            groupExpressionMap.collectFirst {
+              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+            }.getOrElse(expression)
+        }.asInstanceOf[NamedExpression]
+      }
+
+      TungstenAggregate(
+        requiredChildDistributionExpressions = Some(namedGroupingAttributes),
+        groupingExpressions = namedGroupingAttributes,
+        nonCompleteAggregateExpressions = finalAggregateExpressions,
+        completeAggregateExpressions = completeAggregateExpressions,
+        initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
+        resultExpressions = rewrittenResultExpressions,
+        child = partialMergeAggregate)
+    } else {
+      val rewrittenResultExpressions = resultExpressions.map { expr =>
+        expr.transform {
+          case agg: AggregateExpression2 =>
+            aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2
+          case expression =>
+            // We do not rely on the equality check at here since attributes may
+            // different cosmetically. Instead, we use semanticEquals.
+            groupExpressionMap.collectFirst {
+              case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+            }.getOrElse(expression)
+        }.asInstanceOf[NamedExpression]
+      }
+      SortBasedAggregate(
         requiredChildDistributionExpressions = Some(namedGroupingAttributes),
         groupingExpressions = namedGroupingAttributes,
         nonCompleteAggregateExpressions = finalAggregateExpressions,
@@ -225,6 +352,7 @@ object Utils {
         initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length,
         resultExpressions = rewrittenResultExpressions,
         child = partialMergeAggregate)
+    }
 
     finalAndCompleteAggregate :: Nil
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index cef40dd..c64aa7a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -262,7 +262,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils
{
     val df = sql(sqlText)
     // First, check if we have GeneratedAggregate.
     val hasGeneratedAgg = df.queryExecution.executedPlan
-      .collect { case _: aggregate.Aggregate => true }
+      .collect { case _: aggregate.TungstenAggregate => true }
       .nonEmpty
     if (!hasGeneratedAgg) {
       fail(

http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 4b35c8f..7b5aa47 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -21,9 +21,9 @@ import org.apache.spark.sql.execution.aggregate
 import org.apache.spark.sql.hive.test.TestHive
 import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.sql.{SQLConf, AnalysisException, QueryTest, Row}
+import org.apache.spark.sql._
 import org.scalatest.BeforeAndAfterAll
-import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
+import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
 
 abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll
{
 
@@ -141,6 +141,22 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils
with Be
       Nil)
   }
 
+  test("null literal") {
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  AVG(null),
+          |  COUNT(null),
+          |  FIRST(null),
+          |  LAST(null),
+          |  MAX(null),
+          |  MIN(null),
+          |  SUM(null)
+        """.stripMargin),
+      Row(null, 0, null, null, null, null, null) :: Nil)
+  }
+
   test("only do grouping") {
     checkAnswer(
       sqlContext.sql(
@@ -266,13 +282,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils
with Be
           |SELECT avg(value) FROM agg1
         """.stripMargin),
       Row(11.125) :: Nil)
-
-    checkAnswer(
-      sqlContext.sql(
-        """
-          |SELECT avg(null)
-        """.stripMargin),
-      Row(null) :: Nil)
   }
 
   test("udaf") {
@@ -364,7 +373,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils
with Be
           |  max(distinct value1)
           |FROM agg2
         """.stripMargin),
-      Row(-60, 70.0, 101.0/9.0, 5.6, 100.0))
+      Row(-60, 70.0, 101.0/9.0, 5.6, 100))
 
     checkAnswer(
       sqlContext.sql(
@@ -402,6 +411,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils
with Be
         Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) ::
         Row(3, null, 3.0, null, null, null) ::
         Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil)
+
+    checkAnswer(
+      sqlContext.sql(
+        """
+          |SELECT
+          |  count(value1),
+          |  count(*),
+          |  count(1),
+          |  count(DISTINCT value1),
+          |  key
+          |FROM agg2
+          |GROUP BY key
+        """.stripMargin),
+      Row(3, 3, 3, 2, 1) ::
+        Row(3, 4, 4, 2, 2) ::
+        Row(0, 2, 2, 0, 3) ::
+        Row(3, 4, 4, 3, null) :: Nil)
   }
 
   test("test count") {
@@ -496,7 +522,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils
with Be
           |FROM agg1
           |GROUP BY key
         """.stripMargin).queryExecution.executedPlan.collect {
-        case agg: aggregate.Aggregate => agg
+        case agg: aggregate.SortBasedAggregate => agg
+        case agg: aggregate.TungstenAggregate => agg
       }
       val message =
         "We should fallback to the old aggregation code path if " +
@@ -537,3 +564,58 @@ class TungstenAggregationQuerySuite extends AggregationQuerySuite {
     sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
   }
 }
+
+class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite {
+
+  var originalUnsafeEnabled: Boolean = _
+
+  override def beforeAll(): Unit = {
+    originalUnsafeEnabled = sqlContext.conf.unsafeEnabled
+    sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, "true")
+    super.beforeAll()
+  }
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+    sqlContext.setConf(SQLConf.UNSAFE_ENABLED.key, originalUnsafeEnabled.toString)
+    sqlContext.conf.unsetConf("spark.sql.TungstenAggregate.testFallbackStartsAt")
+  }
+
+  override protected def checkAnswer(actual: DataFrame, expectedAnswer: Seq[Row]): Unit =
{
+    (0 to 2).foreach { fallbackStartsAt =>
+      sqlContext.setConf(
+        "spark.sql.TungstenAggregate.testFallbackStartsAt",
+        fallbackStartsAt.toString)
+
+      // Create a new df to make sure its physical operator picks up
+      // spark.sql.TungstenAggregate.testFallbackStartsAt.
+      val newActual = DataFrame(sqlContext, actual.logicalPlan)
+
+      QueryTest.checkAnswer(newActual, expectedAnswer) match {
+        case Some(errorMessage) =>
+          val newErrorMessage =
+            s"""
+              |The following aggregation query failed when using TungstenAggregate with
+              |controlled fallback (it falls back to sort-based aggregation once it has processed
+              |$fallbackStartsAt input rows). The query is
+              |${actual.queryExecution}
+              |
+              |$errorMessage
+            """.stripMargin
+
+          fail(newErrorMessage)
+        case None =>
+      }
+    }
+  }
+
+  // Override it to make sure we call the actually overridden checkAnswer.
+  override protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = {
+    checkAnswer(df, Seq(expectedAnswer))
+  }
+
+  // Override it to make sure we call the actually overridden checkAnswer.
+  override protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = {
+    checkAnswer(df, expectedAnswer.collect())
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message