spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ues...@apache.org
Subject spark git commit: [SPARK-23927][SQL] Add "sequence" expression
Date Wed, 27 Jun 2018 02:52:38 GMT
Repository: spark
Updated Branches:
  refs/heads/master d08f53dc6 -> 2669b4de3


[SPARK-23927][SQL] Add "sequence" expression

## What changes were proposed in this pull request?
The PR adds the SQL function ```sequence```.
https://issues.apache.org/jira/browse/SPARK-23927

The behavior of the function is based on Presto's one.
Ref: https://prestodb.io/docs/current/functions/array.html

- ```sequence(start, stop) → array<bigint>```
Generate a sequence of integers from ```start``` to ```stop```, incrementing by ```1``` if
```start``` is less than or equal to ```stop```, otherwise ```-1```.
- ```sequence(start, stop, step) → array<bigint>```
Generate a sequence of integers from ```start``` to ```stop```, incrementing by ```step```.
- ```sequence(start_date, stop_date) → array<date>```
Generate a sequence of dates from ```start_date``` to ```stop_date```, incrementing by ```interval
1 day``` if ```start_date``` is less than or equal to ```stop_date```, otherwise ```- interval
1 day```.
- ```sequence(start_date, stop_date, step_interval) → array<date>```
Generate a sequence of dates from ```start_date``` to ```stop_date```, incrementing by ```step_interval```.
The type of ```step_interval``` is ```CalendarInterval```.
- ```sequence(start_timestemp, stop_timestemp) → array<timestamp>```
Generate a sequence of timestamps from ```start_timestamps``` to ```stop_timestamps```, incrementing
by ```interval 1 day``` if ```start_date``` is less than or equal to ```stop_date```, otherwise
```- interval 1 day```.
- ```sequence(start_timestamp, stop_timestamp, step_interval) → array<timestamp>```
Generate a sequence of timestamps from ```start_timestamps``` to ```stop_timestamps```, incrementing
by ```step_interval```. The type of ```step_interval``` is ```CalendarInterval```.

## How was this patch tested?

Added unit tests.

Author: Vayda, Oleksandr: IT (PRG) <Oleksandr.Vayda@barclayscapital.com>

Closes #21155 from wajda/feature/array-api-sequence.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2669b4de
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2669b4de
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2669b4de

Branch: refs/heads/master
Commit: 2669b4de3b336dde84b698c20dbc73b30abf79d4
Parents: d08f53d
Author: Vayda, Oleksandr: IT (PRG) <Oleksandr.Vayda@barclayscapital.com>
Authored: Wed Jun 27 11:52:31 2018 +0900
Committer: Takuya UESHIN <ueshin@databricks.com>
Committed: Wed Jun 27 11:52:31 2018 +0900

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../sql/catalyst/analysis/TypeCoercion.scala    |   7 +
 .../expressions/collectionOperations.scala      | 402 ++++++++++++++++++-
 .../CollectionExpressionsSuite.scala            | 292 ++++++++++++++
 .../scala/org/apache/spark/sql/functions.scala  |  21 +
 .../spark/sql/DataFrameFunctionsSuite.scala     |  56 +++
 6 files changed, 777 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2669b4de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 8abc616..a574d8a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -432,6 +432,7 @@ object FunctionRegistry {
     expression[Reverse]("reverse"),
     expression[Concat]("concat"),
     expression[Flatten]("flatten"),
+    expression[Sequence]("sequence"),
     expression[ArrayRepeat]("array_repeat"),
     expression[ArrayRemove]("array_remove"),
     expression[ArrayDistinct]("array_distinct"),

http://git-wip-us.apache.org/repos/asf/spark/blob/2669b4de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 6379239..3ebab43 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -544,6 +544,13 @@ object TypeCoercion {
           case None => aj
         }
 
