spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yh...@apache.org
Subject spark git commit: [SPARK-13749][SQL] Faster pivot implementation for many distinct values with two phase aggregation
Date Mon, 02 May 2016 18:13:08 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.0 eb7336a75 -> 08ae32e61


[SPARK-13749][SQL] Faster pivot implementation for many distinct values with two phase aggregation

## What changes were proposed in this pull request?

The existing implementation of pivot translates into a single aggregation with one aggregate
per distinct pivot value. When the number of distinct pivot values is large (say 1000+) this
can get extremely slow since each input value gets evaluated on every aggregate even though
it only affects the value of one of them.

I'm proposing an alternate strategy for when there are 10+ (somewhat arbitrary threshold)
distinct pivot values. We do two phases of aggregation. In the first we group by the grouping
columns plus the pivot column and perform the specified aggregations (one or sometimes more).
In the second aggregation we group by the grouping columns and use the new (non public) PivotFirst
aggregate that rearranges the outputs of the first aggregation into an array indexed by the
pivot value. Finally we do a project to extract the array entries into the appropriate output
column.

## How was this patch tested?

Additional unit tests in DataFramePivotSuite and manual larger scale testing.

Author: Andrew Ray <ray.andrew@gmail.com>

Closes #11583 from aray/fast-pivot.

(cherry picked from commit 99274418684ebae5b98d15b4686b95c1ac029e94)
Signed-off-by: Yin Huai <yhuai@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/08ae32e6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/08ae32e6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/08ae32e6

Branch: refs/heads/branch-2.0
Commit: 08ae32e6104e998b3c9a4822e563e63aeae55578
Parents: eb7336a
Author: Andrew Ray <ray.andrew@gmail.com>
Authored: Mon May 2 11:12:55 2016 -0700
Committer: Yin Huai <yhuai@databricks.com>
Committed: Mon May 2 11:13:04 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  85 +++++++----
 .../expressions/aggregate/PivotFirst.scala      | 152 +++++++++++++++++++
 .../apache/spark/sql/DataFramePivotSuite.scala  |  92 ++++++++++-
 3 files changed, 296 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/08ae32e6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e98036a..2f8ab3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -363,43 +363,68 @@ class Analyzer(
 
   object ResolvePivot extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-      case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p
+      case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved)
+        | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p
       case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
         val singleAgg = aggregates.size == 1
-        val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
-          def ifExpr(expr: Expression) = {
-            If(EqualTo(pivotColumn, value), expr, Literal(null))
+        def outputName(value: Literal, aggregate: Expression): String = {
+          if (singleAgg) value.toString else value + "_" + aggregate.sql
+        }
+        if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) {
+          // Since evaluating |pivotValues| if statements for each input row can get slow
this is an
+          // alternate plan that instead uses two steps of aggregation.
+          val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)())
+          val namedPivotCol = pivotColumn match {
+            case n: NamedExpression => n
+            case _ => Alias(pivotColumn, "__pivot_col")()
+          }
+          val bigGroup = groupByExprs :+ namedPivotCol
+          val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
+          val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow))
+          val pivotAggs = namedAggExps.map { a =>
+            Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues)
+              .toAggregateExpression()
+            , "__pivot_" + a.sql)()
+          }
+          val secondAgg = Aggregate(groupByExprs, groupByExprs ++ pivotAggs, firstAgg)
+          val pivotAggAttribute = pivotAggs.map(_.toAttribute)
+          val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) =>
+            aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) =>
+              Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))()
+            }
           }
