spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-13320][SQL] Support Star in CreateStruct/CreateArray and Error Handling when DataFrame/DataSet Functions using Star
Date Tue, 22 Mar 2016 00:21:13 GMT
Repository: spark
Updated Branches:
  refs/heads/master b5f1ab701 -> 3f49e0766


[SPARK-13320][SQL] Support Star in CreateStruct/CreateArray and Error Handling when DataFrame/DataSet
Functions using Star

This PR resolves two issues:

First, expanding * inside aggregate functions of structs when using Dataframe/Dataset APIs.
For example,
```scala
structDf.groupBy($"a").agg(min(struct($"record.*")))
```

Second, it improves the error messages when having invalid star usage when using Dataframe/Dataset
APIs. For example,
```scala
pagecounts4PartitionsDS
  .map(line => (line._1, line._3))
  .toDF()
  .groupBy($"_1")
  .agg(sum("*") as "sumOccurances")
```
Before the fix, the invalid usage will issue a confusing error message, like:
```
org.apache.spark.sql.AnalysisException: cannot resolve '_1' given input columns _1, _2;
```
After the fix, the message is like:
```
org.apache.spark.sql.AnalysisException: Invalid usage of '*' in function 'sum'
```
cc: rxin nongli cloud-fan

Author: gatorsmile <gatorsmile@gmail.com>

Closes #11208 from gatorsmile/sumDataSetResolution.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3f49e076
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3f49e076
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3f49e076

Branch: refs/heads/master
Commit: 3f49e0766f3a369a44e14632de68c657773b7a27
Parents: b5f1ab7
Author: gatorsmile <gatorsmile@gmail.com>
Authored: Tue Mar 22 08:21:02 2016 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Tue Mar 22 08:21:02 2016 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 138 ++++++++++---------
 .../catalyst/analysis/AnalysisErrorSuite.scala  |   5 +
 .../org/apache/spark/sql/DataFrameSuite.scala   |  35 ++++-
 .../sql/hive/execution/SQLQuerySuite.scala      |  12 +-
 4 files changed, 113 insertions(+), 77 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3f49e076/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7a08c7d..5951a70 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -80,6 +80,7 @@ class Analyzer(
       EliminateUnions),
     Batch("Resolution", fixedPoint,
       ResolveRelations ::
+      ResolveStar ::
       ResolveReferences ::
       ResolveGroupingAnalytics ::
       ResolvePivot ::
@@ -373,28 +374,83 @@ class Analyzer(
   }
 
   /**
-   * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
-   * a logical plan node's children.
+   * Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's
output.
    */
-  object ResolveReferences extends Rule[LogicalPlan] {
+  object ResolveStar extends Rule[LogicalPlan] {
+
+    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+      case p: LogicalPlan if !p.childrenResolved => p
+
+      // If the projection list contains Stars, expand it.
+      case p: Project if containsStar(p.projectList) =>
+        val expanded = p.projectList.flatMap {
+          case s: Star => s.expand(p.child, resolver)
+          case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct,
_) =>
+            UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil
+          case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
+            a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil)
+              .asInstanceOf[Alias] :: Nil
+          case o => o :: Nil
+        }
+        Project(projectList = expanded, p.child)
+      // If the aggregate function argument contains Stars, expand it.
+      case a: Aggregate if containsStar(a.aggregateExpressions) =>
+        val expanded = a.aggregateExpressions.flatMap {
+          case s: Star => s.expand(a.child, resolver)
+          case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil
+          case o => o :: Nil
+        }.map(_.asInstanceOf[NamedExpression])
+        a.copy(aggregateExpressions = expanded)
+      // If the script transformation input contains Stars, expand it.
+      case t: ScriptTransformation if containsStar(t.input) =>
+        t.copy(
+          input = t.input.flatMap {
+            case s: Star => s.expand(t.child, resolver)
+            case o => o :: Nil
+          }
+        )
+      case g: Generate if containsStar(g.generator.children) =>
+        failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
+    }
+
+    /**
+     * Returns true if `exprs` contains a [[Star]].
+     */
+    def containsStar(exprs: Seq[Expression]): Boolean =
+      exprs.exists(_.collect { case _: Star => true }.nonEmpty)
+
     /**
-     * Foreach expression, expands the matching attribute.*'s in `child`'s input for the
subtree
-     * rooted at each expression.
+     * Expands the matching attribute.*'s in `child`'s output.
      */
