spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From davies <...@git.apache.org>
Subject [GitHub] spark pull request: [SPARK-8935][SQL] Implement code generation fo...
Date Wed, 22 Jul 2015 05:50:58 GMT
Github user davies commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7365#discussion_r35183319
  
    --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
---
    @@ -418,51 +418,515 @@ case class Cast(child: Expression, dataType: DataType)
       protected override def nullSafeEval(input: Any): Any = cast(input)
     
       override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
    -    // TODO: Add support for more data types.
    -    (child.dataType, dataType) match {
    +    val eval = child.gen(ctx)
    +    val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
    +    eval.code +
    +      castCode(ctx, eval.primitive, eval.isNull, ev.primitive, ev.isNull, dataType, nullSafeCast)
    +  }
    +
    +  // three function arguments are: child.primitive, result.primitive and result.isNull
    +  // it returns the code snippets to be put in null safe evaluation region
    +  private[this] type CastFunction = (String, String, String) => String
    +
    +  private[this] def nullSafeCastFunction(
    +      from: DataType,
    +      to: DataType,
    +      ctx: CodeGenContext): CastFunction = to match {
    +
    +    case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;"
    +    case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;"
    +    case StringType => castToStringCode(from, ctx)
    +    case BinaryType => castToBinaryCode(from)
    +    case DateType => castToDateCode(from, ctx)
    +    case decimal: DecimalType => castToDecimalCode(from, decimal)
    +    case TimestampType => castToTimestampCode(from, ctx)
    +    case IntervalType => castToIntervalCode(from)
    +    case BooleanType => castToBooleanCode(from)
    +    case ByteType => castToByteCode(from)
    +    case ShortType => castToShortCode(from)
    +    case IntegerType => castToIntCode(from)
    +    case FloatType => castToFloatCode(from)
    +    case LongType => castToLongCode(from)
    +    case DoubleType => castToDoubleCode(from)
    +
    +    case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType], array, ctx)
    +    case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
    +    case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct,
ctx)
    +  }
    +
    +  // Since we need to cast child expressions recursively inside ComplexTypes, such as
Map's
    +  // Key and Value, Struct's field, we need to name out all the variable names involved
in a cast.
    +  private[this] def castCode(ctx: CodeGenContext, childPrim: String, childNull: String,
    +    resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction):
String = {
    +    s"""
    +      boolean $resultNull = $childNull;
    +      ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)};
    +      if (!${childNull}) {
    +        ${cast(childPrim, resultPrim, resultNull)}
    +      }
    +    """
    +  }
    +
    +  private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): CastFunction
= {
    +    from match {
    +      case BinaryType =>
    +        (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);"
    +      case DateType =>
    +        (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
    +          org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));"""
    +      case TimestampType =>
    +        (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
    +          org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c));"""
    +      case _ =>
    +        (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
    +    }
    +  }
    +
    +  private[this] def castToBinaryCode(from: DataType): CastFunction = from match {
    +    case StringType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c.getBytes();"
    +  }
     
    -      case (BinaryType, StringType) =>
    -        defineCodeGen (ctx, ev, c =>
    -          s"UTF8String.fromBytes($c)")
    +  private[this] def castToDateCode(
    +      from: DataType,
    +      ctx: CodeGenContext): CastFunction = from match {
    +    case StringType =>
    +      val intOpt = ctx.freshName("intOpt")
    +      (c, evPrim, evNull) => s"""
    +        scala.Option<Integer> $intOpt =
    +          org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c);
    +        if ($intOpt.isDefined()) {
    +          $evPrim = ((Integer) $intOpt.get()).intValue();
    +        } else {
    +          $evNull = true;
    +        }
    +       """
    +    case TimestampType =>
    +      (c, evPrim, evNull) =>
    +        s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c
/ 1000L);";
    +    case _ =>
    +      (c, evPrim, evNull) => s"$evNull = true;"
    +  }
    +
    +  private[this] def changePrecision(d: String, decimalType: DecimalType,
    +      evPrim: String, evNull: String): String = {
    +    decimalType match {
    +      case DecimalType.Unlimited =>
    +        s"$evPrim = $d;"
    +      case DecimalType.Fixed(precision, scale) =>
    +        s"""
    +          if ($d.changePrecision($precision, $scale)) {
    +            $evPrim = $d;
    +          } else {
    +            $evNull = true;
    +          }
    +        """
    +    }
    +  }
     
    -      case (DateType, StringType) =>
    -        defineCodeGen(ctx, ev, c =>
    -          s"""UTF8String.fromString(
    -                org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""")
    +  private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction
