spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-13440][SQL] ObjectType should accept any ObjectType, If should not care about nullability
Date Tue, 23 Feb 2016 19:20:35 GMT
Repository: spark
Updated Branches:
  refs/heads/master 9f4263392 -> c5bfe5d2a


[SPARK-13440][SQL] ObjectType should accept any ObjectType, If should not care about nullability

The type checking functions of `If` and `UnwrapOption` are fixed to eliminate spurious failures.
 `UnwrapOption` was checking for an input of `ObjectType` but `ObjectType`'s accept function
was hard coded to return `false`.  `If`'s type check was returning a false negative in the
case that the two options differed only by nullability.

Tests added:
 -  an end-to-end regression test is added to `DatasetSuite` for the reported failure.
 - all the unit tests in `ExpressionEncoderSuite` are augmented to also confirm successful
analysis.  These tests are actually what pointed out the additional issues with `If` resolution.

Author: Michael Armbrust <michael@databricks.com>

Closes #11316 from marmbrus/datasetOptions.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c5bfe5d2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c5bfe5d2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c5bfe5d2

Branch: refs/heads/master
Commit: c5bfe5d2a22e0e66b27aa28a19785c3aac5d9f2e
Parents: 9f42633
Author: Michael Armbrust <michael@databricks.com>
Authored: Tue Feb 23 11:20:27 2016 -0800
Committer: Michael Armbrust <michael@databricks.com>
Committed: Tue Feb 23 11:20:27 2016 -0800

----------------------------------------------------------------------
 .../expressions/conditionalExpressions.scala     |  2 +-
 .../catalyst/plans/logical/LocalRelation.scala   |  3 +++
 .../org/apache/spark/sql/types/ObjectType.scala  |  6 ++++--
 .../sql/catalyst/analysis/AnalysisTest.scala     | 13 ++++++++++++-
 .../encoders/ExpressionEncoderSuite.scala        | 19 +++++++++++++++----
 .../org/apache/spark/sql/DatasetSuite.scala      |  8 ++++++++
 6 files changed, 43 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c5bfe5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 200c6a0..c3e9fa3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -34,7 +34,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue:
Expressi
     if (predicate.dataType != BooleanType) {
       TypeCheckResult.TypeCheckFailure(
         s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
-    } else if (trueValue.dataType != falseValue.dataType) {
+    } else if (trueValue.dataType.asNullable != falseValue.dataType.asNullable) {
       TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
         s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/c5bfe5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index d3b5879..f9f1f88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -45,6 +45,9 @@ object LocalRelation {
 case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
   extends LeafNode with analysis.MultiInstanceRelation {
 
+  // A local relation must have resolved output.
+  require(output.forall(_.resolved), "Unresolved attributes found when constructing LocalRelation.")
+
   /**
    * Returns an identical copy of this relation with new exprIds for all attributes.  Different
    * attributes are required when a relation is going to be included multiple times in the
same

http://git-wip-us.apache.org/repos/asf/spark/blob/c5bfe5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
index fca0b79..06ee0fb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala
@@ -23,8 +23,10 @@ private[sql] object ObjectType extends AbstractDataType {
   override private[sql] def defaultConcreteType: DataType =
     throw new UnsupportedOperationException("null literals can't be casted to ObjectType")
 
-  // No casting or comparison is supported.
-  override private[sql] def acceptsType(other: DataType): Boolean = false
+  override private[sql] def acceptsType(other: DataType): Boolean = other match {
+    case ObjectType(_) => true
+    case _ => false
+  }
 
   override private[sql] def simpleString: String = "Object"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c5bfe5d2/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index e0a95ba..ef825e6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -60,7 +60,18 @@ trait AnalysisTest extends PlanTest {
       inputPlan: LogicalPlan,
       caseSensitive: Boolean = true): Unit = {
     val analyzer = getAnalyzer(caseSensitive)
-    analyzer.checkAnalysis(analyzer.execute(inputPlan))
+    val analysisAttempt = analyzer.execute(inputPlan)
+    try analyzer.checkAnalysis(analysisAttempt) catch {
+      case a: AnalysisException =>
+        fail(
+          s"""
+            |Failed to Analyze Plan
+            |$inputPlan
+            |
+            |Partial Analysis
+            |$analysisAttempt
+          """.stripMargin, a)
+    }
   }
 
   protected def assertAnalysisError(

http://git-wip-us.apache.org/repos/asf/spark/blob/c5bfe5d2/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index e00060f..cca320f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -23,12 +23,14 @@ import java.util.Arrays
 import scala.collection.mutable.ArrayBuffer
 import scala.reflect.runtime.universe.TypeTag
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.analysis.AnalysisTest
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
 import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{ArrayType, StructType}
+import org.apache.spark.sql.types.{ArrayType, ObjectType, StructType}
 
 case class RepeatedStruct(s: Seq[PrimitiveData])
 
@@ -74,7 +76,7 @@ class JavaSerializable(val value: Int) extends Serializable {
   }
 }
 
-class ExpressionEncoderSuite extends SparkFunSuite {
+class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
   OuterScopes.addOuterScope(this)
 
   implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
@@ -305,6 +307,15 @@ class ExpressionEncoderSuite extends SparkFunSuite {
             """.stripMargin, e)
       }
 
+      // Test the correct resolution of serialization / deserialization.
+      val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))()
+      val inputPlan = LocalRelation(attr)
+      val plan =
+        Project(Alias(encoder.fromRowExpression, "obj")() :: Nil,
+          Project(encoder.namedExpressions,
+            inputPlan))
+      assertAnalysisSuccess(plan)
+
       val isCorrect = (input, convertedBack) match {
         case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2)
         case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2)

http://git-wip-us.apache.org/repos/asf/spark/blob/c5bfe5d2/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 498f007..14fc37b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -613,6 +613,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
         " - Input schema: struct<a:string,b:int>\n" +
         " - Target schema: struct<_1:string>")
   }
+
+  test("SPARK-13440: Resolving option fields") {
+    val df = Seq(1, 2, 3).toDS()
+    val ds = df.as[Option[Int]]
+    checkAnswer(
+      ds.filter(_ => true),
+      Some(1), Some(2), Some(3))
+  }
 }
 
 class OuterClass extends Serializable {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message