spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yh...@apache.org
Subject spark git commit: [SPARK-8992][SQL] Add pivot to dataframe api
Date Thu, 12 Nov 2015 00:23:43 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 4151afbf5 -> 5940fc71d


[SPARK-8992][SQL] Add pivot to dataframe api

This adds a pivot method to the dataframe api.

Following the lead of cube and rollup this adds a Pivot operator that is translated into an
Aggregate by the analyzer.

Currently the syntax is like:
~~courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings"))~~

~~Would we be interested in the following syntax also/alternatively? and~~

    courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
    //or
    courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings"))

Later we can add it to `SQLParser`, but as Hive doesn't support it we cant add it there, right?

~~Also what would be the suggested Java friendly method signature for this?~~

Author: Andrew Ray <ray.andrew@gmail.com>

Closes #7841 from aray/sql-pivot.

(cherry picked from commit b8ff6888e76b437287d7d6bf2d4b9c759710a195)
Signed-off-by: Yin Huai <yhuai@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5940fc71
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5940fc71
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5940fc71

Branch: refs/heads/branch-1.6
Commit: 5940fc71d2a245cc6e50edb455c3dd3dbb8de43a
Parents: 4151afb
Author: Andrew Ray <ray.andrew@gmail.com>
Authored: Wed Nov 11 16:23:24 2015 -0800
Committer: Yin Huai <yhuai@databricks.com>
Committed: Wed Nov 11 16:23:39 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  42 ++++++++
 .../catalyst/plans/logical/basicOperators.scala |  14 +++
 .../org/apache/spark/sql/GroupedData.scala      | 103 +++++++++++++++++--
 .../scala/org/apache/spark/sql/SQLConf.scala    |   7 ++
 .../apache/spark/sql/DataFramePivotSuite.scala  |  87 ++++++++++++++++
 .../org/apache/spark/sql/test/SQLTestData.scala |  12 +++
 6 files changed, 255 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5940fc71/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a9cd9a7..2f4670b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -72,6 +72,7 @@ class Analyzer(
       ResolveRelations ::
       ResolveReferences ::
       ResolveGroupingAnalytics ::
+      ResolvePivot ::
       ResolveSortReferences ::
       ResolveGenerate ::
       ResolveFunctions ::
@@ -166,6 +167,10 @@ class Analyzer(
       case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations)
=>
         g.withNewAggs(assignAliases(g.aggregations))
 
+      case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
+        if child.resolved && hasUnresolvedAlias(groupByExprs) =>
+        Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child)
+
       case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList)
=>
         Project(assignAliases(projectList), child)
     }
@@ -248,6 +253,43 @@ class Analyzer(
     }
   }
 