= {
    +    from match {
    +      case StringType =>
    +        (c, evPrim, evNull) =>
    +          s"""
    +            try {
    +              org.apache.spark.sql.types.Decimal tmpDecimal =
    +                new org.apache.spark.sql.types.Decimal().set(
    +                  new scala.math.BigDecimal(
    +                    new java.math.BigDecimal($c.toString())));
    +              ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +            } catch (java.lang.NumberFormatException e) {
    +              $evNull = true;
    +            }
    +          """
    +      case BooleanType =>
    +        (c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal = null;
    +            if ($c) {
    +              tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1);
    +            } else {
    +              tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0);
    +            }
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """
    +      case DateType =>
    +        // date can't cast to decimal in Hive
    +        (c, evPrim, evNull) => s"$evNull = true;"
    +      case TimestampType =>
    +        // Note that we lose precision here.
    +        (c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal =
    +              new org.apache.spark.sql.types.Decimal().set(
    +                scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """
    +      case DecimalType() =>
    +        (c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone();
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """
    +      case LongType =>
    +        (c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal =
    +              new org.apache.spark.sql.types.Decimal().set($c);
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """
    +      case x: NumericType =>
    +        // All other numeric types can be represented precisely as Doubles
    +        (c, evPrim, evNull) =>
    +          s"""
    +            try {
    +              org.apache.spark.sql.types.Decimal tmpDecimal =
    +                new org.apache.spark.sql.types.Decimal().set(
    +                  scala.math.BigDecimal.valueOf((double) $c));
    +              ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +            } catch (java.lang.NumberFormatException e) {
    +              $evNull = true;
    +            }
    +          """
    +    }
    +  }
    +
    +  private[this] def castToTimestampCode(
    +      from: DataType,
    +      ctx: CodeGenContext): CastFunction = from match {
    +    case StringType =>
    +      val longOpt = ctx.freshName("longOpt")
    +      (c, evPrim, evNull) =>
    +        s"""
    +          scala.Option<Long> $longOpt =
    +            org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c);
    +          if ($longOpt.isDefined()) {
    +            $evPrim = ((Long) $longOpt.get()).longValue();
    +          } else {
    +            $evNull = true;
    +          }
    +         """
    +    case BooleanType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;"
    +    case _: IntegralType =>
    +      (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};"
    +    case DateType =>
    +      (c, evPrim, evNull) =>
    +        s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c)
