spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yh...@apache.org
Subject spark git commit: [SPARK-12024][SQL] More efficient multi-column counting.
Date Sun, 29 Nov 2015 22:13:26 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 5601d8fd0 -> abd31515d


[SPARK-12024][SQL] More efficient multi-column counting.

In https://github.com/apache/spark/pull/9409 we enabled multi-column counting. The approach
taken in that PR introduces a bit of overhead by first creating a row only to check if all
of the columns are non-null.

This PR fixes that technical debt. Count now takes multiple columns as its input. In order
to make this work I have also added support for multiple columns in the single distinct code
path.

cc yhuai

Author: Herman van Hovell <hvanhovell@questtec.nl>

Closes #10015 from hvanhovell/SPARK-12024.

(cherry picked from commit 3d28081e53698ed77e93c04299957c02bcaba9bf)
Signed-off-by: Yin Huai <yhuai@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/abd31515
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/abd31515
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/abd31515

Branch: refs/heads/branch-1.6
Commit: abd31515dc15f6f28fc8d7f9d538a351d65a74b5
Parents: 5601d8f
Author: Herman van Hovell <hvanhovell@questtec.nl>
Authored: Sun Nov 29 14:13:11 2015 -0800
Committer: Yin Huai <yhuai@databricks.com>
Committed: Sun Nov 29 14:13:22 2015 -0800

----------------------------------------------------------------------
 .../catalyst/expressions/aggregate/Count.scala  | 21 ++---------
 .../expressions/conditionalExpressions.scala    | 27 --------------
 .../sql/catalyst/optimizer/Optimizer.scala      | 14 ++++---
 .../ConditionalExpressionSuite.scala            | 14 -------
 .../spark/sql/execution/aggregate/utils.scala   | 39 ++++++++++----------
 .../spark/sql/expressions/WindowSpec.scala      |  4 +-
 6 files changed, 33 insertions(+), 86 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/abd31515/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 09a1da9..441f52a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 
-case class Count(child: Expression) extends DeclarativeAggregate {
-  override def children: Seq[Expression] = child :: Nil
+case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
 
   override def nullable: Boolean = false
 
@@ -30,7 +29,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
   override def dataType: DataType = LongType
 
   // Expected input data type.
-  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+  override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
 
   private lazy val count = AttributeReference("count", LongType)()
 
@@ -41,7 +40,7 @@ case class Count(child: Expression) extends DeclarativeAggregate {
   )
 
   override lazy val updateExpressions = Seq(
-    /* count = */ If(IsNull(child), count, count + 1L)
+    /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L)
   )
 
   override lazy val mergeExpressions = Seq(
@@ -54,17 +53,5 @@ case class Count(child: Expression) extends DeclarativeAggregate {
 }
 
 object Count {
-  def apply(children: Seq[Expression]): Count = {
-    // This is used to deal with COUNT DISTINCT. When we have multiple
-    // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row).
-    // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any
-    // null in the arguments, we will not count that row. So, we use DropAnyNull at here
-    // to return a null when any field of the created STRUCT is null.
-    val child = if (children.size > 1) {
-      DropAnyNull(CreateStruct(children))
-    } else {
-      children.head
-    }
-    Count(child)
-  }
+  def apply(child: Expression): Count = Count(child :: Nil)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/abd31515/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 694a2a7..40b1eec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -426,30 +426,3 @@ case class Greatest(children: Seq[Expression]) extends Expression {
   }
 }
 
-/** Operator that drops a row when it contains any nulls. */
-case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes
{
-  override def nullable: Boolean = true
-  override def dataType: DataType = child.dataType
-  override def inputTypes: Seq[AbstractDataType] = Seq(StructType)
-
-  protected override def nullSafeEval(input: Any): InternalRow = {
-    val row = input.asInstanceOf[InternalRow]
-    if (row.anyNull) {
-      null
-    } else {
-      row
-    }
-  }
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    nullSafeCodeGen(ctx, ev, eval => {
-      s"""
-        if ($eval.anyNull()) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.value} = $eval;
-        }
-      """
-    })
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/abd31515/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 2901d8f..06d14fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -362,9 +362,14 @@ object LikeSimplification extends Rule[LogicalPlan] {
  * Null value propagation from bottom to top of the expression tree.
  */
 object NullPropagation extends Rule[LogicalPlan] {
+  def nonNullLiteral(e: Expression): Boolean = e match {
+    case Literal(null, _) => false
+    case _ => true
+  }
+
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case q: LogicalPlan => q transformExpressionsUp {
-      case e @ AggregateExpression(Count(Literal(null, _)), _, _) =>
+      case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
         Cast(Literal(0L), e.dataType)
       case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
       case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -377,16 +382,13 @@ object NullPropagation extends Rule[LogicalPlan] {
         Literal.create(null, e.dataType)
       case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
       case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
-      case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable =>
+      case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable)
=>
         // This rule should be only triggered when isDistinct field is false.
         AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
 
       // For Coalesce, remove null literals.
       case e @ Coalesce(children) =>
-        val newChildren = children.filter {
-          case Literal(null, _) => false
-          case _ => true
-        }
+        val newChildren = children.filter(nonNullLiteral)
         if (newChildren.length == 0) {
           Literal.create(null, e.dataType)
         } else if (newChildren.length == 1) {

http://git-wip-us.apache.org/repos/asf/spark/blob/abd31515/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index c1e3c17..0df673b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -231,18 +231,4 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
       checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2)
     }
   }
-
-  test("function dropAnyNull") {
-    val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1))))
-    val a = create_row("a", "q")
-    val nullStr: String = null
-    checkEvaluation(drop, a, a)
-    checkEvaluation(drop, null, create_row("b", nullStr))
-    checkEvaluation(drop, null, create_row(nullStr, nullStr))
-
-    val row = 'r.struct(
-      StructField("a", StringType, false),
-      StructField("b", StringType, true)).at(0)
-    checkEvaluation(DropAnyNull(row), null, create_row(null))
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/abd31515/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index a70e414..76b938c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -146,20 +146,16 @@ object Utils {
       aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
 
     // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more
than one
-    // DISTINCT aggregate function, all of those functions will have the same column expression.
+    // DISTINCT aggregate function, all of those functions will have the same column expressions.
     // For example, it would be valid for functionsWithDistinct to be
     // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT
foo)] is
     // disallowed because those two distinct aggregates have different column expressions.
