spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject [3/7] spark git commit: [SPARK-14855][SQL] Add "Exec" suffix to physical operators
Date Sat, 23 Apr 2016 00:44:03 GMT
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
deleted file mode 100644
index 785373b..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ /dev/null
@@ -1,81 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.TaskContext
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow}
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-
-/**
- * Performs a hash join of two child relations by first shuffling the data using the join keys.
- */
-case class ShuffledHashJoin(
-    leftKeys: Seq[Expression],
-    rightKeys: Seq[Expression],
-    joinType: JoinType,
-    buildSide: BuildSide,
-    condition: Option[Expression],
-    left: SparkPlan,
-    right: SparkPlan)
-  extends BinaryNode with HashJoin {
-
-  override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
-    "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
-    "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
-
-  override def outputPartitioning: Partitioning = joinType match {
-    case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
-    case LeftAnti => left.outputPartitioning
-    case LeftSemi => left.outputPartitioning
-    case LeftOuter => left.outputPartitioning
-    case RightOuter => right.outputPartitioning
-    case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
-    case x =>
-      throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x as the JoinType")
-  }
-
-  override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
-  private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
-    val buildDataSize = longMetric("buildDataSize")
-    val buildTime = longMetric("buildTime")
-    val start = System.nanoTime()
-    val context = TaskContext.get()
-    val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
-    buildTime += (System.nanoTime() - start) / 1000000
-    buildDataSize += relation.estimatedSize
-    // This relation is usually used until the end of task.
-    context.addTaskCompletionListener(_ => relation.close())
-    relation
-  }
-
-  protected override def doExecute(): RDD[InternalRow] = {
-    val numOutputRows = longMetric("numOutputRows")
-    streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
-      val hashed = buildHashedRelation(buildIter)
-      join(streamIter, hashed, numOutputRows)
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
new file mode 100644
index 0000000..68cd3cb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+/**
+ * Performs a hash join of two child relations by first shuffling the data using the join keys.
+ */
+case class ShuffledHashJoinExec(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    joinType: JoinType,
+    buildSide: BuildSide,
+    condition: Option[Expression],
+    left: SparkPlan,
+    right: SparkPlan)
+  extends BinaryExecNode with HashJoin {
+
+  override private[sql] lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+    "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
+    "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
+
+  override def outputPartitioning: Partitioning = joinType match {
+    case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+    case LeftAnti => left.outputPartitioning
+    case LeftSemi => left.outputPartitioning
+    case LeftOuter => left.outputPartitioning
+    case RightOuter => right.outputPartitioning
+    case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+    case x =>
+      throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x as the JoinType")
+  }
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+  private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
+    val buildDataSize = longMetric("buildDataSize")
+    val buildTime = longMetric("buildTime")
+    val start = System.nanoTime()
+    val context = TaskContext.get()
+    val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
+    buildTime += (System.nanoTime() - start) / 1000000
+    buildDataSize += relation.estimatedSize
+    // This relation is usually used until the end of task.
+    context.addTaskCompletionListener(_ => relation.close())
+    relation
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val numOutputRows = longMetric("numOutputRows")
+    streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
+      val hashed = buildHashedRelation(buildIter)
+      join(streamIter, hashed, numOutputRows)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
deleted file mode 100644
index 4e45fd6..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ /dev/null
@@ -1,964 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
-import org.apache.spark.util.collection.BitSet
-
-/**
- * Performs an sort merge join of two child relations.
- */
-case class SortMergeJoin(
-    leftKeys: Seq[Expression],
-    rightKeys: Seq[Expression],
-    joinType: JoinType,
-    condition: Option[Expression],
-    left: SparkPlan,
-    right: SparkPlan) extends BinaryNode with CodegenSupport {
-
-  override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
-  override def output: Seq[Attribute] = {
-    joinType match {
-      case Inner =>
-        left.output ++ right.output
-      case LeftOuter =>
-        left.output ++ right.output.map(_.withNullability(true))
-      case RightOuter =>
-        left.output.map(_.withNullability(true)) ++ right.output
-      case FullOuter =>
-        (left.output ++ right.output).map(_.withNullability(true))
-      case x =>
-        throw new IllegalArgumentException(
-          s"${getClass.getSimpleName} should not take $x as the JoinType")
-    }
-  }
-
-  override def outputPartitioning: Partitioning = joinType match {
-    case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
-    // For left and right outer joins, the output is partitioned by the streamed input's join keys.
-    case LeftOuter => left.outputPartitioning
-    case RightOuter => right.outputPartitioning
-    case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
-    case x =>
-      throw new IllegalArgumentException(
-        s"${getClass.getSimpleName} should not take $x as the JoinType")
-  }
-
-  override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
-  override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)
-
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
-    requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
-
-  private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
-    // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
-    keys.map(SortOrder(_, Ascending))
-  }
-
-  private def createLeftKeyGenerator(): Projection =
-    UnsafeProjection.create(leftKeys, left.output)
-
-  private def createRightKeyGenerator(): Projection =
-    UnsafeProjection.create(rightKeys, right.output)
-
-  protected override def doExecute(): RDD[InternalRow] = {
-    val numOutputRows = longMetric("numOutputRows")
-
-    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
-      val boundCondition: (InternalRow) => Boolean = {
-        condition.map { cond =>
-          newPredicate(cond, left.output ++ right.output)
-        }.getOrElse {
-          (r: InternalRow) => true
-        }
-      }
-      // An ordering that can be used to compare keys from both sides.
-      val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
-      val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
-
-      joinType match {
-        case Inner =>
-          new RowIterator {
-            // The projection used to extract keys from input rows of the left child.
-            private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
-
-            // The projection used to extract keys from input rows of the right child.
-            private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output)
-
-            // An ordering that can be used to compare keys from both sides.
-            private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
-            private[this] var currentLeftRow: InternalRow = _
-            private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
-            private[this] var currentMatchIdx: Int = -1
-            private[this] val smjScanner = new SortMergeJoinScanner(
-              leftKeyGenerator,
-              rightKeyGenerator,
-              keyOrdering,
-              RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter)
-            )
-            private[this] val joinRow = new JoinedRow
-            private[this] val resultProjection: (InternalRow) => InternalRow =
-              UnsafeProjection.create(schema)
-
-            if (smjScanner.findNextInnerJoinRows()) {
-              currentRightMatches = smjScanner.getBufferedMatches
-              currentLeftRow = smjScanner.getStreamedRow
-              currentMatchIdx = 0
-            }
-
-            override def advanceNext(): Boolean = {
-              while (currentMatchIdx >= 0) {
-                if (currentMatchIdx == currentRightMatches.length) {
-                  if (smjScanner.findNextInnerJoinRows()) {
-                    currentRightMatches = smjScanner.getBufferedMatches
-                    currentLeftRow = smjScanner.getStreamedRow
-                    currentMatchIdx = 0
-                  } else {
-                    currentRightMatches = null
-                    currentLeftRow = null
-                    currentMatchIdx = -1
-                    return false
-                  }
-                }
-                joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
-                currentMatchIdx += 1
-                if (boundCondition(joinRow)) {
-                  numOutputRows += 1
-                  return true
-                }
-              }
-              false
-            }
-
-            override def getRow: InternalRow = resultProjection(joinRow)
-          }.toScala
-
-        case LeftOuter =>
-          val smjScanner = new SortMergeJoinScanner(
-            streamedKeyGenerator = createLeftKeyGenerator(),
-            bufferedKeyGenerator = createRightKeyGenerator(),
-            keyOrdering,
-            streamedIter = RowIterator.fromScala(leftIter),
-            bufferedIter = RowIterator.fromScala(rightIter)
-          )
-          val rightNullRow = new GenericInternalRow(right.output.length)
-          new LeftOuterIterator(
-            smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala
-
-        case RightOuter =>
-          val smjScanner = new SortMergeJoinScanner(
-            streamedKeyGenerator = createRightKeyGenerator(),
-            bufferedKeyGenerator = createLeftKeyGenerator(),
-            keyOrdering,
-            streamedIter = RowIterator.fromScala(rightIter),
-            bufferedIter = RowIterator.fromScala(leftIter)
-          )
-          val leftNullRow = new GenericInternalRow(left.output.length)
-          new RightOuterIterator(
-            smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
-
-        case FullOuter =>
-          val leftNullRow = new GenericInternalRow(left.output.length)
-          val rightNullRow = new GenericInternalRow(right.output.length)
-          val smjScanner = new SortMergeFullOuterJoinScanner(
-            leftKeyGenerator = createLeftKeyGenerator(),
-            rightKeyGenerator = createRightKeyGenerator(),
-            keyOrdering,
-            leftIter = RowIterator.fromScala(leftIter),
-            rightIter = RowIterator.fromScala(rightIter),
-            boundCondition,
-            leftNullRow,
-            rightNullRow)
-
-          new FullOuterIterator(
-            smjScanner,
-            resultProj,
-            numOutputRows).toScala
-
-        case x =>
-          throw new IllegalArgumentException(
-            s"SortMergeJoin should not take $x as the JoinType")
-      }
-
-    }
-  }
-
-  override def supportCodegen: Boolean = {
-    joinType == Inner
-  }
-
-  override def inputRDDs(): Seq[RDD[InternalRow]] = {
-    left.execute() :: right.execute() :: Nil
-  }
-
-  private def createJoinKey(
-      ctx: CodegenContext,
-      row: String,
-      keys: Seq[Expression],
-      input: Seq[Attribute]): Seq[ExprCode] = {
-    ctx.INPUT_ROW = row
-    keys.map(BindReferences.bindReference(_, input).genCode(ctx))
-  }
-
-  private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
-    vars.zipWithIndex.map { case (ev, i) =>
-      val value = ctx.freshName("value")
-      ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "")
-      val code =
-        s"""
-           |$value = ${ev.value};
-         """.stripMargin
-      ExprCode(code, "false", value)
-    }
-  }
-
-  private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = {
-    val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
-      s"""
-         |if (comp == 0) {
-         |  comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
-         |}
-       """.stripMargin.trim
-    }
-    s"""
-       |comp = 0;
-       |${comparisons.mkString("\n")}
-     """.stripMargin
-  }
-
-  /**
-   * Generate a function to scan both left and right to find a match, returns the term for
-   * matched one row from left side and buffered rows from right side.
-   */
-  private def genScanner(ctx: CodegenContext): (String, String) = {
-    // Create class member for next row from both sides.
-    val leftRow = ctx.freshName("leftRow")
-    ctx.addMutableState("InternalRow", leftRow, "")
-    val rightRow = ctx.freshName("rightRow")
-    ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;")
-
-    // Create variables for join keys from both sides.
-    val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
-    val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
-    val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output)
-    val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ")
-    // Copy the right key as class members so they could be used in next function call.
-    val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)
-
-    // A list to hold all matched rows from right side.
-    val matches = ctx.freshName("matches")
-    val clsName = classOf[java.util.ArrayList[InternalRow]].getName
-    ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
-    // Copy the left keys as class members so they could be used in next function call.
-    val matchedKeyVars = copyKeys(ctx, leftKeyVars)
-
-    ctx.addNewFunction("findNextInnerJoinRows",
-      s"""
-         |private boolean findNextInnerJoinRows(
-         |    scala.collection.Iterator leftIter,
-         |    scala.collection.Iterator rightIter) {
-         |  $leftRow = null;
-         |  int comp = 0;
-         |  while ($leftRow == null) {
-         |    if (!leftIter.hasNext()) return false;
-         |    $leftRow = (InternalRow) leftIter.next();
-         |    ${leftKeyVars.map(_.code).mkString("\n")}
-         |    if ($leftAnyNull) {
-         |      $leftRow = null;
-         |      continue;
-         |    }
-         |    if (!$matches.isEmpty()) {
-         |      ${genComparision(ctx, leftKeyVars, matchedKeyVars)}
-         |      if (comp == 0) {
-         |        return true;
-         |      }
-         |      $matches.clear();
-         |    }
-         |
-         |    do {
-         |      if ($rightRow == null) {
-         |        if (!rightIter.hasNext()) {
-         |          ${matchedKeyVars.map(_.code).mkString("\n")}
-         |          return !$matches.isEmpty();
-         |        }
-         |        $rightRow = (InternalRow) rightIter.next();
-         |        ${rightKeyTmpVars.map(_.code).mkString("\n")}
-         |        if ($rightAnyNull) {
-         |          $rightRow = null;
-         |          continue;
-         |        }
-         |        ${rightKeyVars.map(_.code).mkString("\n")}
-         |      }
-         |      ${genComparision(ctx, leftKeyVars, rightKeyVars)}
-         |      if (comp > 0) {
-         |        $rightRow = null;
-         |      } else if (comp < 0) {
-         |        if (!$matches.isEmpty()) {
-         |          ${matchedKeyVars.map(_.code).mkString("\n")}
-         |          return true;
-         |        }
-         |        $leftRow = null;
-         |      } else {
-         |        $matches.add($rightRow.copy());
-         |        $rightRow = null;;
-         |      }
-         |    } while ($leftRow != null);
-         |  }
-         |  return false; // unreachable
-         |}
-       """.stripMargin)
-
-    (leftRow, matches)
-  }
-
-  /**
-   * Creates variables for left part of result row.
-   *
-   * In order to defer the access after condition and also only access once in the loop,
-   * the variables should be declared separately from accessing the columns, we can't use the
-   * codegen of BoundReference here.
-   */
-  private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
-    ctx.INPUT_ROW = leftRow
-    left.output.zipWithIndex.map { case (a, i) =>
-      val value = ctx.freshName("value")
-      val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
-      // declare it as class member, so we can access the column before or in the loop.
-      ctx.addMutableState(ctx.javaType(a.dataType), value, "")
-      if (a.nullable) {
-        val isNull = ctx.freshName("isNull")
-        ctx.addMutableState("boolean", isNull, "")
-        val code =
-          s"""
-             |$isNull = $leftRow.isNullAt($i);
-             |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
-           """.stripMargin
-        ExprCode(code, isNull, value)
-      } else {
-        ExprCode(s"$value = $valueCode;", "false", value)
-      }
-    }
-  }
-
-  /**
-   * Creates the variables for right part of result row, using BoundReference, since the right
-   * part are accessed inside the loop.
-   */
-  private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
-    ctx.INPUT_ROW = rightRow
-    right.output.zipWithIndex.map { case (a, i) =>
-      BoundReference(i, a.dataType, a.nullable).genCode(ctx)
-    }
-  }
-
-  /**
-   * Splits variables based on whether it's used by condition or not, returns the code to create
-   * these variables before the condition and after the condition.
-   *
-   * Only a few columns are used by condition, then we can skip the accessing of those columns
-   * that are not used by condition also filtered out by condition.
-   */
-  private def splitVarsByCondition(
-      attributes: Seq[Attribute],
-      variables: Seq[ExprCode]): (String, String) = {
-    if (condition.isDefined) {
-      val condRefs = condition.get.references
-      val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
-        condRefs.contains(a)
-      }
-      val beforeCond = evaluateVariables(used.map(_._2))
-      val afterCond = evaluateVariables(notUsed.map(_._2))
-      (beforeCond, afterCond)
-    } else {
-      (evaluateVariables(variables), "")
-    }
-  }
-
-  override def doProduce(ctx: CodegenContext): String = {
-    ctx.copyResult = true
-    val leftInput = ctx.freshName("leftInput")
-    ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];")
-    val rightInput = ctx.freshName("rightInput")
-    ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];")
-
-    val (leftRow, matches) = genScanner(ctx)
-
-    // Create variables for row from both sides.
-    val leftVars = createLeftVars(ctx, leftRow)
-    val rightRow = ctx.freshName("rightRow")
-    val rightVars = createRightVar(ctx, rightRow)
-
-    val size = ctx.freshName("size")
-    val i = ctx.freshName("i")
-    val numOutput = metricTerm(ctx, "numOutputRows")
-    val (beforeLoop, condCheck) = if (condition.isDefined) {
-      // Split the code of creating variables based on whether it's used by condition or not.
-      val loaded = ctx.freshName("loaded")
-      val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
-      val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
-      // Generate code for condition
-      ctx.currentVars = leftVars ++ rightVars
-      val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
-      // evaluate the columns those used by condition before loop
-      val before = s"""
-           |boolean $loaded = false;
-           |$leftBefore
-         """.stripMargin
-
-      val checking = s"""
-         |$rightBefore
-         |${cond.code}
-         |if (${cond.isNull} || !${cond.value}) continue;
-         |if (!$loaded) {
-         |  $loaded = true;
-         |  $leftAfter
-         |}
-         |$rightAfter
-     """.stripMargin
-      (before, checking)
-    } else {
-      (evaluateVariables(leftVars), "")
-    }
-
-    s"""
-       |while (findNextInnerJoinRows($leftInput, $rightInput)) {
-       |  int $size = $matches.size();
-       |  ${beforeLoop.trim}
-       |  for (int $i = 0; $i < $size; $i ++) {
-       |    InternalRow $rightRow = (InternalRow) $matches.get($i);
-       |    ${condCheck.trim}
-       |    $numOutput.add(1);
-       |    ${consume(ctx, leftVars ++ rightVars)}
-       |  }
-       |  if (shouldStop()) return;
-       |}
-     """.stripMargin
-  }
-}
-
-/**
- * Helper class that is used to implement [[SortMergeJoin]].
- *
- * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]]
- * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false`
- * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return
- * the matching row from the streamed input and may call [[getBufferedMatches]] to return the
- * sequence of matching rows from the buffered input (in the case of an outer join, this will return
- * an empty sequence if there are no matches from the buffered input). For efficiency, both of these
- * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()`
- * methods.
- *
- * @param streamedKeyGenerator a projection that produces join keys from the streamed input.
- * @param bufferedKeyGenerator a projection that produces join keys from the buffered input.
- * @param keyOrdering an ordering which can be used to compare join keys.
- * @param streamedIter an input whose rows will be streamed.
- * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that
- *                     have the same join key.
- */
-private[joins] class SortMergeJoinScanner(
-    streamedKeyGenerator: Projection,
-    bufferedKeyGenerator: Projection,
-    keyOrdering: Ordering[InternalRow],
-    streamedIter: RowIterator,
-    bufferedIter: RowIterator) {
-  private[this] var streamedRow: InternalRow = _
-  private[this] var streamedRowKey: InternalRow = _
-  private[this] var bufferedRow: InternalRow = _
-  // Note: this is guaranteed to never have any null columns:
-  private[this] var bufferedRowKey: InternalRow = _
-  /**
-   * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty
-   */
-  private[this] var matchJoinKey: InternalRow = _
-  /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
-  private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
-
-  // Initialization (note: do _not_ want to advance streamed here).
-  advancedBufferedToRowWithNullFreeJoinKey()
-
-  // --- Public methods ---------------------------------------------------------------------------
-
-  def getStreamedRow: InternalRow = streamedRow
-
-  def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
-
-  /**
-   * Advances both input iterators, stopping when we have found rows with matching join keys.
-   * @return true if matching rows have been found and false otherwise. If this returns true, then
-   *         [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join
-   *         results.
-   */
-  final def findNextInnerJoinRows(): Boolean = {
-    while (advancedStreamed() && streamedRowKey.anyNull) {
-      // Advance the streamed side of the join until we find the next row whose join key contains
-      // no nulls or we hit the end of the streamed iterator.
-    }
-    if (streamedRow == null) {
-      // We have consumed the entire streamed iterator, so there can be no more matches.
-      matchJoinKey = null
-      bufferedMatches.clear()
-      false
-    } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
-      // The new streamed row has the same join key as the previous row, so return the same matches.
-      true
-    } else if (bufferedRow == null) {
-      // The streamed row's join key does not match the current batch of buffered rows and there are
-      // no more rows to read from the buffered iterator, so there can be no more matches.
-      matchJoinKey = null
-      bufferedMatches.clear()
-      false
-    } else {
-      // Advance both the streamed and buffered iterators to find the next pair of matching rows.
-      var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
-      do {
-        if (streamedRowKey.anyNull) {
-          advancedStreamed()
-        } else {
-          assert(!bufferedRowKey.anyNull)
-          comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
-          if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey()
-          else if (comp < 0) advancedStreamed()
-        }
-      } while (streamedRow != null && bufferedRow != null && comp != 0)
-      if (streamedRow == null || bufferedRow == null) {
-        // We have either hit the end of one of the iterators, so there can be no more matches.
-        matchJoinKey = null
-        bufferedMatches.clear()
-        false
-      } else {
-        // The streamed row's join key matches the current buffered row's join, so walk through the
-        // buffered iterator to buffer the rest of the matching rows.
-        assert(comp == 0)
-        bufferMatchingRows()
-        true
-      }
-    }
-  }
-
-  /**
-   * Advances the streamed input iterator and buffers all rows from the buffered input that
-   * have matching keys.
-   * @return true if the streamed iterator returned a row, false otherwise. If this returns true,
-   *         then [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the outer
-   *         join results.
-   */
-  final def findNextOuterJoinRows(): Boolean = {
-    if (!advancedStreamed()) {
-      // We have consumed the entire streamed iterator, so there can be no more matches.
-      matchJoinKey = null
-      bufferedMatches.clear()
-      false
-    } else {
-      if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
-        // Matches the current group, so do nothing.
-      } else {
-        // The streamed row does not match the current group.
-        matchJoinKey = null
-        bufferedMatches.clear()
-        if (bufferedRow != null && !streamedRowKey.anyNull) {
-          // The buffered iterator could still contain matching rows, so we'll need to walk through
-          // it until we either find matches or pass where they would be found.
-          var comp = 1
-          do {
-            comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
-          } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey())
-          if (comp == 0) {
-            // We have found matches, so buffer them (this updates matchJoinKey)
-            bufferMatchingRows()
-          } else {
-            // We have overshot the position where the row would be found, hence no matches.
-          }
-        }
-      }
-      // If there is a streamed input then we always return true
-      true
-    }
-  }
-
-  // --- Private methods --------------------------------------------------------------------------
-
-  /**
-   * Advance the streamed iterator and compute the new row's join key.
-   * @return true if the streamed iterator returned a row and false otherwise.
-   */
-  private def advancedStreamed(): Boolean = {
-    if (streamedIter.advanceNext()) {
-      streamedRow = streamedIter.getRow
-      streamedRowKey = streamedKeyGenerator(streamedRow)
-      true
-    } else {
-      streamedRow = null
-      streamedRowKey = null
-      false
-    }
-  }
-
-  /**
-   * Advance the buffered iterator until we find a row with join key that does not contain nulls.
-   * @return true if the buffered iterator returned a row and false otherwise.
-   */
-  private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = {
-    var foundRow: Boolean = false
-    while (!foundRow && bufferedIter.advanceNext()) {
-      bufferedRow = bufferedIter.getRow
-      bufferedRowKey = bufferedKeyGenerator(bufferedRow)
-      foundRow = !bufferedRowKey.anyNull
-    }
-    if (!foundRow) {
-      bufferedRow = null
-      bufferedRowKey = null
-      false
-    } else {
-      true
-    }
-  }
-
-  /**
-   * Called when the streamed and buffered join keys match in order to buffer the matching rows.
-   */
-  private def bufferMatchingRows(): Unit = {
-    assert(streamedRowKey != null)
-    assert(!streamedRowKey.anyNull)
-    assert(bufferedRowKey != null)
-    assert(!bufferedRowKey.anyNull)
-    assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
-    // This join key may have been produced by a mutable projection, so we need to make a copy:
-    matchJoinKey = streamedRowKey.copy()
-    bufferedMatches.clear()
-    do {
-      bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them
-      advancedBufferedToRowWithNullFreeJoinKey()
-    } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
-  }
-}
-
-/**
- * An iterator for outputting rows in left outer join.
- */
-private class LeftOuterIterator(
-    smjScanner: SortMergeJoinScanner,
-    rightNullRow: InternalRow,
-    boundCondition: InternalRow => Boolean,
-    resultProj: InternalRow => InternalRow,
-    numOutputRows: LongSQLMetric)
-  extends OneSideOuterIterator(
-    smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
-
-  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
-  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
-}
-
-/**
- * An iterator for outputting rows in right outer join.
- */
-private class RightOuterIterator(
-    smjScanner: SortMergeJoinScanner,
-    leftNullRow: InternalRow,
-    boundCondition: InternalRow => Boolean,
-    resultProj: InternalRow => InternalRow,
-    numOutputRows: LongSQLMetric)
-  extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
-
-  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
-  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
-}
-
-/**
- * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]].
- *
- * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the
- * streamed side will output 0 or many rows, one for each matching row on the buffered side.
- * If there are no matches, then the buffered side of the joined output will be a null row.
- *
- * In left outer join, the left is the streamed side and the right is the buffered side.
- * In right outer join, the right is the streamed side and the left is the buffered side.
- *
- * @param smjScanner a scanner that streams rows and buffers any matching rows
- * @param bufferedSideNullRow the default row to return when a streamed row has no matches
- * @param boundCondition an additional filter condition for buffered rows
- * @param resultProj how the output should be projected
- * @param numOutputRows an accumulator metric for the number of rows output
- */
-private abstract class OneSideOuterIterator(
-    smjScanner: SortMergeJoinScanner,
-    bufferedSideNullRow: InternalRow,
-    boundCondition: InternalRow => Boolean,
-    resultProj: InternalRow => InternalRow,
-    numOutputRows: LongSQLMetric) extends RowIterator {
-
-  // A row to store the joined result, reused many times
-  protected[this] val joinedRow: JoinedRow = new JoinedRow()
-
-  // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
-  private[this] var bufferIndex: Int = 0
-
-  // This iterator is initialized lazily so there should be no matches initially
-  assert(smjScanner.getBufferedMatches.length == 0)
-
-  // Set output methods to be overridden by subclasses
-  protected def setStreamSideOutput(row: InternalRow): Unit
-  protected def setBufferedSideOutput(row: InternalRow): Unit
-
-  /**
-   * Advance to the next row on the stream side and populate the buffer with matches.
-   * @return whether there are more rows in the stream to consume.
-   */
-  private def advanceStream(): Boolean = {
-    bufferIndex = 0
-    if (smjScanner.findNextOuterJoinRows()) {
-      setStreamSideOutput(smjScanner.getStreamedRow)
-      if (smjScanner.getBufferedMatches.isEmpty) {
-        // There are no matching rows in the buffer, so return the null row
-        setBufferedSideOutput(bufferedSideNullRow)
-      } else {
-        // Find the next row in the buffer that satisfied the bound condition
-        if (!advanceBufferUntilBoundConditionSatisfied()) {
-          setBufferedSideOutput(bufferedSideNullRow)
-        }
-      }
-      true
-    } else {
-      // Stream has been exhausted
-      false
-    }
-  }
-
-  /**
-   * Advance to the next row in the buffer that satisfies the bound condition.
-   * @return whether there is such a row in the current buffer.
-   */
-  private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
-    var foundMatch: Boolean = false
-    while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
-      setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
-      foundMatch = boundCondition(joinedRow)
-      bufferIndex += 1
-    }
-    foundMatch
-  }
-
-  override def advanceNext(): Boolean = {
-    val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream()
-    if (r) numOutputRows += 1
-    r
-  }
-
-  override def getRow: InternalRow = resultProj(joinedRow)
-}
-
-private class SortMergeFullOuterJoinScanner(
-    leftKeyGenerator: Projection,
-    rightKeyGenerator: Projection,
-    keyOrdering: Ordering[InternalRow],
-    leftIter: RowIterator,
-    rightIter: RowIterator,
-    boundCondition: InternalRow => Boolean,
-    leftNullRow: InternalRow,
-    rightNullRow: InternalRow)  {
-  private[this] val joinedRow: JoinedRow = new JoinedRow()
-  private[this] var leftRow: InternalRow = _
-  private[this] var leftRowKey: InternalRow = _
-  private[this] var rightRow: InternalRow = _
-  private[this] var rightRowKey: InternalRow = _
-
-  private[this] var leftIndex: Int = 0
-  private[this] var rightIndex: Int = 0
-  private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
-  private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
-  private[this] var leftMatched: BitSet = new BitSet(1)
-  private[this] var rightMatched: BitSet = new BitSet(1)
-
-  advancedLeft()
-  advancedRight()
-
-  // --- Private methods --------------------------------------------------------------------------
-
-  /**
-   * Advance the left iterator and compute the new row's join key.
-   * @return true if the left iterator returned a row and false otherwise.
-   */
-  private def advancedLeft(): Boolean = {
-    if (leftIter.advanceNext()) {
-      leftRow = leftIter.getRow
-      leftRowKey = leftKeyGenerator(leftRow)
-      true
-    } else {
-      leftRow = null
-      leftRowKey = null
-      false
-    }
-  }
-
-  /**
-   * Advance the right iterator and compute the new row's join key.
-   * @return true if the right iterator returned a row and false otherwise.
-   */
-  private def advancedRight(): Boolean = {
-    if (rightIter.advanceNext()) {
-      rightRow = rightIter.getRow
-      rightRowKey = rightKeyGenerator(rightRow)
-      true
-    } else {
-      rightRow = null
-      rightRowKey = null
-      false
-    }
-  }
-
-  /**
-   * Populate the left and right buffers with rows matching the provided key.
-   * This consumes rows from both iterators until their keys are different from the matching key.
-   */
-  private def findMatchingRows(matchingKey: InternalRow): Unit = {
-    leftMatches.clear()
-    rightMatches.clear()
-    leftIndex = 0
-    rightIndex = 0
-
-    while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) {
-      leftMatches += leftRow.copy()
-      advancedLeft()
-    }
-    while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) {
-      rightMatches += rightRow.copy()
-      advancedRight()
-    }
-
-    if (leftMatches.size <= leftMatched.capacity) {
-      leftMatched.clear()
-    } else {
-      leftMatched = new BitSet(leftMatches.size)
-    }
-    if (rightMatches.size <= rightMatched.capacity) {
-      rightMatched.clear()
-    } else {
-      rightMatched = new BitSet(rightMatches.size)
-    }
-  }
-
-  /**
-   * Scan the left and right buffers for the next valid match.
-   *
-   * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers.
-   * If a left row has no valid matches on the right, or a right row has no valid matches on the
-   * left, then the row is joined with the null row and the result is considered a valid match.
-   *
-   * @return true if a valid match is found, false otherwise.
-   */
-  private def scanNextInBuffered(): Boolean = {
-    while (leftIndex < leftMatches.size) {
-      while (rightIndex < rightMatches.size) {
-        joinedRow(leftMatches(leftIndex), rightMatches(rightIndex))
-        if (boundCondition(joinedRow)) {
-          leftMatched.set(leftIndex)
-          rightMatched.set(rightIndex)
-          rightIndex += 1
-          return true
-        }
-        rightIndex += 1
-      }
-      rightIndex = 0
-      if (!leftMatched.get(leftIndex)) {
-        // the left row has never matched any right row, join it with null row
-        joinedRow(leftMatches(leftIndex), rightNullRow)
-        leftIndex += 1
-        return true
-      }
-      leftIndex += 1
-    }
-
-    while (rightIndex < rightMatches.size) {
-      if (!rightMatched.get(rightIndex)) {
-        // the right row has never matched any left row, join it with null row
-        joinedRow(leftNullRow, rightMatches(rightIndex))
-        rightIndex += 1
-        return true
-      }
-      rightIndex += 1
-    }
-
-    // There are no more valid matches in the left and right buffers
-    false
-  }
-
-  // --- Public methods --------------------------------------------------------------------------
-
-  def getJoinedRow(): JoinedRow = joinedRow
-
-  def advanceNext(): Boolean = {
-    // If we already buffered some matching rows, use them directly
-    if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) {
-      if (scanNextInBuffered()) {
-        return true
-      }
-    }
-
-    if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) {
-      joinedRow(leftRow.copy(), rightNullRow)
-      advancedLeft()
-      true
-    } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) {
-      joinedRow(leftNullRow, rightRow.copy())
-      advancedRight()
-      true
-    } else if (leftRow != null && rightRow != null) {
-      // Both rows are present and neither have null values,
-      // so we populate the buffers with rows matching the next key
-      val comp = keyOrdering.compare(leftRowKey, rightRowKey)
-      if (comp <= 0) {
-        findMatchingRows(leftRowKey.copy())
-      } else {
-        findMatchingRows(rightRowKey.copy())
-      }
-      scanNextInBuffered()
-      true
-    } else {
-      // Both iterators have been consumed
-      false
-    }
-  }
-}
-
-private class FullOuterIterator(
-    smjScanner: SortMergeFullOuterJoinScanner,
-    resultProj: InternalRow => InternalRow,
-    numRows: LongSQLMetric) extends RowIterator {
-  private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
-
-  override def advanceNext(): Boolean = {
-    val r = smjScanner.advanceNext()
-    if (r) numRows += 1
-    r
-  }
-
-  override def getRow: InternalRow = resultProj(joinedRow)
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
new file mode 100644
index 0000000..96b283a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -0,0 +1,964 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan}
+import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
+import org.apache.spark.util.collection.BitSet
+
+/**
+ * Performs an sort merge join of two child relations.
+ */
+case class SortMergeJoinExec(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    joinType: JoinType,
+    condition: Option[Expression],
+    left: SparkPlan,
+    right: SparkPlan) extends BinaryExecNode with CodegenSupport {
+
+  override private[sql] lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+  override def output: Seq[Attribute] = {
+    joinType match {
+      case Inner =>
+        left.output ++ right.output
+      case LeftOuter =>
+        left.output ++ right.output.map(_.withNullability(true))
+      case RightOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output
+      case FullOuter =>
+        (left.output ++ right.output).map(_.withNullability(true))
+      case x =>
+        throw new IllegalArgumentException(
+          s"${getClass.getSimpleName} should not take $x as the JoinType")
+    }
+  }
+
+  override def outputPartitioning: Partitioning = joinType match {
+    case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+    // For left and right outer joins, the output is partitioned by the streamed input's join keys.
+    case LeftOuter => left.outputPartitioning
+    case RightOuter => right.outputPartitioning
+    case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+    case x =>
+      throw new IllegalArgumentException(
+        s"${getClass.getSimpleName} should not take $x as the JoinType")
+  }
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+  override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
+
+  private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
+    // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
+    keys.map(SortOrder(_, Ascending))
+  }
+
+  private def createLeftKeyGenerator(): Projection =
+    UnsafeProjection.create(leftKeys, left.output)
+
+  private def createRightKeyGenerator(): Projection =
+    UnsafeProjection.create(rightKeys, right.output)
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val numOutputRows = longMetric("numOutputRows")
+
+    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+      val boundCondition: (InternalRow) => Boolean = {
+        condition.map { cond =>
+          newPredicate(cond, left.output ++ right.output)
+        }.getOrElse {
+          (r: InternalRow) => true
+        }
+      }
+      // An ordering that can be used to compare keys from both sides.
+      val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
+      val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
+
+      joinType match {
+        case Inner =>
+          new RowIterator {
+            // The projection used to extract keys from input rows of the left child.
+            private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
+
+            // The projection used to extract keys from input rows of the right child.
+            private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output)
+
+            // An ordering that can be used to compare keys from both sides.
+            private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
+            private[this] var currentLeftRow: InternalRow = _
+            private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
+            private[this] var currentMatchIdx: Int = -1
+            private[this] val smjScanner = new SortMergeJoinScanner(
+              leftKeyGenerator,
+              rightKeyGenerator,
+              keyOrdering,
+              RowIterator.fromScala(leftIter),
+              RowIterator.fromScala(rightIter)
+            )
+            private[this] val joinRow = new JoinedRow
+            private[this] val resultProjection: (InternalRow) => InternalRow =
+              UnsafeProjection.create(schema)
+
+            if (smjScanner.findNextInnerJoinRows()) {
+              currentRightMatches = smjScanner.getBufferedMatches
+              currentLeftRow = smjScanner.getStreamedRow
+              currentMatchIdx = 0
+            }
+
+            override def advanceNext(): Boolean = {
+              while (currentMatchIdx >= 0) {
+                if (currentMatchIdx == currentRightMatches.length) {
+                  if (smjScanner.findNextInnerJoinRows()) {
+                    currentRightMatches = smjScanner.getBufferedMatches
+                    currentLeftRow = smjScanner.getStreamedRow
+                    currentMatchIdx = 0
+                  } else {
+                    currentRightMatches = null
+                    currentLeftRow = null
+                    currentMatchIdx = -1
+                    return false
+                  }
+                }
+                joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
+                currentMatchIdx += 1
+                if (boundCondition(joinRow)) {
+                  numOutputRows += 1
+                  return true
+                }
+              }
+              false
+            }
+
+            override def getRow: InternalRow = resultProjection(joinRow)
+          }.toScala
+
+        case LeftOuter =>
+          val smjScanner = new SortMergeJoinScanner(
+            streamedKeyGenerator = createLeftKeyGenerator(),
+            bufferedKeyGenerator = createRightKeyGenerator(),
+            keyOrdering,
+            streamedIter = RowIterator.fromScala(leftIter),
+            bufferedIter = RowIterator.fromScala(rightIter)
+          )
+          val rightNullRow = new GenericInternalRow(right.output.length)
+          new LeftOuterIterator(
+            smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala
+
+        case RightOuter =>
+          val smjScanner = new SortMergeJoinScanner(
+            streamedKeyGenerator = createRightKeyGenerator(),
+            bufferedKeyGenerator = createLeftKeyGenerator(),
+            keyOrdering,
+            streamedIter = RowIterator.fromScala(rightIter),
+            bufferedIter = RowIterator.fromScala(leftIter)
+          )
+          val leftNullRow = new GenericInternalRow(left.output.length)
+          new RightOuterIterator(
+            smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
+
+        case FullOuter =>
+          val leftNullRow = new GenericInternalRow(left.output.length)
+          val rightNullRow = new GenericInternalRow(right.output.length)
+          val smjScanner = new SortMergeFullOuterJoinScanner(
+            leftKeyGenerator = createLeftKeyGenerator(),
+            rightKeyGenerator = createRightKeyGenerator(),
+            keyOrdering,
+            leftIter = RowIterator.fromScala(leftIter),
+            rightIter = RowIterator.fromScala(rightIter),
+            boundCondition,
+            leftNullRow,
+            rightNullRow)
+
+          new FullOuterIterator(
+            smjScanner,
+            resultProj,
+            numOutputRows).toScala
+
+        case x =>
+          throw new IllegalArgumentException(
+            s"SortMergeJoin should not take $x as the JoinType")
+      }
+
+    }
+  }
+
+  override def supportCodegen: Boolean = {
+    joinType == Inner
+  }
+
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    left.execute() :: right.execute() :: Nil
+  }
+
+  private def createJoinKey(
+      ctx: CodegenContext,
+      row: String,
+      keys: Seq[Expression],
+      input: Seq[Attribute]): Seq[ExprCode] = {
+    ctx.INPUT_ROW = row
+    keys.map(BindReferences.bindReference(_, input).genCode(ctx))
+  }
+
+  private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
+    vars.zipWithIndex.map { case (ev, i) =>
+      val value = ctx.freshName("value")
+      ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "")
+      val code =
+        s"""
+           |$value = ${ev.value};
+         """.stripMargin
+      ExprCode(code, "false", value)
+    }
+  }
+
+  private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = {
+    val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
+      s"""
+         |if (comp == 0) {
+         |  comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
+         |}
+       """.stripMargin.trim
+    }
+    s"""
+       |comp = 0;
+       |${comparisons.mkString("\n")}
+     """.stripMargin
+  }
+
+  /**
+   * Generate a function to scan both left and right to find a match, returns the term for
+   * matched one row from left side and buffered rows from right side.
+   */
+  private def genScanner(ctx: CodegenContext): (String, String) = {
+    // Create class member for next row from both sides.
+    val leftRow = ctx.freshName("leftRow")
+    ctx.addMutableState("InternalRow", leftRow, "")
+    val rightRow = ctx.freshName("rightRow")
+    ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;")
+
+    // Create variables for join keys from both sides.
+    val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
+    val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
+    val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output)
+    val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ")
+    // Copy the right key as class members so they could be used in next function call.
+    val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)
+
+    // A list to hold all matched rows from right side.
+    val matches = ctx.freshName("matches")
+    val clsName = classOf[java.util.ArrayList[InternalRow]].getName
+    ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
+    // Copy the left keys as class members so they could be used in next function call.
+    val matchedKeyVars = copyKeys(ctx, leftKeyVars)
+
+    ctx.addNewFunction("findNextInnerJoinRows",
+      s"""
+         |private boolean findNextInnerJoinRows(
+         |    scala.collection.Iterator leftIter,
+         |    scala.collection.Iterator rightIter) {
+         |  $leftRow = null;
+         |  int comp = 0;
+         |  while ($leftRow == null) {
+         |    if (!leftIter.hasNext()) return false;
+         |    $leftRow = (InternalRow) leftIter.next();
+         |    ${leftKeyVars.map(_.code).mkString("\n")}
+         |    if ($leftAnyNull) {
+         |      $leftRow = null;
+         |      continue;
+         |    }
+         |    if (!$matches.isEmpty()) {
+         |      ${genComparision(ctx, leftKeyVars, matchedKeyVars)}
+         |      if (comp == 0) {
+         |        return true;
+         |      }
+         |      $matches.clear();
+         |    }
+         |
+         |    do {
+         |      if ($rightRow == null) {
+         |        if (!rightIter.hasNext()) {
+         |          ${matchedKeyVars.map(_.code).mkString("\n")}
+         |          return !$matches.isEmpty();
+         |        }
+         |        $rightRow = (InternalRow) rightIter.next();
+         |        ${rightKeyTmpVars.map(_.code).mkString("\n")}
+         |        if ($rightAnyNull) {
+         |          $rightRow = null;
+         |          continue;
+         |        }
+         |        ${rightKeyVars.map(_.code).mkString("\n")}
+         |      }
+         |      ${genComparision(ctx, leftKeyVars, rightKeyVars)}
+         |      if (comp > 0) {
+         |        $rightRow = null;
+         |      } else if (comp < 0) {
+         |        if (!$matches.isEmpty()) {
+         |          ${matchedKeyVars.map(_.code).mkString("\n")}
+         |          return true;
+         |        }
+         |        $leftRow = null;
+         |      } else {
+         |        $matches.add($rightRow.copy());
+         |        $rightRow = null;;
+         |      }
+         |    } while ($leftRow != null);
+         |  }
+         |  return false; // unreachable
+         |}
+       """.stripMargin)
+
+    (leftRow, matches)
+  }
+
+  /**
+   * Creates variables for left part of result row.
+   *
+   * In order to defer the access after condition and also only access once in the loop,
+   * the variables should be declared separately from accessing the columns, we can't use the
+   * codegen of BoundReference here.
+   */
+  private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
+    ctx.INPUT_ROW = leftRow
+    left.output.zipWithIndex.map { case (a, i) =>
+      val value = ctx.freshName("value")
+      val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
+      // declare it as class member, so we can access the column before or in the loop.
+      ctx.addMutableState(ctx.javaType(a.dataType), value, "")
+      if (a.nullable) {
+        val isNull = ctx.freshName("isNull")
+        ctx.addMutableState("boolean", isNull, "")
+        val code =
+          s"""
+             |$isNull = $leftRow.isNullAt($i);
+             |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
+           """.stripMargin
+        ExprCode(code, isNull, value)
+      } else {
+        ExprCode(s"$value = $valueCode;", "false", value)
+      }
+    }
+  }
+
+  /**
+   * Creates the variables for right part of result row, using BoundReference, since the right
+   * part are accessed inside the loop.
+   */
+  private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
+    ctx.INPUT_ROW = rightRow
+    right.output.zipWithIndex.map { case (a, i) =>
+      BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+    }
+  }
+
+  /**
+   * Splits variables based on whether it's used by condition or not, returns the code to create
+   * these variables before the condition and after the condition.
+   *
+   * Only a few columns are used by condition, then we can skip the accessing of those columns
+   * that are not used by condition also filtered out by condition.
+   */
+  private def splitVarsByCondition(
+      attributes: Seq[Attribute],
+      variables: Seq[ExprCode]): (String, String) = {
+    if (condition.isDefined) {
+      val condRefs = condition.get.references
+      val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
+        condRefs.contains(a)
+      }
+      val beforeCond = evaluateVariables(used.map(_._2))
+      val afterCond = evaluateVariables(notUsed.map(_._2))
+      (beforeCond, afterCond)
+    } else {
+      (evaluateVariables(variables), "")
+    }
+  }
+
+  override def doProduce(ctx: CodegenContext): String = {
+    ctx.copyResult = true
+    val leftInput = ctx.freshName("leftInput")
+    ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];")
+    val rightInput = ctx.freshName("rightInput")
+    ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];")
+
+    val (leftRow, matches) = genScanner(ctx)
+
+    // Create variables for row from both sides.
+    val leftVars = createLeftVars(ctx, leftRow)
+    val rightRow = ctx.freshName("rightRow")
+    val rightVars = createRightVar(ctx, rightRow)
+
+    val size = ctx.freshName("size")
+    val i = ctx.freshName("i")
+    val numOutput = metricTerm(ctx, "numOutputRows")
+    val (beforeLoop, condCheck) = if (condition.isDefined) {
+      // Split the code of creating variables based on whether it's used by condition or not.
+      val loaded = ctx.freshName("loaded")
+      val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
+      val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
+      // Generate code for condition
+      ctx.currentVars = leftVars ++ rightVars
+      val cond = BindReferences.bindReference(condition.get, output).genCode(ctx)
+      // evaluate the columns those used by condition before loop
+      val before = s"""
+           |boolean $loaded = false;
+           |$leftBefore
+         """.stripMargin
+
+      val checking = s"""
+         |$rightBefore
+         |${cond.code}
+         |if (${cond.isNull} || !${cond.value}) continue;
+         |if (!$loaded) {
+         |  $loaded = true;
+         |  $leftAfter
+         |}
+         |$rightAfter
+     """.stripMargin
+      (before, checking)
+    } else {
+      (evaluateVariables(leftVars), "")
+    }
+
+    s"""
+       |while (findNextInnerJoinRows($leftInput, $rightInput)) {
+       |  int $size = $matches.size();
+       |  ${beforeLoop.trim}
+       |  for (int $i = 0; $i < $size; $i ++) {
+       |    InternalRow $rightRow = (InternalRow) $matches.get($i);
+       |    ${condCheck.trim}
+       |    $numOutput.add(1);
+       |    ${consume(ctx, leftVars ++ rightVars)}
+       |  }
+       |  if (shouldStop()) return;
+       |}
+     """.stripMargin
+  }
+}
+
+/**
+ * Helper class that is used to implement [[SortMergeJoinExec]].
+ *
+ * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]]
+ * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false`
+ * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return
+ * the matching row from the streamed input and may call [[getBufferedMatches]] to return the
+ * sequence of matching rows from the buffered input (in the case of an outer join, this will return
+ * an empty sequence if there are no matches from the buffered input). For efficiency, both of these
+ * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()`
+ * methods.
+ *
+ * @param streamedKeyGenerator a projection that produces join keys from the streamed input.
+ * @param bufferedKeyGenerator a projection that produces join keys from the buffered input.
+ * @param keyOrdering an ordering which can be used to compare join keys.
+ * @param streamedIter an input whose rows will be streamed.
+ * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that
+ *                     have the same join key.
+ */
+private[joins] class SortMergeJoinScanner(
+    streamedKeyGenerator: Projection,
+    bufferedKeyGenerator: Projection,
+    keyOrdering: Ordering[InternalRow],
+    streamedIter: RowIterator,
+    bufferedIter: RowIterator) {
+  private[this] var streamedRow: InternalRow = _
+  private[this] var streamedRowKey: InternalRow = _
+  private[this] var bufferedRow: InternalRow = _
+  // Note: this is guaranteed to never have any null columns:
+  private[this] var bufferedRowKey: InternalRow = _
+  /**
+   * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty
+   */
+  private[this] var matchJoinKey: InternalRow = _
+  /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
+  private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+
+  // Initialization (note: do _not_ want to advance streamed here).
+  advancedBufferedToRowWithNullFreeJoinKey()
+
+  // --- Public methods ---------------------------------------------------------------------------
+
+  def getStreamedRow: InternalRow = streamedRow
+
+  def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
+
+  /**
+   * Advances both input iterators, stopping when we have found rows with matching join keys.
+   * @return true if matching rows have been found and false otherwise. If this returns true, then
+   *         [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join
+   *         results.
+   */
+  final def findNextInnerJoinRows(): Boolean = {
+    while (advancedStreamed() && streamedRowKey.anyNull) {
+      // Advance the streamed side of the join until we find the next row whose join key contains
+      // no nulls or we hit the end of the streamed iterator.
+    }
+    if (streamedRow == null) {
+      // We have consumed the entire streamed iterator, so there can be no more matches.
+      matchJoinKey = null
+      bufferedMatches.clear()
+      false
+    } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
+      // The new streamed row has the same join key as the previous row, so return the same matches.
+      true
+    } else if (bufferedRow == null) {
+      // The streamed row's join key does not match the current batch of buffered rows and there are
+      // no more rows to read from the buffered iterator, so there can be no more matches.
+      matchJoinKey = null
+      bufferedMatches.clear()
+      false
+    } else {
+      // Advance both the streamed and buffered iterators to find the next pair of matching rows.
+      var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+      do {
+        if (streamedRowKey.anyNull) {
+          advancedStreamed()
+        } else {
+          assert(!bufferedRowKey.anyNull)
+          comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+          if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey()
+          else if (comp < 0) advancedStreamed()
+        }
+      } while (streamedRow != null && bufferedRow != null && comp != 0)
+      if (streamedRow == null || bufferedRow == null) {
+        // We have either hit the end of one of the iterators, so there can be no more matches.
+        matchJoinKey = null
+        bufferedMatches.clear()
+        false
+      } else {
+        // The streamed row's join key matches the current buffered row's join, so walk through the
+        // buffered iterator to buffer the rest of the matching rows.
+        assert(comp == 0)
+        bufferMatchingRows()
+        true
+      }
+    }
+  }
+
+  /**
+   * Advances the streamed input iterator and buffers all rows from the buffered input that
+   * have matching keys.
+   * @return true if the streamed iterator returned a row, false otherwise. If this returns true,
+   *         then [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the outer
+   *         join results.
+   */
+  final def findNextOuterJoinRows(): Boolean = {
+    if (!advancedStreamed()) {
+      // We have consumed the entire streamed iterator, so there can be no more matches.
+      matchJoinKey = null
+      bufferedMatches.clear()
+      false
+    } else {
+      if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) {
+        // Matches the current group, so do nothing.
+      } else {
+        // The streamed row does not match the current group.
+        matchJoinKey = null
+        bufferedMatches.clear()
+        if (bufferedRow != null && !streamedRowKey.anyNull) {
+          // The buffered iterator could still contain matching rows, so we'll need to walk through
+          // it until we either find matches or pass where they would be found.
+          var comp = 1
+          do {
+            comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+          } while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey())
+          if (comp == 0) {
+            // We have found matches, so buffer them (this updates matchJoinKey)
+            bufferMatchingRows()
+          } else {
+            // We have overshot the position where the row would be found, hence no matches.
+          }
+        }
+      }
+      // If there is a streamed input then we always return true
+      true
+    }
+  }
+
+  // --- Private methods --------------------------------------------------------------------------
+
+  /**
+   * Advance the streamed iterator and compute the new row's join key.
+   * @return true if the streamed iterator returned a row and false otherwise.
+   */
+  private def advancedStreamed(): Boolean = {
+    if (streamedIter.advanceNext()) {
+      streamedRow = streamedIter.getRow
+      streamedRowKey = streamedKeyGenerator(streamedRow)
+      true
+    } else {
+      streamedRow = null
+      streamedRowKey = null
+      false
+    }
+  }
+
+  /**
+   * Advance the buffered iterator until we find a row with join key that does not contain nulls.
+   * @return true if the buffered iterator returned a row and false otherwise.
+   */
+  private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = {
+    var foundRow: Boolean = false
+    while (!foundRow && bufferedIter.advanceNext()) {
+      bufferedRow = bufferedIter.getRow
+      bufferedRowKey = bufferedKeyGenerator(bufferedRow)
+      foundRow = !bufferedRowKey.anyNull
+    }
+    if (!foundRow) {
+      bufferedRow = null
+      bufferedRowKey = null
+      false
+    } else {
+      true
+    }
+  }
+
+  /**
+   * Called when the streamed and buffered join keys match in order to buffer the matching rows.
+   */
+  private def bufferMatchingRows(): Unit = {
+    assert(streamedRowKey != null)
+    assert(!streamedRowKey.anyNull)
+    assert(bufferedRowKey != null)
+    assert(!bufferedRowKey.anyNull)
+    assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
+    // This join key may have been produced by a mutable projection, so we need to make a copy:
+    matchJoinKey = streamedRowKey.copy()
+    bufferedMatches.clear()
+    do {
+      bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them
+      advancedBufferedToRowWithNullFreeJoinKey()
+    } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
+  }
+}
+
+/**
+ * An iterator for outputting rows in left outer join.
+ */
+private class LeftOuterIterator(
+    smjScanner: SortMergeJoinScanner,
+    rightNullRow: InternalRow,
+    boundCondition: InternalRow => Boolean,
+    resultProj: InternalRow => InternalRow,
+    numOutputRows: LongSQLMetric)
+  extends OneSideOuterIterator(
+    smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
+
+  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
+  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
+}
+
+/**
+ * An iterator for outputting rows in right outer join.
+ */
+private class RightOuterIterator(
+    smjScanner: SortMergeJoinScanner,
+    leftNullRow: InternalRow,
+    boundCondition: InternalRow => Boolean,
+    resultProj: InternalRow => InternalRow,
+    numOutputRows: LongSQLMetric)
+  extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
+
+  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
+  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
+}
+
+/**
+ * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]].
+ *
+ * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the
+ * streamed side will output 0 or many rows, one for each matching row on the buffered side.
+ * If there are no matches, then the buffered side of the joined output will be a null row.
+ *
+ * In left outer join, the left is the streamed side and the right is the buffered side.
+ * In right outer join, the right is the streamed side and the left is the buffered side.
+ *
+ * @param smjScanner a scanner that streams rows and buffers any matching rows
+ * @param bufferedSideNullRow the default row to return when a streamed row has no matches
+ * @param boundCondition an additional filter condition for buffered rows
+ * @param resultProj how the output should be projected
+ * @param numOutputRows an accumulator metric for the number of rows output
+ */
+private abstract class OneSideOuterIterator(
+    smjScanner: SortMergeJoinScanner,
+    bufferedSideNullRow: InternalRow,
+    boundCondition: InternalRow => Boolean,
+    resultProj: InternalRow => InternalRow,
+    numOutputRows: LongSQLMetric) extends RowIterator {
+
+  // A row to store the joined result, reused many times
+  protected[this] val joinedRow: JoinedRow = new JoinedRow()
+
+  // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
+  private[this] var bufferIndex: Int = 0
+
+  // This iterator is initialized lazily so there should be no matches initially
+  assert(smjScanner.getBufferedMatches.length == 0)
+
+  // Set output methods to be overridden by subclasses
+  protected def setStreamSideOutput(row: InternalRow): Unit
+  protected def setBufferedSideOutput(row: InternalRow): Unit
+
+  /**
+   * Advance to the next row on the stream side and populate the buffer with matches.
+   * @return whether there are more rows in the stream to consume.
+   */
+  private def advanceStream(): Boolean = {
+    bufferIndex = 0
+    if (smjScanner.findNextOuterJoinRows()) {
+      setStreamSideOutput(smjScanner.getStreamedRow)
+      if (smjScanner.getBufferedMatches.isEmpty) {
+        // There are no matching rows in the buffer, so return the null row
+        setBufferedSideOutput(bufferedSideNullRow)
+      } else {
+        // Find the next row in the buffer that satisfied the bound condition
+        if (!advanceBufferUntilBoundConditionSatisfied()) {
+          setBufferedSideOutput(bufferedSideNullRow)
+        }
+      }
+      true
+    } else {
+      // Stream has been exhausted
+      false
+    }
+  }
+
+  /**
+   * Advance to the next row in the buffer that satisfies the bound condition.
+   * @return whether there is such a row in the current buffer.
+   */
+  private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
+    var foundMatch: Boolean = false
+    while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
+      setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
+      foundMatch = boundCondition(joinedRow)
+      bufferIndex += 1
+    }
+    foundMatch
+  }
+
+  override def advanceNext(): Boolean = {
+    val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream()
+    if (r) numOutputRows += 1
+    r
+  }
+
+  override def getRow: InternalRow = resultProj(joinedRow)
+}
+
+private class SortMergeFullOuterJoinScanner(
+    leftKeyGenerator: Projection,
+    rightKeyGenerator: Projection,
+    keyOrdering: Ordering[InternalRow],
+    leftIter: RowIterator,
+    rightIter: RowIterator,
+    boundCondition: InternalRow => Boolean,
+    leftNullRow: InternalRow,
+    rightNullRow: InternalRow)  {
+  private[this] val joinedRow: JoinedRow = new JoinedRow()
+  private[this] var leftRow: InternalRow = _
+  private[this] var leftRowKey: InternalRow = _
+  private[this] var rightRow: InternalRow = _
+  private[this] var rightRowKey: InternalRow = _
+
+  private[this] var leftIndex: Int = 0
+  private[this] var rightIndex: Int = 0
+  private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+  private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+  private[this] var leftMatched: BitSet = new BitSet(1)
+  private[this] var rightMatched: BitSet = new BitSet(1)
+
+  advancedLeft()
+  advancedRight()
+
+  // --- Private methods --------------------------------------------------------------------------
+
+  /**
+   * Advance the left iterator and compute the new row's join key.
+   * @return true if the left iterator returned a row and false otherwise.
+   */
+  private def advancedLeft(): Boolean = {
+    if (leftIter.advanceNext()) {
+      leftRow = leftIter.getRow
+      leftRowKey = leftKeyGenerator(leftRow)
+      true
+    } else {
+      leftRow = null
+      leftRowKey = null
+      false
+    }
+  }
+
+  /**
+   * Advance the right iterator and compute the new row's join key.
+   * @return true if the right iterator returned a row and false otherwise.
+   */
+  private def advancedRight(): Boolean = {
+    if (rightIter.advanceNext()) {
+      rightRow = rightIter.getRow
+      rightRowKey = rightKeyGenerator(rightRow)
+      true
+    } else {
+      rightRow = null
+      rightRowKey = null
+      false
+    }
+  }
+
+  /**
+   * Populate the left and right buffers with rows matching the provided key.
+   * This consumes rows from both iterators until their keys are different from the matching key.
+   */
+  private def findMatchingRows(matchingKey: InternalRow): Unit = {
+    leftMatches.clear()
+    rightMatches.clear()
+    leftIndex = 0
+    rightIndex = 0
+
+    while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) {
+      leftMatches += leftRow.copy()
+      advancedLeft()
+    }
+    while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) {
+      rightMatches += rightRow.copy()
+      advancedRight()
+    }
+
+    if (leftMatches.size <= leftMatched.capacity) {
+      leftMatched.clear()
+    } else {
+      leftMatched = new BitSet(leftMatches.size)
+    }
+    if (rightMatches.size <= rightMatched.capacity) {
+      rightMatched.clear()
+    } else {
+      rightMatched = new BitSet(rightMatches.size)
+    }
+  }
+
+  /**
+   * Scan the left and right buffers for the next valid match.
+   *
+   * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers.
+   * If a left row has no valid matches on the right, or a right row has no valid matches on the
+   * left, then the row is joined with the null row and the result is considered a valid match.
+   *
+   * @return true if a valid match is found, false otherwise.
+   */
+  private def scanNextInBuffered(): Boolean = {
+    while (leftIndex < leftMatches.size) {
+      while (rightIndex < rightMatches.size) {
+        joinedRow(leftMatches(leftIndex), rightMatches(rightIndex))
+        if (boundCondition(joinedRow)) {
+          leftMatched.set(leftIndex)
+          rightMatched.set(rightIndex)
+          rightIndex += 1
+          return true
+        }
+        rightIndex += 1
+      }
+      rightIndex = 0
+      if (!leftMatched.get(leftIndex)) {
+        // the left row has never matched any right row, join it with null row
+        joinedRow(leftMatches(leftIndex), rightNullRow)
+        leftIndex += 1
+        return true
+      }
+      leftIndex += 1
+    }
+
+    while (rightIndex < rightMatches.size) {
+      if (!rightMatched.get(rightIndex)) {
+        // the right row has never matched any left row, join it with null row
+        joinedRow(leftNullRow, rightMatches(rightIndex))
+        rightIndex += 1
+        return true
+      }
+      rightIndex += 1
+    }
+
+    // There are no more valid matches in the left and right buffers
+    false
+  }
+
+  // --- Public methods --------------------------------------------------------------------------
+
+  def getJoinedRow(): JoinedRow = joinedRow
+
+  def advanceNext(): Boolean = {
+    // If we already buffered some matching rows, use them directly
+    if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) {
+      if (scanNextInBuffered()) {
+        return true
+      }
+    }
+
+    if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) {
+      joinedRow(leftRow.copy(), rightNullRow)
+      advancedLeft()
+      true
+    } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) {
+      joinedRow(leftNullRow, rightRow.copy())
+      advancedRight()
+      true
+    } else if (leftRow != null && rightRow != null) {
+      // Both rows are present and neither have null values,
+      // so we populate the buffers with rows matching the next key
+      val comp = keyOrdering.compare(leftRowKey, rightRowKey)
+      if (comp <= 0) {
+        findMatchingRows(leftRowKey.copy())
+      } else {
+        findMatchingRows(rightRowKey.copy())
+      }
+      scanNextInBuffered()
+      true
+    } else {
+      // Both iterators have been consumed
+      false
+    }
+  }
+}
+
+private class FullOuterIterator(
+    smjScanner: SortMergeFullOuterJoinScanner,
+    resultProj: InternalRow => InternalRow,
+    numRows: LongSQLMetric) extends RowIterator {
+  private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
+
+  override def advanceNext(): Boolean = {
+    val r = smjScanner.advanceNext()
+    if (r) numRows += 1
+    r
+  }
+
+  override def getRow: InternalRow = resultProj(joinedRow)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index c9a1459..b71f333 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchange
  * This operator will be used when a logical `Limit` operation is the final operator in an
  * logical plan, which happens when the user is collecting results back to the driver.
  */
-case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
+case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode {
   override def output: Seq[Attribute] = child.output
   override def outputPartitioning: Partitioning = SinglePartition
   override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
@@ -46,9 +46,10 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
 }
 
 /**
- * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]].
+ * Helper trait which defines methods that are shared by both
+ * [[LocalLimitExec]] and [[GlobalLimitExec]].
  */
-trait BaseLimit extends UnaryNode with CodegenSupport {
+trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
   val limit: Int
   override def output: Seq[Attribute] = child.output
   override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -91,29 +92,29 @@ trait BaseLimit extends UnaryNode with CodegenSupport {
 /**
  * Take the first `limit` elements of each child partition, but do not collect or shuffle them.
  */
-case class LocalLimit(limit: Int, child: SparkPlan) extends BaseLimit {
+case class LocalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
   override def outputOrdering: Seq[SortOrder] = child.outputOrdering
 }
 
 /**
  * Take the first `limit` elements of the child's single output partition.
  */
-case class GlobalLimit(limit: Int, child: SparkPlan) extends BaseLimit {
+case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
   override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil
 }
 
 /**
  * Take the first limit elements as defined by the sortOrder, and do projection if needed.
- * This is logically equivalent to having a Limit operator after a [[Sort]] operator,
- * or having a [[Project]] operator between them.
+ * This is logically equivalent to having a Limit operator after a [[SortExec]] operator,
+ * or having a [[ProjectExec]] operator between them.
  * This could have been named TopK, but Spark's top operator does the opposite in ordering
  * so we name it TakeOrdered to avoid confusion.
  */
-case class TakeOrderedAndProject(
+case class TakeOrderedAndProjectExec(
     limit: Int,
     sortOrder: Seq[SortOrder],
     projectList: Option[Seq[NamedExpression]],
-    child: SparkPlan) extends UnaryNode {
+    child: SparkPlan) extends UnaryExecNode {
 
   override def output: Seq[Attribute] = {
     projectList.map(_.map(_.toAttribute)).getOrElse(child.output)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message