spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-9634] [SPARK-9323] [SQL] cleanup unnecessary Aliases in LogicalPlan at the end of analysis
Date Sat, 15 Aug 2015 03:59:59 GMT
Repository: spark
Updated Branches:
  refs/heads/master 37586e544 -> ec29f2034


[SPARK-9634] [SPARK-9323] [SQL] cleanup unnecessary Aliases in LogicalPlan at the end of analysis

Also alias the ExtractValue instead of wrapping it with UnresolvedAlias when resolve attribute
in LogicalPlan, as this alias will be trimmed if it's unnecessary.

Based on #7957 without the changes to mllib, but instead maintaining earlier behavior when
using `withColumn` on expressions that already have metadata.

Author: Wenchen Fan <cloud0fan@outlook.com>
Author: Michael Armbrust <michael@databricks.com>

Closes #8215 from marmbrus/pr/7957.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ec29f203
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ec29f203
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ec29f203

Branch: refs/heads/master
Commit: ec29f2034a3306cc0afdc4c160b42c2eefa0897c
Parents: 37586e5
Author: Wenchen Fan <cloud0fan@outlook.com>
Authored: Fri Aug 14 20:59:54 2015 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Fri Aug 14 20:59:54 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 75 ++++++++++++++++----
 .../expressions/complexTypeCreator.scala        |  2 -
 .../sql/catalyst/optimizer/Optimizer.scala      |  9 ++-
 .../catalyst/plans/logical/LogicalPlan.scala    |  8 +--
 .../catalyst/plans/logical/basicOperators.scala |  2 +-
 .../sql/catalyst/analysis/AnalysisSuite.scala   | 17 +++++
 .../scala/org/apache/spark/sql/Column.scala     | 16 ++++-
 .../spark/sql/ColumnExpressionSuite.scala       |  9 +++
 .../org/apache/spark/sql/DataFrameSuite.scala   |  6 ++
 9 files changed, 120 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ec29f203/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a684dbc..4bc1c1a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -82,7 +82,9 @@ class Analyzer(
       HiveTypeCoercion.typeCoercionRules ++
       extendedResolutionRules : _*),
     Batch("Nondeterministic", Once,
-      PullOutNondeterministic)
+      PullOutNondeterministic),
+    Batch("Cleanup", fixedPoint,
+      CleanupAliases)
   )
 
   /**
@@ -146,8 +148,6 @@ class Analyzer(
           child match {
             case _: UnresolvedAttribute => u
             case ne: NamedExpression => ne
-            case g: GetStructField => Alias(g, g.field.name)()
-            case g: GetArrayStructFields => Alias(g, g.field.name)()
             case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g,
Nil)
             case e if !e.resolved => u
             case other => Alias(other, s"_c$i")()
@@ -384,9 +384,7 @@ class Analyzer(
           case u @ UnresolvedAttribute(nameParts) =>
             // Leave unchanged if resolution fails.  Hopefully will be resolved next round.
             val result =
-              withPosition(u) {
-                q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
-              }
+              withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
             logDebug(s"Resolving $u to $result")
             result
           case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
@@ -412,11 +410,6 @@ class Analyzer(
       exprs.exists(_.collect { case _: Star => true }.nonEmpty)
   }
 
-  private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
-    case UnresolvedAlias(child) => child
-    case other => other
-  }
-
   private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean)
= {
     ordering.map { order =>
       // Resolve SortOrder in one round.
@@ -426,7 +419,7 @@ class Analyzer(
       try {
         val newOrder = order transformUp {
           case u @ UnresolvedAttribute(nameParts) =>
-            plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
+            plan.resolve(nameParts, resolver).getOrElse(u)
           case UnresolvedExtractValue(child, fieldName) if child.resolved =>
             ExtractValue(child, fieldName, resolver)
         }
@@ -968,3 +961,61 @@ object EliminateSubQueries extends Rule[LogicalPlan] {
     case Subquery(_, child) => child
   }
 }
+
+/**
+ * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
+ * expression in Project(project list) or Aggregate(aggregate expressions) or
+ * Window(window expressions).
+ */
+object CleanupAliases extends Rule[LogicalPlan] {
+  private def trimAliases(e: Expression): Expression = {
+    var stop = false
+    e.transformDown {
+      // CreateStruct is a special case, we need to retain its top level Aliases as they
decide the
+      // name of StructField. We also need to stop transform down this expression, or the
Aliases
+      // under CreateStruct will be mistakenly trimmed.
+      case c: CreateStruct if !stop =>
+        stop = true
+        c.copy(children = c.children.map(trimNonTopLevelAliases))
+      case c: CreateStructUnsafe if !stop =>
+        stop = true
+        c.copy(children = c.children.map(trimNonTopLevelAliases))
+      case Alias(child, _) if !stop => child
+    }
+  }
+
+  def trimNonTopLevelAliases(e: Expression): Expression = e match {
+    case a: Alias =>
+      Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata)
+    case other => trimAliases(other)
+  }
+
+  override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+    case Project(projectList, child) =>
+      val cleanedProjectList =
+        projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
+      Project(cleanedProjectList, child)
+
+    case Aggregate(grouping, aggs, child) =>
+      val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
+      Aggregate(grouping.map(trimAliases), cleanedAggs, child)
+
+    case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
+      val cleanedWindowExprs =
+        windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression])
+      Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
+        orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
+
+    case other =>
+      var stop = false
+      other transformExpressionsDown {
+        case c: CreateStruct if !stop =>
+          stop = true
+          c.copy(children = c.children.map(trimNonTopLevelAliases))
+        case c: CreateStructUnsafe if !stop =>
+          stop = true
+          c.copy(children = c.children.map(trimNonTopLevelAliases))
+        case Alias(child, _) if !stop => child
+      }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ec29f203/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 4a071e6..298aee3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -75,8 +75,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
 
   override def foldable: Boolean = children.forall(_.foldable)
 
