Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 5309411A73 for ; Wed, 30 Jul 2014 03:58:37 +0000 (UTC) Received: (qmail 5666 invoked by uid 500); 30 Jul 2014 03:58:37 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 5620 invoked by uid 500); 30 Jul 2014 03:58:37 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@spark.apache.org Delivered-To: mailing list commits@spark.apache.org Received: (qmail 5611 invoked by uid 99); 30 Jul 2014 03:58:37 -0000 Received: from tyr.zones.apache.org (HELO tyr.zones.apache.org) (140.211.11.114) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 30 Jul 2014 03:58:37 +0000 Received: by tyr.zones.apache.org (Postfix, from userid 65534) id B276E910F81; Wed, 30 Jul 2014 03:58:36 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: marmbrus@apache.org To: commits@spark.apache.org Date: Wed, 30 Jul 2014 03:58:36 -0000 Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: [1/2] [SPARK-2054][SQL] Code Generation for Expression Evaluation Repository: spark Updated Branches: refs/heads/master 22649b6cd -> 84467468d http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala new file mode 100644 index 0000000..245a2e1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use code generation for evaluation. + */ +class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + val plan = try { + GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)() + } catch { + case e: Throwable => + val evaluated = GenerateProjection.expressionEvaluator(expression) + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow).apply(0) + if(actual != expected) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + + + test("multithreaded eval") { + import scala.concurrent._ + import ExecutionContext.Implicits.global + import scala.concurrent.duration._ + + val futures = (1 to 20).map { _ => + future { + GeneratePredicate(EqualTo(Literal(1), Literal(1))) + GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil) + } + } + + futures.foreach(Await.result(_, 10.seconds)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala new file mode 100644 index 0000000..887aabb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ + +/** + * Overrides our expression evaluation tests to use generated code on mutable rows. + */ +class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { + override def checkEvaluation( + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { + lazy val evaluated = GenerateProjection.expressionEvaluator(expression) + + val plan = try { + GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code.mkString("\n")} + |$e + """.stripMargin) + } + + val actual = plan(inputRow) + val expectedRow = new GenericRow(Array[Any](expected)) + if (actual.hashCode() != expectedRow.hashCode()) { + fail( + s""" + |Mismatched hashCodes for values: $actual, $expectedRow + |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} + |${evaluated.code.mkString("\n")} + """.stripMargin) + } + if (actual != expectedRow) { + val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 4896f1b..e2ae0d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -27,9 +27,9 @@ class CombiningLimitsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("Combine Limit", FixedPoint(2), + Batch("Combine Limit", FixedPoint(10), CombineLimits) :: - Batch("Constant Folding", FixedPoint(3), + Batch("Constant Folding", FixedPoint(10), NullPropagation, ConstantFolding, BooleanSimplification) :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 5d85a0f..2d40707 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -24,8 +24,11 @@ import scala.collection.JavaConverters._ object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" - val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" + val AUTO_CONVERT_JOIN_SIZE = "spark.sql.auto.convert.join.size" + val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" + val JOIN_BROADCAST_TABLES = "spark.sql.join.broadcastTables" + val CODEGEN_ENABLED = "spark.sql.codegen" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -57,6 +60,18 @@ trait SQLConf { private[spark] def numShufflePartitions: Int = get(SHUFFLE_PARTITIONS, "200").toInt /** + * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode + * that evaluates expressions found in queries. In general this custom code runs much faster + * than interpreted evaluation, but there are significant start-up costs due to compilation. + * As a result codegen is only benificial when queries run for a long time, or when the same + * expressions are used multiple times. + * + * Defaults to false as this feature is currently experimental. + */ + private[spark] def codegenEnabled: Boolean = + if (get(CODEGEN_ENABLED, "false") == "true") true else false + + /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to * a broadcast value during the physical executions of join operations. Setting this to -1 * effectively disables auto conversion. @@ -111,5 +126,5 @@ trait SQLConf { private[spark] def clear() { settings.clear() } - } + http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c2bdef7..e4b6810 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -94,7 +94,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration))) + new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) /** * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. @@ -160,7 +160,8 @@ class SQLContext(@transient val sparkContext: SparkContext) conf: Configuration = new Configuration()): SchemaRDD = { new SchemaRDD( this, - ParquetRelation.createEmpty(path, ScalaReflection.attributesFor[A], allowExisting, conf)) + ParquetRelation.createEmpty( + path, ScalaReflection.attributesFor[A], allowExisting, conf, this)) } /** @@ -228,12 +229,14 @@ class SQLContext(@transient val sparkContext: SparkContext) val sqlContext: SQLContext = self + def codegenEnabled = self.codegenEnabled + def numPartitions = self.numShufflePartitions val strategies: Seq[Strategy] = CommandStrategy(self) :: TakeOrdered :: - PartialAggregation :: + HashAggregation :: LeftSemiJoin :: HashJoin :: InMemoryScans :: @@ -291,27 +294,30 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[Row], 1) /** - * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and - * inserting shuffle operations as needed. + * Prepares a planned SparkPlan for execution by inserting shuffle operations as needed. */ @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: - Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil + Batch("Add exchange", Once, AddExchange(self)) :: Nil } /** + * :: DeveloperApi :: * The primary workflow for executing relational queries using Spark. Designed to allow easy * access to the intermediate phases of query execution for developers. */ + @DeveloperApi protected abstract class QueryExecution { def logical: LogicalPlan lazy val analyzed = analyzer(logical) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... - lazy val sparkPlan = planner(optimizedPlan).next() + lazy val sparkPlan = { + SparkPlan.currentContext.set(self) + planner(optimizedPlan).next() + } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) @@ -331,6 +337,9 @@ class SQLContext(@transient val sparkContext: SparkContext) |${stringOrError(optimizedPlan)} |== Physical Plan == |${stringOrError(executedPlan)} + |Code Generation: ${executedPlan.codegenEnabled} + |== RDD == + |${stringOrError(toRdd.toDebugString)} """.stripMargin.trim } http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 806097c..85726ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -72,7 +72,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { conf: Configuration = new Configuration()): JavaSchemaRDD = { new JavaSchemaRDD( sqlContext, - ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf)) + ParquetRelation.createEmpty(path, getSchema(beanClass), allowExisting, conf, sqlContext)) } /** @@ -101,7 +101,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { def parquetFile(path: String): JavaSchemaRDD = new JavaSchemaRDD( sqlContext, - ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration))) + ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext)) /** * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]]. http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index c1ced8b..463a1d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -42,8 +42,8 @@ case class Aggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sqlContext: SQLContext) - extends UnaryNode with NoBind { + child: SparkPlan) + extends UnaryNode { override def requiredChildDistribution = if (partial) { @@ -56,8 +56,6 @@ case class Aggregate( } } - override def otherCopyArgs = sqlContext :: Nil - // HACK: Generators don't correctly preserve their output through serializations so we grab // out child's output attributes statically here. private[this] val childOutput = child.output @@ -138,7 +136,7 @@ case class Aggregate( i += 1 } } - val resultProjection = new Projection(resultExpressions, computedSchema) + val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) val aggregateResults = new GenericMutableRow(computedAggregates.length) var i = 0 @@ -152,7 +150,7 @@ case class Aggregate( } else { child.execute().mapPartitions { iter => val hashTable = new HashMap[Row, Array[AggregateFunction]] - val groupingProjection = new MutableProjection(groupingExpressions, childOutput) + val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) var currentRow: Row = null while (iter.hasNext) { @@ -175,7 +173,8 @@ case class Aggregate( private[this] val hashTableIter = hashTable.entrySet().iterator() private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) private[this] val resultProjection = - new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2)) + new InterpretedMutableProjection( + resultExpressions, computedSchema ++ namedGroups.map(_._2)) private[this] val joinedRow = new JoinedRow override final def hasNext: Boolean = hashTableIter.hasNext http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 00010ef..392a7f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -22,7 +22,7 @@ import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{NoBind, MutableProjection, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -31,7 +31,7 @@ import org.apache.spark.util.MutablePair * :: DeveloperApi :: */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode with NoBind { +case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { override def outputPartitioning = newPartitioning @@ -42,7 +42,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. val rdd = child.execute().mapPartitions { iter => - val hashExpressions = new MutableProjection(expressions, child.output) + @transient val hashExpressions = + newMutableProjection(expressions, child.output)() + val mutablePair = new MutablePair[Row, Row]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 47b3d00..c386fd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -47,23 +47,26 @@ case class Generate( } } - override def output = + // This must be a val since the generator output expr ids are not preserved by serialization. + override val output = if (join) child.output ++ generatorOutput else generatorOutput + val boundGenerator = BindReferences.bindReference(generator, child.output) + override def execute() = { if (join) { child.execute().mapPartitions { iter => val nullValues = Seq.fill(generator.output.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = - new Projection(child.output ++ nullValues, child.output) + newProjection(child.output ++ nullValues, child.output) val joinProjection = - new Projection(child.output ++ generator.output, child.output ++ generator.output) + newProjection(child.output ++ generator.output, child.output ++ generator.output) val joinedRow = new JoinedRow iter.flatMap {row => - val outputRows = generator.eval(row) + val outputRows = boundGenerator.eval(row) if (outer && outputRows.isEmpty) { outerProjection(row) :: Nil } else { @@ -72,7 +75,7 @@ case class Generate( } } } else { - child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row))) + child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row))) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala new file mode 100644 index 0000000..4a26934 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.types._ + +case class AggregateEvaluation( + schema: Seq[Attribute], + initialValues: Seq[Expression], + update: Seq[Expression], + result: Expression) + +/** + * :: DeveloperApi :: + * Alternate version of aggregation that leverages projection and thus code generation. + * Aggregations are converted into a set of projections from a aggregation buffer tuple back onto + * itself. Currently only used for simple aggregations like SUM, COUNT, or AVERAGE are supported. + * + * @param partial if true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param child the input data source. + */ +@DeveloperApi +case class GeneratedAggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution = + if (partial) { + UnspecifiedDistribution :: Nil + } else { + if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def output = aggregateExpressions.map(_.toAttribute) + + override def execute() = { + val aggregatesToCompute = aggregateExpressions.flatMap { a => + a.collect { case agg: AggregateExpression => agg} + } + + val computeFunctions = aggregatesToCompute.map { + case c @ Count(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val initialValue = Literal(0L) + val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val result = currentCount + + AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case Sum(expr) => + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() + val initialValue = Cast(Literal(0L), expr.dataType) + + // Coalasce avoids double calculation... + // but really, common sub expression elimination would be better.... + val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + val result = currentSum + + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + + case a @ Average(expr) => + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() + val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() + val initialCount = Literal(0L) + val initialSum = Cast(Literal(0L), expr.dataType) + val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) + + val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType)) + + AggregateEvaluation( + currentCount :: currentSum :: Nil, + initialCount :: initialSum :: Nil, + updateCount :: updateSum :: Nil, + result + ) + } + + val computationSchema = computeFunctions.flatMap(_.schema) + + val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map { + case (agg, func) => agg.id -> func.result + }.toMap + + val namedGroups = groupingExpressions.zipWithIndex.map { + case (ne: NamedExpression, _) => (ne, ne) + case (e, i) => (e, Alias(e, s"GroupingExpr$i")()) + } + + val groupMap: Map[Expression, Attribute] = + namedGroups.map { case (k, v) => k -> v.toAttribute}.toMap + + // The set of expressions that produce the final output given the aggregation buffer and the + // grouping expressions. + val resultExpressions = aggregateExpressions.map(_.transform { + case e: Expression if resultMap.contains(e.id) => resultMap(e.id) + case e: Expression if groupMap.contains(e) => groupMap(e) + }) + + child.execute().mapPartitions { iter => + // Builds a new custom class for holding the results of aggregation for a group. + val initialValues = computeFunctions.flatMap(_.initialValues) + val newAggregationBuffer = newProjection(initialValues, child.output) + log.info(s"Initial values: ${initialValues.mkString(",")}") + + // A projection that computes the group given an input tuple. + val groupProjection = newProjection(groupingExpressions, child.output) + log.info(s"Grouping Projection: ${groupingExpressions.mkString(",")}") + + // A projection that is used to update the aggregate values for a group given a new tuple. + // This projection should be targeted at the current values for the group and then applied + // to a joined row of the current values with the new input row. + val updateExpressions = computeFunctions.flatMap(_.update) + val updateSchema = computeFunctions.flatMap(_.schema) ++ child.output + val updateProjection = newMutableProjection(updateExpressions, updateSchema)() + log.info(s"Update Expressions: ${updateExpressions.mkString(",")}") + + // A projection that produces the final result, given a computation. + val resultProjectionBuilder = + newMutableProjection( + resultExpressions, + (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) + log.info(s"Result Projection: ${resultExpressions.mkString(",")}") + + val joinedRow = new JoinedRow + + if (groupingExpressions.isEmpty) { + // TODO: Codegening anything other than the updateProjection is probably over kill. + val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + var currentRow: Row = null + updateProjection.target(buffer) + + while (iter.hasNext) { + currentRow = iter.next() + updateProjection(joinedRow(buffer, currentRow)) + } + + val resultProjection = resultProjectionBuilder() + Iterator(resultProjection(buffer)) + } else { + val buffers = new java.util.HashMap[Row, MutableRow]() + + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + val currentGroup = groupProjection(currentRow) + var currentBuffer = buffers.get(currentGroup) + if (currentBuffer == null) { + currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow] + buffers.put(currentGroup, currentBuffer) + } + // Target the projection at the current aggregation buffer and then project the updated + // values. + updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow)) + } + + new Iterator[Row] { + private[this] val resultIterator = buffers.entrySet.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext = resultIterator.hasNext + + def next() = { + val currentGroup = resultIterator.next() + resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue)) + } + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 77c874d..21cbbc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -18,22 +18,55 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Logging, Row, SQLContext} + + +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ + +object SparkPlan { + protected[sql] val currentContext = new ThreadLocal[SQLContext]() +} + /** * :: DeveloperApi :: */ @DeveloperApi -abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { +abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { self: Product => + /** + * A handle to the SQL Context that was used to create this plan. Since many operators need + * access to the sqlContext for RDD operations or configuration this field is automatically + * populated by the query planning infrastructure. + */ + @transient + protected val sqlContext = SparkPlan.currentContext.get() + + protected def sparkContext = sqlContext.sparkContext + + // sqlContext will be null when we are being deserialized on the slaves. In this instance + // the value of codegenEnabled will be set by the desserializer after the constructor has run. + val codegenEnabled: Boolean = if (sqlContext != null) { + sqlContext.codegenEnabled + } else { + false + } + + /** Overridden make copy also propogates sqlContext to copied plan. */ + override def makeCopy(newArgs: Array[AnyRef]): this.type = { + SparkPlan.currentContext.set(sqlContext) + super.makeCopy(newArgs) + } + // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! @@ -51,8 +84,46 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { */ def executeCollect(): Array[Row] = execute().map(_.copy()).collect() - protected def buildRow(values: Seq[Any]): Row = - new GenericRow(values.toArray) + protected def newProjection( + expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { + log.debug( + s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if (codegenEnabled) { + GenerateProjection(expressions, inputSchema) + } else { + new InterpretedProjection(expressions, inputSchema) + } + } + + protected def newMutableProjection( + expressions: Seq[Expression], + inputSchema: Seq[Attribute]): () => MutableProjection = { + log.debug( + s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") + if(codegenEnabled) { + GenerateMutableProjection(expressions, inputSchema) + } else { + () => new InterpretedMutableProjection(expressions, inputSchema) + } + } + + + protected def newPredicate( + expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { + if (codegenEnabled) { + GeneratePredicate(expression, inputSchema) + } else { + InterpretedPredicate(expression, inputSchema) + } + } + + protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { + if (codegenEnabled) { + GenerateOrdering(order, inputSchema) + } else { + new RowOrdering(order, inputSchema) + } + } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 404d48a..5f1fe99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import scala.util.Try - import org.apache.spark.sql.{SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ @@ -41,7 +39,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition)(sqlContext) :: Nil + planLater(left), planLater(right), condition) :: Nil case _ => Nil } } @@ -60,6 +58,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * will instead be used to decide the build side in a [[execution.ShuffledHashJoin]]. */ object HashJoin extends Strategy with PredicateHelper { + private[this] def makeBroadcastHashJoin( leftKeys: Seq[Expression], rightKeys: Seq[Expression], @@ -68,24 +67,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition: Option[Expression], side: BuildSide) = { val broadcastHashJoin = execution.BroadcastHashJoin( - leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext) + leftKeys, rightKeys, side, planLater(left), planLater(right)) condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if Try(sqlContext.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) => + if sqlContext.autoBroadcastJoinThreshold > 0 && + right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if Try(sqlContext.autoBroadcastJoinThreshold > 0 && - left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold).getOrElse(false) => + if sqlContext.autoBroadcastJoinThreshold > 0 && + left.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft) case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = - if (Try(right.statistics.sizeInBytes <= left.statistics.sizeInBytes).getOrElse(false)) { + if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { BuildRight } else { BuildLeft @@ -99,65 +98,65 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object PartialAggregation extends Strategy { + object HashAggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. - val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a }) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p }) - - // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[Long, SplitEvaluation] = - partialAggregates.map(a => (a.id, a.asPartial)).toMap - - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - }.toMap + // Aggregations that can be performed in two phases, before and after the shuffle. - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { - case e: Expression if partialEvaluations.contains(e.id) => - partialEvaluations(e.id).finalEvaluation - case e: Expression if namedGroupingExpressions.contains(e) => - namedGroupingExpressions(e).toAttribute - }).asInstanceOf[Seq[NamedExpression]] - - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq - - // Construct two phased aggregation. - execution.Aggregate( + // Cases where all aggregates can be codegened. + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) + if canBeCodeGened( + allAggregates(partialComputation) ++ + allAggregates(rewrittenAggregateExpressions)) && + codegenEnabled => + execution.GeneratedAggregate( partial = false, - namedGroupingExpressions.values.map(_.toAttribute).toSeq, + namedGroupingAttributes, rewrittenAggregateExpressions, - execution.Aggregate( + execution.GeneratedAggregate( partial = true, groupingExpressions, partialComputation, - planLater(child))(sqlContext))(sqlContext) :: Nil - } else { - Nil - } + planLater(child))) :: Nil + + // Cases where some aggregate can not be codegened + case PartialAggregation( + namedGroupingAttributes, + rewrittenAggregateExpressions, + groupingExpressions, + partialComputation, + child) => + execution.Aggregate( + partial = false, + namedGroupingAttributes, + rewrittenAggregateExpressions, + execution.Aggregate( + partial = true, + groupingExpressions, + partialComputation, + planLater(child))) :: Nil + case _ => Nil } + + def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { + case _: Sum | _: Count => false + case _ => true + } + + def allAggregates(exprs: Seq[Expression]) = + exprs.flatMap(_.collect { case a: AggregateExpression => a }) } object BroadcastNestedLoopJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => execution.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil + planLater(left), planLater(right), joinType, condition) :: Nil case _ => Nil } } @@ -176,16 +175,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) - def convertToCatalyst(a: Any): Any = a match { - case s: Seq[Any] => s.map(convertToCatalyst) - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => - execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil + execution.TakeOrdered(limit, order, planLater(child)) :: Nil case _ => Nil } } @@ -195,11 +188,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // TODO: need to support writing to other types of files. Unify the below code paths. case logical.WriteToFile(path, child) => val relation = - ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) + ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext) // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil + InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => - InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil + InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { @@ -228,7 +221,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, prunePushedDownFilters, - ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil + ParquetTableScan(_, relation, filters)) :: Nil case _ => Nil } @@ -266,20 +259,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil + execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => - val dataAsRdd = - sparkContext.parallelize(data.map(r => - new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) - execution.ExistingRdd(output, dataAsRdd) :: Nil + ExistingRdd( + output, + ExistingRdd.productToRowRdd(sparkContext.parallelize(data, numPartitions))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.Limit(limit, planLater(child))(sqlContext) :: Nil + execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => - execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil - case logical.Except(left,right) => - execution.Except(planLater(left),planLater(right)) :: Nil + execution.Union(unionChildren.map(planLater)) :: Nil + case logical.Except(left, right) => + execution.Except(planLater(left), planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil case logical.Generate(generator, join, outer, _, child) => http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 966d8f9..174eda8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -37,9 +37,11 @@ import org.apache.spark.util.MutablePair case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override def output = projectList.map(_.toAttribute) - override def execute() = child.execute().mapPartitions { iter => - @transient val reusableProjection = new MutableProjection(projectList) - iter.map(reusableProjection) + @transient lazy val buildProjection = newMutableProjection(projectList, child.output) + + def execute() = child.execute().mapPartitions { iter => + val resuableProjection = buildProjection() + iter.map(resuableProjection) } } @@ -50,8 +52,10 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def output = child.output - override def execute() = child.execute().mapPartitions { iter => - iter.filter(condition.eval(_).asInstanceOf[Boolean]) + @transient lazy val conditionEvaluator = newPredicate(condition, child.output) + + def execute() = child.execute().mapPartitions { iter => + iter.filter(conditionEvaluator) } } @@ -72,12 +76,10 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: * :: DeveloperApi :: */ @DeveloperApi -case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan { +case class Union(children: Seq[SparkPlan]) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output = children.head.output - override def execute() = sqlContext.sparkContext.union(children.map(_.execute())) - - override def otherCopyArgs = sqlContext :: Nil + override def execute() = sparkContext.union(children.map(_.execute())) } /** @@ -89,13 +91,11 @@ case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) ex * repartition all the data to a single partition to compute the global limit. */ @DeveloperApi -case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext) +case class Limit(limit: Int, child: SparkPlan) extends UnaryNode { // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: // partition local limit -> exchange into one partition -> partition local limit again - override def otherCopyArgs = sqlContext :: Nil - override def output = child.output /** @@ -161,20 +161,18 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) - (@transient sqlContext: SQLContext) extends UnaryNode { - override def otherCopyArgs = sqlContext :: Nil +case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def output = child.output - @transient - lazy val ordering = new RowOrdering(sortOrder) + val ordering = new RowOrdering(sortOrder, child.output) + // TODO: Is this copying for no reason? override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1) + override def execute() = sparkContext.makeRDD(executeCollect(), 1) } /** @@ -189,15 +187,13 @@ case class Sort( override def requiredChildDistribution = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - @transient - lazy val ordering = new RowOrdering(sortOrder) override def execute() = attachTree(this, "sort") { - // TODO: Optimize sorting operation? child.execute() - .mapPartitions( - iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, - preservesPartitioning = true) + .mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + iterator.map(_.copy()).toArray.sorted(ordering).iterator + }, preservesPartitioning = true) } override def output = child.output http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c6fbd6d..5ef46c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -41,13 +41,13 @@ package object debug { */ @DeveloperApi implicit class DebugQuery(query: SchemaRDD) { - def debug(implicit sc: SparkContext): Unit = { + def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[Long]() val debugPlan = plan transform { case s: SparkPlan if !visited.contains(s.id) => visited += s.id - DebugNode(sc, s) + DebugNode(s) } println(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { @@ -57,9 +57,7 @@ package object debug { } } - private[sql] case class DebugNode( - @transient sparkContext: SparkContext, - child: SparkPlan) extends UnaryNode { + private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { def references = Set.empty def output = child.output http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 7d1f11c..2750ddb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -38,6 +38,8 @@ case object BuildLeft extends BuildSide case object BuildRight extends BuildSide trait HashJoin { + self: SparkPlan => + val leftKeys: Seq[Expression] val rightKeys: Seq[Expression] val buildSide: BuildSide @@ -56,9 +58,9 @@ trait HashJoin { def output = left.output ++ right.output - @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) @transient lazy val streamSideKeyGenerator = - () => new MutableProjection(streamedKeys, streamedPlan.output) + newMutableProjection(streamedKeys, streamedPlan.output) def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { // TODO: Use Spark's HashMap implementation. @@ -217,9 +219,8 @@ case class BroadcastHashJoin( rightKeys: Seq[Expression], buildSide: BuildSide, left: SparkPlan, - right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin { + right: SparkPlan) extends BinaryNode with HashJoin { - override def otherCopyArgs = sqlContext :: Nil override def outputPartitioning: Partitioning = left.outputPartitioning @@ -228,7 +229,7 @@ case class BroadcastHashJoin( @transient lazy val broadcastFuture = future { - sqlContext.sparkContext.broadcast(buildPlan.executeCollect()) + sparkContext.broadcast(buildPlan.executeCollect()) } def execute() = { @@ -248,14 +249,11 @@ case class BroadcastHashJoin( @DeveloperApi case class LeftSemiJoinBNL( streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sqlContext :: Nil - def output = left.output /** The Streamed Relation */ @@ -271,7 +269,7 @@ case class LeftSemiJoinBNL( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow @@ -300,8 +298,14 @@ case class LeftSemiJoinBNL( case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { def output = left.output ++ right.output - def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map { - case (l: Row, r: Row) => buildRow(l ++ r) + def execute() = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.cartesian(rightResults).mapPartitions { iter => + val joinedRow = new JoinedRow + iter.map(r => joinedRow(r._1, r._2)) + } } } @@ -311,14 +315,11 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod @DeveloperApi case class BroadcastNestedLoopJoin( streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) - (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sqlContext :: Nil - override def output = { joinType match { case LeftOuter => @@ -345,13 +346,14 @@ case class BroadcastNestedLoopJoin( def execute() = { val broadcastedRelation = - sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] // TODO: Use Spark's BitSet. val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val rightNulls = new GenericMutableRow(right.output.size) streamedIter.foreach { streamedRow => var i = 0 @@ -361,7 +363,7 @@ case class BroadcastNestedLoopJoin( // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { - matchedRows += buildRow(streamedRow ++ broadcastedRow) + matchedRows += joinedRow(streamedRow, broadcastedRow).copy() matched = true includedBroadcastTuples += i } @@ -369,7 +371,7 @@ case class BroadcastNestedLoopJoin( } if (!matched && (joinType == LeftOuter || joinType == FullOuter)) { - matchedRows += buildRow(streamedRow ++ Array.fill(right.output.size)(null)) + matchedRows += joinedRow(streamedRow, rightNulls).copy() } } Iterator((matchedRows, includedBroadcastTuples)) @@ -383,20 +385,20 @@ case class BroadcastNestedLoopJoin( streamedPlusMatches.map(_._2).reduce(_ ++ _) } + val leftNulls = new GenericMutableRow(left.output.size) val rightOuterMatches: Seq[Row] = if (joinType == RightOuter || joinType == FullOuter) { broadcastedRelation.value.zipWithIndex.filter { case (row, i) => !allIncludedBroadcastTuples.contains(i) }.map { - // TODO: Use projection. - case (row, _) => buildRow(Vector.fill(left.output.size)(null) ++ row) + case (row, _) => new JoinedRow(leftNulls, row) } } else { Vector() } // TODO: Breaks lineage. - sqlContext.sparkContext.union( - streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches)) + sparkContext.union( + streamedPlusMatches.flatMap(_._1), sparkContext.makeRDD(rightOuterMatches)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 8c7dbd5..b3bae5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -46,7 +46,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} */ private[sql] case class ParquetRelation( path: String, - @transient conf: Option[Configuration] = None) + @transient conf: Option[Configuration], + @transient sqlContext: SQLContext) extends LeafNode with MultiInstanceRelation { self: Product => @@ -61,7 +62,7 @@ private[sql] case class ParquetRelation( /** Attributes */ override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) - override def newInstance = ParquetRelation(path).asInstanceOf[this.type] + override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, @@ -70,6 +71,9 @@ private[sql] case class ParquetRelation( p.path == path && p.output == output case _ => false } + + // TODO: Use data from the footers. + override lazy val statistics = Statistics(sizeInBytes = sqlContext.defaultSizeInBytes) } private[sql] object ParquetRelation { @@ -106,13 +110,14 @@ private[sql] object ParquetRelation { */ def create(pathString: String, child: LogicalPlan, - conf: Configuration): ParquetRelation = { + conf: Configuration, + sqlContext: SQLContext): ParquetRelation = { if (!child.resolved) { throw new UnresolvedException[LogicalPlan]( child, "Attempt to create Parquet table from unresolved child (when schema is not available)") } - createEmpty(pathString, child.output, false, conf) + createEmpty(pathString, child.output, false, conf, sqlContext) } /** @@ -127,14 +132,15 @@ private[sql] object ParquetRelation { def createEmpty(pathString: String, attributes: Seq[Attribute], allowExisting: Boolean, - conf: Configuration): ParquetRelation = { + conf: Configuration, + sqlContext: SQLContext): ParquetRelation = { val path = checkPath(pathString, allowExisting, conf) if (conf.get(ParquetOutputFormat.COMPRESSION) == null) { conf.set(ParquetOutputFormat.COMPRESSION, ParquetRelation.defaultCompression.name()) } ParquetRelation.enableLogForwarding() ParquetTypesConverter.writeMetaData(attributes, path, conf) - new ParquetRelation(path.toString, Some(conf)) { + new ParquetRelation(path.toString, Some(conf), sqlContext) { override val output = attributes } } http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ea74320..912a9f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -55,8 +55,7 @@ case class ParquetTableScan( // https://issues.apache.org/jira/browse/SPARK-1367 output: Seq[Attribute], relation: ParquetRelation, - columnPruningPred: Seq[Expression])( - @transient val sqlContext: SQLContext) + columnPruningPred: Seq[Expression]) extends LeafNode { override def execute(): RDD[Row] = { @@ -99,8 +98,6 @@ case class ParquetTableScan( .filter(_ != null) // Parquet's record filters may produce null values } - override def otherCopyArgs = sqlContext :: Nil - /** * Applies a (candidate) projection. * @@ -110,7 +107,7 @@ case class ParquetTableScan( def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { val success = validateProjection(prunedAttributes) if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext) + ParquetTableScan(prunedAttributes, relation, columnPruningPred) } else { sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") this @@ -150,8 +147,7 @@ case class ParquetTableScan( case class InsertIntoParquetTable( relation: ParquetRelation, child: SparkPlan, - overwrite: Boolean = false)( - @transient val sqlContext: SQLContext) + overwrite: Boolean = false) extends UnaryNode with SparkHadoopMapReduceUtil { /** @@ -171,7 +167,7 @@ case class InsertIntoParquetTable( val writeSupport = if (child.output.map(_.dataType).forall(_.isPrimitive)) { - logger.debug("Initializing MutableRowWriteSupport") + log.debug("Initializing MutableRowWriteSupport") classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] } else { classOf[org.apache.spark.sql.parquet.RowWriteSupport] @@ -203,8 +199,6 @@ case class InsertIntoParquetTable( override def output = child.output - override def otherCopyArgs = sqlContext :: Nil - /** * Stores the given Row RDD as a Hadoop file. * http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index d4599da..837ea76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -22,6 +22,7 @@ import java.io.File import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.Job +import org.apache.spark.sql.test.TestSQLContext import parquet.example.data.{GroupWriter, Group} import parquet.example.data.simple.SimpleGroup @@ -103,7 +104,7 @@ private[sql] object ParquetTestData { val testDir = Utils.createTempDir() val testFilterDir = Utils.createTempDir() - lazy val testData = new ParquetRelation(testDir.toURI.toString) + lazy val testData = new ParquetRelation(testDir.toURI.toString, None, TestSQLContext) val testNestedSchema1 = // based on blogpost example, source: @@ -202,8 +203,10 @@ private[sql] object ParquetTestData { val testNestedDir3 = Utils.createTempDir() val testNestedDir4 = Utils.createTempDir() - lazy val testNestedData1 = new ParquetRelation(testNestedDir1.toURI.toString) - lazy val testNestedData2 = new ParquetRelation(testNestedDir2.toURI.toString) + lazy val testNestedData1 = + new ParquetRelation(testNestedDir1.toURI.toString, None, TestSQLContext) + lazy val testNestedData2 = + new ParquetRelation(testNestedDir2.toURI.toString, None, TestSQLContext) def writeFile() = { testDir.delete() http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8e1e197..1fd8d27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -45,6 +45,7 @@ class QueryTest extends PlanTest { |${rdd.queryExecution} |== Exception == |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} """.stripMargin) } http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 215618e..76b1724 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -39,22 +39,22 @@ class PlannerSuite extends FunSuite { test("count is partially aggregated") { val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed - val planned = PartialAggregation(query).head - val aggregations = planned.collect { case a: Aggregate => a } + val planned = HashAggregation(query).head + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } assert(aggregations.size === 2) } test("count distinct is not partially aggregated") { val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } test("mixed aggregates are not partially aggregated") { val query = testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed - val planned = PartialAggregation(query) + val planned = HashAggregation(query) assert(planned.isEmpty) } } http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala index e55648b..2cab5e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.test.TestSQLContext._ * Note: this is only a rough example of how TGFs can be expressed, the final version will likely * involve a lot more sugar for cleaner use in Scala/Java/etc. */ -case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generator { +case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator { def children = input protected def makeOutput() = 'nameAndAge.string :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 3c911e9..561f5b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -25,6 +25,7 @@ import parquet.schema.MessageTypeParser import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job + import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} @@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils @@ -207,10 +209,11 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Projection of simple Parquet file") { + SparkPlan.currentContext.set(TestSQLContext) val scanner = new ParquetTableScan( ParquetTestData.testData.output, ParquetTestData.testData, - Seq())(TestSQLContext) + Seq()) val projected = scanner.pruneColumns(ParquetTypesConverter .convertToAttributes(MessageTypeParser .parseMessageType(ParquetTestData.subTestSchema))) http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 84d43ea..f0a6127 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -231,7 +231,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveTableScans, DataSinks, Scripts, - PartialAggregation, + HashAggregation, LeftSemiJoin, HashJoin, BasicOperators, http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index c2b0b00..39033bd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -131,7 +131,7 @@ case class InsertIntoHiveTable( conf, SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) - logger.debug("Saving as hadoop file of type " + valueClass.getSimpleName) + log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) writer.preSetup() http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 8258ee5..0c8f676 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -67,7 +67,7 @@ case class ScriptTransformation( } } readerThread.start() - val outputProjection = new Projection(input) + val outputProjection = new InterpretedProjection(input, child.output) iter .map(outputProjection) // TODO: Use SerDe http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 057eb60..7582b47 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -251,8 +251,10 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val function: GenericUDTF = createFunction() + @transient protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) + @transient protected lazy val outputInspectors = { val structInspector = function.initialize(inputInspectors.toArray) structInspector.getAllStructFieldRefs.map(_.getFieldObjectInspector) @@ -278,7 +280,7 @@ private[hive] case class HiveGenericUdtf( override def eval(input: Row): TraversableOnce[Row] = { outputInspectors // Make sure initialized. - val inputProjection = new Projection(children) + val inputProjection = new InterpretedProjection(children) val collector = new UDTFCollector function.setCollector(collector) @@ -332,7 +334,7 @@ private[hive] case class HiveUdafFunction( override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector) @transient - val inputProjection = new Projection(exprs) + val inputProjection = new InterpretedProjection(exprs) def update(input: Row): Unit = { val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d new file mode 100644 index 0000000..00750ed --- /dev/null +++ b/sql/hive/src/test/resources/golden/case else null-0-8ef2f741400830ef889a9dd0c817fe3d @@ -0,0 +1 @@ +3 http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 new file mode 100644 index 0000000..00750ed --- /dev/null +++ b/sql/hive/src/test/resources/golden/double case-0-f513687d17dcb18546fefa75000a52f2 @@ -0,0 +1 @@ +3 http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 new file mode 100644 index 0000000..d00491f --- /dev/null +++ b/sql/hive/src/test/resources/golden/single case-0-c264e319c52f1840a32959d552b99e73 @@ -0,0 +1 @@ +1 http://git-wip-us.apache.org/repos/asf/spark/blob/84467468/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index aadfd2e..89cc589 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.execution import scala.util.Try +import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.{Row, SchemaRDD} @@ -30,6 +32,15 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("single case", + """SELECT case when true then 1 else 2 end FROM src LIMIT 1""") + + createQueryTest("double case", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else 2 end FROM src LIMIT 1""") + + createQueryTest("case else null", + """SELECT case when 1 = 2 then 1 when 2 = 2 then 3 else null end FROM src LIMIT 1""") + createQueryTest("having no references", "SELECT key FROM src GROUP BY key HAVING COUNT(*) > 1")