-    val distinctColumnExpression: Expression = {
-      val allDistinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
-      assert(allDistinctColumnExpressions.length == 1)
-      allDistinctColumnExpressions.head
-    }
-    val namedDistinctColumnExpression: NamedExpression = distinctColumnExpression match {
+    val distinctColumnExpressions = functionsWithDistinct.head.aggregateFunction.children
+    val namedDistinctColumnExpressions = distinctColumnExpressions.map {
       case ne: NamedExpression => ne
       case other => Alias(other, other.toString)()
     }
-    val distinctColumnAttribute: Attribute = namedDistinctColumnExpression.toAttribute
+    val distinctColumnAttributes = namedDistinctColumnExpressions.map(_.toAttribute)
     val groupingAttributes = groupingExpressions.map(_.toAttribute)
 
     // 1. Create an Aggregate Operator for partial aggregations.
@@ -170,10 +166,11 @@ object Utils {
       // We will group by the original grouping expression, plus an additional expression
for the
       // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
       // expressions will be [key, value].
-      val partialAggregateGroupingExpressions = groupingExpressions :+ namedDistinctColumnExpression
+      val partialAggregateGroupingExpressions =
+        groupingExpressions ++ namedDistinctColumnExpressions
       val partialAggregateResult =
         groupingAttributes ++
-          Seq(distinctColumnAttribute) ++
+          distinctColumnAttributes ++
           partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
       if (usesTungstenAggregate) {
         TungstenAggregate(
@@ -208,28 +205,28 @@ object Utils {
         partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
       val partialMergeAggregateResult =
         groupingAttributes ++
-          Seq(distinctColumnAttribute) ++
+          distinctColumnAttributes ++
           partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
       if (usesTungstenAggregate) {
         TungstenAggregate(
           requiredChildDistributionExpressions = Some(groupingAttributes),
-          groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+          groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
           nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
           nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
           completeAggregateExpressions = Nil,
           completeAggregateAttributes = Nil,
-          initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+          initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
           resultExpressions = partialMergeAggregateResult,
           child = partialAggregate)
       } else {
         SortBasedAggregate(
           requiredChildDistributionExpressions = Some(groupingAttributes),
-          groupingExpressions = groupingAttributes :+ distinctColumnAttribute,
+          groupingExpressions = groupingAttributes ++ distinctColumnAttributes,
           nonCompleteAggregateExpressions = partialMergeAggregateExpressions,
           nonCompleteAggregateAttributes = partialMergeAggregateAttributes,
           completeAggregateExpressions = Nil,
           completeAggregateAttributes = Nil,
-          initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+          initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
           resultExpressions = partialMergeAggregateResult,
           child = partialAggregate)
       }
@@ -244,14 +241,16 @@ object Utils {
         expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
       }
 
+      val distinctColumnAttributeLookup =
+        distinctColumnExpressions.zip(distinctColumnAttributes).toMap
       val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map
{
         // Children of an AggregateFunction with DISTINCT keyword has already
         // been evaluated. At here, we need to replace original children
         // to AttributeReferences.
         case agg @ AggregateExpression(aggregateFunction, mode, true) =>
-          val rewrittenAggregateFunction = aggregateFunction.transformDown {
-            case expr if expr == distinctColumnExpression => distinctColumnAttribute
-          }.asInstanceOf[AggregateFunction]
+          val rewrittenAggregateFunction = aggregateFunction
+            .transformDown(distinctColumnAttributeLookup)
+            .asInstanceOf[AggregateFunction]
           // We rewrite the aggregate function to a non-distinct aggregation because
           // its input will have distinct arguments.
           // We just keep the isDistinct setting to true, so when users look at the query
plan,
@@ -270,7 +269,7 @@ object Utils {
           nonCompleteAggregateAttributes = finalAggregateAttributes,
           completeAggregateExpressions = completeAggregateExpressions,
           completeAggregateAttributes = completeAggregateAttributes,
-          initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+          initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
           resultExpressions = resultExpressions,
           child = partialMergeAggregate)
       } else {
@@ -281,7 +280,7 @@ object Utils {
           nonCompleteAggregateAttributes = finalAggregateAttributes,
           completeAggregateExpressions = completeAggregateExpressions,
           completeAggregateAttributes = completeAggregateAttributes,
-          initialInputBufferOffset = (groupingAttributes :+ distinctColumnAttribute).length,
+          initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length,
           resultExpressions = resultExpressions,
           child = partialMergeAggregate)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/abd31515/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index fc873c0..893e800 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -152,8 +152,8 @@ class WindowSpec private[sql](
           case Sum(child) => WindowExpression(
             UnresolvedWindowFunction("sum", child :: Nil),
             WindowSpecDefinition(partitionSpec, orderSpec, frame))
-          case Count(child) => WindowExpression(
-            UnresolvedWindowFunction("count", child :: Nil),
+          case Count(children) => WindowExpression(
+            UnresolvedWindowFunction("count", children),
             WindowSpecDefinition(partitionSpec, orderSpec, frame))
           case First(child, ignoreNulls) => WindowExpression(
             // TODO this is a hack for Hive UDAF first_value


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message