spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yh...@apache.org
Subject spark git commit: [SPARK-11329][SQL] Support star expansion for structs.
Date Tue, 03 Nov 2015 04:32:15 GMT
Repository: spark
Updated Branches:
  refs/heads/master 2cef1bb0b -> 9cb5c731d


[SPARK-11329][SQL] Support star expansion for structs.

1. Supporting expanding structs in Projections. i.e.
  "SELECT s.*" where s is a struct type.
  This is fixed by allowing the expand function to handle structs in addition to tables.

2. Supporting expanding * inside aggregate functions of structs.
   "SELECT max(struct(col1, structCol.*))"
   This requires recursively expanding the expressions. In this case, it it the aggregate
   expression "max(...)" and we need to recursively expand its children inputs.

Author: Nong Li <nongli@gmail.com>

Closes #9343 from nongli/spark-11329.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9cb5c731
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9cb5c731
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9cb5c731

Branch: refs/heads/master
Commit: 9cb5c731dadff9539126362827a258d6b65754bb
Parents: 2cef1bb
Author: Nong Li <nongli@gmail.com>
Authored: Mon Nov 2 20:32:08 2015 -0800
Committer: Yin Huai <yhuai@databricks.com>
Committed: Mon Nov 2 20:32:08 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/SqlParser.scala   |   6 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  46 ++++---
 .../sql/catalyst/analysis/unresolved.scala      |  78 ++++++++---
 .../scala/org/apache/spark/sql/Column.scala     |   3 +-
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 133 +++++++++++++++++++
 .../org/apache/spark/sql/hive/HiveQl.scala      |   2 +-
 6 files changed, 230 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9cb5c731/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 0fef043..d7567e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -466,9 +466,9 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
 
   protected lazy val baseExpression: Parser[Expression] =
     ( "*" ^^^ UnresolvedStar(None)
-    | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) }
-    | primary
-    )
+    | (ident <~ "."). + <~ "*" ^^ { case target => { UnresolvedStar(Option(target))
}
+    } | primary
+   )
 
   protected lazy val signedPrimary: Parser[Expression] =
     sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e }

http://git-wip-us.apache.org/repos/asf/spark/blob/9cb5c731/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index beabacf..912c967 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -279,6 +279,24 @@ class Analyzer(
    * a logical plan node's children.
    */
   object ResolveReferences extends Rule[LogicalPlan] {
+    /**
+     * Foreach expression, expands the matching attribute.*'s in `child`'s input for the
subtree
+     * rooted at each expression.
+     */
+    def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression]
= {
+      exprs.flatMap {
+        case s: Star => s.expand(child, resolver)
+        case e =>
+          e.transformDown {
+            case f1: UnresolvedFunction if containsStar(f1.children) =>
+              f1.copy(children = f1.children.flatMap {
+                case s: Star => s.expand(child, resolver)
+                case o => o :: Nil
+              })
+          } :: Nil
+      }
+    }
+
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
       case p: LogicalPlan if !p.childrenResolved => p
 
@@ -286,44 +304,42 @@ class Analyzer(
       case p @ Project(projectList, child) if containsStar(projectList) =>
         Project(
           projectList.flatMap {
-            case s: Star => s.expand(child.output, resolver)
+            case s: Star => s.expand(child, resolver)
             case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args)
=>
-              val expandedArgs = args.flatMap {
-                case s: Star => s.expand(child.output, resolver)
-                case o => o :: Nil
-              }
-              UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil
+              val newChildren = expandStarExpressions(args, child)
+              UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil
+            case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
+              val newChildren = expandStarExpressions(args, child)
+              Alias(child = f.copy(children = newChildren), name)() :: Nil
             case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) =>
               val expandedArgs = args.flatMap {
-                case s: Star => s.expand(child.output, resolver)
+                case s: Star => s.expand(child, resolver)
                 case o => o :: Nil
               }
               UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
             case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) =>
               val expandedArgs = args.flatMap {
-                case s: Star => s.expand(child.output, resolver)
+                case s: Star => s.expand(child, resolver)
                 case o => o :: Nil
               }
               UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
             case o => o :: Nil
           },
           child)
