spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-9293] [SPARK-9813] Analysis should check that set operations are only performed on tables with equal numbers of columns
Date Tue, 25 Aug 2015 07:04:15 GMT
Repository: spark
Updated Branches:
  refs/heads/master bf03fe68d -> 82268f07a


[SPARK-9293] [SPARK-9813] Analysis should check that set operations are only performed on
tables with equal numbers of columns

This patch adds an analyzer rule to ensure that set operations (union, intersect, and except)
are only applied to tables with the same number of columns. Without this rule, there are scenarios
where invalid queries can return incorrect results instead of failing with error messages;
SPARK-9813 provides one example of this problem. In other cases, the invalid query can crash
at runtime with extremely confusing exceptions.

I also performed a bit of cleanup to refactor some of those logical operators' code into a
common `SetOperation` base class.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #7631 from JoshRosen/SPARK-9293.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/82268f07
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/82268f07
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/82268f07

Branch: refs/heads/master
Commit: 82268f07abfa658869df2354ae72f8d6ddd119e8
Parents: bf03fe6
Author: Josh Rosen <joshrosen@databricks.com>
Authored: Tue Aug 25 00:04:10 2015 -0700
Committer: Michael Armbrust <michael@databricks.com>
Committed: Tue Aug 25 00:04:10 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  6 ++++
 .../catalyst/analysis/HiveTypeCoercion.scala    | 14 +++-----
 .../catalyst/plans/logical/basicOperators.scala | 38 +++++++++-----------
 .../catalyst/analysis/AnalysisErrorSuite.scala  | 18 ++++++++++
 .../spark/sql/hive/HiveMetastoreCatalog.scala   |  2 +-
 .../hive/execution/InsertIntoHiveTable.scala    |  2 +-
 6 files changed, 48 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/82268f07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 39f554c..7701fd0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -137,6 +137,12 @@ trait CheckAnalysis {
               }
             }
 
+          case s @ SetOperation(left, right) if left.output.length != right.output.length
=>
+            failAnalysis(
+              s"${s.nodeName} can only be performed on tables with the same number of columns,
" +
+               s"but the left table has ${left.output.length} columns and the right has "
+
+               s"${right.output.length}")
+
           case _ => // Fallbacks to the following checks
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/82268f07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 2cb067f..a1aa2a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -203,6 +203,7 @@ object HiveTypeCoercion {
         planName: String,
         left: LogicalPlan,
         right: LogicalPlan): (LogicalPlan, LogicalPlan) = {
+      require(left.output.length == right.output.length)
 
       val castedTypes = left.output.zip(right.output).map {
         case (lhs, rhs) if lhs.dataType != rhs.dataType =>
@@ -229,15 +230,10 @@ object HiveTypeCoercion {
     def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
       case p if p.analyzed => p
 
-      case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
-        val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right)
-        Union(newLeft, newRight)
-      case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
-        val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right)
-        Except(newLeft, newRight)
-      case i @ Intersect(left, right) if i.childrenResolved && !i.resolved =>
-        val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right)
-        Intersect(newLeft, newRight)
+      case s @ SetOperation(left, right) if s.childrenResolved
+          && left.output.length == right.output.length && !s.resolved =>
+        val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right)
+        s.makeCopy(Array(newLeft, newRight))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/82268f07/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 73b8261..722f69c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -89,13 +89,21 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode
{
   override def output: Seq[Attribute] = child.output
 }
 
-case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
+abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
   // TODO: These aren't really the same attributes as nullability etc might change.
-  override def output: Seq[Attribute] = left.output
+  final override def output: Seq[Attribute] = left.output
 
-  override lazy val resolved: Boolean =
+  final override lazy val resolved: Boolean =
     childrenResolved &&
-    left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
+      left.output.length == right.output.length &&
+      left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
+}
+
+private[sql] object SetOperation {
+  def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
+}
+
+case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right)
{
 
   override def statistics: Statistics = {
     val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes
@@ -103,6 +111,10 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode
{
   }
 }
 
+case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right)
+
+case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right)
+
 case class Join(
   left: LogicalPlan,
   right: LogicalPlan,
@@ -142,15 +154,6 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
 }
 
-
-case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
-  override def output: Seq[Attribute] = left.output
-
-  override lazy val resolved: Boolean =
-    childrenResolved &&
-      left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
-}
-
 case class InsertIntoTable(
     table: LogicalPlan,
     partition: Map[String, Option[String]],
@@ -160,7 +163,7 @@ case class InsertIntoTable(
   extends LogicalPlan {
 
   override def children: Seq[LogicalPlan] = child :: Nil
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] = Seq.empty
 
   assert(overwrite || !ifNotExists)
   override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall
{
@@ -440,10 +443,3 @@ case object OneRowRelation extends LeafNode {
   override def statistics: Statistics = Statistics(sizeInBytes = 1)
 }
 
-case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
-  override def output: Seq[Attribute] = left.output
-
-  override lazy val resolved: Boolean =
-    childrenResolved &&
-      left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/82268f07/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 7065adc..fbdd3a7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -146,6 +146,24 @@ class AnalysisErrorSuite extends AnalysisTest {
     "unresolved" :: Nil)
 
   errorTest(
+    "union with unequal number of columns",
+    testRelation.unionAll(testRelation2),
+    "union" :: "number of columns" :: testRelation2.output.length.toString ::
+      testRelation.output.length.toString :: Nil)
+
+  errorTest(
+    "intersect with unequal number of columns",
+    testRelation.intersect(testRelation2),
+    "intersect" :: "number of columns" :: testRelation2.output.length.toString ::
+      testRelation.output.length.toString :: Nil)
+
+  errorTest(
+    "except with unequal number of columns",
+    testRelation.except(testRelation2),
+    "except" :: "number of columns" :: testRelation2.output.length.toString ::
+      testRelation.output.length.toString :: Nil)
+
+  errorTest(
     "SPARK-9955: correct error message for aggregate",
     // When parse SQL string, we will wrap aggregate expressions with UnresolvedAlias.
     testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))),

http://git-wip-us.apache.org/repos/asf/spark/blob/82268f07/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index bbe8c19..98d21aa 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -751,7 +751,7 @@ private[hive] case class InsertIntoHiveTable(
   extends LogicalPlan {
 
   override def children: Seq[LogicalPlan] = child :: Nil
-  override def output: Seq[Attribute] = child.output
+  override def output: Seq[Attribute] = Seq.empty
 
   val numDynamicPartitions = partition.values.count(_.isEmpty)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/82268f07/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 12c667e..62efda6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -61,7 +61,7 @@ case class InsertIntoHiveTable(
     serializer
   }
 
-  def output: Seq[Attribute] = child.output
+  def output: Seq[Attribute] = Seq.empty
 
   def saveAsHiveFile(
       rdd: RDD[InternalRow],


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message