spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yh...@apache.org
Subject spark git commit: [SPARK-9361] [SQL] Refactor new aggregation code to reduce the times of checking compatibility
Date Thu, 30 Jul 2015 17:32:22 GMT
Repository: spark
Updated Branches:
  refs/heads/master 7bbf02f0b -> 5363ed715


[SPARK-9361] [SQL] Refactor new aggregation code to reduce the times of checking compatibility

JIRA: https://issues.apache.org/jira/browse/SPARK-9361

Currently, we call `aggregate.Utils.tryConvert` in many places to check it the logical.Aggregate
can be run with new aggregation. But looks like `aggregate.Utils.tryConvert` will cost considerable
time to run. We should only call `tryConvert` once and keep it value in `logical.Aggregate`
and reuse it.

In `org.apache.spark.sql.execution.aggregate.Utils`, the codes involving with `tryConvert`
should be moved to catalyst because it actually doesn't deal with execution details.

Author: Liang-Chi Hsieh <viirya@appier.com>

Closes #7677 from viirya/refactor_aggregate and squashes the following commits:

babea30 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into refactor_aggregate
9a589d7 [Liang-Chi Hsieh] Fix scala style.
0a91329 [Liang-Chi Hsieh] Refactor new aggregation code to reduce the times to call tryConvert.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5363ed71
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5363ed71
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5363ed71

Branch: refs/heads/master
Commit: 5363ed71568c3e7c082146d654a9c669d692d894
Parents: 7bbf02f
Author: Liang-Chi Hsieh <viirya@appier.com>
Authored: Thu Jul 30 10:30:37 2015 -0700
Committer: Yin Huai <yhuai@databricks.com>
Committed: Thu Jul 30 10:32:12 2015 -0700