-  override lazy val resolved: Boolean = childrenResolved
-
   override lazy val dataType: StructType = {
     val fields = children.zipWithIndex.map { case (child, idx) =>
       child match {

http://git-wip-us.apache.org/repos/asf/spark/blob/ec29f203/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 4ab5ac2..47b06ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.optimizer
 
 import scala.collection.immutable.HashSet
-import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.FullOuter
@@ -260,8 +260,11 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
         val substitutedProjection = projectList1.map(_.transform {
           case a: Attribute => aliasMap.getOrElse(a, a)
         }).asInstanceOf[Seq[NamedExpression]]
-
-        Project(substitutedProjection, child)
+        // collapse 2 projects may introduce unnecessary Aliases, trim them here.
+        val cleanedProjection = substitutedProjection.map(p =>
+          CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
+        )
+        Project(cleanedProjection, child)
       }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ec29f203/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index c290e6a..9bb466a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -259,13 +259,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging
{
       // One match, but we also need to extract the requested nested field.
       case Seq((a, nestedFields)) =>
         // The foldLeft adds ExtractValues for every remaining parts of the identifier,
-        // and wrap it with UnresolvedAlias which will be removed later.
+        // and aliased it with the last part of the name.
         // For example, consider "a.b.c", where "a" is resolved to an existing attribute.
-        // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
-        // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
+        // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
+        // expression as "c".
         val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
           ExtractValue(expr, Literal(fieldName), resolver))
-        Some(UnresolvedAlias(fieldExprs))
+        Some(Alias(fieldExprs, nestedFields.last)())
 
       // No matches.
       case Seq() =>

http://git-wip-us.apache.org/repos/asf/spark/blob/ec29f203/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 7c40472..73b8261 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -228,7 +228,7 @@ case class Window(
     child: LogicalPlan) extends UnaryNode {
 
   override def output: Seq[Attribute] =
-    (projectList ++ windowExpressions).map(_.toAttribute)
+    projectList ++ windowExpressions.map(_.toAttribute)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/ec29f203/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index c944bc6..1e0cc81 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -119,4 +119,21 @@ class AnalysisSuite extends AnalysisTest {
           Project(testRelation.output :+ projected, testRelation)))
     checkAnalysis(plan, expected)
   }
+
+  test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") {
+    val a = testRelation.output.head
+    var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col"))
+    var expected = testRelation.select((a + 1 + 2).as("col"))
+    checkAnalysis(plan, expected)
+
+    plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col"))
+    expected = testRelation.groupBy(a)((min(a) + 1).as("col"))
+    checkAnalysis(plan, expected)
+
+    // CreateStruct is a special case that we should not trim Alias for it.
+    plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col"))
+    checkAnalysis(plan, plan)
+    plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col"))
+    checkAnalysis(plan, plan)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ec29f203/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 27bd084..807bc8c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -753,10 +753,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    *   df.select($"colA".as("colB"))
    * }}}
    *
+   * If the current column has metadata associated with it, this metadata will be propagated
+   * to the new column.  If this not desired, use `as` with explicitly empty metadata.
+   *
    * @group expr_ops
    * @since 1.3.0
    */
-  def as(alias: String): Column = Alias(expr, alias)()
+  def as(alias: String): Column = expr match {
+    case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata))
+    case other => Alias(other, alias)()
+  }
 
   /**
    * (Scala-specific) Assigns the given aliases to the results of a table generating function.
@@ -789,10 +795,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    *   df.select($"colA".as('colB))
    * }}}
    *
+   * If the current column has metadata associated with it, this metadata will be propagated
+   * to the new column.  If this not desired, use `as` with explicitly empty metadata.
+   *
    * @group expr_ops
    * @since 1.3.0
    */
-  def as(alias: Symbol): Column = Alias(expr, alias.name)()
+  def as(alias: Symbol): Column = expr match {
+    case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata))
+    case other => Alias(other, alias.name)()
+  }
 
   /**
    * Gives the column an alias with metadata.

http://git-wip-us.apache.org/repos/asf/spark/blob/ec29f203/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index ee74e3e..37738ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.catalyst.expressions.NamedExpression
 import org.scalatest.Matchers._
 
 import org.apache.spark.sql.execution.{Project, TungstenProject}
@@ -110,6 +111,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
     assert(df.select(df("a").alias("b")).columns.head === "b")
   }
 
+  test("as propagates metadata") {
+    val metadata = new MetadataBuilder
+    metadata.putString("key", "value")
+    val origCol = $"a".as("b", metadata.build())
+    val newCol = origCol.as("c")
+    assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
+  }
+
   test("single explode") {
     val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
     checkAnswer(

http://git-wip-us.apache.org/repos/asf/spark/blob/ec29f203/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 10bfa9b..cf22797 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -867,4 +867,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val actual = df.sort(rand(seed)).collect().map(_.getInt(0))
     assert(expected === actual)
   }
+
+  test("SPARK-9323: DataFrame.orderBy should support nested column name") {
+    val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+      """{"a": {"b": 1}}""" :: Nil))
+    checkAnswer(df.orderBy("a.b"), Row(Row(1)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message