spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-8245][SQL] FormatNumber/Length Support for Expression
Date Thu, 16 Jul 2015 04:47:29 GMT
Repository: spark
Updated Branches:
  refs/heads/master 9c64a75bf -> 42dea3acf


[SPARK-8245][SQL] FormatNumber/Length Support for Expression

- `BinaryType` for `Length`
- `FormatNumber`

Author: Cheng Hao <hao.cheng@intel.com>

Closes #7034 from chenghao-intel/expression and squashes the following commits:

e534b87 [Cheng Hao] python api style issue
601bbf5 [Cheng Hao] add python API support
3ebe288 [Cheng Hao] update as feedback
52274f7 [Cheng Hao] add support for udf_format_number and length for binary


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/42dea3ac
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/42dea3ac
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/42dea3ac

Branch: refs/heads/master
Commit: 42dea3acf90ec506a0b79720b55ae1d753cc7544
Parents: 9c64a75
Author: Cheng Hao <hao.cheng@intel.com>
Authored: Wed Jul 15 21:47:21 2015 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Wed Jul 15 21:47:21 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 25 ++++--
 .../catalyst/analysis/FunctionRegistry.scala    |  5 +-
 .../catalyst/expressions/stringOperations.scala | 94 ++++++++++++++++++--
 .../expressions/StringFunctionsSuite.scala      | 53 ++++++++---
 .../scala/org/apache/spark/sql/functions.scala  | 32 ++++++-
 .../spark/sql/DataFrameFunctionsSuite.scala     | 93 ++++++++++++++++---
 6 files changed, 261 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/42dea3ac/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index dca39fa..e0816b3 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -39,6 +39,8 @@ __all__ = [
     'coalesce',
     'countDistinct',
     'explode',
+    'format_number',
+    'length',
     'log2',
     'md5',
     'monotonicallyIncreasingId',
@@ -47,7 +49,6 @@ __all__ = [
     'sha1',
     'sha2',
     'sparkPartitionId',
-    'strlen',
     'struct',
     'udf',
     'when']
@@ -506,14 +507,28 @@ def sparkPartitionId():
 
 @ignore_unicode_prefix
 @since(1.5)
-def strlen(col):
-    """Calculates the length of a string expression.
+def length(col):
+    """Calculates the length of a string or binary expression.
 
-    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect()
+    >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect()
     [Row(length=3)]
     """
     sc = SparkContext._active_spark_context
-    return Column(sc._jvm.functions.strlen(_to_java_column(col)))
+    return Column(sc._jvm.functions.length(_to_java_column(col)))
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def format_number(col, d):
+    """Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
+       and returns the result as a string.
+    :param col: the column name of the numeric value to be formatted
+    :param d: the N decimal places
+    >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect()
+    [Row(v=u'5.0000')]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.format_number(_to_java_column(col), d))
 
 
 @ignore_unicode_prefix

http://git-wip-us.apache.org/repos/asf/spark/blob/42dea3ac/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index d2678ce..e0beafe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -152,11 +152,12 @@ object FunctionRegistry {
     expression[Base64]("base64"),
     expression[Encode]("encode"),
     expression[Decode]("decode"),
-    expression[StringInstr]("instr"),
+    expression[FormatNumber]("format_number"),
     expression[Lower]("lcase"),
     expression[Lower]("lower"),
-    expression[StringLength]("length"),
+    expression[Length]("length"),
     expression[Levenshtein]("levenshtein"),
+    expression[StringInstr]("instr"),
     expression[StringLocate]("locate"),
     expression[StringLPad]("lpad"),
     expression[StringTrimLeft]("ltrim"),

http://git-wip-us.apache.org/repos/asf/spark/blob/42dea3ac/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 03b55ce..c64afe7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -17,11 +17,10 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import java.text.DecimalFormat
 import java.util.Locale
 import java.util.regex.Pattern
 
-import org.apache.commons.lang3.StringUtils
-
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.UnresolvedException
 import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -553,17 +552,22 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
 }
 
 /**
- * A function that return the length of the given string expression.
+ * A function that return the length of the given string or binary expression.
  */
-case class StringLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes
{
+case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes {
   override def dataType: DataType = IntegerType
-  override def inputTypes: Seq[DataType] = Seq(StringType)
+  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
 
-  protected override def nullSafeEval(string: Any): Any =
-    string.asInstanceOf[UTF8String].numChars
+  protected override def nullSafeEval(value: Any): Any = child.dataType match {
+    case StringType => value.asInstanceOf[UTF8String].numChars
+    case BinaryType => value.asInstanceOf[Array[Byte]].length
+  }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    defineCodeGen(ctx, ev, c => s"($c).numChars()")
+    child.dataType match {
+      case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()")
+      case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length")
+    }
   }
 
   override def prettyName: String = "length"
@@ -668,3 +672,77 @@ case class Encode(value: Expression, charset: Expression)
   }
 }
 
