spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject git commit: [SQL] Pass SQLContext instead of SparkContext into physical operators.
Date Sat, 21 Jun 2014 05:49:58 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.0 36668662f -> 1829ec411


[SQL] Pass SQLContext instead of SparkContext into physical operators.

This makes it easier to use config options in operators.

Author: Reynold Xin <rxin@apache.org>

Closes #1164 from rxin/sqlcontext and squashes the following commits:

797b2fd [Reynold Xin] Pass SQLContext instead of SparkContext into physical operators.

(cherry picked from commit ca5d8b5904dc6dd5b691af506d3a842e508b3673)
Signed-off-by: Reynold Xin <rxin@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1829ec41
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1829ec41
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1829ec41

Branch: refs/heads/branch-1.0
Commit: 1829ec4111d80fe4fb177faef33b46a6b80fd88d
Parents: 3666866
Author: Reynold Xin <rxin@apache.org>
Authored: Fri Jun 20 22:49:48 2014 -0700
Committer: Reynold Xin <rxin@apache.org>
Committed: Fri Jun 20 22:49:55 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/SQLContext.scala |  4 +++-
 .../apache/spark/sql/execution/Aggregate.scala  |  5 +++--
 .../spark/sql/execution/SparkStrategies.scala   | 22 ++++++++++----------
 .../spark/sql/execution/basicOperators.scala    | 20 ++++++++++--------
 .../org/apache/spark/sql/execution/joins.scala  | 21 ++++++++++---------
 .../sql/parquet/ParquetTableOperations.scala    | 21 ++++++++++---------
 .../spark/sql/parquet/ParquetQuerySuite.scala   |  2 +-
 7 files changed, 51 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1829ec41/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index ab376e5..c60af28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -221,7 +221,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
   }
 
   protected[sql] class SparkPlanner extends SparkStrategies {
-    val sparkContext = self.sparkContext
+    val sparkContext: SparkContext = self.sparkContext
+
+    val sqlContext: SQLContext = self
 
     def numPartitions = self.numShufflePartitions
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1829ec41/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 34d88fe..d85d2d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -24,6 +24,7 @@ import org.apache.spark.SparkContext
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.SQLContext
 
 /**
  * :: DeveloperApi ::
@@ -41,7 +42,7 @@ case class Aggregate(
     partial: Boolean,
     groupingExpressions: Seq[Expression],
     aggregateExpressions: Seq[NamedExpression],
-    child: SparkPlan)(@transient sc: SparkContext)
+    child: SparkPlan)(@transient sqlContext: SQLContext)
   extends UnaryNode with NoBind {
 
   override def requiredChildDistribution =
@@ -55,7 +56,7 @@ case class Aggregate(
       }
     }
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   // HACK: Generators don't correctly preserve their output through serializations so we
grab
   // out child's output attributes statically here.

http://git-wip-us.apache.org/repos/asf/spark/blob/1829ec41/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 4694f25..bd8ae4c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -40,7 +40,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
       // no predicate can be evaluated by matching hash keys
       case logical.Join(left, right, LeftSemi, condition) =>
         execution.LeftSemiJoinBNL(
-          planLater(left), planLater(right), condition)(sparkContext) :: Nil
+          planLater(left), planLater(right), condition)(sqlContext) :: Nil
       case _ => Nil
     }
   }
@@ -103,7 +103,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
               partial = true,
               groupingExpressions,
               partialComputation,
-              planLater(child))(sparkContext))(sparkContext) :: Nil
+              planLater(child))(sqlContext))(sqlContext) :: Nil
         } else {
           Nil
         }
@@ -115,7 +115,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case logical.Join(left, right, joinType, condition) =>
         execution.BroadcastNestedLoopJoin(
-          planLater(left), planLater(right), joinType, condition)(sparkContext) :: Nil
+          planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil
       case _ => Nil
     }
   }
@@ -143,7 +143,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
   object TakeOrdered extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
-        execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil
+        execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil
       case _ => Nil
     }
   }
@@ -155,9 +155,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
         val relation =
           ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
         // Note: overwrite=false because otherwise the metadata we just created will be deleted
-        InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sparkContext)
:: Nil
+        InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) ::
Nil
       case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
-        InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil
+        InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil
       case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation)
=>
         val prunePushedDownFilters =
           if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED,
true)) {
@@ -186,7 +186,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
           projectList,
           filters,
           prunePushedDownFilters,
-          ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil
+          ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil
 
       case _ => Nil
     }
@@ -211,7 +211,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case logical.Distinct(child) =>
         execution.Aggregate(
-          partial = false, child.output, child.output, planLater(child))(sparkContext) ::
Nil
+          partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil
       case logical.Sort(sortExprs, child) =>
         // This sort is a global sort. Its requiredDistribution will be an OrderedDistribution.
         execution.Sort(sortExprs, global = true, planLater(child)):: Nil
@@ -224,7 +224,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
       case logical.Filter(condition, child) =>
         execution.Filter(condition, planLater(child)) :: Nil
       case logical.Aggregate(group, agg, child) =>
-        execution.Aggregate(partial = false, group, agg, planLater(child))(sparkContext)
:: Nil
+        execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) ::
Nil
       case logical.Sample(fraction, withReplacement, seed, child) =>
         execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil
       case logical.LocalRelation(output, data) =>
@@ -233,9 +233,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
             new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row))
         execution.ExistingRdd(output, dataAsRdd) :: Nil
       case logical.Limit(IntegerLiteral(limit), child) =>
-        execution.Limit(limit, planLater(child))(sparkContext) :: Nil
+        execution.Limit(limit, planLater(child))(sqlContext) :: Nil
       case Unions(unionChildren) =>
-        execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil
+        execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil
       case logical.Generate(generator, join, outer, _, child) =>
         execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
       case logical.NoRelation =>

http://git-wip-us.apache.org/repos/asf/spark/blob/1829ec41/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 8969794..18f4a58 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,8 +20,9 @@ package org.apache.spark.sql.execution
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
+import org.apache.spark.{HashPartitioner, SparkConf}
 import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
@@ -70,12 +71,12 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long,
child:
  * :: DeveloperApi ::
  */
 @DeveloperApi
-case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan
{
+case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan
{
   // TODO: attributes output by union should be distinct for nullability purposes
   override def output = children.head.output
-  override def execute() = sc.union(children.map(_.execute()))
+  override def execute() = sqlContext.sparkContext.union(children.map(_.execute()))
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 }
 
 /**
@@ -87,11 +88,12 @@ case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext)
extends
  * data to a single partition to compute the global limit.
  */
 @DeveloperApi
-case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode
{
+case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)
+  extends UnaryNode {
   // TODO: Implement a partition local limit, and use a strategy to generate the proper limit
plan:
   // partition local limit -> exchange into one partition -> partition local limit
again
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   override def output = child.output
 
@@ -117,8 +119,8 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext)
exte
  */
 @DeveloperApi
 case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
-                      (@transient sc: SparkContext) extends UnaryNode {
-  override def otherCopyArgs = sc :: Nil
+                      (@transient sqlContext: SQLContext) extends UnaryNode {
+  override def otherCopyArgs = sqlContext :: Nil
 
   override def output = child.output
 
@@ -129,7 +131,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
 
   // TODO: Terminal split should be implemented differently from non-terminal split.
   // TODO: Pick num splits based on |limit|.
-  override def execute() = sc.makeRDD(executeCollect(), 1)
+  override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/1829ec41/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 8d7a5ba..84bdde3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -19,9 +19,8 @@ package org.apache.spark.sql.execution
 
 import scala.collection.mutable.{ArrayBuffer, BitSet}
 
-import org.apache.spark.SparkContext
-
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
@@ -200,13 +199,13 @@ case class LeftSemiJoinHash(
 @DeveloperApi
 case class LeftSemiJoinBNL(
     streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
-    (@transient sc: SparkContext)
+    (@transient sqlContext: SQLContext)
   extends BinaryNode {
   // TODO: Override requiredChildDistribution.
 
   override def outputPartitioning: Partitioning = streamed.outputPartitioning
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   def output = left.output
 
@@ -223,7 +222,8 @@ case class LeftSemiJoinBNL(
 
 
   def execute() = {
-    val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+    val broadcastedRelation =
+      sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
 
     streamed.execute().mapPartitions { streamedIter =>
       val joinedRow = new JoinedRow
@@ -263,13 +263,13 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends
BinaryNod
 @DeveloperApi
 case class BroadcastNestedLoopJoin(
     streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
-    (@transient sc: SparkContext)
+    (@transient sqlContext: SQLContext)
   extends BinaryNode {
   // TODO: Override requiredChildDistribution.
 
   override def outputPartitioning: Partitioning = streamed.outputPartitioning
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   def output = left.output ++ right.output
 
@@ -286,7 +286,8 @@ case class BroadcastNestedLoopJoin(
 
 
   def execute() = {
-    val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+    val broadcastedRelation =
+      sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
 
     val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
       val matchedRows = new ArrayBuffer[Row]
@@ -337,7 +338,7 @@ case class BroadcastNestedLoopJoin(
       }
 
     // TODO: Breaks lineage.
-    sc.union(
-      streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches))
+    sqlContext.sparkContext.union(
+      streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1829ec41/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 624f2e2..ade823b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -33,10 +33,10 @@ import parquet.hadoop.util.ContextUtil
 import parquet.io.InvalidRecordException
 import parquet.schema.MessageType
 
-import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext}
+import org.apache.spark.{Logging, SerializableWritable, TaskContext}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
-import org.apache.spark.sql.catalyst.types.StructType
 import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
 
 /**
@@ -49,10 +49,11 @@ case class ParquetTableScan(
     output: Seq[Attribute],
     relation: ParquetRelation,
     columnPruningPred: Seq[Expression])(
-    @transient val sc: SparkContext)
+    @transient val sqlContext: SQLContext)
   extends LeafNode {
 
   override def execute(): RDD[Row] = {
+    val sc = sqlContext.sparkContext
     val job = new Job(sc.hadoopConfiguration)
     ParquetInputFormat.setReadSupportClass(
       job,
@@ -93,7 +94,7 @@ case class ParquetTableScan(
       .filter(_ != null) // Parquet's record filters may produce null values
   }
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   /**
    * Applies a (candidate) projection.
@@ -104,7 +105,7 @@ case class ParquetTableScan(
   def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = {
     val success = validateProjection(prunedAttributes)
     if (success) {
-      ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sc)
+      ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext)
     } else {
       sys.error("Warning: Could not validate Parquet schema projection in pruneColumns")
       this
@@ -152,7 +153,7 @@ case class InsertIntoParquetTable(
     relation: ParquetRelation,
     child: SparkPlan,
     overwrite: Boolean = false)(
-    @transient val sc: SparkContext)
+    @transient val sqlContext: SQLContext)
   extends UnaryNode with SparkHadoopMapReduceUtil {
 
   /**
@@ -168,7 +169,7 @@ case class InsertIntoParquetTable(
     val childRdd = child.execute()
     assert(childRdd != null)
 
-    val job = new Job(sc.hadoopConfiguration)
+    val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
 
     val writeSupport =
       if (child.output.map(_.dataType).forall(_.isPrimitive)) {
@@ -204,7 +205,7 @@ case class InsertIntoParquetTable(
 
   override def output = child.output
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   /**
    * Stores the given Row RDD as a Hadoop file.
@@ -231,7 +232,7 @@ case class InsertIntoParquetTable(
     val wrappedConf = new SerializableWritable(job.getConfiguration)
     val formatter = new SimpleDateFormat("yyyyMMddHHmm")
     val jobtrackerID = formatter.format(new Date())
-    val stageId = sc.newRddId()
+    val stageId = sqlContext.sparkContext.newRddId()
 
     val taskIdOffset =
       if (overwrite) {
@@ -270,7 +271,7 @@ case class InsertIntoParquetTable(
     val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
     val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
     jobCommitter.setupJob(jobTaskContext)
-    sc.runJob(rdd, writeShard _)
+    sqlContext.sparkContext.runJob(rdd, writeShard _)
     jobCommitter.commitJob(jobTaskContext)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1829ec41/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 7714eb1..2ca0c1c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -166,7 +166,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
     val scanner = new ParquetTableScan(
       ParquetTestData.testData.output,
       ParquetTestData.testData,
-      Seq())(TestSQLContext.sparkContext)
+      Seq())(TestSQLContext)
     val projected = scanner.pruneColumns(ParquetTypesConverter
       .convertToAttributes(MessageTypeParser
       .parseMessageType(ParquetTestData.subTestSchema)))


Mime
View raw message