+  object ResolvePivot extends Rule[LogicalPlan] {
+    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+      case p: Pivot if !p.childrenResolved => p
+      case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
+        val singleAgg = aggregates.size == 1
+        val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
+          def ifExpr(expr: Expression) = {
+            If(EqualTo(pivotColumn, value), expr, Literal(null))
+          }
+          aggregates.map { aggregate =>
+            val filteredAggregate = aggregate.transformDown {
+              // Assumption is the aggregate function ignores nulls. This is true for all
current
+              // AggregateFunction's with the exception of First and Last in their default
mode
+              // (which we handle) and possibly some Hive UDAF's.
+              case First(expr, _) =>
+                First(ifExpr(expr), Literal(true))
+              case Last(expr, _) =>
+                Last(ifExpr(expr), Literal(true))
+              case a: AggregateFunction =>
+                a.withNewChildren(a.children.map(ifExpr))
+            }
+            if (filteredAggregate.fastEquals(aggregate)) {
+              throw new AnalysisException(
+                s"Aggregate expression required for pivot, found '$aggregate'")
+            }
+            val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString
+            Alias(filteredAggregate, name)()
+          }
+        }
+        val newGroupByExprs = groupByExprs.map {
+          case UnresolvedAlias(e) => e
+          case e => e
+        }
+        Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
+    }
+  }
+
   /**
    * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/5940fc71/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 597f03e..32b09b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -386,6 +386,20 @@ case class Rollup(
     this.copy(aggregations = aggs)
 }
 
+case class Pivot(
+    groupByExprs: Seq[NamedExpression],
+    pivotColumn: Expression,
+    pivotValues: Seq[Literal],
+    aggregates: Seq[Expression],
+    child: LogicalPlan) extends UnaryNode {
+  override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match
{
+    case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString,
agg.dataType)())
+    case _ => pivotValues.flatMap{ value =>
+      aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)())
+    }
+  }
+}
+
 case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5940fc71/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 5babf2c..63dd7fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute,
Star}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
-import org.apache.spark.sql.types.NumericType
+import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
+import org.apache.spark.sql.types.{StringType, NumericType}
 
 
 /**
@@ -50,14 +50,8 @@ class GroupedData protected[sql](
       aggExprs
     }
 
-    val aliasedAgg = aggregates.map {
-      // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute,
we
-      // will remove intermediate Alias for ExtractValue chain, and we need to alias it again
to
-      // make it a NamedExpression.
-      case u: UnresolvedAttribute => UnresolvedAlias(u)
-      case expr: NamedExpression => expr
-      case expr: Expression => Alias(expr, expr.prettyString)()
-    }
+    val aliasedAgg = aggregates.map(alias)
+
     groupType match {
       case GroupedData.GroupByType =>
         DataFrame(
@@ -68,9 +62,22 @@ class GroupedData protected[sql](
       case GroupedData.CubeType =>
         DataFrame(
           df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
+      case GroupedData.PivotType(pivotCol, values) =>
+        val aliasedGrps = groupingExprs.map(alias)
+        DataFrame(
+          df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
     }
   }
 
+  // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute,
we
+  // will remove intermediate Alias for ExtractValue chain, and we need to alias it again
to
+  // make it a NamedExpression.
+  private[this] def alias(expr: Expression): NamedExpression = expr match {
+    case u: UnresolvedAttribute => UnresolvedAlias(u)
+    case expr: NamedExpression => expr
+    case expr: Expression => Alias(expr, expr.prettyString)()
+  }
+
   private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
     : DataFrame = {
 
@@ -273,6 +280,77 @@ class GroupedData protected[sql](
   def sum(colNames: String*): DataFrame = {
     aggregateNumericColumns(colNames : _*)(Sum)
   }
+
+  /**
+    * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
+    * aggregation.
+    * {{{
+    *   // Compute the sum of earnings for each year by course with each course as a separate
column
+    *   df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
+    *   // Or without specifying column values
+    *   df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
+    * }}}
+    * @param pivotColumn Column to pivot
+    * @param values Optional list of values of pivotColumn that will be translated to columns
in the
+    *               output data frame. If values are not provided the method with do an immediate
+    *               call to .distinct() on the pivot column.
+    * @since 1.6.0
+    */
+  @scala.annotation.varargs
+  def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
+    case _: GroupedData.PivotType =>
+      throw new UnsupportedOperationException("repeated pivots are not supported")
+    case GroupedData.GroupByType =>
+      val pivotValues = if (values.nonEmpty) {
+        values.map {
+          case Column(literal: Literal) => literal
+          case other =>
+            throw new UnsupportedOperationException(
+              s"The values of a pivot must be literals, found $other")
+        }
+      } else {
+        // This is to prevent unintended OOM errors when the number of distinct values is
large
+        val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
+        // Get the distinct values of the column and sort them so its consistent
+        val values = df.select(pivotColumn)
+          .distinct()
+          .sort(pivotColumn)
+          .map(_.get(0))
+          .take(maxValues + 1)
+          .map(Literal(_)).toSeq
+        if (values.length > maxValues) {
+          throw new RuntimeException(
+            s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
+              "this could indicate an error. " +
+              "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\"
" +
+              s"to at least the number of distinct values of the pivot column.")
+        }
+        values
+      }
+      new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
+    case _ =>
+      throw new UnsupportedOperationException("pivot is only supported after a groupBy")
+  }
+
+  /**
+    * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
+    * {{{
+    *   // Compute the sum of earnings for each year by course with each course as a separate
column
+    *   df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
+    *   // Or without specifying column values
+    *   df.groupBy("year").pivot("course").sum("earnings")
+    * }}}
+    * @param pivotColumn Column to pivot
+    * @param values Optional list of values of pivotColumn that will be translated to columns
in the
+    *               output data frame. If values are not provided the method with do an immediate
+    *               call to .distinct() on the pivot column.
+    * @since 1.6.0
+    */
+  @scala.annotation.varargs
+  def pivot(pivotColumn: String, values: Any*): GroupedData = {
+    val resolvedPivotColumn = Column(df.resolve(pivotColumn))
+    pivot(resolvedPivotColumn, values.map(functions.lit): _*)
+  }
 }
 
 