+/**
+ * Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
+ * and returns the result as a string. If D is 0, the result has no decimal point or
+ * fractional part.
+ */
+case class FormatNumber(x: Expression, d: Expression)
+  extends BinaryExpression with ExpectsInputTypes {
+
+  override def left: Expression = x
+  override def right: Expression = d
+  override def dataType: DataType = StringType
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
+
+  // Associated with the pattern, for the last d value, and we will update the
+  // pattern (DecimalFormat) once the new coming d value differ with the last one.
+  @transient
+  private var lastDValue: Int = -100
+
+  // A cached DecimalFormat, for performance concern, we will change it
+  // only if the d value changed.
+  @transient
+  private val pattern: StringBuffer = new StringBuffer()
+
+  @transient
+  private val numberFormat: DecimalFormat = new DecimalFormat("")
+
+  override def eval(input: InternalRow): Any = {
+    val xObject = x.eval(input)
+    if (xObject == null) {
+      return null
+    }
+
+    val dObject = d.eval(input)
+
+    if (dObject == null || dObject.asInstanceOf[Int] < 0) {
+      return null
+    }
+    val dValue = dObject.asInstanceOf[Int]
+
+    if (dValue != lastDValue) {
+      // construct a new DecimalFormat only if a new dValue
+      pattern.delete(0, pattern.length())
+      pattern.append("#,###,###,###,###,###,##0")
+
+      // decimal place
+      if (dValue > 0) {
+        pattern.append(".")
+
+        var i = 0
+        while (i < dValue) {
+          i += 1
+          pattern.append("0")
+        }
+      }
+      val dFormat = new DecimalFormat(pattern.toString())
+      lastDValue = dValue;
+      numberFormat.applyPattern(dFormat.toPattern())
+    }
+
+    x.dataType match {
+      case ByteType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Byte]))
+      case ShortType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Short]))
+      case FloatType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Float]))
+      case IntegerType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Int]))
+      case LongType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Long]))
+      case DoubleType => UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Double]))
+      case _: DecimalType =>
+        UTF8String.fromString(numberFormat.format(xObject.asInstanceOf[Decimal].toJavaBigDecimal))
+    }
+  }
+
+  override def prettyName: String = "format_number"
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/42dea3ac/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
index b19f4ee..5d7763b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}
+import org.apache.spark.sql.types._
 
 
 class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -216,15 +216,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
{
     }
   }
 
