spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-12374][SPARK-12150][SQL] Adding logical/physical operators for Range
Date Mon, 21 Dec 2015 21:47:04 GMT
Repository: spark
Updated Branches:
  refs/heads/master 7634fe951 -> 4883a5087


[SPARK-12374][SPARK-12150][SQL] Adding logical/physical operators for Range

Based on the suggestions from marmbrus , added logical/physical operators for Range for improving
the performance.

Also added another API for resolving the JIRA Spark-12150.

Could you take a look at my implementation, marmbrus ? If not good, I can rework it. : )

Thank you very much!

Author: gatorsmile <gatorsmile@gmail.com>

Closes #10335 from gatorsmile/rangeOperators.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4883a508
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4883a508
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4883a508

Branch: refs/heads/master
Commit: 4883a5087d481d4de5d3beabbd709853de01399a
Parents: 7634fe9
Author: gatorsmile <gatorsmile@gmail.com>
Authored: Mon Dec 21 13:46:58 2015 -0800
Committer: Michael Armbrust <michael@databricks.com>
Committed: Mon Dec 21 13:46:58 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/SparkContext.scala   |  2 +-
 .../catalyst/plans/logical/basicOperators.scala | 32 ++++++++++
 .../scala/org/apache/spark/sql/SQLContext.scala | 23 +++++---
 .../spark/sql/execution/SparkStrategies.scala   |  2 +
 .../spark/sql/execution/basicOperators.scala    | 62 ++++++++++++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala   |  5 ++
 .../execution/ExchangeCoordinatorSuite.scala    |  1 +
 7 files changed, 119 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 194ecc0..81a4d0a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -759,7 +759,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
     val numElements: BigInt = {
       val safeStart = BigInt(start)
       val safeEnd = BigInt(end)
-      if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) {
+      if ((safeEnd - safeStart) % step == 0 || (safeEnd > safeStart) != (step > 0))
{
         (safeEnd - safeStart) / step
       } else {
         // the remainder has the same sign with range, could add 1 more

http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index ec42b76..64ef4d7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -210,6 +210,38 @@ case class Sort(
   override def output: Seq[Attribute] = child.output
 }
 
+/** Factory for constructing new `Range` nodes. */
+object Range {
+  def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = {
+    val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
+    new Range(start, end, step, numSlices, output)
+  }
+}
+
+case class Range(
+    start: Long,
+    end: Long,
+    step: Long,
+    numSlices: Int,
+    output: Seq[Attribute]) extends LeafNode {
+  require(step != 0, "step cannot be 0")
+  val numElements: BigInt = {
+    val safeStart = BigInt(start)
+    val safeEnd = BigInt(end)
+    if ((safeEnd - safeStart) % step == 0 || (safeEnd > safeStart) != (step > 0)) {
+      (safeEnd - safeStart) / step
+    } else {
+      // the remainder has the same sign with range, could add 1 more
+      (safeEnd - safeStart) / step + 1
+    }
+  }
+
+  override def statistics: Statistics = {
+    val sizeInBytes = LongType.defaultSize * numElements
+    Statistics( sizeInBytes = sizeInBytes )
+  }
+}
+
 case class Aggregate(
     groupingExpressions: Seq[Expression],
     aggregateExpressions: Seq[NamedExpression],

http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index db286ea..eadf5cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.errors.DialectException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
 import org.apache.spark.sql.execution._
@@ -785,9 +785,20 @@ class SQLContext private[sql](
    */
   @Experimental
   def range(start: Long, end: Long): DataFrame = {
-    createDataFrame(
-      sparkContext.range(start, end).map(Row(_)),
-      StructType(StructField("id", LongType, nullable = false) :: Nil))
+    range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism)
+  }
+
+  /**
+    * :: Experimental ::
+    * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
+    * in an range from `start` to `end` (exclusive) with an step value.
+    *
+    * @since 2.0.0
+    * @group dataframe
+    */
+  @Experimental
+  def range(start: Long, end: Long, step: Long): DataFrame = {
+    range(start, end, step, numPartitions = sparkContext.defaultParallelism)
   }
 
   /**
@@ -801,9 +812,7 @@ class SQLContext private[sql](
    */
   @Experimental
   def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = {
-    createDataFrame(
-      sparkContext.range(start, end, step, numPartitions).map(Row(_)),
-      StructType(StructField("id", LongType, nullable = false) :: Nil))
+    DataFrame(this, Range(start, end, step, numPartitions))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 688555c..183d9b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -358,6 +358,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
           generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
       case logical.OneRowRelation =>
         execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil
+      case r @ logical.Range(start, end, step, numSlices, output) =>
+        execution.Range(start, step, numSlices, r.numElements, output) :: Nil
       case logical.RepartitionByExpression(expressions, child, nPartitions) =>
         execution.Exchange(HashPartitioning(
           expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index b3e4688..21325be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.LongType
 import org.apache.spark.util.MutablePair
 import org.apache.spark.util.random.PoissonSampler
 import org.apache.spark.{HashPartitioner, SparkEnv}
@@ -126,6 +127,67 @@ case class Sample(
   }
 }
 
+case class Range(
+    start: Long,
+    step: Long,
+    numSlices: Int,
+    numElements: BigInt,
+    output: Seq[Attribute])
+  extends LeafNode {
+
+  override def outputsUnsafeRows: Boolean = true
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    sqlContext
+      .sparkContext
+      .parallelize(0 until numSlices, numSlices)
+      .mapPartitionsWithIndex((i, _) => {
+        val partitionStart = (i * numElements) / numSlices * step + start
+        val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start
+        def getSafeMargin(bi: BigInt): Long =
+          if (bi.isValidLong) {
+            bi.toLong
+          } else if (bi > 0) {
+            Long.MaxValue
+          } else {
+            Long.MinValue
+          }
+        val safePartitionStart = getSafeMargin(partitionStart)
+        val safePartitionEnd = getSafeMargin(partitionEnd)
+        val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
+        val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
+
+        new Iterator[InternalRow] {
+          private[this] var number: Long = safePartitionStart
+          private[this] var overflow: Boolean = false
+
+          override def hasNext =
+            if (!overflow) {
+              if (step > 0) {
+                number < safePartitionEnd
+              } else {
+                number > safePartitionEnd
+              }
+            } else false
+
+          override def next() = {
+            val ret = number
+            number += step
+            if (number < ret ^ step < 0) {
+              // we have Long.MaxValue + Long.MaxValue < Long.MaxValue
+              // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes
a step
+              // back, we are pretty sure that we have an overflow.
+              overflow = true
+            }
+
+            unsafeRow.setLong(0, ret)
+            unsafeRow
+          }
+        }
+      })
+  }
+}
+
 /**
  * Union two plans, without a distinct. This is UNION ALL in SQL.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 1a0f1b6..ad478b0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -769,6 +769,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
 
     val res11 = sqlContext.range(-1).select("id")
     assert(res11.count == 0)
+
+    // using the default slice number
+    val res12 = sqlContext.range(3, 15, 3).select("id")
+    assert(res12.count == 4)
+    assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
   }
 
   test("SPARK-8621: support empty string column name") {

http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 180050b..101cf50 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -260,6 +260,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll
{
         .set("spark.driver.allowMultipleContexts", "true")
         .set(SQLConf.SHUFFLE_PARTITIONS.key, "5")
         .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+        .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
         .set(
           SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key,
           targetNumPostShufflePartitions.toString)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message