----------------------------------------------------------------------
 .../expressions/aggregate/interfaces.scala      |   4 +-
 .../catalyst/expressions/aggregate/utils.scala  | 167 +++++++++++++++++++
 .../catalyst/plans/logical/basicOperators.scala |   3 +
 .../spark/sql/execution/SparkStrategies.scala   |  34 ++--
 .../spark/sql/execution/aggregate/utils.scala   | 144 ----------------
 5 files changed, 188 insertions(+), 164 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5363ed71/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 9fb7623..d08f553 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -42,7 +42,7 @@ private[sql] case object Partial extends AggregateMode
 private[sql] case object PartialMerge extends AggregateMode
 
 /**
- * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers
  * containing intermediate results for this function and then generate final result.
  * This function updates the given aggregation buffer by merging multiple aggregation buffers.
  * When it has processed all input rows, the final result of this function is returned.
@@ -50,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode
 private[sql] case object Final extends AggregateMode
 
 /**
- * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly
+ * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly
  * from original input rows without any partial aggregation.
  * This function updates the given aggregation buffer with the original input of this
  * function. When it has processed all input rows, the final result of this function is returned.

http://git-wip-us.apache.org/repos/asf/spark/blob/5363ed71/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
new file mode 100644
index 0000000..4a43318
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
+
+/**
+ * Utility functions used by the query planner to convert our plan to new aggregation code
path.
+ */
+object Utils {
+  // Right now, we do not support complex types in the grouping key schema.
+  private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
+    val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
+      case array: ArrayType => true
+      case map: MapType => true
+      case struct: StructType => true
+      case _ => false
+    }
+
+    !hasComplexTypes
+  }
+
+  private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+    case p: Aggregate if supportsGroupingKeySchema(p) =>
+      val converted = p.transformExpressionsDown {
+        case expressions.Average(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Average(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Count(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Count(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        // We do not support multiple COUNT DISTINCT columns for now.
+        case expressions.CountDistinct(children) if children.length == 1 =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Count(children.head),
+            mode = aggregate.Complete,
+            isDistinct = true)
+
+        case expressions.First(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.First(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Last(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Last(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Max(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Max(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Min(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Min(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.Sum(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Sum(child),
+            mode = aggregate.Complete,
+            isDistinct = false)
+
+        case expressions.SumDistinct(child) =>
+          aggregate.AggregateExpression2(
+            aggregateFunction = aggregate.Sum(child),
+            mode = aggregate.Complete,
+            isDistinct = true)
+      }
+      // Check if there is any expressions.AggregateExpression1 left.
+      // If so, we cannot convert this plan.
+      val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
+        // For every expressions, check if it contains AggregateExpression1.
+        expr.find {
+          case agg: expressions.AggregateExpression1 => true
+          case other => false
+        }.isDefined
+      }
+
+      // Check if there are multiple distinct columns.
+      val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
+        expr.collect {
+          case agg: AggregateExpression2 => agg
+        }
+      }.toSet.toSeq
+      val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
+      val hasMultipleDistinctColumnSets =
+        if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length >
1) {
+          true
+        } else {
+          false
+        }
+
+      if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted)
else None
+
+    case other => None
+  }
+
+  def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
+    // If the plan cannot be converted, we will do a final round check to see if the original
+    // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If
so,
+    // we need to throw an exception.
+    val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
+      expr.collect {
+        case agg: AggregateExpression2 => agg.aggregateFunction
+      }
+    }.distinct
+    if (aggregateFunction2s.nonEmpty) {
+      // For functions implemented based on the new interface, prepare a list of function
names.
+      val invalidFunctions = {
+        if (aggregateFunction2s.length > 1) {
+          s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
+            s"and ${aggregateFunction2s.head.nodeName} are"
+        } else {
+          s"${aggregateFunction2s.head.nodeName} is"
+        }
+      }
+      val errorMessage =
+        s"${invalidFunctions} implemented based on the new Aggregate Function " +
+          s"interface and it cannot be used with functions implemented based on " +
+          s"the old Aggregate Function interface."
+      throw new AnalysisException(errorMessage)
+    }
+  }
+
+  def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
+    case p: Aggregate =>
+      val converted = doConvert(p)
+      if (converted.isDefined) {
+        converted
+      } else {
+        checkInvalidAggregateFunction2(p)
+        None
+      }
+    case other => None
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5363ed71/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index ad5af19..a67f8de 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.Utils
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.OpenHashSet
@@ -219,6 +220,8 @@ case class Aggregate(
     expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions
   }
 
+  lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this)
+
   override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5363ed71/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f3ef066..52a9b02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils}
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -193,11 +193,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
       case _ => Nil
     }
 
-    def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = {
-      aggregate.Utils.tryConvert(
-        plan,
-        sqlContext.conf.useSqlAggregate2,
-        sqlContext.conf.codegenEnabled).isDefined
+    def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match {
+      case a: logical.Aggregate =>
+        if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) {
+          a.newAggregation.isDefined
+        } else {
+          Utils.checkInvalidAggregateFunction2(a)
+          false
+        }
+      case _ => false
     }
 
     def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
@@ -217,12 +221,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
    */
   object Aggregation extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case p: logical.Aggregate =>
-        val converted =
-          aggregate.Utils.tryConvert(
-            p,
-            sqlContext.conf.useSqlAggregate2,
-            sqlContext.conf.codegenEnabled)
+      case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 &&
+          sqlContext.conf.codegenEnabled =>
+        val converted = p.newAggregation
         converted match {
           case None => Nil // Cannot convert to new aggregation code path.
           case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
@@ -377,17 +378,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
       case e @ logical.Expand(_, _, _, child) =>
         execution.Expand(e.projections, e.output, planLater(child)) :: Nil
       case a @ logical.Aggregate(group, agg, child) => {
-        val useNewAggregation =
-          aggregate.Utils.tryConvert(
-            a,
-            sqlContext.conf.useSqlAggregate2,
-            sqlContext.conf.codegenEnabled).isDefined
-        if (useNewAggregation) {
+        val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled
+        if (useNewAggregation && a.newAggregation.isDefined) {
           // If this logical.Aggregate can be planned to use new aggregation code path
           // (i.e. it can be planned by the Strategy Aggregation), we will not use the old
           // aggregation code path.
           Nil
         } else {
+          Utils.checkInvalidAggregateFunction2(a)
           execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
         }
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/5363ed71/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 6549c87..03635ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -29,150 +29,6 @@ import org.apache.spark.sql.types.{StructType, MapType, ArrayType}
  * Utility functions used by the query planner to convert our plan to new aggregation code
path.
  */
 object Utils {
-  // Right now, we do not support complex types in the grouping key schema.
-  private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = {
-    val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists {
-      case array: ArrayType => true
-      case map: MapType => true
-      case struct: StructType => true
-      case _ => false
-    }
-
-    !hasComplexTypes
-  }
-
-  private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match {
-    case p: Aggregate if supportsGroupingKeySchema(p) =>
-      val converted = p.transformExpressionsDown {
-        case expressions.Average(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Average(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Count(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Count(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        // We do not support multiple COUNT DISTINCT columns for now.
-        case expressions.CountDistinct(children) if children.length == 1 =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Count(children.head),
-            mode = aggregate.Complete,
-            isDistinct = true)
-
-        case expressions.First(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.First(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Last(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Last(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Max(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Max(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Min(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Min(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.Sum(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Sum(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
-        case expressions.SumDistinct(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Sum(child),
-            mode = aggregate.Complete,
-            isDistinct = true)
-      }
-      // Check if there is any expressions.AggregateExpression1 left.
-      // If so, we cannot convert this plan.
-      val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr =>
-        // For every expressions, check if it contains AggregateExpression1.
-        expr.find {
-          case agg: expressions.AggregateExpression1 => true
-          case other => false
-        }.isDefined
-      }
-
-      // Check if there are multiple distinct columns.
-      val aggregateExpressions = converted.aggregateExpressions.flatMap { expr =>
-        expr.collect {
-          case agg: AggregateExpression2 => agg
-        }
-      }.toSet.toSeq
-      val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct)
-      val hasMultipleDistinctColumnSets =
-        if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length >
1) {
-          true
-        } else {
-          false
-        }
-
-      if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted)
else None
-
-    case other => None
-  }
-
-  private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = {
-    // If the plan cannot be converted, we will do a final round check to if the original
-    // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If
so,
-    // we need to throw an exception.
-    val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr =>
-      expr.collect {
-        case agg: AggregateExpression2 => agg.aggregateFunction
-      }
-    }.distinct
-    if (aggregateFunction2s.nonEmpty) {
-      // For functions implemented based on the new interface, prepare a list of function
names.
-      val invalidFunctions = {
-        if (aggregateFunction2s.length > 1) {
-          s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " +
-            s"and ${aggregateFunction2s.head.nodeName} are"
-        } else {
-          s"${aggregateFunction2s.head.nodeName} is"
-        }
-      }
-      val errorMessage =
-        s"${invalidFunctions} implemented based on the new Aggregate Function " +
-          s"interface and it cannot be used with functions implemented based on " +
-          s"the old Aggregate Function interface."
-      throw new AnalysisException(errorMessage)
-    }
-  }
-
-  def tryConvert(
-      plan: LogicalPlan,
-      useNewAggregation: Boolean,
-      codeGenEnabled: Boolean): Option[Aggregate] = plan match {
-    case p: Aggregate if useNewAggregation && codeGenEnabled =>
-      val converted = tryConvert(p)
-      if (converted.isDefined) {
-        converted
-      } else {
-        checkInvalidAggregateFunction2(p)
-        None
-      }
-    case p: Aggregate =>
-      checkInvalidAggregateFunction2(p)
-      None
-    case other => None
-  }
-
   def planAggregateWithoutDistinct(
       groupingExpressions: Seq[Expression],
       aggregateExpressions: Seq[AggregateExpression2],


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message