-  test("length for string") {
-    val a = 'a.string.at(0)
-    checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef"))
-    checkEvaluation(StringLength(a), 5, create_row("abdef"))
-    checkEvaluation(StringLength(a), 0, create_row(""))
-    checkEvaluation(StringLength(a), null, create_row(null))
-    checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
-  }
-
   test("ascii for string") {
     val a = 'a.string.at(0)
     checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef"))
@@ -426,4 +417,46 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
{
     checkEvaluation(
       StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
   }
+
+  test("length for string / binary") {
+    val a = 'a.string.at(0)
+    val b = 'b.binary.at(0)
+    val bytes = Array[Byte](1, 2, 3, 1, 2)
+    val string = "abdef"
+
+    // scalastyle:off
+    // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+    checkEvaluation(Length(Literal("a花花c")), 4, create_row(string))
+    // scalastyle:on
+    checkEvaluation(Length(Literal(bytes)), 5, create_row(Array[Byte]()))
+
+    checkEvaluation(Length(a), 5, create_row(string))
+    checkEvaluation(Length(b), 5, create_row(bytes))
+
+    checkEvaluation(Length(a), 0, create_row(""))
+    checkEvaluation(Length(b), 0, create_row(Array[Byte]()))
+
+    checkEvaluation(Length(a), null, create_row(null))
+    checkEvaluation(Length(b), null, create_row(null))
+
+    checkEvaluation(Length(Literal.create(null, StringType)), null, create_row(string))
+    checkEvaluation(Length(Literal.create(null, BinaryType)), null, create_row(bytes))
+  }
+
+  test("number format") {
+    checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Byte]), Literal(3)), "4.000")
+    checkEvaluation(FormatNumber(Literal(4.asInstanceOf[Short]), Literal(3)), "4.000")
+    checkEvaluation(FormatNumber(Literal(4.0f), Literal(3)), "4.000")
+    checkEvaluation(FormatNumber(Literal(4), Literal(3)), "4.000")
+    checkEvaluation(FormatNumber(Literal(12831273.23481d), Literal(3)), "12,831,273.235")
+    checkEvaluation(FormatNumber(Literal(12831273.83421d), Literal(0)), "12,831,274")
+    checkEvaluation(FormatNumber(Literal(123123324123L), Literal(3)), "123,123,324,123.000")
+    checkEvaluation(FormatNumber(Literal(123123324123L), Literal(-1)), null)
+    checkEvaluation(
+      FormatNumber(
+        Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)),
+      "15,159,339,180,002,773.2778")
+    checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
+    checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/42dea3ac/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index c7deaca..d6da284 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1685,20 +1685,44 @@ object functions {
   //////////////////////////////////////////////////////////////////////////////////////////////
 
   /**
-   * Computes the length of a given string value.
+   * Computes the length of a given string / binary value.
    *
    * @group string_funcs
    * @since 1.5.0
    */
-  def strlen(e: Column): Column = StringLength(e.expr)
+  def length(e: Column): Column = Length(e.expr)
 
   /**
-   * Computes the length of a given string column.
+   * Computes the length of a given string / binary column.
    *
    * @group string_funcs
    * @since 1.5.0
    */
-  def strlen(columnName: String): Column = strlen(Column(columnName))
+  def length(columnName: String): Column = length(Column(columnName))
+
+  /**
+   * Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
+   * and returns the result as a string.
+   * If d is 0, the result has no decimal point or fractional part.
+   * If d < 0, the result will be null.
+   *
+   * @group string_funcs
+   * @since 1.5.0
+   */
+  def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr)
+
+  /**
+   * Formats the number X to a format like '#,###,###.##', rounded to d decimal places,
+   * and returns the result as a string.
+   * If d is 0, the result has no decimal point or fractional part.
+   * If d < 0, the result will be null.
+   *
+   * @group string_funcs
+   * @since 1.5.0
+   */
+  def format_number(columnXName: String, d: Int): Column = {
+    format_number(Column(columnXName), d)
+  }
 
   /**
    * Computes the Levenshtein distance of the two given strings.

http://git-wip-us.apache.org/repos/asf/spark/blob/42dea3ac/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 70bd787..6dccdd8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -208,17 +208,6 @@ class DataFrameFunctionsSuite extends QueryTest {
       Row(2743272264L, 2180413220L))
   }
 
-  test("string length function") {
-    val df = Seq(("abc", "")).toDF("a", "b")
-    checkAnswer(
-      df.select(strlen($"a"), strlen("b")),
-      Row(3, 0))
-
-    checkAnswer(
-      df.selectExpr("length(a)", "length(b)"),
-      Row(3, 0))
-  }
-
   test("Levenshtein distance") {
     val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
     checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
@@ -433,11 +422,91 @@ class DataFrameFunctionsSuite extends QueryTest {
     val doubleData = Seq((7.2, 4.1)).toDF("a", "b")
     checkAnswer(
       doubleData.select(pmod('a, 'b)),
-      Seq(Row(3.1000000000000005))  // same as hive
+      Seq(Row(3.1000000000000005)) // same as hive
     )
     checkAnswer(
       doubleData.select(pmod(lit(2), lit(Int.MaxValue))),
       Seq(Row(2))
     )
   }
+
+  test("string / binary length function") {
+    val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c")
+    checkAnswer(
+      df.select(length($"a"), length("a"), length($"b"), length("b")),
+      Row(3, 3, 4, 4))
+
+    checkAnswer(
+      df.selectExpr("length(a)", "length(b)"),
+      Row(3, 4))
+
+    intercept[AnalysisException] {
+      checkAnswer(
+        df.selectExpr("length(c)"), // int type of the argument is unacceptable
+        Row("5.0000"))
+    }
+  }
+
+  test("number format function") {
+    val tuple =
+      ("aa", 1.asInstanceOf[Byte], 2.asInstanceOf[Short],
+        3.13223f, 4, 5L, 6.48173d, Decimal(7.128381))
+    val df =
+      Seq(tuple)
+        .toDF(
+          "a", // string "aa"
+          "b", // byte    1
+          "c", // short   2
+          "d", // float   3.13223f
+          "e", // integer 4
+          "f", // long    5L
+          "g", // double  6.48173d
+          "h") // decimal 7.128381
+
+    checkAnswer(
+      df.select(
+        format_number($"f", 4),
+        format_number("f", 4)),
+      Row("5.0000", "5.0000"))
+
+    checkAnswer(
+      df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer
+      Row("1.0000"))
+
+    checkAnswer(
+      df.selectExpr("format_number(c, e)"), // convert the 1st argument to integer
+      Row("2.0000"))
+
+    checkAnswer(
+      df.selectExpr("format_number(d, e)"), // convert the 1st argument to double
+      Row("3.1322"))
+
+    checkAnswer(
+      df.selectExpr("format_number(e, e)"), // not convert anything
+      Row("4.0000"))
+
+    checkAnswer(
+      df.selectExpr("format_number(f, e)"), // not convert anything
+      Row("5.0000"))
+
+    checkAnswer(
+      df.selectExpr("format_number(g, e)"), // not convert anything
+      Row("6.4817"))
+
+    checkAnswer(
+      df.selectExpr("format_number(h, e)"), // not convert anything
+      Row("7.1284"))
+
+    intercept[AnalysisException] {
+      checkAnswer(
+        df.selectExpr("format_number(a, e)"), // string type of the 1st argument is unacceptable
+        Row("5.0000"))
+    }
+
+    intercept[AnalysisException] {
+      checkAnswer(
+        df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable
+        Row("5.0000"))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message