* 1000;"
    +    case DecimalType() =>
    +      (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};"
    +    case DoubleType =>
    +      (c, evPrim, evNull) =>
    +        s"""
    +          if (Double.isNaN($c) || Double.isInfinite($c)) {
    +            $evNull = true;
    +          } else {
    +            $evPrim = (long)($c * 1000000L);
    +          }
    +        """
    +    case FloatType =>
    +      (c, evPrim, evNull) =>
    +        s"""
    +          if (Float.isNaN($c) || Float.isInfinite($c)) {
    +            $evNull = true;
    +          } else {
    +            $evPrim = (long)($c * 1000000L);
    +          }
    +        """
    +  }
    +
    +  private[this] def castToIntervalCode(from: DataType): CastFunction = from match {
    +    case StringType =>
    +      (c, evPrim, evNull) =>
    +        s"$evPrim = org.apache.spark.unsafe.types.Interval.fromString($c.toString());"
    +  }
    +
    +  private[this] def decimalToTimestampCode(d: String): String =
    +    s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()"
    +  private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L"
    +  private[this] def timestampToIntegerCode(ts: String): String =
    +    s"java.lang.Math.floor((double) $ts / 1000000L)"
    +  private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0"
    +
    +  private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
    +    case StringType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;"
    +    case TimestampType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c != 0;"
    +    case DateType =>
    +      // Hive would return null when cast from date to boolean
    +      (c, evPrim, evNull) => s"$evNull = true;"
    +    case DecimalType() =>
    +      (c, evPrim, evNull) => s"$evPrim = !$c.isZero();"
    +    case n: NumericType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c != 0;"
    +  }
    +
    +  private[this] def castToByteCode(from: DataType): CastFunction = from match {
    +    case StringType =>
    +      (c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Byte.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """
    +    case BooleanType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
    +    case DateType =>
    +      (c, evPrim, evNull) => s"$evNull = true;"
    +    case TimestampType =>
    +      (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};"
    +    case DecimalType() =>
    +      (c, evPrim, evNull) => s"$evPrim = $c.toByte();"
    +    case x: NumericType =>
    +      (c, evPrim, evNull) => s"$evPrim = (byte) $c;"
    +  }
     
    -      case (TimestampType, StringType) =>
    -        defineCodeGen(ctx, ev, c =>
    -          s"""UTF8String.fromString(
    -                org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""")
    +  private[this] def castToShortCode(from: DataType): CastFunction = from match {
    +    case StringType =>
    +      (c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Short.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """
    +    case BooleanType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
    +    case DateType =>
    +      (c, evPrim, evNull) => s"$evNull = true;"
    +    case TimestampType =>
    +      (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};"
    +    case DecimalType() =>
    +      (c, evPrim, evNull) => s"$evPrim = $c.toShort();"
    +    case x: NumericType =>
    +      (c, evPrim, evNull) => s"$evPrim = (short) $c;"
    +  }
     
    -      case (_, StringType) =>
    -        defineCodeGen(ctx, ev, c => s"UTF8String.fromString(String.valueOf($c))")
    +  private[this] def castToIntCode(from: DataType): CastFunction = from match {
    +    case StringType =>
    +      (c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Integer.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """
    +    case BooleanType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
    +    case DateType =>
    +      (c, evPrim, evNull) => s"$evNull = true;"
    +    case TimestampType =>
    +      (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};"
    +    case DecimalType() =>
    +      (c, evPrim, evNull) => s"$evPrim = $c.toInt();"
    +    case x: NumericType =>
    +      (c, evPrim, evNull) => s"$evPrim = (int) $c;"
    +  }
     
    -      case (StringType, IntervalType) =>
    -        defineCodeGen(ctx, ev, c =>
    -          s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())")
    +  private[this] def castToLongCode(from: DataType): CastFunction = from match {
    +    case StringType =>
    +      (c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Long.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """
    +    case BooleanType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
    +    case DateType =>
    +      (c, evPrim, evNull) => s"$evNull = true;"
    +    case TimestampType =>
    +      (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};"
    +    case DecimalType() =>
    +      (c, evPrim, evNull) => s"$evPrim = $c.toLong();"
    +    case x: NumericType =>
    +      (c, evPrim, evNull) => s"$evPrim = (long) $c;"
    +  }
     
    -      // fallback for DecimalType, this must be before other numeric types
    -      case (_, dt: DecimalType) =>
    -        super.genCode(ctx, ev)
    +  private[this] def castToFloatCode(from: DataType): CastFunction = from match {
    +    case StringType =>
    +      (c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Float.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """
    +    case BooleanType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
    +    case DateType =>
    +      (c, evPrim, evNull) => s"$evNull = true;"
    +    case TimestampType =>
    +      (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});"
    +    case DecimalType() =>
    +      (c, evPrim, evNull) => s"$evPrim = $c.toFloat();"
    +    case x: NumericType =>
    +      (c, evPrim, evNull) => s"$evPrim = (float) $c;"
    +  }
     
    -      case (BooleanType, dt: NumericType) =>
    -        defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
    +  private[this] def castToDoubleCode(from: DataType): CastFunction = from match {
    +    case StringType =>
    +      (c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = Double.valueOf($c.toString());
    +          } catch (java.lang.NumberFormatException e) {
    +            $evNull = true;
    +          }
    +        """
    +    case BooleanType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;"
    +    case DateType =>
    +      (c, evPrim, evNull) => s"$evNull = true;"
    +    case TimestampType =>
    +      (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};"
    +    case DecimalType() =>
    +      (c, evPrim, evNull) => s"$evPrim = $c.toDouble();"
    +    case x: NumericType =>
    +      (c, evPrim, evNull) => s"$evPrim = (double) $c;"
    +  }
     
    -      case (dt: DecimalType, BooleanType) =>
    -        defineCodeGen(ctx, ev, c => s"!$c.isZero()")
    +  private[this] def unboxPrimitive(ctx: CodeGenContext, dt: DataType, obj: String): String
= {
    +    dt match {
    +      case _: IntegralType | FloatType | DoubleType =>
    +        s"((${ctx.boxedType(dt)}) $obj).${ctx.javaType(dt)}Value()";
    +      case _ =>
    +        s"(${ctx.javaType(dt)}) $obj"
    +    }
    +  }
     
    -      case (dt: NumericType, BooleanType) =>
    -        defineCodeGen(ctx, ev, c => s"$c != 0")
    +  private[this] def castArrayCode(
    +      from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = {
    +    val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx)
    +
    +    val arraySeqClass = "scala.collection.mutable.ArraySeq"
    +    val fromElementNull = ctx.freshName("feNull")
    +    val fromElementPrim = ctx.freshName("fePrim")
    +    val toElementNull = ctx.freshName("teNull")
    +    val toElementPrim = ctx.freshName("tePrim")
    +    val size = ctx.freshName("n")
    +    val j = ctx.freshName("j")
    +    val result = ctx.freshName("result")
    +
    +    (c, evPrim, evNull) =>
    +      s"""
    +        final int $size = $c.size();
    +        final $arraySeqClass<Object> $result = new $arraySeqClass<Object>($size);
    +        for (int $j = 0; $j < $size; $j ++) {
    +          if ($c.apply($j) == null) {
    +            $result.update($j, null);
    +          } else {
    +            boolean $fromElementNull = false;
    +            ${ctx.javaType(from.elementType)} $fromElementPrim =
    +              ${unboxPrimitive(ctx, from.elementType, s"$c.apply($j)")};
    --- End diff --
    
    Use should use the special getter to access them, because UnsafeRow does not support generic
getter for primitive types, see `ctx.getColumn`, then we don't need to `unboxPrimitive()`


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message