spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-8263] [SQL] substr/substring should also support binary type
Date Sat, 01 Aug 2015 15:48:50 GMT
Repository: spark
Updated Branches:
  refs/heads/master cf6c9ca32 -> c5166f7a6


[SPARK-8263] [SQL] substr/substring should also support binary type

This is based on #7641, thanks to zhichao-li

Closes #7641

Author: zhichao.li <zhichao.li@intel.com>
Author: Davies Liu <davies@databricks.com>

Closes #7848 from davies/substr and squashes the following commits:

461b709 [Davies Liu] remove bytearry from tests
b45377a [Davies Liu] Merge branch 'master' of github.com:apache/spark into substr
01d795e [zhichao.li] scala style
99aa130 [zhichao.li] add substring to dataframe
4f68bfe [zhichao.li] add binary type support for substring


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c5166f7a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c5166f7a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c5166f7a

Branch: refs/heads/master
Commit: c5166f7a69faeaa8a41a774c73c1ed4d4c2cf0ce
Parents: cf6c9ca
Author: zhichao.li <zhichao.li@intel.com>
Authored: Sat Aug 1 08:48:46 2015 -0700
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Sat Aug 1 08:48:46 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 18 ++++++-
 .../catalyst/expressions/stringOperations.scala | 51 ++++++++++++++++++--
 .../expressions/StringExpressionsSuite.scala    | 15 +++++-
 .../scala/org/apache/spark/sql/functions.scala  | 11 +++++
 .../apache/spark/sql/StringFunctionsSuite.scala | 10 ++++
 .../apache/spark/unsafe/types/UTF8String.java   | 17 ++++---
 6 files changed, 109 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 81dc7d8..96975f5 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -64,7 +64,7 @@ __all__ += [
     'year', 'quarter', 'month', 'hour', 'minute', 'second',
     'dayofmonth', 'dayofyear', 'weekofyear']
 
-__all__ += ['soundex']
+__all__ += ['soundex', 'substring', 'substring_index']
 
 
 def _create_function(name, doc=""):
@@ -925,6 +925,22 @@ def trunc(date, format):
 
 @since(1.5)
 @ignore_unicode_prefix
+def substring(str, pos, len):
+    """
+    Substring starts at `pos` and is of length `len` when str is String type or
+    returns the slice of byte array that starts at `pos` in byte and is of length `len`
+    when str is Binary type
+
+    >>> df = sqlContext.createDataFrame([('abcd',)], ['s',])
+    >>> df.select(substring(df.s, 1, 2).alias('s')).collect()
+    [Row(s=u'ab')]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.substring(_to_java_column(str), pos, len))
+
+
+@since(1.5)
+@ignore_unicode_prefix
 def substring_index(str, delim, count):
     """
     Returns the substring from string str before count occurrences of the delimiter delim.

http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 3ce5d6a..4d78c55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import java.text.DecimalFormat
+import java.util.Arrays
 import java.util.Locale
 import java.util.regex.{MatchResult, Pattern}
 
@@ -679,6 +680,34 @@ case class StringSplit(str: Expression, pattern: Expression)
   override def prettyName: String = "split"
 }
 