-          aggregates.map { aggregate =>
-            val filteredAggregate = aggregate.transformDown {
-              // Assumption is the aggregate function ignores nulls. This is true for all
current
-              // AggregateFunction's with the exception of First and Last in their default
mode
-              // (which we handle) and possibly some Hive UDAF's.
-              case First(expr, _) =>
-                First(ifExpr(expr), Literal(true))
-              case Last(expr, _) =>
-                Last(ifExpr(expr), Literal(true))
-              case a: AggregateFunction =>
-                a.withNewChildren(a.children.map(ifExpr))
-            }.transform {
-              // We are duplicating aggregates that are now computing a different value for
each
-              // pivot value.
-              // TODO: Don't construct the physical container until after analysis.
-              case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
+          Project(groupByExprs ++ pivotOutputs, secondAgg)
+        } else {
+          val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
+            def ifExpr(expr: Expression) = {
+              If(EqualTo(pivotColumn, value), expr, Literal(null))
             }
-            if (filteredAggregate.fastEquals(aggregate)) {
-              throw new AnalysisException(
-                s"Aggregate expression required for pivot, found '$aggregate'")
+            aggregates.map { aggregate =>
+              val filteredAggregate = aggregate.transformDown {
+                // Assumption is the aggregate function ignores nulls. This is true for all
current
+                // AggregateFunction's with the exception of First and Last in their default
mode
+                // (which we handle) and possibly some Hive UDAF's.
+                case First(expr, _) =>
+                  First(ifExpr(expr), Literal(true))
+                case Last(expr, _) =>
+                  Last(ifExpr(expr), Literal(true))
+                case a: AggregateFunction =>
+                  a.withNewChildren(a.children.map(ifExpr))
+              }.transform {
+                // We are duplicating aggregates that are now computing a different value
for each
+                // pivot value.
+                // TODO: Don't construct the physical container until after analysis.
+                case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
+              }
+              if (filteredAggregate.fastEquals(aggregate)) {
+                throw new AnalysisException(
+                  s"Aggregate expression required for pivot, found '$aggregate'")
+              }
+              Alias(filteredAggregate, outputName(value, aggregate))()
             }
-            val name = if (singleAgg) value.toString else value + "_" + aggregate.sql
-            Alias(filteredAggregate, name)()
           }
+          Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
         }
-        val newGroupByExprs = groupByExprs.map {
-          case UnresolvedAlias(e, _) => e
-          case e => e
-        }
-        Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/08ae32e6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
new file mode 100644
index 0000000..9154e96
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import scala.collection.immutable.HashMap
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.types._
+
+object PivotFirst {
+
+  def supportsDataType(dataType: DataType): Boolean = updateFunction.isDefinedAt(dataType)
+
+  // Currently UnsafeRow does not support the generic update method (throws
+  // UnsupportedOperationException), so we need to explicitly support each DataType.
+  private val updateFunction: PartialFunction[DataType, (MutableRow, Int, Any) => Unit]
= {
+    case DoubleType =>
+      (row, offset, value) => row.setDouble(offset, value.asInstanceOf[Double])
+    case IntegerType =>
+      (row, offset, value) => row.setInt(offset, value.asInstanceOf[Int])
+    case LongType =>
+      (row, offset, value) => row.setLong(offset, value.asInstanceOf[Long])
+    case FloatType =>
+      (row, offset, value) => row.setFloat(offset, value.asInstanceOf[Float])
+    case BooleanType =>
+      (row, offset, value) => row.setBoolean(offset, value.asInstanceOf[Boolean])
+    case ShortType =>
+      (row, offset, value) => row.setShort(offset, value.asInstanceOf[Short])
+    case ByteType =>
+      (row, offset, value) => row.setByte(offset, value.asInstanceOf[Byte])
+    case d: DecimalType =>
+      (row, offset, value) => row.setDecimal(offset, value.asInstanceOf[Decimal], d.precision)
+  }
+}
+
+/**
+ * PivotFirst is a aggregate function used in the second phase of a two phase pivot to do
the
+ * required rearrangement of values into pivoted form.
+ *
+ * For example on an input of
+ * A | B
+ * --+--
+ * x | 1
+ * y | 2
+ * z | 3
+ *
+ * with pivotColumn=A, valueColumn=B, and pivotColumnValues=[z,y] the output is [3,2].
+ *
+ * @param pivotColumn column that determines which output position to put valueColumn in.
+ * @param valueColumn the column that is being rearranged.
+ * @param pivotColumnValues the list of pivotColumn values in the order of desired output.
Values
+ *                          not listed here will be ignored.
+ */
+case class PivotFirst(
+  pivotColumn: Expression,
+  valueColumn: Expression,
+  pivotColumnValues: Seq[Any],
+  mutableAggBufferOffset: Int = 0,
+  inputAggBufferOffset: Int = 0) extends ImperativeAggregate {
+
+  override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil
+
+  override lazy val inputTypes: Seq[AbstractDataType] = children.map(_.dataType)
+
+  override val nullable: Boolean = false
+
+  val valueDataType = valueColumn.dataType
+
+  override val dataType: DataType = ArrayType(valueDataType)
+
+  val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*)
+
+  val indexSize = pivotIndex.size
+
+  private val updateRow: (MutableRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType)
+
+  override def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit = {
+    val pivotColValue = pivotColumn.eval(inputRow)
+    if (pivotColValue != null) {
+      // We ignore rows whose pivot column value is not in the list of pivot column values.
+      val index = pivotIndex.getOrElse(pivotColValue, -1)
+      if (index >= 0) {
+        val value = valueColumn.eval(inputRow)
+        if (value != null) {
+          updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value)
+        }
+      }
+    }
+  }
+
+  override def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit = {
+    for (i <- 0 until indexSize) {
+      if (!inputAggBuffer.isNullAt(inputAggBufferOffset + i)) {
+        val value = inputAggBuffer.get(inputAggBufferOffset + i, valueDataType)
+        updateRow(mutableAggBuffer, mutableAggBufferOffset + i, value)
+      }
+    }
+  }
+
+  override def initialize(mutableAggBuffer: MutableRow): Unit = valueDataType match {
+    case d: DecimalType =>
+      // Per doc of setDecimal we need to do this instead of setNullAt for DecimalType.
+      for (i <- 0 until indexSize) {
+        mutableAggBuffer.setDecimal(mutableAggBufferOffset + i, null, d.precision)
+      }
+    case _ =>
+      for (i <- 0 until indexSize) {
+        mutableAggBuffer.setNullAt(mutableAggBufferOffset + i)
+      }
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val result = new Array[Any](indexSize)
+    for (i <- 0 until indexSize) {
+      result(i) = input.get(mutableAggBufferOffset + i, valueDataType)
+    }
+    new GenericArrayData(result)
+  }
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate
=
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate
=
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+
+  override lazy val aggBufferAttributes: Seq[AttributeReference] =
+    pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)())
+
+  override lazy val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
+
+  override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
+    aggBufferAttributes.map(_.newInstance())
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/08ae32e6/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index 368aa5c..b17284a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -17,14 +17,16 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.catalyst.expressions.aggregate.PivotFirst
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
 
 class DataFramePivotSuite extends QueryTest with SharedSQLContext{
   import testImplicits._
 
-  test("pivot courses with literals") {
+  test("pivot courses") {
     checkAnswer(
       courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
         .agg(sum($"earnings")),
@@ -32,14 +34,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
     )
   }
 
-  test("pivot year with literals") {
+  test("pivot year") {
     checkAnswer(
       courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
       Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
     )
   }
 
-  test("pivot courses with literals and multiple aggregations") {
+  test("pivot courses with multiple aggregations") {
     checkAnswer(
       courseSales.groupBy($"year")
         .pivot("course", Seq("dotNET", "Java"))
@@ -94,4 +96,88 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
       Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
     )
   }
+
+  // Tests for optimized pivot (with PivotFirst) below
+
+  test("optimized pivot planned") {
+    val df = courseSales.groupBy("year")
+      // pivot with extra columns to trigger optimization
+      .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
+      .agg(sum($"earnings"))
+    val queryExecution = sqlContext.executePlan(df.queryExecution.logical)
+    assert(queryExecution.simpleString.contains("pivotfirst"))
+  }
+
+
+  test("optimized pivot courses with literals") {
+    checkAnswer(
+      courseSales.groupBy("year")
+        // pivot with extra columns to trigger optimization
+        .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
+        .agg(sum($"earnings"))
+        .select("year", "dotNET", "Java"),
+      Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
+    )
+  }
+
+  test("optimized pivot year with literals") {
+    checkAnswer(
+      courseSales.groupBy($"course")
+        // pivot with extra columns to trigger optimization
+        .pivot("year", Seq(2012, 2013) ++ (1 to 10))
+        .agg(sum($"earnings"))
+        .select("course", "2012", "2013"),
+      Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+    )
+  }
+
+  test("optimized pivot year with string values (cast)") {
+    checkAnswer(
+      courseSales.groupBy("course")
+        // pivot with extra columns to trigger optimization
+        .pivot("year", Seq("2012", "2013") ++ (1 to 10).map(_.toString))
+        .sum("earnings")
+        .select("course", "2012", "2013"),
+      Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+    )
+  }
+
+  test("optimized pivot DecimalType") {
+    val df = courseSales.select($"course", $"year", $"earnings".cast(DecimalType(10, 2)))
+      .groupBy("year")
+      // pivot with extra columns to trigger optimization
+      .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
+      .agg(sum($"earnings"))
+      .select("year", "dotNET", "Java")
+
+    assertResult(IntegerType)(df.schema("year").dataType)
+    assertResult(DecimalType(20, 2))(df.schema("Java").dataType)
+    assertResult(DecimalType(20, 2))(df.schema("dotNET").dataType)
+
+    checkAnswer(df, Row(2012, BigDecimal(1500000, 2), BigDecimal(2000000, 2)) ::
+      Row(2013, BigDecimal(4800000, 2), BigDecimal(3000000, 2)) :: Nil)
+  }
+
+  test("PivotFirst supported datatypes") {
+    val supportedDataTypes: Seq[DataType] = DoubleType :: IntegerType :: LongType :: FloatType
::
+      BooleanType :: ShortType :: ByteType :: Nil
+    for (datatype <- supportedDataTypes) {
+      assertResult(true)(PivotFirst.supportsDataType(datatype))
+    }
+    assertResult(true)(PivotFirst.supportsDataType(DecimalType(10, 1)))
+    assertResult(false)(PivotFirst.supportsDataType(null))
+    assertResult(false)(PivotFirst.supportsDataType(ArrayType(IntegerType)))
+  }
+
+  test("optimized pivot with multiple aggregations") {
+    checkAnswer(
+      courseSales.groupBy($"year")
+        // pivot with extra columns to trigger optimization
+        .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
+        .agg(sum($"earnings"), avg($"earnings")),
+      Row(Seq(2012, 15000.0, 7500.0, 20000.0, 20000.0) ++ Seq.fill(20)(null): _*) ::
+        Row(Seq(2013, 48000.0, 48000.0, 30000.0, 30000.0) ++ Seq.fill(20)(null): _*) :: Nil
+    )
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message