@@ -307,4 +385,9 @@ private[sql] object GroupedData {
    * To indicate it's the ROLLUP
    */
   private[sql] object RollupType extends GroupType
+
+  /**
+    * To indicate it's the PIVOT
+    */
+  private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5940fc71/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index e02b502..41d28d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -437,6 +437,13 @@ private[spark] object SQLConf {
     defaultValue = Some(true),
     isPublic = false)
 
+  val DATAFRAME_PIVOT_MAX_VALUES = intConf(
+    "spark.sql.pivotMaxValues",
+    defaultValue = Some(10000),
+    doc = "When doing a pivot without specifying values for the pivot column this is the
maximum " +
+      "number of (distinct) values that will be collected without error."
+  )
+
   val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles",
     defaultValue = Some(true),
     isPublic = false,

http://git-wip-us.apache.org/repos/asf/spark/blob/5940fc71/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
new file mode 100644
index 0000000..0c23d14
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DataFramePivotSuite extends QueryTest with SharedSQLContext{
+  import testImplicits._
+
+  test("pivot courses with literals") {
+    checkAnswer(
+      courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+        .agg(sum($"earnings")),
+      Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
+    )
+  }
+
+  test("pivot year with literals") {
+    checkAnswer(
+      courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")),
+      Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+    )
+  }
+
+  test("pivot courses with literals and multiple aggregations") {
+    checkAnswer(
+      courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+        .agg(sum($"earnings"), avg($"earnings")),
+      Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
+        Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
+    )
+  }
+
+  test("pivot year with string values (cast)") {
+    checkAnswer(
+      courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"),
+      Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+    )
+  }
+
+  test("pivot year with int values") {
+    checkAnswer(
+      courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
+      Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+    )
+  }
+
+  test("pivot courses with no values") {
+    // Note Java comes before dotNet in sorted order
+    checkAnswer(
+      courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
+      Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
+    )
+  }
+
+  test("pivot year with no values") {
+    checkAnswer(
+      courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
+      Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+    )
+  }
+
+  test("pivot max values inforced") {
+    sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
+    intercept[RuntimeException](
+      courseSales.groupBy($"year").pivot($"course")
+    )
+    sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
+      SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5940fc71/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 520dea7..abad0d7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -242,6 +242,17 @@ private[sql] trait SQLTestData { self =>
     df
   }
 
+  protected lazy val courseSales: DataFrame = {
+    val df = sqlContext.sparkContext.parallelize(
+      CourseSales("dotNET", 2012, 10000) ::
+        CourseSales("Java", 2012, 20000) ::
+        CourseSales("dotNET", 2012, 5000) ::
+        CourseSales("dotNET", 2013, 48000) ::
+        CourseSales("Java", 2013, 30000) :: Nil).toDF()
+    df.registerTempTable("courseSales")
+    df
+  }
+
   /**
    * Initialize all test data such that all temp tables are properly registered.
    */
@@ -295,4 +306,5 @@ private[sql] object SQLTestData {
   case class Person(id: Int, name: String, age: Int)
   case class Salary(personId: Int, salary: Double)
   case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
+  case class CourseSales(course: String, year: Int, earnings: Double)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message