+
       case t: ScriptTransformation if containsStar(t.input) =>
         t.copy(
           input = t.input.flatMap {
-            case s: Star => s.expand(t.child.output, resolver)
+            case s: Star => s.expand(t.child, resolver)
             case o => o :: Nil
           }
         )
 
       // If the aggregate function argument contains Stars, expand it.
       case a: Aggregate if containsStar(a.aggregateExpressions) =>
-        a.copy(
-          aggregateExpressions = a.aggregateExpressions.flatMap {
-            case s: Star => s.expand(a.child.output, resolver)
-            case o => o :: Nil
-          }
-        )
+        val expanded = expandStarExpressions(a.aggregateExpressions, a.child)
+            .map(_.asInstanceOf[NamedExpression])
+        a.copy(aggregateExpressions = expanded)
 
       // Special handling for cases when self-join introduce duplicate expression ids.
       case j @ Join(left, right, _, _) if !j.selfJoinResolved =>

http://git-wip-us.apache.org/repos/asf/spark/blob/9cb5c731/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index c973650..6975662 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -18,12 +18,12 @@
 package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.{TableIdentifier, errors}
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LeafNode
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode}
 import org.apache.spark.sql.catalyst.trees.TreeNode
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.catalyst.{TableIdentifier, errors}
+import org.apache.spark.sql.types.{DataType, StructType}
 
 /**
  * Thrown when an invalid attempt is made to access a property of a tree that has yet to
be fully
@@ -158,7 +158,7 @@ abstract class Star extends LeafExpression with NamedExpression {
   override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
   override lazy val resolved = false
 
-  def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression]
+  def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression]
 }
 
 
@@ -166,26 +166,68 @@ abstract class Star extends LeafExpression with NamedExpression {
  * Represents all of the input attributes to a given relational operator, for example in
  * "SELECT * FROM ...".
  *
- * @param table an optional table that should be the target of the expansion.  If omitted
all
- *              tables' columns are produced.
+ * This is also used to expand structs. For example:
+ * "SELECT record.* from (SELECT struct(a,b,c) as record ...)
+ *
+ * @param target an optional name that should be the target of the expansion.  If omitted
all
+ *              targets' columns are produced. This can either be a table name or struct
name. This
+ *              is a list of identifiers that is the path of the expansion.
  */
-case class UnresolvedStar(table: Option[String]) extends Star with Unevaluable {
+case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevaluable {
+
+  override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = {
 
-  override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] =
{
-    val expandedAttributes: Seq[Attribute] = table match {
+    // First try to expand assuming it is table.*.
+    val expandedAttributes: Seq[Attribute] = target match {
       // If there is no table specified, use all input attributes.
-      case None => input
+      case None => input.output
       // If there is a table, pick out attributes that are part of this table.
-      case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty)
+      case Some(t) => if (t.size == 1) {
+        input.output.filter(_.qualifiers.filter(resolver(_, t.head)).nonEmpty)
+      } else {
+        List()
+      }
     }
-    expandedAttributes.zip(input).map {
-      case (n: NamedExpression, _) => n
-      case (e, originalAttribute) =>
-        Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers)
+    if (!expandedAttributes.isEmpty) {
+      if (expandedAttributes.forall(_.isInstanceOf[NamedExpression])) {
+        return expandedAttributes
+      } else {
+        require(expandedAttributes.size == input.output.size)
+        expandedAttributes.zip(input.output).map {
+          case (e, originalAttribute) =>
+            Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers)
+        }
+      }
+      return expandedAttributes
+    }
+
+    require(target.isDefined)
+
+    // Try to resolve it as a struct expansion. If there is a conflict and both are possible,
+    // (i.e. [name].* is both a table and a struct), the struct path can always be qualified.
+    val attribute = input.resolve(target.get, resolver)
+    if (attribute.isDefined) {
+      // This target resolved to an attribute in child. It must be a struct. Expand it.
+      attribute.get.dataType match {
+        case s: StructType => {
+          s.fields.map( f => {
+            val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get)
+            Alias(extract, target.get + "." + f.name)()
+          })
+        }
+        case _ => {
+          throw new AnalysisException("Can only star expand struct data types. Attribute:
`" +
+            target.get + "`")
+        }
+      }
+    } else {
+      val from = input.inputSet.map(_.name).mkString(", ")
+      val targetString = target.get.mkString(".")
+      throw new AnalysisException(s"cannot resolve '$targetString.*' give input columns '$from'")
     }
   }
 
-  override def toString: String = table.map(_ + ".").getOrElse("") + "*"
+  override def toString: String = target.map(_ + ".").getOrElse("") + "*"
 }
 
 /**
@@ -225,7 +267,7 @@ case class MultiAlias(child: Expression, names: Seq[String])
  * @param expressions Expressions to expand.
  */
 case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable
{
-  override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] =
expressions
+  override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = expressions
   override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9cb5c731/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index e4f4cf1..3cde9d6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -60,7 +60,8 @@ class Column(protected[sql] val expr: Expression) extends Logging {
 
   def this(name: String) = this(name match {
     case "*" => UnresolvedStar(None)
-    case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length
- 2)))
+    case _ if name.endsWith(".*") => UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(
+      name.substring(0, name.length - 2))))
     case _ => UnresolvedAttribute.quotedString(name)
   })
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9cb5c731/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5413ef1..ee54bff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1932,4 +1932,137 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
       assert(sampled.count() == sampledOdd.count() + sampledEven.count())
     }
   }
