spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From pepinoflo <...@git.apache.org>
Subject [GitHub] spark pull request #21208: [SPARK-23925][SQL] Add array_repeat collection fu...
Date Sun, 13 May 2018 23:18:44 GMT
Github user pepinoflo commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21208#discussion_r187816430
  
    --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
    @@ -1229,3 +1229,140 @@ case class Flatten(child: Expression) extends UnaryExpression
{
     
       override def prettyName: String = "flatten"
     }
    +
    +/**
    + * Returns the array containing the given input value (left) count (right) times.
    + */
    +@ExpressionDescription(
    +  usage = "_FUNC_(element, count) - Returns the array containing element count times.",
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_('123', 2);
    +       ['123', '123']
    +  """)
    +case class ArrayRepeat(left: Expression, right: Expression)
    +  extends BinaryExpression with ExpectsInputTypes {
    +
    +  override def dataType: ArrayType = ArrayType(left.dataType, left.nullable)
    +
    +  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType)
    +
    +  override def nullable: Boolean = right.nullable
    +
    +  override def eval(input: InternalRow): Any = {
    +    val count = right.eval(input)
    +    if (count == null) {
    +      null
    +    } else {
    +      new GenericArrayData(List.fill(count.asInstanceOf[Int])(left.eval(input)))
    +    }
    +  }
    +
    +  override def prettyName: String = "array_repeat"
    +
    +  override def nullSafeCodeGen(ctx: CodegenContext,
    +                               ev: ExprCode,
    +                               f: (String, String) => String): ExprCode = {
    +    val leftGen = left.genCode(ctx)
    +    val rightGen = right.genCode(ctx)
    +    val resultCode = f(leftGen.value, rightGen.value)
    +
    +    if (nullable) {
    +      val nullSafeEval =
    +        leftGen.code +
    +          rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) {
    +            s"""
    +              ${ev.isNull} = false;
    +              $resultCode
    +            """
    +          }
    +
    +      ev.copy(code =
    +        s"""
    +           | boolean ${ev.isNull} = true;
    +           | ${CodeGenerator.javaType(dataType)} ${ev.value} =
    +           |   ${CodeGenerator.defaultValue(dataType)};
    +           | $nullSafeEval
    +         """.stripMargin
    +      )
    +    } else {
    +      ev.copy(code =
    +        s"""
    +           | boolean ${ev.isNull} = false;
    +           | ${leftGen.code}
    +           | ${rightGen.code}
    +           | ${CodeGenerator.javaType(dataType)} ${ev.value} =
    +           |   ${CodeGenerator.defaultValue(dataType)};
    +           | $resultCode
    +         """.stripMargin
    +        , isNull = FalseLiteral)
    +    }
    +
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +
    +    nullSafeCodeGen(ctx, ev, (l, r) => {
    +      val et = dataType.elementType
    +      val isPrimitive = CodeGenerator.isPrimitiveType(et)
    +
    +      val arrayDataName = ctx.freshName("arrayData")
    +      val arrayName = ctx.freshName("arrayObject")
    +      val numElements = ctx.freshName("numElements")
    +
    +      val genNumElements =
    +        s"""
    +           | int $numElements = 0;
    +           | if ($r > 0) {
    +           |   $numElements = $r;
    +           | }
    +         """.stripMargin
    +
    +      val initialization = if (isPrimitive) {
    +        val arrayName = ctx.freshName("array")
    +        val baseOffset = Platform.BYTE_ARRAY_OFFSET
    +        s"""
    +           | int numBytes = ${et.defaultSize} * $numElements;
    +           | int unsafeArraySizeInBytes =
    +           |   UnsafeArrayData.calculateHeaderPortionInBytes($numElements)
    +           |     + org.apache.spark.unsafe.array.ByteArrayMethods
    +           |       .roundNumberOfBytesToNearestWord(numBytes);
    +           | byte[] $arrayName = new byte[unsafeArraySizeInBytes];
    +           | UnsafeArrayData $arrayDataName = new UnsafeArrayData();
    +           | Platform.putLong($arrayName, $baseOffset, $numElements);
    +           | $arrayDataName.pointTo($arrayName, $baseOffset, unsafeArraySizeInBytes);
    +           | ${ev.value} = $arrayDataName;
    +         """.stripMargin
    +      } else {
    +        s"${ev.value} = new ${classOf[GenericArrayData].getName()}(new Object[$numElements]);"
    +      }
    +
    +      val primitiveValueTypeName = CodeGenerator.primitiveTypeName(et)
    +      val assignments = {
    +        val updateArray = if (isPrimitive) {
    +          val isNull = left.genCode(ctx).isNull
    +          s"""
    +             | if ($isNull) {
    +             |   ${ev.value}.setNullAt(k);
    +             | } else {
    +             |   ${ev.value}.set$primitiveValueTypeName(k, $l);
    +             | }
    +           """.stripMargin
    +        } else {
    +          s"${ev.value}.update(k, $l);"
    --- End diff --
    
    Following ueshin's comment, I am now only calling `update()` when element is not `null`,
as when initializing the array all the elements are set to `null`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message