-    def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression]
= {
-      exprs.flatMap {
-        case s: Star => s.expand(child, resolver)
-        case e =>
-          e.transformDown {
-            case f1: UnresolvedFunction if containsStar(f1.children) =>
-              f1.copy(children = f1.children.flatMap {
-                case s: Star => s.expand(child, resolver)
-                case o => o :: Nil
-              })
-          } :: Nil
+    def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
+      expr.transformUp {
+        case f1: UnresolvedFunction if containsStar(f1.children) =>
+          f1.copy(children = f1.children.flatMap {
+            case s: Star => s.expand(child, resolver)
+            case o => o :: Nil
+          })
+        case c: CreateStruct if containsStar(c.children) =>
+          c.copy(children = c.children.flatMap {
+            case s: Star => s.expand(child, resolver)
+            case o => o :: Nil
+          })
+        case c: CreateArray if containsStar(c.children) =>
+          c.copy(children = c.children.flatMap {
+            case s: Star => s.expand(child, resolver)
+            case o => o :: Nil
+          })
+        // count(*) has been replaced by count(1)
+        case o if containsStar(o.children) =>
+          failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
       }
     }
+  }
 
+  /**
+   * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
+   * a logical plan node's children.
+   */
+  object ResolveReferences extends Rule[LogicalPlan] {
     /**
      * Generate a new logical plan for the right child with different expression IDs
      * for all conflicting attributes.
@@ -456,48 +512,6 @@ class Analyzer(
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
       case p: LogicalPlan if !p.childrenResolved => p
 
-      // If the projection list contains Stars, expand it.
-      case p @ Project(projectList, child) if containsStar(projectList) =>
-        Project(
-          projectList.flatMap {
-            case s: Star => s.expand(child, resolver)
-            case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args)
=>
-              val newChildren = expandStarExpressions(args, child)
-              UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil
-            case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args)
=>
-              val newChildren = expandStarExpressions(args, child)
-              Alias(child = f.copy(children = newChildren), name)(
-                isGenerated = a.isGenerated) :: Nil
-            case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) =>
-              val expandedArgs = args.flatMap {
-                case s: Star => s.expand(child, resolver)
-                case o => o :: Nil
-              }
-              UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
-            case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) =>
-              val expandedArgs = args.flatMap {
-                case s: Star => s.expand(child, resolver)
-                case o => o :: Nil
-              }
-              UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
-            case o => o :: Nil
-          },
-          child)
-
-      case t: ScriptTransformation if containsStar(t.input) =>
-        t.copy(
-          input = t.input.flatMap {
-            case s: Star => s.expand(t.child, resolver)
-            case o => o :: Nil
-          }
-        )
-
-      // If the aggregate function argument contains Stars, expand it.
-      case a: Aggregate if containsStar(a.aggregateExpressions) =>
-        val expanded = expandStarExpressions(a.aggregateExpressions, a.child)
-            .map(_.asInstanceOf[NamedExpression])
-        a.copy(aggregateExpressions = expanded)
-
       // To resolve duplicate expression IDs for Join and Intersect
       case j @ Join(left, right, _, _) if !j.duplicateResolved =>
         j.copy(right = dedupRight(left, right))
@@ -592,12 +606,6 @@ class Analyzer(
     def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
       AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
     }
-
-    /**
-     * Returns true if `exprs` contains a [[Star]].
-     */
-    def containsStar(exprs: Seq[Expression]): Boolean =
-      exprs.exists(_.collect { case _: Star => true }.nonEmpty)
   }
 
   protected[sql] def resolveExpression(
@@ -923,8 +931,6 @@ class Analyzer(
    */
   object ResolveGenerate extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-      case g: Generate if ResolveReferences.containsStar(g.generator.children) =>
-        failAnalysis("Cannot explode *, explode can only be applied on a specific column.")
       case p: Generate if !p.child.resolved || !p.generator.resolved => p
       case g: Generate if !g.resolved =>
         g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))