+
+  test("Struct Star Expansion") {
+    val structDf = testData2.select("a", "b").as("record")
+
+    checkAnswer(
+      structDf.select($"record.a", $"record.b"),
+      Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
+
+    checkAnswer(
+      structDf.select($"record.*"),
+      Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
+
+    checkAnswer(
+      structDf.select($"record.*", $"record.*"),
+      Row(1, 1, 1, 1) :: Row(1, 2, 1, 2) :: Row(2, 1, 2, 1) :: Row(2, 2, 2, 2) ::
+        Row(3, 1, 3, 1) :: Row(3, 2, 3, 2) :: Nil)
+
+    checkAnswer(
+      sql("select struct(a, b) as r1, struct(b, a) as r2 from testData2").select($"r1.*",
$"r2.*"),
+      Row(1, 1, 1, 1) :: Row(1, 2, 2, 1) :: Row(2, 1, 1, 2) :: Row(2, 2, 2, 2) ::
+        Row(3, 1, 1, 3) :: Row(3, 2, 2, 3) :: Nil)
+
+    // Try with a registered table.
+    sql("select struct(a, b) as record from testData2").registerTempTable("structTable")
+    checkAnswer(sql("SELECT record.* FROM structTable"),
+      Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
+
+    checkAnswer(sql(
+      """
+        | SELECT min(struct(record.*)) FROM
+        |   (select struct(a,b) as record from testData2) tmp
+      """.stripMargin),
+      Row(Row(1, 1)) :: Nil)
+
+    // Try with an alias on the select list
+    checkAnswer(sql(
+      """
+        | SELECT max(struct(record.*)) as r FROM
+        |   (select struct(a,b) as record from testData2) tmp
+      """.stripMargin).select($"r.*"),
+      Row(3, 2) :: Nil)
+
+    // With GROUP BY
+    checkAnswer(sql(
+      """
+        | SELECT min(struct(record.*)) FROM
+        |   (select a as a, struct(a,b) as record from testData2) tmp
+        | GROUP BY a
+      """.stripMargin),
+      Row(Row(1, 1)) :: Row(Row(2, 1)) :: Row(Row(3, 1)) :: Nil)
+
+    // With GROUP BY and alias
+    checkAnswer(sql(
+      """
+        | SELECT max(struct(record.*)) as r FROM
+        |   (select a as a, struct(a,b) as record from testData2) tmp
+        | GROUP BY a
+      """.stripMargin).select($"r.*"),
+      Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil)
+
+    // With GROUP BY and alias and additional fields in the struct
+    checkAnswer(sql(
+      """
+        | SELECT max(struct(a, record.*, b)) as r FROM
+        |   (select a as a, b as b, struct(a,b) as record from testData2) tmp
+        | GROUP BY a
+      """.stripMargin).select($"r.*"),
+      Row(1, 1, 2, 2) :: Row(2, 2, 2, 2) :: Row(3, 3, 2, 2) :: Nil)
+
+    // Create a data set that contains nested structs.
+    val nestedStructData = sql(
+      """
+        | SELECT struct(r1, r2) as record FROM
+        |   (SELECT struct(a, b) as r1, struct(b, a) as r2 FROM testData2) tmp
+      """.stripMargin)
+
+    checkAnswer(nestedStructData.select($"record.*"),
+      Row(Row(1, 1), Row(1, 1)) :: Row(Row(1, 2), Row(2, 1)) :: Row(Row(2, 1), Row(1, 2))
::
+        Row(Row(2, 2), Row(2, 2)) :: Row(Row(3, 1), Row(1, 3)) :: Row(Row(3, 2), Row(2, 3))
:: Nil)
+    checkAnswer(nestedStructData.select($"record.r1"),
+      Row(Row(1, 1)) :: Row(Row(1, 2)) :: Row(Row(2, 1)) :: Row(Row(2, 2)) ::
+        Row(Row(3, 1)) :: Row(Row(3, 2)) :: Nil)
+    checkAnswer(
+      nestedStructData.select($"record.r1.*"),
+      Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
+
+    // Try with a registered table
+    nestedStructData.registerTempTable("nestedStructTable")
+    checkAnswer(sql("SELECT record.* FROM nestedStructTable"),
+      nestedStructData.select($"record.*"))
+    checkAnswer(sql("SELECT record.r1 FROM nestedStructTable"),
+      nestedStructData.select($"record.r1"))
+    checkAnswer(sql("SELECT record.r1.* FROM nestedStructTable"),
+      nestedStructData.select($"record.r1.*"))
+
+    // Create paths with unusual characters.
+    val specialCharacterPath = sql(
+      """
+        | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM
+        |   (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp
+      """.stripMargin)
+    specialCharacterPath.registerTempTable("specialCharacterTable")
+    checkAnswer(specialCharacterPath.select($"`r&&b.c`.*"),
+      nestedStructData.select($"record.*"))
+    checkAnswer(sql("SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"),
+      nestedStructData.select($"record.r1"))
+    checkAnswer(sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"),
+      nestedStructData.select($"record.r2"))
+    checkAnswer(sql("SELECT `r&&b.c`.`col$.a_`.* FROM specialCharacterTable"),
+      nestedStructData.select($"record.r1.*"))
+
+    // Try star expanding a scalar. This should fail.
+    assert(intercept[AnalysisException](sql("select a.* from testData2")).getMessage.contains(
+      "Can only star expand struct data types."))
+
+    // Try resolving something not there.
+    assert(intercept[AnalysisException](sql("SELECT abc.* FROM nestedStructTable"))
+      .getMessage.contains("cannot resolve"))
+  }
+
+
+  test("Struct Star Expansion - Name conflict") {
+    // Create a data set that contains a naming conflict
+    val nameConflict = sql("SELECT struct(a, b) as nameConflict, a as a FROM testData2")
+    nameConflict.registerTempTable("nameConflict")
+    // Unqualified should resolve to table.
+    checkAnswer(sql("SELECT nameConflict.* FROM nameConflict"),
+      Row(Row(1, 1), 1) :: Row(Row(1, 2), 1) :: Row(Row(2, 1), 2) :: Row(Row(2, 2), 2) ::
+        Row(Row(3, 1), 3) :: Row(Row(3, 2), 3) :: Nil)
+    // Qualify the struct type with the table name.
+    checkAnswer(sql("SELECT nameConflict.nameConflict.* FROM nameConflict"),
+      Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9cb5c731/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 3697761..ab88c1e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1505,7 +1505,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
     // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will
only
     // has a single child which is tableName.
     case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) =>
-      UnresolvedStar(Some(name))
+      UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name)))
 
     /* Aggregate Functions */
     case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message