spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SQL][DataFrame] Fix column computability bug.
Date Wed, 11 Feb 2015 03:50:59 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.3 7fa0d5f5c -> e477e91e3


[SQL][DataFrame] Fix column computability bug.

Do not recursively strip out projects. Only strip the first level project.

```scala
df("colA") + df("colB").as("colC")
```

Previously, the above would construct an invalid plan.

Author: Reynold Xin <rxin@databricks.com>

Closes #4519 from rxin/computability and squashes the following commits:

87ff763 [Reynold Xin] Code review feedback.
015c4fc [Reynold Xin] [SQL][DataFrame] Fix column computability.

(cherry picked from commit 7e24249af1e2f896328ef0402fa47db78cb6f9ec)
Signed-off-by: Michael Armbrust <michael@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e477e91e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e477e91e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e477e91e

Branch: refs/heads/branch-1.3
Commit: e477e91e3b65a6feb5f8d5593a2e69f3c715497a
Parents: 7fa0d5f
Author: Reynold Xin <rxin@databricks.com>
Authored: Tue Feb 10 19:50:44 2015 -0800
Committer: Michael Armbrust <michael@databricks.com>
Committed: Tue Feb 10 19:50:56 2015 -0800

----------------------------------------------------------------------
 .../MatrixFactorizationModel.scala              |  2 +-
 .../scala/org/apache/spark/sql/Column.scala     | 35 +++++++++++++++-----
 .../scala/org/apache/spark/sql/SQLContext.scala |  4 +--
 .../spark/sql/ColumnExpressionSuite.scala       | 13 ++++++--
 4 files changed, 39 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e477e91e/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 9ff06ac..16979c9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -180,7 +180,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel]
{
     def save(model: MatrixFactorizationModel, path: String): Unit = {
       val sc = model.userFeatures.sparkContext
       val sqlContext = new SQLContext(sc)
-      import sqlContext.implicits.createDataFrame
+      import sqlContext.implicits._
       val metadata = (thisClassName, thisFormatVersion, model.rank)
       val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version",
"rank")
       metadataRDD.toJSON.saveAsTextFile(metadataPath(path))

http://git-wip-us.apache.org/repos/asf/spark/blob/e477e91e/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index b0e9590..9d5d6e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -66,27 +66,44 @@ trait Column extends DataFrame {
    */
   def isComputable: Boolean
 
+  /** Removes the top project so we can get to the underlying plan. */
+  private def stripProject(p: LogicalPlan): LogicalPlan = p match {
+    case Project(_, child) => child
+    case p => sys.error("Unexpected logical plan (expected Project): " + p)
+  }
+
   private def computableCol(baseCol: ComputableColumn, expr: Expression) = {
-    val plan = Project(Seq(expr match {
+    val namedExpr = expr match {
       case named: NamedExpression => named
       case unnamed: Expression => Alias(unnamed, "col")()
-    }), baseCol.plan)
+    }
+    val plan = Project(Seq(namedExpr), stripProject(baseCol.plan))
     Column(baseCol.sqlContext, plan, expr)
   }
 
+  /**
+   * Construct a new column based on the expression and the other column value.
+   *
+   * There are two cases that can happen here:
+   * If otherValue is a constant, it is first turned into a Column.
+   * If otherValue is a Column, then:
+   *   - If this column and otherValue are both computable and come from the same logical
plan,
+   *     then we can construct a ComputableColumn by applying a Project on top of the base
plan.
+   *   - If this column is not computable, but otherValue is computable, then we can construct
+   *     a ComputableColumn based on otherValue's base plan.
+   *   - If this column is computable, but otherValue is not, then we can construct a
+   *     ComputableColumn based on this column's base plan.
+   *   - If neither columns are computable, then we create an IncomputableColumn.
+   */
   private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column
= {
-    // Removes all the top level projection and subquery so we can get to the underlying
plan.
-    @tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match {
-      case Project(_, child) => stripProject(child)
-      case Subquery(_, child) => stripProject(child)
-      case _ => p
-    }
-
+    // lit(otherValue) returns a Column always.
     (this, lit(otherValue)) match {
       case (left: ComputableColumn, right: ComputableColumn) =>
         if (stripProject(left.plan).sameResult(stripProject(right.plan))) {
           computableCol(right, newExpr(right))
         } else {
+          // We don't want to throw an exception here because "df1("a") === df2("b")" can
be
+          // a valid expression for join conditions, even though standalone they are not
valid.
           Column(newExpr(right))
         }
       case (left: ComputableColumn, right) => computableCol(left, newExpr(right))

http://git-wip-us.apache.org/repos/asf/spark/blob/e477e91e/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 523911d..05ac162 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -183,14 +183,14 @@ class SQLContext(@transient val sparkContext: SparkContext)
      *
      * @group userf
      */
-    implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
+    implicit def rddToDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
       self.createDataFrame(rdd)
     }
 
     /**
      * Creates a DataFrame from a local Seq of Product.
      */
-    implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
+    implicit def localSeqToDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame
= {
       self.createDataFrame(data)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e477e91e/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 1d71039..e3e6f65 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.Dsl._
 import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.TestSQLContext.implicits._
 import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType}
 
 
@@ -44,10 +45,10 @@ class ColumnExpressionSuite extends QueryTest {
     shouldBeComputable(-testData2("a"))
     shouldBeComputable(!testData2("a"))
 
-    shouldBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
-    shouldBeComputable(
+    shouldNotBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
+    shouldNotBeComputable(
       testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d"))
-    shouldBeComputable(
+    shouldNotBeComputable(
       testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b"))
 
     // Literals and unresolved columns should not be computable.
@@ -66,6 +67,12 @@ class ColumnExpressionSuite extends QueryTest {
     shouldNotBeComputable(sum(testData2("a")))
   }
 
+  test("collect on column produced by a binary operator") {
+    val df = Seq((1, 2, 3)).toDataFrame("a", "b", "c")
+    checkAnswer(df("a") + df("b"), Seq(Row(3)))
+    checkAnswer(df("a") + df("b").as("c"), Seq(Row(3)))
+  }
+
   test("star") {
     checkAnswer(testData.select($"*"), testData.collect().toSeq)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message