http://git-wip-us.apache.org/repos/asf/spark/blob/3f49e076/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 1b29752..c87a2e2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -190,6 +190,11 @@ class AnalysisErrorSuite extends AnalysisTest {
     "cannot resolve" :: "havingCondition" :: Nil)
 
   errorTest(
+    "unresolved star expansion in max",
+    testRelation2.groupBy('a)(sum(UnresolvedStar(None))),
+    "Invalid usage of '*'" :: "in expression 'sum'" :: Nil)
+
+  errorTest(
     "bad casts",
     testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
   "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)

http://git-wip-us.apache.org/repos/asf/spark/blob/3f49e076/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index df2a287..d03597e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -166,22 +166,43 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     )
   }
 
-  test("SPARK-8930: explode should fail with a meaningful message if it takes a star") {
+  test("Star Expansion - CreateStruct and CreateArray") {
+    val structDf = testData2.select("a", "b").as("record")
+    // CreateStruct and CreateArray in aggregateExpressions
+    assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3,
1)))
+    assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1)))
+
+    // CreateStruct and CreateArray in project list (unresolved alias)
+    assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1)))
+    assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1))
+
+    // CreateStruct and CreateArray in project list (alias)
+    assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1)))
+    assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1,
1))
+  }
+
+  test("Star Expansion - explode should fail with a meaningful message if it takes a star")
{
     val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv")
     val e = intercept[AnalysisException] {
       df.explode($"*") { case Row(prefix: String, csv: String) =>
         csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq
       }.queryExecution.assertAnalyzed()
     }
-    assert(e.getMessage.contains(
-      "Cannot explode *, explode can only be applied on a specific column."))
+    assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF"))
 
-    df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) =>
-      csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq
-    }.queryExecution.assertAnalyzed()
+    checkAnswer(
+      df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) =>
+        csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq
+      },
+      Row("1", "1,2", "1:1") ::
+        Row("1", "1,2", "1:2") ::
+        Row("2", "4", "2:4") ::
+        Row("3", "7,8,9", "3:7") ::
+        Row("3", "7,8,9", "3:8") ::
+        Row("3", "7,8,9", "3:9") :: Nil)
   }
 
-  test("explode alias and star") {
+  test("Star Expansion - explode alias and star") {
     val df = Seq((Array("a"), 1)).toDF("a", "b")
 
     checkAnswer(

http://git-wip-us.apache.org/repos/asf/spark/blob/3f49e076/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 6bedead..2806b87 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -738,20 +738,24 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton
{
       .queryExecution.analyzed
   }
 
+  test("Star Expansion - script transform") {
+    val data = (1 to 100000).map { i => (i, i, i) }
+    data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
+    assert(100000 === sql("SELECT TRANSFORM (*) USING 'cat' FROM script_trans").count())
+  }
+
   test("test script transform for stdout") {
     val data = (1 to 100000).map { i => (i, i, i) }
     data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
     assert(100000 ===
-      sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans")
-        .queryExecution.toRdd.count())
+      sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans").count())
   }
 
   test("test script transform for stderr") {
     val data = (1 to 100000).map { i => (i, i, i) }
     data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
     assert(0 ===
-      sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans")
-        .queryExecution.toRdd.count())
+      sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans").count())
   }
 
   test("test script transform data type") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message