+object Substring {
+  def subStringBinarySQL(bytes: Array[Byte], pos: Int, len: Int): Array[Byte] = {
+    if (pos > bytes.length) {
+      return Array[Byte]()
+    }
+
+    var start = if (pos > 0) {
+      pos - 1
+    } else if (pos < 0) {
+      bytes.length + pos
+    } else {
+      0
+    }
+
+    val end = if ((bytes.length - start) < len) {
+      bytes.length
+    } else {
+      start + len
+    }
+
+    start = Math.max(start, 0)  // underflow
+    if (start < end) {
+      Arrays.copyOfRange(bytes, start, end)
+    } else {
+      Array[Byte]()
+    }
+  }
+}
 /**
  * A function that takes a substring of its first argument starting at a given position.
  * Defined for String and Binary types.
@@ -690,18 +719,31 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
     this(str, pos, Literal(Integer.MAX_VALUE))
   }
 
-  override def dataType: DataType = StringType
+  override def dataType: DataType = str.dataType
 
-  override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(TypeCollection(StringType, BinaryType), IntegerType, IntegerType)
 
   override def children: Seq[Expression] = str :: pos :: len :: Nil
 
   override def nullSafeEval(string: Any, pos: Any, len: Any): Any = {
-    string.asInstanceOf[UTF8String].substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int])
+    str.dataType match {
+      case StringType => string.asInstanceOf[UTF8String]
+        .substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int])
+      case BinaryType => Substring.subStringBinarySQL(string.asInstanceOf[Array[Byte]],
+        pos.asInstanceOf[Int], len.asInstanceOf[Int])
+    }
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    defineCodeGen(ctx, ev, (str, pos, len) => s"$str.substringSQL($pos, $len)")
+
+    val cls = classOf[Substring].getName
+    defineCodeGen(ctx, ev, (string, pos, len) => {
+      str.dataType match {
+        case StringType => s"$string.substringSQL($pos, $len)"
+        case BinaryType => s"$cls.subStringBinarySQL($string, $pos, $len)"
+      }
+    })
   }
 }
 
@@ -1161,4 +1203,3 @@ case class FormatNumber(x: Expression, d: Expression)
 
   override def prettyName: String = "format_number"
 }
-

http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index ad87ab3..89c1e33 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
 
 
 class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -186,6 +185,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
{
     checkEvaluation(s.substr(0), "example", row)
     checkEvaluation(s.substring(0, 2), "ex", row)
     checkEvaluation(s.substring(0), "example", row)
+
+    val bytes = Array[Byte](1, 2, 3, 4)
+    checkEvaluation(Substring(bytes, 0, 2), Array[Byte](1, 2))
+    checkEvaluation(Substring(bytes, 1, 2), Array[Byte](1, 2))
+    checkEvaluation(Substring(bytes, 2, 2), Array[Byte](2, 3))
+    checkEvaluation(Substring(bytes, 3, 2), Array[Byte](3, 4))
+    checkEvaluation(Substring(bytes, 4, 2), Array[Byte](4))
+    checkEvaluation(Substring(bytes, 8, 2), Array[Byte]())
+    checkEvaluation(Substring(bytes, -1, 2), Array[Byte](4))
+    checkEvaluation(Substring(bytes, -2, 2), Array[Byte](3, 4))
+    checkEvaluation(Substring(bytes, -3, 2), Array[Byte](2, 3))
+    checkEvaluation(Substring(bytes, -4, 2), Array[Byte](1, 2))
+    checkEvaluation(Substring(bytes, -5, 2), Array[Byte](1))
+    checkEvaluation(Substring(bytes, -8, 2), Array[Byte]())
   }
 
   test("string substring_index function") {

http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3c9421f..babfe21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1726,6 +1726,17 @@ object functions {
   def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr)
 
   /**
+   * Substring starts at `pos` and is of length `len` when str is String type or
+   * returns the slice of byte array that starts at `pos` in byte and is of length `len`
+   * when str is Binary type
+   *
+   * @group string_funcs
+   * @since 1.5.0
+   */
+  def substring(str: Column, pos: Int, len: Int): Column =
+    Substring(str.expr, lit(pos).expr, lit(len).expr)
+
+  /**
    * Computes the Levenshtein distance of the two given string columns.
    * @group string_funcs
    * @since 1.5.0

http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 628da95..f40233d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -103,6 +103,16 @@ class StringFunctionsSuite extends QueryTest {
       Row("AQIDBA==", bytes))
   }
 
+  test("string / binary substring function") {
+    // scalastyle:off
+    // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+    val df = Seq(("1世3", Array[Byte](1, 2, 3, 4))).toDF("a", "b")
+    checkAnswer(df.select(substring($"a", 1, 2)), Row("1世"))
+    checkAnswer(df.select(substring($"b", 2, 2)), Row(Array[Byte](2,3)))
+    checkAnswer(df.selectExpr("substring(a, 1, 2)"), Row("1世"))
+    // scalastyle:on
+  }
+
   test("string encode/decode function") {
     val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
     // scalastyle:off

http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index f6dafe9..208503d 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -198,7 +198,7 @@ public final class UTF8String implements Comparable<UTF8String>,
Serializable {
    */
   public UTF8String substring(final int start, final int until) {
     if (until <= start || start >= numBytes) {
-      return UTF8String.EMPTY_UTF8;
+      return EMPTY_UTF8;
     }
 
     int i = 0;
@@ -214,9 +214,13 @@ public final class UTF8String implements Comparable<UTF8String>,
Serializable {
       c += 1;
     }
 
-    byte[] bytes = new byte[i - j];
-    copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j);
-    return fromBytes(bytes);
+    if (i > j) {
+      byte[] bytes = new byte[i - j];
+      copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j);
+      return fromBytes(bytes);
+    } else {
+      return EMPTY_UTF8;
+    }
   }
 
   public UTF8String substringSQL(int pos, int length) {
@@ -226,8 +230,9 @@ public final class UTF8String implements Comparable<UTF8String>,
Serializable {
     // refers to element i-1 in the sequence. If a start index i is less than 0, it refers
     // to the -ith element before the end of the sequence. If a start index i is 0, it
     // refers to the first element.
-    int start = (pos > 0) ? pos -1 : ((pos < 0) ? numChars() + pos : 0);
-    int end = (length == Integer.MAX_VALUE) ? Integer.MAX_VALUE : start + length;
+    int len = numChars();
+    int start = (pos > 0) ? pos -1 : ((pos < 0) ? len + pos : 0);
+    int end = (length == Integer.MAX_VALUE) ? len : start + length;
     return substring(start, end);
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message