spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l...@apache.org
Subject spark git commit: [SQL] [MINOR] correct semanticEquals logic
Date Fri, 12 Jun 2015 08:38:46 GMT
Repository: spark
Updated Branches:
  refs/heads/master e428b3a95 -> c19c78577


[SQL] [MINOR] correct semanticEquals logic

It's a follow up of https://github.com/apache/spark/pull/6173, for expressions like `Coalesce`
that have a `Seq[Expression]`, when we do semantic equal check for it, we need to do semantic
equal check for all of its children.
Also we can just use `Seq[(Expression, NamedExpression)]` instead of `Map[Expression, NamedExpression]`
as we only search it with `find`.

chenghao-intel, I agree that we probably never knows `semanticEquals` in a general way, but
I think we have done that in `TreeNode`, so we can use similar logic. Then we can handle something
like `Coalesce(children: Seq[Expression])` correctly.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #6261 from cloud-fan/tmp and squashes the following commits:

4daef88 [Wenchen Fan] address comments
dd8fbd9 [Wenchen Fan] correct semanticEquals


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c19c7857
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c19c7857
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c19c7857

Branch: refs/heads/master
Commit: c19c78577a211eefe1112ebd4670a4ce7c3cc3be
Parents: e428b3a
Author: Wenchen Fan <cloud0fan@outlook.com>
Authored: Fri Jun 12 16:38:28 2015 +0800
Committer: Cheng Lian <lian@databricks.com>
Committed: Fri Jun 12 16:38:28 2015 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/Expression.scala     | 13 +++++++++----
 .../spark/sql/catalyst/planning/patterns.scala    | 18 ++++++++----------
 .../spark/sql/execution/GeneratedAggregate.scala  | 14 +++++++-------
 .../org/apache/spark/sql/SQLQuerySuite.scala      |  2 +-
 4 files changed, 25 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c19c7857/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 8c1e4d7..0b9f621 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -136,12 +136,17 @@ abstract class Expression extends TreeNode[Expression] {
    * cosmetically (i.e. capitalization of names in attributes may be different).
    */
   def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass &&
{
+    def checkSemantic(elements1: Seq[Any], elements2: Seq[Any]): Boolean = {
+      elements1.length == elements2.length && elements1.zip(elements2).forall {
+        case (e1: Expression, e2: Expression) => e1 semanticEquals e2
+        case (Some(e1: Expression), Some(e2: Expression)) => e1 semanticEquals e2
+        case (t1: Traversable[_], t2: Traversable[_]) => checkSemantic(t1.toSeq, t2.toSeq)
+        case (i1, i2) => i1 == i2
+      }
+    }
     val elements1 = this.productIterator.toSeq
     val elements2 = other.asInstanceOf[Product].productIterator.toSeq
-    elements1.length == elements2.length && elements1.zip(elements2).forall {
-      case (e1: Expression, e2: Expression) => e1 semanticEquals e2
-      case (i1, i2) => i1 == i2
-    }
+    checkSemantic(elements1, elements2)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/c19c7857/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 1dd75a8..3b6f8bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -143,11 +143,11 @@ object PartialAggregation {
         // We need to pass all grouping expressions though so the grouping can happen a second
         // time. However some of them might be unnamed so we alias them allowing them to
be
         // referenced in the second aggregation.
-        val namedGroupingExpressions: Map[Expression, NamedExpression] =
+        val namedGroupingExpressions: Seq[(Expression, NamedExpression)] =
           groupingExpressions.filter(!_.isInstanceOf[Literal]).map {
             case n: NamedExpression => (n, n)
             case other => (other, Alias(other, "PartialGroup")())
-          }.toMap
+          }
 
         // Replace aggregations with a new expression that computes the result from the already
         // computed partial evaluations and grouping values.
@@ -160,17 +160,15 @@ object PartialAggregation {
             // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
             // (Should we just turn `GetField` into a `NamedExpression`?)
             val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
-            namedGroupingExpressions
-              .find { case (k, v) => k semanticEquals trimmed }
-              .map(_._2.toAttribute)
-              .getOrElse(e)
+            namedGroupingExpressions.collectFirst {
+              case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute
+            }.getOrElse(e)
         }).asInstanceOf[Seq[NamedExpression]]
 
-        val partialComputation =
-          (namedGroupingExpressions.values ++
-            partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq
+        val partialComputation = namedGroupingExpressions.map(_._2) ++
+          partialEvaluations.values.flatMap(_.partialEvaluations)
 
-        val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq
+        val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
 
         Some(
           (namedGroupingAttributes,

http://git-wip-us.apache.org/repos/asf/spark/blob/c19c7857/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index af37917..1c40a92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -214,18 +214,18 @@ case class GeneratedAggregate(
       }.toMap
 
     val namedGroups = groupingExpressions.zipWithIndex.map {
-      case (ne: NamedExpression, _) => (ne, ne)
-      case (e, i) => (e, Alias(e, s"GroupingExpr$i")())
+      case (ne: NamedExpression, _) => (ne, ne.toAttribute)
+      case (e, i) => (e, Alias(e, s"GroupingExpr$i")().toAttribute)
     }
 
-    val groupMap: Map[Expression, Attribute] =
-      namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap
-
     // The set of expressions that produce the final output given the aggregation buffer
and the
     // grouping expressions.
     val resultExpressions = aggregateExpressions.map(_.transform {
       case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
-      case e: Expression if groupMap.contains(e) => groupMap(e)
+      case e: Expression =>
+        namedGroups.collectFirst {
+          case (expr, attr) if expr semanticEquals e => attr
+        }.getOrElse(e)
     })
 
     val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)
@@ -265,7 +265,7 @@ case class GeneratedAggregate(
       val resultProjectionBuilder =
         newMutableProjection(
           resultExpressions,
-          (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
+          namedGroups.map(_._2) ++ computationSchema)
       log.info(s"Result Projection: ${resultExpressions.mkString(",")}")
 
       val joinedRow = new JoinedRow3

http://git-wip-us.apache.org/repos/asf/spark/blob/c19c7857/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 14ecd4e..6898d58 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -697,7 +697,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils
{
         row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
   }
 
-  ignore("cartesian product join") {
+  test("cartesian product join") {
     checkAnswer(
       testData3.join(testData3),
       Row(1, null, 1, null) ::


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message