spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-7133] [SQL] Implement struct, array, and map field accessor
Date Fri, 08 May 2015 18:49:54 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.4 0b2c252d0 -> f8468c451


[SPARK-7133] [SQL] Implement struct, array, and map field accessor

It's the first step: generalize UnresolvedGetField to support all map, struct, and array
TODO: add `apply` in Scala and `__getitem__` in Python, and unify the `getItem` and `getField`
methods to one single API(or should we keep them for compatibility?).

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #5744 from cloud-fan/generalize and squashes the following commits:

715c589 [Wenchen Fan] address comments
7ea5b31 [Wenchen Fan] fix python test
4f0833a [Wenchen Fan] add python test
f515d69 [Wenchen Fan] add apply method and test cases
8df6199 [Wenchen Fan] fix python test
239730c [Wenchen Fan] fix test compile
2a70526 [Wenchen Fan] use _bin_op in dataframe.py
6bf72bc [Wenchen Fan] address comments
3f880c3 [Wenchen Fan] add java doc
ab35ab5 [Wenchen Fan] fix python test
b5961a9 [Wenchen Fan] fix style
c9d85f5 [Wenchen Fan] generalize UnresolvedGetField to support all map, struct, and array

(cherry picked from commit 2d05f325dc3c70349bd17ed399897f22d967c687)
Signed-off-by: Michael Armbrust <michael@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f8468c45
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f8468c45
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f8468c45

Branch: refs/heads/branch-1.4
Commit: f8468c4511caf72856a61a57473a3c948d632d09
Parents: 0b2c252
Author: Wenchen Fan <cloud0fan@outlook.com>
Authored: Fri May 8 11:49:38 2015 -0700
Committer: Michael Armbrust <michael@databricks.com>
Committed: Fri May 8 11:49:49 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 |  24 +--
 python/pyspark/sql/tests.py                     |   7 +
 .../apache/spark/sql/catalyst/SqlParser.scala   |   4 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   4 +-
 .../sql/catalyst/analysis/unresolved.scala      |  14 +-
 .../apache/spark/sql/catalyst/dsl/package.scala |   7 +-
 .../sql/catalyst/expressions/ExtractValue.scala | 206 +++++++++++++++++++
 .../sql/catalyst/expressions/complexTypes.scala | 131 ------------
 .../sql/catalyst/optimizer/Optimizer.scala      |   6 +-
 .../spark/sql/catalyst/planning/patterns.scala  |   2 +-
 .../catalyst/plans/logical/LogicalPlan.scala    |   3 +-
 .../expressions/ExpressionEvaluationSuite.scala |  69 +++++--
 .../optimizer/ConstantFoldingSuite.scala        |   8 +-
 .../scala/org/apache/spark/sql/Column.scala     |  19 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   |  10 +-
 .../org/apache/spark/sql/hive/HiveQl.scala      |   4 +-
 16 files changed, 327 insertions(+), 191 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index cee804f..a969799 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1275,7 +1275,7 @@ class Column(object):
 
     # container operators
     __contains__ = _bin_op("contains")
-    __getitem__ = _bin_op("getItem")
+    __getitem__ = _bin_op("apply")
 
     # bitwise operators
     bitwiseOR = _bin_op("bitwiseOR")
@@ -1308,19 +1308,19 @@ class Column(object):
         >>> from pyspark.sql import Row
         >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
         >>> df.select(df.r.getField("b")).show()
-        +---+
-        |r.b|
-        +---+
-        |  b|
-        +---+
+        +----+
+        |r[b]|
+        +----+
+        |   b|
+        +----+
         >>> df.select(df.r.a).show()