+      case s @ Sequence(_, _, _, timeZoneId) if !haveSameType(s.coercibleChildren) =>
+        val types = s.coercibleChildren.map(_.dataType)
+        findWiderCommonType(types) match {
+          case Some(widerDataType) => s.castChildrenTo(widerDataType)
+          case None => s
+        }
+
       case m @ CreateMap(children) if m.keys.length == m.values.length &&
         (!haveSameType(m.keys) || !haveSameType(m.values)) =>
         val newKeys = if (haveSameType(m.keys)) {

http://git-wip-us.apache.org/repos/asf/spark/blob/2669b4de/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index abd6c88..0395e1e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -16,9 +16,10 @@
  */
 package org.apache.spark.sql.catalyst.expressions
 
-import java.util.Comparator
+import java.util.{Comparator, TimeZone}
 
 import scala.collection.mutable
+import scala.reflect.ClassTag
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
@@ -26,11 +27,13 @@ import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.catalyst.util.DateTimeUtils._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
 import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
+import org.apache.spark.unsafe.types.CalendarInterval
 import org.apache.spark.util.collection.OpenHashSet
 
 /**
@@ -2313,6 +2316,401 @@ case class Flatten(child: Expression) extends UnaryExpression {
   override def prettyName: String = "flatten"
 }
 
+@ExpressionDescription(
+  usage = """
+    _FUNC_(start, stop, step) - Generates an array of elements from start to stop (inclusive),
+      incrementing by step. The type of the returned elements is the same as the type of
argument
+      expressions.
+
+      Supported types are: byte, short, integer, long, date, timestamp.
+
+      The start and stop expressions must resolve to the same type.
+      If start and stop expressions resolve to the 'date' or 'timestamp' type
+      then the step expression must resolve to the 'interval' type, otherwise to the same
type
+      as the start and stop expressions.
+  """,
+  arguments = """
+    Arguments:
+      * start - an expression. The start of the range.
+      * stop - an expression. The end the range (inclusive).
+      * step - an optional expression. The step of the range.
+          By default step is 1 if start is less than or equal to stop, otherwise -1.
+          For the temporal sequences it's 1 day and -1 day respectively.
+          If start is greater than stop then the step must be negative, and vice versa.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(1, 5);
+       [1, 2, 3, 4, 5]
+      > SELECT _FUNC_(5, 1);
+       [5, 4, 3, 2, 1]
+      > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month);
+       [2018-01-01, 2018-02-01, 2018-03-01]
+  """,
+  since = "2.4.0"
+)
+case class Sequence(
+    start: Expression,
+    stop: Expression,
+    stepOpt: Option[Expression],
+    timeZoneId: Option[String] = None)
+  extends Expression
+  with TimeZoneAwareExpression {
+
+  import Sequence._
+
+  def this(start: Expression, stop: Expression) =
+    this(start, stop, None, None)
+
+  def this(start: Expression, stop: Expression, step: Expression) =
+    this(start, stop, Some(step), None)
+
+  override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+    copy(timeZoneId = Some(timeZoneId))
+
+  override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt
+
+  override def foldable: Boolean = children.forall(_.foldable)
+
+  override def nullable: Boolean = children.exists(_.nullable)
+
+  override lazy val dataType: ArrayType = ArrayType(start.dataType, containsNull = false)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val startType = start.dataType
+    def stepType = stepOpt.get.dataType
+    val typesCorrect =
+      startType.sameType(stop.dataType) &&
+        (startType match {
+          case TimestampType | DateType =>
+            stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType)
+          case _: IntegralType =>
+            stepOpt.isEmpty || stepType.sameType(startType)
+          case _ => false
+        })
+
+    if (typesCorrect) {
+      TypeCheckResult.TypeCheckSuccess
+    } else {
+      TypeCheckResult.TypeCheckFailure(
+        s"$prettyName only supports integral, timestamp or date types")
+    }
+  }
+
+  def coercibleChildren: Seq[Expression] = children.filter(_.dataType != CalendarIntervalType)
+
+  def castChildrenTo(widerType: DataType): Expression = Sequence(
+    Cast(start, widerType),
+    Cast(stop, widerType),
+    stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType)
else step),
+    timeZoneId)
+
+  private lazy val impl: SequenceImpl = dataType.elementType match {
+    case iType: IntegralType =>
+      type T = iType.InternalType
+      val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe))
+      new IntegralSequenceImpl(iType)(ct, iType.integral)
+
+    case TimestampType =>
+      new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone)
+
+    case DateType =>
+      new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, timeZone)
+  }
+
+  override def eval(input: InternalRow): Any = {
+    val startVal = start.eval(input)
+    if (startVal == null) return null
+    val stopVal = stop.eval(input)
+    if (stopVal == null) return null
+    val stepVal = stepOpt.map(_.eval(input)).getOrElse(impl.defaultStep(startVal, stopVal))
+    if (stepVal == null) return null
+
+    ArrayData.toArrayData(impl.eval(startVal, stopVal, stepVal))
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val startGen = start.genCode(ctx)
+    val stopGen = stop.genCode(ctx)
+    val stepGen = stepOpt.map(_.genCode(ctx)).getOrElse(
+      impl.defaultStep.genCode(ctx, startGen, stopGen))
+
+    val resultType = CodeGenerator.javaType(dataType)
+    val resultCode = {
+      val arr = ctx.freshName("arr")
+      val arrElemType = CodeGenerator.javaType(dataType.elementType)
+      s"""
+         |final $arrElemType[] $arr = null;
+         |${impl.genCode(ctx, startGen.value, stopGen.value, stepGen.value, arr, arrElemType)}
+         |${ev.value} = UnsafeArrayData.fromPrimitiveArray($arr);
+       """.stripMargin
+    }
+
+    if (nullable) {
+      val nullSafeEval =
+        startGen.code + ctx.nullSafeExec(start.nullable, startGen.isNull) {
+          stopGen.code + ctx.nullSafeExec(stop.nullable, stopGen.isNull) {
+            stepGen.code + ctx.nullSafeExec(stepOpt.exists(_.nullable), stepGen.isNull) {
+              s"""
+                 |${ev.isNull} = false;
+                 |$resultCode
+               """.stripMargin
+            }
+          }
+        }
+      ev.copy(code =
+        code"""
+           |boolean ${ev.isNull} = true;
+           |$resultType ${ev.value} = null;
+           |$nullSafeEval
+         """.stripMargin)
+
+    } else {
+      ev.copy(code =
+        code"""
+           |${startGen.code}
+           |${stopGen.code}
+           |${stepGen.code}
+           |$resultType ${ev.value} = null;
+           |$resultCode
+         """.stripMargin,
+        isNull = FalseLiteral)
+    }
+  }
+}
+
+object Sequence {
+
+  private type LessThanOrEqualFn = (Any, Any) => Boolean
+
+  private class DefaultStep(lteq: LessThanOrEqualFn, stepType: DataType, one: Any) {
+    private val negativeOne = UnaryMinus(Literal(one)).eval()
+
+    def apply(start: Any, stop: Any): Any = {
+      if (lteq(start, stop)) one else negativeOne
+    }
+
+    def genCode(ctx: CodegenContext, startGen: ExprCode, stopGen: ExprCode): ExprCode = {
+      val Seq(oneVal, negativeOneVal) = Seq(one, negativeOne).map(Literal(_).genCode(ctx).value)
+      ExprCode.forNonNullValue(JavaCode.expression(
+        s"${startGen.value} <= ${stopGen.value} ? $oneVal : $negativeOneVal",
+        stepType))
+    }
+  }
+
+  private trait SequenceImpl {
+    def eval(start: Any, stop: Any, step: Any): Any
+
+    def genCode(
+        ctx: CodegenContext,
+        start: String,
+        stop: String,
+        step: String,
+        arr: String,
+        elemType: String): String
+
+    val defaultStep: DefaultStep
+  }
+
+  private class IntegralSequenceImpl[T: ClassTag]
+    (elemType: IntegralType)(implicit num: Integral[T]) extends SequenceImpl {
+
+    override val defaultStep: DefaultStep = new DefaultStep(
+      (elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
+      elemType,
+      num.one)
+
+    override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
+      import num._
+
+      val start = input1.asInstanceOf[T]
+      val stop = input2.asInstanceOf[T]
+      val step = input3.asInstanceOf[T]
+
+      var i: Int = getSequenceLength(start, stop, step)
+      val arr = new Array[T](i)
+      while (i > 0) {
+        i -= 1
+        arr(i) = start + step * num.fromInt(i)
+      }
+      arr
+    }
+
+    override def genCode(
+        ctx: CodegenContext,
+        start: String,
+        stop: String,
+        step: String,
+        arr: String,
+        elemType: String): String = {
+      val i = ctx.freshName("i")
+      s"""
+         |${genSequenceLengthCode(ctx, start, stop, step, i)}
+         |$arr = new $elemType[$i];
+         |while ($i > 0) {
+         |  $i--;
+         |  $arr[$i] = ($elemType) ($start + $step * $i);
+         |}
+         """.stripMargin
+    }
+  }
+
+  private class TemporalSequenceImpl[T: ClassTag]
+      (dt: IntegralType, scale: Long, fromLong: Long => T, timeZone: TimeZone)
+      (implicit num: Integral[T]) extends SequenceImpl {
+
+    override val defaultStep: DefaultStep = new DefaultStep(
+      (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
+      CalendarIntervalType,
+      new CalendarInterval(0, MICROS_PER_DAY))
+
+    private val backedSequenceImpl = new IntegralSequenceImpl[T](dt)
+    private val microsPerMonth = 28 * CalendarInterval.MICROS_PER_DAY
+
+    override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
+      val start = input1.asInstanceOf[T]
+      val stop = input2.asInstanceOf[T]
+      val step = input3.asInstanceOf[CalendarInterval]
+      val stepMonths = step.months
+      val stepMicros = step.microseconds
+
+      if (stepMonths == 0) {
+        backedSequenceImpl.eval(start, stop, fromLong(stepMicros / scale))
+
+      } else {
+        // To estimate the resulted array length we need to make assumptions
+        // about a month length in microseconds
+        val intervalStepInMicros = stepMicros + stepMonths * microsPerMonth
+        val startMicros: Long = num.toLong(start) * scale
+        val stopMicros: Long = num.toLong(stop) * scale
+        val maxEstimatedArrayLength =
+          getSequenceLength(startMicros, stopMicros, intervalStepInMicros)
+
+        val stepSign = if (stopMicros > startMicros) +1 else -1
+        val exclusiveItem = stopMicros + stepSign
+        val arr = new Array[T](maxEstimatedArrayLength)
+        var t = startMicros
+        var i = 0
+
+        while (t < exclusiveItem ^ stepSign < 0) {
+          arr(i) = fromLong(t / scale)
+          t = timestampAddInterval(t, stepMonths, stepMicros, timeZone)
+          i += 1
+        }
+
+        // truncate array to the correct length
+        if (arr.length == i) arr else arr.slice(0, i)
+      }
+    }
+
+    override def genCode(
+        ctx: CodegenContext,
+        start: String,
+        stop: String,
+        step: String,
+        arr: String,
+        elemType: String): String = {
+      val stepMonths = ctx.freshName("stepMonths")
+      val stepMicros = ctx.freshName("stepMicros")
+      val stepScaled = ctx.freshName("stepScaled")
+      val intervalInMicros = ctx.freshName("intervalInMicros")
+      val startMicros = ctx.freshName("startMicros")
+      val stopMicros = ctx.freshName("stopMicros")
+      val arrLength = ctx.freshName("arrLength")
+      val stepSign = ctx.freshName("stepSign")
+      val exclusiveItem = ctx.freshName("exclusiveItem")
+      val t = ctx.freshName("t")
+      val i = ctx.freshName("i")
+      val genTimeZone = ctx.addReferenceObj("timeZone", timeZone, classOf[TimeZone].getName)
+
+      val sequenceLengthCode =
+        s"""
+           |final long $intervalInMicros = $stepMicros + $stepMonths * ${microsPerMonth}L;
+           |${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)}
+          """.stripMargin
+
+      val timestampAddIntervalCode =
+        s"""
+           |$t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval(
+           |  $t, $stepMonths, $stepMicros, $genTimeZone);
+          """.stripMargin
+
+      s"""
+         |final int $stepMonths = $step.months;
+         |final long $stepMicros = $step.microseconds;
+         |
+         |if ($stepMonths == 0) {
+         |  final $elemType $stepScaled = ($elemType) ($stepMicros / ${scale}L);
+         |  ${backedSequenceImpl.genCode(ctx, start, stop, stepScaled, arr, elemType)};
+         |
+         |} else {
+         |  final long $startMicros = $start * ${scale}L;
+         |  final long $stopMicros = $stop * ${scale}L;
+         |
+         |  $sequenceLengthCode
+         |
+         |  final int $stepSign = $stopMicros > $startMicros ? +1 : -1;
+         |  final long $exclusiveItem = $stopMicros + $stepSign;
+         |
+         |  $arr = new $elemType[$arrLength];
+         |  long $t = $startMicros;
+         |  int $i = 0;
+         |
+         |  while ($t < $exclusiveItem ^ $stepSign < 0) {
+         |    $arr[$i] = ($elemType) ($t / ${scale}L);
+         |    $timestampAddIntervalCode
+         |    $i += 1;
+         |  }
+         |
+         |  if ($arr.length > $i) {
+         |    $arr = java.util.Arrays.copyOf($arr, $i);
+         |  }
+         |}
+         """.stripMargin
+    }
+  }
+
+  private def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: Integral[U]):
Int = {
+    import num._
+    require(
+      (step > num.zero && start <= stop)
+        || (step < num.zero && start >= stop)
+        || (step == num.zero && start == stop),
+      s"Illegal sequence boundaries: $start to $stop by $step")
+
+    val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / step.toLong
+
+    require(
+      len <= MAX_ROUNDED_ARRAY_LENGTH,
+      s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
+
+    len.toInt
+  }
+
+  private def genSequenceLengthCode(
+      ctx: CodegenContext,
+      start: String,
+      stop: String,
+      step: String,
+      len: String): String = {
+    val longLen = ctx.freshName("longLen")
+    s"""
+       |if (!(($step > 0 && $start <= $stop) ||
+       |  ($step < 0 && $start >= $stop) ||
+       |  ($step == 0 && $start == $stop))) {
+       |  throw new IllegalArgumentException(
+       |    "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step);
+       |}
+       |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $step;
+       |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
+       |  throw new IllegalArgumentException(
+       |    "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH");
+       |}
+       |int $len = (int) $longLen;
+       """.stripMargin
+  }
+}
+
 /**
  * Returns the array containing the given input value (left) count (right) times.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/2669b4de/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index caea4fb..d7744eb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -17,10 +17,16 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.sql.{Date, Timestamp}
+import java.util.TimeZone
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+import org.apache.spark.unsafe.types.CalendarInterval
 
 class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
 
@@ -484,6 +490,292 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
       ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
   }
 
+  test("Sequence of numbers") {
+    // test null handling
+
+    checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L)), null)
+    checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType)), null)
+    checkEvaluation(new Sequence(Literal(null, LongType), Literal(1L), Literal(1L)), null)
+    checkEvaluation(new Sequence(Literal(1L), Literal(null, LongType), Literal(1L)), null)
+    checkEvaluation(new Sequence(Literal(1L), Literal(1L), Literal(null, LongType)), null)
+
+    // test sequence boundaries checking
+
+    checkExceptionInExpression[IllegalArgumentException](
+      new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
+      EmptyRow, s"Too long sequence: 4294967296. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
+
+    checkExceptionInExpression[IllegalArgumentException](
+      new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 1 to 2 by
0")
+    checkExceptionInExpression[IllegalArgumentException](
+      new Sequence(Literal(2), Literal(1), Literal(0)), EmptyRow, "boundaries: 2 to 1 by
0")
+    checkExceptionInExpression[IllegalArgumentException](
+      new Sequence(Literal(2), Literal(1), Literal(1)), EmptyRow, "boundaries: 2 to 1 by
1")
+    checkExceptionInExpression[IllegalArgumentException](
+      new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, "boundaries: 1 to 2 by
-1")
+
+    // test sequence with one element (zero step or equal start and stop)
+
+    checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1))
+    checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(0)), Seq(1))
+    checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(1)), Seq(1))
+    checkEvaluation(new Sequence(Literal(1), Literal(2), Literal(2)), Seq(1))
+    checkEvaluation(new Sequence(Literal(1), Literal(0), Literal(-2)), Seq(1))
+
+    // test sequence of different integral types (ascending and descending)
+
+    checkEvaluation(new Sequence(Literal(1L), Literal(3L), Literal(1L)), Seq(1L, 2L, 3L))
+    checkEvaluation(new Sequence(Literal(-3), Literal(3), Literal(3)), Seq(-3, 0, 3))
+    checkEvaluation(
+      new Sequence(Literal(3.toShort), Literal(-3.toShort), Literal(-3.toShort)),
+      Seq(3.toShort, 0.toShort, -3.toShort))
+    checkEvaluation(
+      new Sequence(Literal(-1.toByte), Literal(-3.toByte), Literal(-1.toByte)),
+      Seq(-1.toByte, -2.toByte, -3.toByte))
+  }
+
+  test("Sequence of timestamps") {
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+      Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
+      Literal(CalendarInterval.fromString("interval 12 hours"))),
+      Seq(
+        Timestamp.valueOf("2018-01-01 00:00:00"),
+        Timestamp.valueOf("2018-01-01 12:00:00"),
+        Timestamp.valueOf("2018-01-02 00:00:00")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+      Literal(Timestamp.valueOf("2018-01-02 00:00:01")),
+      Literal(CalendarInterval.fromString("interval 12 hours"))),
+      Seq(
+        Timestamp.valueOf("2018-01-01 00:00:00"),
+        Timestamp.valueOf("2018-01-01 12:00:00"),
+        Timestamp.valueOf("2018-01-02 00:00:00")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
+      Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+      Literal(CalendarInterval.fromString("interval 12 hours").negate())),
+      Seq(
+        Timestamp.valueOf("2018-01-02 00:00:00"),
+        Timestamp.valueOf("2018-01-01 12:00:00"),
+        Timestamp.valueOf("2018-01-01 00:00:00")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-01-02 00:00:00")),
+      Literal(Timestamp.valueOf("2017-12-31 23:59:59")),
+      Literal(CalendarInterval.fromString("interval 12 hours").negate())),
+      Seq(
+        Timestamp.valueOf("2018-01-02 00:00:00"),
+        Timestamp.valueOf("2018-01-01 12:00:00"),
+        Timestamp.valueOf("2018-01-01 00:00:00")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+      Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
+      Literal(CalendarInterval.fromString("interval 1 month"))),
+      Seq(
+        Timestamp.valueOf("2018-01-01 00:00:00"),
+        Timestamp.valueOf("2018-02-01 00:00:00"),
+        Timestamp.valueOf("2018-03-01 00:00:00")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
+      Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+      Literal(CalendarInterval.fromString("interval 1 month").negate())),
+      Seq(
+        Timestamp.valueOf("2018-03-01 00:00:00"),
+        Timestamp.valueOf("2018-02-01 00:00:00"),
+        Timestamp.valueOf("2018-01-01 00:00:00")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-03-03 00:00:00")),
+      Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+      Literal(CalendarInterval.fromString("interval 1 month 1 day").negate())),
+      Seq(
+        Timestamp.valueOf("2018-03-03 00:00:00"),
+        Timestamp.valueOf("2018-02-02 00:00:00"),
+        Timestamp.valueOf("2018-01-01 00:00:00")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-01-31 00:00:00")),
+      Literal(Timestamp.valueOf("2018-04-30 00:00:00")),
+      Literal(CalendarInterval.fromString("interval 1 month"))),
+      Seq(
+        Timestamp.valueOf("2018-01-31 00:00:00"),
+        Timestamp.valueOf("2018-02-28 00:00:00"),
+        Timestamp.valueOf("2018-03-31 00:00:00"),
+        Timestamp.valueOf("2018-04-30 00:00:00")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+      Literal(Timestamp.valueOf("2018-03-01 00:00:00")),
+      Literal(CalendarInterval.fromString("interval 1 month 1 second"))),
+      Seq(
+        Timestamp.valueOf("2018-01-01 00:00:00"),
+        Timestamp.valueOf("2018-02-01 00:00:01")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+      Literal(Timestamp.valueOf("2018-03-01 00:04:06")),
+      Literal(CalendarInterval.fromString("interval 1 month 2 minutes 3 seconds"))),
+      Seq(
+        Timestamp.valueOf("2018-01-01 00:00:00"),
+        Timestamp.valueOf("2018-02-01 00:02:03"),
+        Timestamp.valueOf("2018-03-01 00:04:06")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+      Literal(Timestamp.valueOf("2023-01-01 00:00:00")),
+      Literal(CalendarInterval.fromYearMonthString("1-5"))),
+      Seq(
+        Timestamp.valueOf("2018-01-01 00:00:00.000"),
+        Timestamp.valueOf("2019-06-01 00:00:00.000"),
+        Timestamp.valueOf("2020-11-01 00:00:00.000"),
+        Timestamp.valueOf("2022-04-01 00:00:00.000")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2022-04-01 00:00:00")),
+      Literal(Timestamp.valueOf("2017-01-01 00:00:00")),
+      Literal(CalendarInterval.fromYearMonthString("1-5").negate())),
+      Seq(
+        Timestamp.valueOf("2022-04-01 00:00:00.000"),
+        Timestamp.valueOf("2020-11-01 00:00:00.000"),
+        Timestamp.valueOf("2019-06-01 00:00:00.000"),
+        Timestamp.valueOf("2018-01-01 00:00:00.000")))
+  }
+
+  test("Sequence on DST boundaries") {
+    val timeZone = TimeZone.getTimeZone("Europe/Prague")
+    val dstOffset = timeZone.getDSTSavings
+
+    def noDST(t: Timestamp): Timestamp = new Timestamp(t.getTime - dstOffset)
+
+    DateTimeTestUtils.withDefaultTimeZone(timeZone) {
+      // Spring time change
+      checkEvaluation(new Sequence(
+        Literal(Timestamp.valueOf("2018-03-25 01:30:00")),
+        Literal(Timestamp.valueOf("2018-03-25 03:30:00")),
+        Literal(CalendarInterval.fromString("interval 30 minutes"))),
+        Seq(
+          Timestamp.valueOf("2018-03-25 01:30:00"),
+          Timestamp.valueOf("2018-03-25 03:00:00"),
+          Timestamp.valueOf("2018-03-25 03:30:00")))
+
+      // Autumn time change
+      checkEvaluation(new Sequence(
+        Literal(Timestamp.valueOf("2018-10-28 01:30:00")),
+        Literal(Timestamp.valueOf("2018-10-28 03:30:00")),
+        Literal(CalendarInterval.fromString("interval 30 minutes"))),
+        Seq(
+          Timestamp.valueOf("2018-10-28 01:30:00"),
+          noDST(Timestamp.valueOf("2018-10-28 02:00:00")),
+          noDST(Timestamp.valueOf("2018-10-28 02:30:00")),
+          Timestamp.valueOf("2018-10-28 02:00:00"),
+          Timestamp.valueOf("2018-10-28 02:30:00"),
+          Timestamp.valueOf("2018-10-28 03:00:00"),
+          Timestamp.valueOf("2018-10-28 03:30:00")))
+    }
+  }
+
+  test("Sequence of dates") {
+    DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) {
+      checkEvaluation(new Sequence(
+        Literal(Date.valueOf("2018-01-01")),
+        Literal(Date.valueOf("2018-01-05")),
+        Literal(CalendarInterval.fromString("interval 2 days"))),
+        Seq(
+          Date.valueOf("2018-01-01"),
+          Date.valueOf("2018-01-03"),
+          Date.valueOf("2018-01-05")))
+
+      checkEvaluation(new Sequence(
+        Literal(Date.valueOf("2018-01-01")),
+        Literal(Date.valueOf("2018-03-01")),
+        Literal(CalendarInterval.fromString("interval 1 month"))),
+        Seq(
+          Date.valueOf("2018-01-01"),
+          Date.valueOf("2018-02-01"),
+          Date.valueOf("2018-03-01")))
+
+      checkEvaluation(new Sequence(
+        Literal(Date.valueOf("2018-01-31")),
+        Literal(Date.valueOf("2018-04-30")),
+        Literal(CalendarInterval.fromString("interval 1 month"))),
+        Seq(
+          Date.valueOf("2018-01-31"),
+          Date.valueOf("2018-02-28"),
+          Date.valueOf("2018-03-31"),
+          Date.valueOf("2018-04-30")))
+
+      checkEvaluation(new Sequence(
+        Literal(Date.valueOf("2018-01-01")),
+        Literal(Date.valueOf("2023-01-01")),
+        Literal(CalendarInterval.fromYearMonthString("1-5"))),
+        Seq(
+          Date.valueOf("2018-01-01"),
+          Date.valueOf("2019-06-01"),
+          Date.valueOf("2020-11-01"),
+          Date.valueOf("2022-04-01")))
+
+      checkExceptionInExpression[IllegalArgumentException](
+        new Sequence(
+          Literal(Date.valueOf("1970-01-02")),
+          Literal(Date.valueOf("1970-01-01")),
+          Literal(CalendarInterval.fromString("interval 1 day"))),
+        EmptyRow, "sequence boundaries: 1 to 0 by 1")
+
+      checkExceptionInExpression[IllegalArgumentException](
+        new Sequence(
+          Literal(Date.valueOf("1970-01-01")),
+          Literal(Date.valueOf("1970-02-01")),
+          Literal(CalendarInterval.fromString("interval 1 month").negate())),
+        EmptyRow,
+        s"sequence boundaries: 0 to 2678400000000 by -${28 * CalendarInterval.MICROS_PER_DAY}")
+    }
+  }
+
+  test("Sequence with default step") {
+    // +/- 1 for integral type
+    checkEvaluation(new Sequence(Literal(1), Literal(3)), Seq(1, 2, 3))
+    checkEvaluation(new Sequence(Literal(3), Literal(1)), Seq(3, 2, 1))
+
+    // +/- 1 day for timestamps
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-01-01 00:00:00")),
+      Literal(Timestamp.valueOf("2018-01-03 00:00:00"))),
+      Seq(
+        Timestamp.valueOf("2018-01-01 00:00:00"),
+        Timestamp.valueOf("2018-01-02 00:00:00"),
+        Timestamp.valueOf("2018-01-03 00:00:00")))
+
+    checkEvaluation(new Sequence(
+      Literal(Timestamp.valueOf("2018-01-03 00:00:00")),
+      Literal(Timestamp.valueOf("2018-01-01 00:00:00"))),
+      Seq(
+        Timestamp.valueOf("2018-01-03 00:00:00"),
+        Timestamp.valueOf("2018-01-02 00:00:00"),
+        Timestamp.valueOf("2018-01-01 00:00:00")))
+
+    // +/- 1 day for dates
+    checkEvaluation(new Sequence(
+      Literal(Date.valueOf("2018-01-01")),
+      Literal(Date.valueOf("2018-01-03"))),
+      Seq(
+        Date.valueOf("2018-01-01"),
+        Date.valueOf("2018-01-02"),
+        Date.valueOf("2018-01-03")))
+
+    checkEvaluation(new Sequence(
+      Literal(Date.valueOf("2018-01-03")),
+      Literal(Date.valueOf("2018-01-01"))),
+      Seq(
+        Date.valueOf("2018-01-03"),
+        Date.valueOf("2018-01-02"),
+        Date.valueOf("2018-01-01")))
+  }
+
   test("Reverse") {
     // Primitive-type elements
     val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType))

http://git-wip-us.apache.org/repos/asf/spark/blob/2669b4de/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index ef99ce3..0b4f526 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3486,6 +3486,27 @@ object functions {
   def flatten(e: Column): Column = withExpr { Flatten(e.expr) }
 
   /**
+   * Generate a sequence of integers from start to stop, incrementing by step.
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def sequence(start: Column, stop: Column, step: Column): Column = withExpr {
+    new Sequence(start.expr, stop.expr, step.expr)
+  }
+
+  /**
+   * Generate a sequence of integers from start to stop,
+   * incrementing by 1 if start is less than or equal to stop, otherwise -1.
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def sequence(start: Column, stop: Column): Column = withExpr {
+    new Sequence(start.expr, stop.expr)
+  }
+
+  /**
    * Creates an array containing the left argument repeated the number of times given by
the
    * right argument.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/2669b4de/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index b109898..4c28e2f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -18,12 +18,15 @@
 package org.apache.spark.sql
 
 import java.nio.charset.StandardCharsets
+import java.sql.{Date, Timestamp}
+import java.util.TimeZone
 
 import scala.util.Random
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
@@ -862,6 +865,59 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
     checkAnswer(df.selectExpr("array_max(a)"), answer)
   }
 
+  test("sequence") {
+    checkAnswer(Seq((-2, 2)).toDF().select(sequence('_1, '_2)), Seq(Row(Array(-2, -1, 0,
1, 2))))
+    checkAnswer(Seq((7, 2, -2)).toDF().select(sequence('_1, '_2, '_3)), Seq(Row(Array(7,
5, 3))))
+
+    checkAnswer(
+      spark.sql("select sequence(" +
+        "   cast('2018-01-01 00:00:00' as timestamp)" +
+        ",  cast('2018-01-02 00:00:00' as timestamp)" +
+        ",  interval 12 hours)"),
+      Seq(Row(Array(
+        Timestamp.valueOf("2018-01-01 00:00:00"),
+        Timestamp.valueOf("2018-01-01 12:00:00"),
+        Timestamp.valueOf("2018-01-02 00:00:00")))))
+
+    DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("UTC")) {
+      checkAnswer(
+        spark.sql("select sequence(" +
+          "   cast('2018-01-01' as date)" +
+          ",  cast('2018-03-01' as date)" +
+          ",  interval 1 month)"),
+        Seq(Row(Array(
+          Date.valueOf("2018-01-01"),
+          Date.valueOf("2018-02-01"),
+          Date.valueOf("2018-03-01")))))
+    }
+
+    // test type coercion
+    checkAnswer(
+      Seq((1.toByte, 3L, 1)).toDF().select(sequence('_1, '_2, '_3)),
+      Seq(Row(Array(1L, 2L, 3L))))
+
+    checkAnswer(
+      spark.sql("select sequence(" +
+        "   cast('2018-01-01' as date)" +
+        ",  cast('2018-01-02 00:00:00' as timestamp)" +
+        ",  interval 12 hours)"),
+      Seq(Row(Array(
+        Timestamp.valueOf("2018-01-01 00:00:00"),
+        Timestamp.valueOf("2018-01-01 12:00:00"),
+        Timestamp.valueOf("2018-01-02 00:00:00")))))
+
+    // test invalid data types
+    intercept[AnalysisException] {
+      Seq((true, false)).toDF().selectExpr("sequence(_1, _2)")
+    }
+    intercept[AnalysisException] {
+      Seq((true, false, 42)).toDF().selectExpr("sequence(_1, _2, _3)")
+    }
+    intercept[AnalysisException] {
+      Seq((1, 2, 0.5)).toDF().selectExpr("sequence(_1, _2, _3)")
+    }
+  }
+
   test("reverse function") {
     val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message