-        +---+
-        |r.a|
-        +---+
-        |  1|
-        +---+
+        +----+
+        |r[a]|
+        +----+
+        |   1|
+        +----+
         """
-        return Column(self._jc.getField(name))
+        return self[name]
 
     def __getattr__(self, item):
         if item.startswith("__"):

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 45dfedc..7e63f4d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -519,6 +519,13 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual("v", df.select(df.d["k"]).first()[0])
         self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
 
+    def test_field_accessor(self):
+        df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
+        self.assertEqual(1, df.select(df.l[0]).first()[0])
+        self.assertEqual(1, df.select(df.r["a"]).first()[0])
+        self.assertEqual("b", df.select(df.r["b"]).first()[0])
+        self.assertEqual("v", df.select(df.d["k"]).first()[0])
+
     def test_infer_long_type(self):
         longrow = [Row(f1='a', f2=100000000000000)]
         df = self.sc.parallelize(longrow).toDF()

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index b06bfb2..fc36b9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -375,9 +375,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
   protected lazy val primary: PackratParser[Expression] =
     ( literal
     | expression ~ ("[" ~> expression <~ "]") ^^
-      { case base ~ ordinal => GetItem(base, ordinal) }
+      { case base ~ ordinal => UnresolvedExtractValue(base, ordinal) }
     | (expression <~ ".") ~ ident ^^
-      { case base ~ fieldName => UnresolvedGetField(base, fieldName) }
+      { case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) }
     | cast
     | "(" ~> expression <~ ")"
     | function

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index bb7913e..ecbac57 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -348,8 +348,8 @@ class Analyzer(
               withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
             logDebug(s"Resolving $u to $result")
             result
-          case UnresolvedGetField(child, fieldName) if child.resolved =>
-            GetField(child, fieldName, resolver)
+          case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
+            ExtractValue(child, fieldExpr, resolver)
         }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index eb736ac..2999c2e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -184,7 +184,17 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star
{
   override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")")
 }
 
-case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression
{
+/**
+ * Extracts a value or values from an Expression
+ *
+ * @param child The expression to extract value from,
+ *              can be Map, Array, Struct or array of Structs.
+ * @param extraction The expression to describe the extraction,
+ *                   can be key of Map, index of Array, field name of Struct.
+ */
+case class UnresolvedExtractValue(child: Expression, extraction: Expression)
+  extends UnaryExpression {
+
   override def dataType: DataType = throw new UnresolvedException(this, "dataType")
   override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
   override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
@@ -193,5 +203,5 @@ case class UnresolvedGetField(child: Expression, fieldName: String) extends
Unar
   override def eval(input: Row = null): EvaluatedType =
     throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
 
-  override def toString: String = s"$child.$fieldName"
+  override def toString: String = s"$child[$extraction]"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index fa6cc7a..4c0d702 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.{TypeTag, typeTag}
 
-import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue,
UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -100,8 +100,9 @@ package object dsl {
     def isNull: Predicate = IsNull(expr)
     def isNotNull: Predicate = IsNotNull(expr)
 
-    def getItem(ordinal: Expression): Expression = GetItem(expr, ordinal)
-    def getField(fieldName: String): UnresolvedGetField = UnresolvedGetField(expr, fieldName)
+    def getItem(ordinal: Expression): UnresolvedExtractValue = UnresolvedExtractValue(expr,
ordinal)
+    def getField(fieldName: String): UnresolvedExtractValue =
+      UnresolvedExtractValue(expr, Literal(fieldName))
 
     def cast(to: DataType): Expression = Cast(expr, to)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
new file mode 100644
index 0000000..e05926c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.collection.Map
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.types._
+
+object ExtractValue {
+  /**
+   * Returns the resolved `ExtractValue`. It will return one kind of concrete `ExtractValue`,
+   * depend on the type of `child` and `extraction`.
+   *
+   *   `child`      |    `extraction`    |    concrete `ExtractValue`
+   * ----------------------------------------------------------------
+   *    Struct      |   Literal String   |        GetStructField
+   * Array[Struct]  |   Literal String   |     GetArrayStructFields
+   *    Array       |   Integral type    |         GetArrayItem
+   *     Map        |      Any type      |         GetMapValue
+   */
+  def apply(
+      child: Expression,
+      extraction: Expression,
+      resolver: Resolver): ExtractValue = {
+
+    (child.dataType, extraction) match {
+      case (StructType(fields), Literal(fieldName, StringType)) =>
+        val ordinal = findField(fields, fieldName.toString, resolver)
+        GetStructField(child, fields(ordinal), ordinal)
+      case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType))
=>
+        val ordinal = findField(fields, fieldName.toString, resolver)
+        GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
+      case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType]  =>
+        GetArrayItem(child, extraction)
+      case (_: MapType, _) =>
+        GetMapValue(child, extraction)
+      case (otherType, _) =>
+        val errorMsg = otherType match {
+          case StructType(_) | ArrayType(StructType(_), _) =>
+            s"Field name should be String Literal, but it's $extraction"
+          case _: ArrayType =>
+            s"Array index should be integral type, but it's ${extraction.dataType}"
+          case other =>
+            s"Can't extract value from $child"
+        }
+        throw new AnalysisException(errorMsg)
+    }
+  }
+
+  def unapply(g: ExtractValue): Option[(Expression, Expression)] = {
+    g match {
+      case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal))
+      case _ => Some((g.child, null))
+    }
+  }
+
+  /**
+   * Find the ordinal of StructField, report error if no desired field or over one
+   * desired fields are found.
+   */
+  private def findField(fields: Array[StructField], fieldName: String, resolver: Resolver):
Int = {
+    val checkField = (f: StructField) => resolver(f.name, fieldName)
+    val ordinal = fields.indexWhere(checkField)
+    if (ordinal == -1) {
+      throw new AnalysisException(
+        s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
+    } else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
+      throw new AnalysisException(
+        s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
+    } else {
+      ordinal
+    }
+  }
+}
+
+trait ExtractValue extends UnaryExpression {
+  self: Product =>
+
+  type EvaluatedType = Any
+}
+
+/**
+ * Returns the value of fields in the Struct `child`.
+ */
+case class GetStructField(child: Expression, field: StructField, ordinal: Int)
+  extends ExtractValue {
+
+  override def dataType: DataType = field.dataType
+  override def nullable: Boolean = child.nullable || field.nullable
+  override def foldable: Boolean = child.foldable
+  override def toString: String = s"$child.${field.name}"
+
+  override def eval(input: Row): Any = {
+    val baseValue = child.eval(input).asInstanceOf[Row]
+    if (baseValue == null) null else baseValue(ordinal)
+  }
+}
+
+/**
+ * Returns the array of value of fields in the Array of Struct `child`.
+ */
+case class GetArrayStructFields(
+    child: Expression,
+    field: StructField,
+    ordinal: Int,
+    containsNull: Boolean) extends ExtractValue {
+
+  override def dataType: DataType = ArrayType(field.dataType, containsNull)
+  override def nullable: Boolean = child.nullable
+  override def foldable: Boolean = child.foldable
+  override def toString: String = s"$child.${field.name}"
+
+  override def eval(input: Row): Any = {
+    val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
+    if (baseValue == null) null else {
+      baseValue.map { row =>
+        if (row == null) null else row(ordinal)
+      }
+    }
+  }
+}
+
+abstract class ExtractValueWithOrdinal extends ExtractValue {
+  self: Product =>
+
+  def ordinal: Expression
+
+  /** `Null` is returned for invalid ordinals. */
+  override def nullable: Boolean = true
+  override def foldable: Boolean = child.foldable && ordinal.foldable
+  override def toString: String = s"$child[$ordinal]"
+  override def children: Seq[Expression] = child :: ordinal :: Nil
+
+  override def eval(input: Row): Any = {
+    val value = child.eval(input)
+    if (value == null) {
+      null
+    } else {
+      val o = ordinal.eval(input)
+      if (o == null) {
+        null
+      } else {
+        evalNotNull(value, o)
+      }
+    }
+  }
+
+  protected def evalNotNull(value: Any, ordinal: Any): Any
+}
+
+/**
+ * Returns the field at `ordinal` in the Array `child`
+ */
+case class GetArrayItem(child: Expression, ordinal: Expression)
+  extends ExtractValueWithOrdinal {
+
+  override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
+
+  override lazy val resolved = childrenResolved &&
+    child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]
+
+  protected def evalNotNull(value: Any, ordinal: Any) = {
+    // TODO: consider using Array[_] for ArrayType child to avoid
+    // boxing of primitives
+    val baseValue = value.asInstanceOf[Seq[_]]
+    val index = ordinal.asInstanceOf[Int]
+    if (index >= baseValue.size || index < 0) {
+      null
+    } else {
+      baseValue(index)
+    }
+  }
+}
+
+/**
+ * Returns the value of key `ordinal` in Map `child`
+ */
+case class GetMapValue(child: Expression, ordinal: Expression)
+  extends ExtractValueWithOrdinal {
+
+  override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
+
+  override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]
+
+  protected def evalNotNull(value: Any, ordinal: Any) = {
+    val baseValue = value.asInstanceOf[Map[Any, _]]
+    baseValue.get(ordinal).orNull
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index fc1f696..956a242 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -17,139 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import scala.collection.Map
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.Resolver
 import org.apache.spark.sql.types._
 
-/**
- * Returns the item at `ordinal` in the Array `child` or the Key `ordinal` in Map `child`.
- */
-case class GetItem(child: Expression, ordinal: Expression) extends Expression {
-  type EvaluatedType = Any
-
-  val children: Seq[Expression] = child :: ordinal :: Nil
-  /** `Null` is returned for invalid ordinals. */
-  override def nullable: Boolean = true
-  override def foldable: Boolean = child.foldable && ordinal.foldable
-
-  override def dataType: DataType = child.dataType match {
-    case ArrayType(dt, _) => dt
-    case MapType(_, vt, _) => vt
-  }
-  override lazy val resolved =
-    childrenResolved &&
-    (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
-
-  override def toString: String = s"$child[$ordinal]"
-
-  override def eval(input: Row): Any = {
-    val value = child.eval(input)
-    if (value == null) {
-      null
-    } else {
-      val key = ordinal.eval(input)
-      if (key == null) {
-        null
-      } else {
-        if (child.dataType.isInstanceOf[ArrayType]) {
-          // TODO: consider using Array[_] for ArrayType child to avoid
-          // boxing of primitives
-          val baseValue = value.asInstanceOf[Seq[_]]
-          val o = key.asInstanceOf[Int]
-          if (o >= baseValue.size || o < 0) {
-            null
-          } else {
-            baseValue(o)
-          }
-        } else {
-          val baseValue = value.asInstanceOf[Map[Any, _]]
-          baseValue.get(key).orNull
-        }
-      }
-    }
-  }
-}
-
-
-trait GetField extends UnaryExpression {
-  self: Product =>
-
-  type EvaluatedType = Any
-  override def foldable: Boolean = child.foldable
-  override def toString: String = s"$child.${field.name}"
-
-  def field: StructField
-}
-
-object GetField {
-  /**
-   * Returns the resolved `GetField`, and report error if no desired field or over one
-   * desired fields are found.
-   */
-  def apply(
-      expr: Expression,
-      fieldName: String,
-      resolver: Resolver): GetField = {
-    def findField(fields: Array[StructField]): Int = {
-      val checkField = (f: StructField) => resolver(f.name, fieldName)
-      val ordinal = fields.indexWhere(checkField)
-      if (ordinal == -1) {
-        throw new AnalysisException(
-          s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
-      } else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
-        throw new AnalysisException(
-          s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
-      } else {
-        ordinal
-      }
-    }
-    expr.dataType match {
-      case StructType(fields) =>
-        val ordinal = findField(fields)
-        StructGetField(expr, fields(ordinal), ordinal)
-      case ArrayType(StructType(fields), containsNull) =>
-        val ordinal = findField(fields)
-        ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
-      case otherType =>
-        throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
-    }
-  }
-}
-
-/**
- * Returns the value of fields in the Struct `child`.
- */
-case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField
{
-
-  override def dataType: DataType = field.dataType
-  override def nullable: Boolean = child.nullable || field.nullable
-
-  override def eval(input: Row): Any = {
-    val baseValue = child.eval(input).asInstanceOf[Row]
-    if (baseValue == null) null else baseValue(ordinal)
-  }
-}
-
-/**
- * Returns the array of value of fields in the Array of Struct `child`.
- */
-case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull:
Boolean)
-  extends GetField {
-
-  override def dataType: DataType = ArrayType(field.dataType, containsNull)
-  override def nullable: Boolean = child.nullable
-
-  override def eval(input: Row): Any = {
-    val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
-    if (baseValue == null) null else {
-      baseValue.map { row =>
-        if (row == null) null else row(ordinal)
-      }
-    }
-  }
-}
 
 /**
  * Returns an Array containing the evaluation of all children expressions.

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e4a60f5..d7b2f20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -227,10 +227,8 @@ object NullPropagation extends Rule[LogicalPlan] {
       case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
       case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
       case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
-      case e @ GetItem(Literal(null, _), _) => Literal.create(null, e.dataType)
-      case e @ GetItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
-      case e @ StructGetField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
-      case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal.create(null, e.dataType)
+      case e @ ExtractValue(Literal(null, _), _) => Literal.create(null, e.dataType)
+      case e @ ExtractValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
       case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
       case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
       case e @ Count(expr) if !expr.nullable => Count(Literal(1))

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 4574934..cd54d04 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -160,7 +160,7 @@ object PartialAggregation {
             // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
             // (Should we just turn `GetField` into a `NamedExpression`?)
             namedGroupingExpressions
-              .get(e.transform { case Alias(g: GetField, _) => g })
+              .get(e.transform { case Alias(g: ExtractValue, _) => g })
               .map(_.toAttribute)
               .getOrElse(e)
         }).asInstanceOf[Seq[NamedExpression]]

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index ae4620a..dbb12d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -209,7 +209,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging
{
           // For example, consider "a.b.c", where "a" is resolved to an existing attribute.
           // Then this will add GetField("c", GetField("b", a)), and alias
           // the final expression as "c".
-          val fieldExprs = nestedFields.foldLeft(a: Expression)(GetField(_, _, resolver))
+          val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
+            ExtractValue(expr, Literal(fieldName), resolver))
           val aliasName = nestedFields.last
           Some(Alias(fieldExprs, aliasName)())
         } catch {

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 88d36d1..04fd261 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.FunSuite
 import org.scalatest.Matchers._
 
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
+import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.mathfuncs._
 import org.apache.spark.sql.types._
@@ -880,7 +880,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite
{
     val row = create_row(
       "^Ba*n",                                // 0
       null.asInstanceOf[UTF8String],          // 1
-      create_row("aa", "bb"),     // 2
+      create_row("aa", "bb"),                 // 2
       Map("aa"->"bb"),                        // 3
       Seq("aa", "bb")                         // 4
     )
@@ -891,54 +891,79 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite
{
     val typeMap = MapType(StringType, StringType)
     val typeArray = ArrayType(StringType)
 
-    checkEvaluation(GetItem(BoundReference(3, typeMap, true),
+    checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
       Literal("aa")), "bb", row)
-    checkEvaluation(GetItem(Literal.create(null, typeMap), Literal("aa")), null, row)
+    checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row)
     checkEvaluation(
-      GetItem(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row)
-    checkEvaluation(GetItem(BoundReference(3, typeMap, true),
+      GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null,
row)
+    checkEvaluation(GetMapValue(BoundReference(3, typeMap, true),
       Literal.create(null, StringType)), null, row)
 
-    checkEvaluation(GetItem(BoundReference(4, typeArray, true),
+    checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
       Literal(1)), "bb", row)
-    checkEvaluation(GetItem(Literal.create(null, typeArray), Literal(1)), null, row)
+    checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row)
     checkEvaluation(
-      GetItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null,
row)
-    checkEvaluation(GetItem(BoundReference(4, typeArray, true),
+      GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null,
row)
+    checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true),
       Literal.create(null, IntegerType)), null, row)
 
-    def quickBuildGetField(expr: Expression, fieldName: String): StructGetField = {
+    def getStructField(expr: Expression, fieldName: String): ExtractValue = {
       expr.dataType match {
         case StructType(fields) =>
           val field = fields.find(_.name == fieldName).get
-          StructGetField(expr, field, fields.indexOf(field))
+          GetStructField(expr, field, fields.indexOf(field))
       }
     }
 
-    def quickResolve(u: UnresolvedGetField): StructGetField = {
-      quickBuildGetField(u.child, u.fieldName)
+    def quickResolve(u: UnresolvedExtractValue): ExtractValue = {
+      ExtractValue(u.child, u.extraction, _ == _)
     }
 
-    checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa",
row)
-    checkEvaluation(quickBuildGetField(Literal.create(null, typeS), "a"), null, row)
+    checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa",
row)
+    checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row)
 
     val typeS_notNullable = StructType(
       StructField("a", StringType, nullable = false)
         :: StructField("b", StringType, nullable = false) :: Nil
     )
 
-    assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable ===
true)
-    assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable
+    assert(getStructField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
+    assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable
       === false)
 
-    assert(quickBuildGetField(Literal.create(null, typeS), "a").nullable === true)
-    assert(quickBuildGetField(Literal.create(null, typeS_notNullable), "a").nullable ===
true)
+    assert(getStructField(Literal.create(null, typeS), "a").nullable === true)
+    assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true)
 
-    checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
-    checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
+    checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row)
+    checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb",
row)
     checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row)
   }
 
+  test("error message of ExtractValue") {
+    val structType = StructType(StructField("a", StringType, true) :: Nil)
+    val arrayStructType = ArrayType(structType)
+    val arrayType = ArrayType(StringType)
+    val otherType = StringType
+
+    def checkErrorMessage(
+        childDataType: DataType,
+        fieldDataType: DataType,
+        errorMesage: String): Unit = {
+      val e = intercept[org.apache.spark.sql.AnalysisException] {
+        ExtractValue(
+          Literal.create(null, childDataType),
+          Literal.create(null, fieldDataType),
+          _ == _)
+      }
+      assert(e.getMessage().contains(errorMesage))
+    }
+
+    checkErrorMessage(structType, IntegerType, "Field name should be String Literal")
+    checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal")
+    checkErrorMessage(arrayType, StringType, "Array index should be integral type")
+    checkErrorMessage(otherType, StringType, "Can't extract value from")
+  }
+
   test("arithmetic") {
     val row = create_row(1, 2, 3, null)
     val c1 = 'a.int.at(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index 18f9215..6b7d9a8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateSubQueries}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, EliminateSubQueries}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -180,10 +180,10 @@ class ConstantFoldingSuite extends PlanTest {
       IsNull(Literal(null)) as 'c1,
       IsNotNull(Literal(null)) as 'c2,
 
-      GetItem(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3,
-      GetItem(
+      UnresolvedExtractValue(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3,
+      UnresolvedExtractValue(
         Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType))
as 'c4,
-      UnresolvedGetField(
+      UnresolvedExtractValue(
         Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))),
         "a") as 'c5,
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 8bbe11b..e6e475b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental
 import org.apache.spark.Logging
 import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue}
 import org.apache.spark.sql.types._
 
 
@@ -68,6 +68,19 @@ class Column(protected[sql] val expr: Expression) extends Logging {
   override def hashCode: Int = this.expr.hashCode
 
   /**
+   * Extracts a value or values from a complex type.
+   * The following types of extraction are supported:
+   * - Given an Array, an integer ordinal can be used to retrieve a single value.
+   * - Given a Map, a key of the correct type can be used to retrieve an individual value.
+   * - Given a Struct, a string fieldName can be used to extract that field.
+   * - Given an Array of Structs, a string fieldName can be used to extract filed
+   *   of every struct in that array, and return an Array of fields
+   *
+   * @group expr_ops
+   */
+  def apply(field: Any): Column = UnresolvedExtractValue(expr, Literal(field))
+
+  /**
    * Unary minus, i.e. negate the expression.
    * {{{
    *   // Scala: select the amount column and negates all values.
@@ -529,14 +542,14 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    *
    * @group expr_ops
    */
-  def getItem(key: Any): Column = GetItem(expr, Literal(key))
+  def getItem(key: Any): Column = UnresolvedExtractValue(expr, Literal(key))
 
   /**
    * An expression that gets a field by name in a [[StructType]].
    *
    * @group expr_ops
    */
-  def getField(fieldName: String): Column = UnresolvedGetField(expr, fieldName)
+  def getField(fieldName: String): Column = UnresolvedExtractValue(expr, Literal(fieldName))
 
   /**
    * An expression that returns a substring.

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 1515e9b..d2ca8dc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -449,7 +449,7 @@ class DataFrameSuite extends QueryTest {
       testData.collect().map { case Row(key: Int, value: String) =>
         Row(key, value, key + 1)
       }.toSeq)
-    assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol"))
+    assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
   }
 
   test("replace column using withColumn") {
@@ -484,7 +484,7 @@ class DataFrameSuite extends QueryTest {
       testData.collect().map { case Row(key: Int, value: String) =>
         Row(key, value, key + 1)
       }.toSeq)
-    assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol"))
+    assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
   }
 
   test("randomSplit") {
@@ -593,4 +593,10 @@ class DataFrameSuite extends QueryTest {
       Row(new java.math.BigDecimal(2.0)))
     TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
   }
+
+  test("SPARK-7133: Implement struct, array, and map field accessor") {
+    assert(complexData.filter(complexData("a")(0) === 2).count() == 1)
+    assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
+    assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f8468c45/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index f30b196..04d40bb 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1204,7 +1204,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
       nodeToExpr(qualifier) match {
         case UnresolvedAttribute(qualifierName) =>
           UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr))
-        case other => UnresolvedGetField(other, attr)
+        case other => UnresolvedExtractValue(other, Literal(attr))
       }
 
     /* Stars (*) */
@@ -1329,7 +1329,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
 
     /* Complex datatype manipulation */
     case Token("[", child :: ordinal :: Nil) =>
-      GetItem(nodeToExpr(child), nodeToExpr(ordinal))
+      UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal))
 
     /* Other functions */
     case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message