Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id B062518EEC for ; Sat, 1 Aug 2015 15:48:50 +0000 (UTC) Received: (qmail 78774 invoked by uid 500); 1 Aug 2015 15:48:50 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 78743 invoked by uid 500); 1 Aug 2015 15:48:50 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 78734 invoked by uid 99); 1 Aug 2015 15:48:50 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Sat, 01 Aug 2015 15:48:50 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 77C32E0435; Sat, 1 Aug 2015 15:48:50 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 8bit From: davies@apache.org To: commits@spark.apache.org Message-Id: <4bc12e378e934999b01cf42209640ce9@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-8263] [SQL] substr/substring should also support binary type Date: Sat, 1 Aug 2015 15:48:50 +0000 (UTC) Repository: spark Updated Branches: refs/heads/master cf6c9ca32 -> c5166f7a6 [SPARK-8263] [SQL] substr/substring should also support binary type This is based on #7641, thanks to zhichao-li Closes #7641 Author: zhichao.li Author: Davies Liu Closes #7848 from davies/substr and squashes the following commits: 461b709 [Davies Liu] remove bytearry from tests b45377a [Davies Liu] Merge branch 'master' of github.com:apache/spark into substr 01d795e [zhichao.li] scala style 99aa130 [zhichao.li] add substring to dataframe 4f68bfe [zhichao.li] add binary type support for substring Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c5166f7a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c5166f7a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c5166f7a Branch: refs/heads/master Commit: c5166f7a69faeaa8a41a774c73c1ed4d4c2cf0ce Parents: cf6c9ca Author: zhichao.li Authored: Sat Aug 1 08:48:46 2015 -0700 Committer: Davies Liu Committed: Sat Aug 1 08:48:46 2015 -0700 ---------------------------------------------------------------------- python/pyspark/sql/functions.py | 18 ++++++- .../catalyst/expressions/stringOperations.scala | 51 ++++++++++++++++++-- .../expressions/StringExpressionsSuite.scala | 15 +++++- .../scala/org/apache/spark/sql/functions.scala | 11 +++++ .../apache/spark/sql/StringFunctionsSuite.scala | 10 ++++ .../apache/spark/unsafe/types/UTF8String.java | 17 ++++--- 6 files changed, 109 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 81dc7d8..96975f5 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -64,7 +64,7 @@ __all__ += [ 'year', 'quarter', 'month', 'hour', 'minute', 'second', 'dayofmonth', 'dayofyear', 'weekofyear'] -__all__ += ['soundex'] +__all__ += ['soundex', 'substring', 'substring_index'] def _create_function(name, doc=""): @@ -925,6 +925,22 @@ def trunc(date, format): @since(1.5) @ignore_unicode_prefix +def substring(str, pos, len): + """ + Substring starts at `pos` and is of length `len` when str is String type or + returns the slice of byte array that starts at `pos` in byte and is of length `len` + when str is Binary type + + >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df.select(substring(df.s, 1, 2).alias('s')).collect() + [Row(s=u'ab')] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.substring(_to_java_column(str), pos, len)) + + +@since(1.5) +@ignore_unicode_prefix def substring_index(str, delim, count): """ Returns the substring from string str before count occurrences of the delimiter delim. http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 3ce5d6a..4d78c55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.text.DecimalFormat +import java.util.Arrays import java.util.Locale import java.util.regex.{MatchResult, Pattern} @@ -679,6 +680,34 @@ case class StringSplit(str: Expression, pattern: Expression) override def prettyName: String = "split" } +object Substring { + def subStringBinarySQL(bytes: Array[Byte], pos: Int, len: Int): Array[Byte] = { + if (pos > bytes.length) { + return Array[Byte]() + } + + var start = if (pos > 0) { + pos - 1 + } else if (pos < 0) { + bytes.length + pos + } else { + 0 + } + + val end = if ((bytes.length - start) < len) { + bytes.length + } else { + start + len + } + + start = Math.max(start, 0) // underflow + if (start < end) { + Arrays.copyOfRange(bytes, start, end) + } else { + Array[Byte]() + } + } +} /** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. @@ -690,18 +719,31 @@ case class Substring(str: Expression, pos: Expression, len: Expression) this(str, pos, Literal(Integer.MAX_VALUE)) } - override def dataType: DataType = StringType + override def dataType: DataType = str.dataType - override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringType, BinaryType), IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil override def nullSafeEval(string: Any, pos: Any, len: Any): Any = { - string.asInstanceOf[UTF8String].substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int]) + str.dataType match { + case StringType => string.asInstanceOf[UTF8String] + .substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int]) + case BinaryType => Substring.subStringBinarySQL(string.asInstanceOf[Array[Byte]], + pos.asInstanceOf[Int], len.asInstanceOf[Int]) + } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (str, pos, len) => s"$str.substringSQL($pos, $len)") + + val cls = classOf[Substring].getName + defineCodeGen(ctx, ev, (string, pos, len) => { + str.dataType match { + case StringType => s"$string.substringSQL($pos, $len)" + case BinaryType => s"$cls.subStringBinarySQL($string, $pos, $len)" + } + }) } } @@ -1161,4 +1203,3 @@ case class FormatNumber(x: Expression, d: Expression) override def prettyName: String = "format_number" } - http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index ad87ab3..89c1e33 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -186,6 +185,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(s.substr(0), "example", row) checkEvaluation(s.substring(0, 2), "ex", row) checkEvaluation(s.substring(0), "example", row) + + val bytes = Array[Byte](1, 2, 3, 4) + checkEvaluation(Substring(bytes, 0, 2), Array[Byte](1, 2)) + checkEvaluation(Substring(bytes, 1, 2), Array[Byte](1, 2)) + checkEvaluation(Substring(bytes, 2, 2), Array[Byte](2, 3)) + checkEvaluation(Substring(bytes, 3, 2), Array[Byte](3, 4)) + checkEvaluation(Substring(bytes, 4, 2), Array[Byte](4)) + checkEvaluation(Substring(bytes, 8, 2), Array[Byte]()) + checkEvaluation(Substring(bytes, -1, 2), Array[Byte](4)) + checkEvaluation(Substring(bytes, -2, 2), Array[Byte](3, 4)) + checkEvaluation(Substring(bytes, -3, 2), Array[Byte](2, 3)) + checkEvaluation(Substring(bytes, -4, 2), Array[Byte](1, 2)) + checkEvaluation(Substring(bytes, -5, 2), Array[Byte](1)) + checkEvaluation(Substring(bytes, -8, 2), Array[Byte]()) } test("string substring_index function") { http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3c9421f..babfe21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1726,6 +1726,17 @@ object functions { def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) /** + * Substring starts at `pos` and is of length `len` when str is String type or + * returns the slice of byte array that starts at `pos` in byte and is of length `len` + * when str is Binary type + * + * @group string_funcs + * @since 1.5.0 + */ + def substring(str: Column, pos: Int, len: Int): Column = + Substring(str.expr, lit(pos).expr, lit(len).expr) + + /** * Computes the Levenshtein distance of the two given string columns. * @group string_funcs * @since 1.5.0 http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 628da95..f40233d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -103,6 +103,16 @@ class StringFunctionsSuite extends QueryTest { Row("AQIDBA==", bytes)) } + test("string / binary substring function") { + // scalastyle:off + // non ascii characters are not allowed in the code, so we disable the scalastyle here. + val df = Seq(("1世3", Array[Byte](1, 2, 3, 4))).toDF("a", "b") + checkAnswer(df.select(substring($"a", 1, 2)), Row("1世")) + checkAnswer(df.select(substring($"b", 2, 2)), Row(Array[Byte](2,3))) + checkAnswer(df.selectExpr("substring(a, 1, 2)"), Row("1世")) + // scalastyle:on + } + test("string encode/decode function") { val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) // scalastyle:off http://git-wip-us.apache.org/repos/asf/spark/blob/c5166f7a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java ---------------------------------------------------------------------- diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f6dafe9..208503d 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -198,7 +198,7 @@ public final class UTF8String implements Comparable, Serializable { */ public UTF8String substring(final int start, final int until) { if (until <= start || start >= numBytes) { - return UTF8String.EMPTY_UTF8; + return EMPTY_UTF8; } int i = 0; @@ -214,9 +214,13 @@ public final class UTF8String implements Comparable, Serializable { c += 1; } - byte[] bytes = new byte[i - j]; - copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); - return fromBytes(bytes); + if (i > j) { + byte[] bytes = new byte[i - j]; + copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j); + return fromBytes(bytes); + } else { + return EMPTY_UTF8; + } } public UTF8String substringSQL(int pos, int length) { @@ -226,8 +230,9 @@ public final class UTF8String implements Comparable, Serializable { // refers to element i-1 in the sequence. If a start index i is less than 0, it refers // to the -ith element before the end of the sequence. If a start index i is 0, it // refers to the first element. - int start = (pos > 0) ? pos -1 : ((pos < 0) ? numChars() + pos : 0); - int end = (length == Integer.MAX_VALUE) ? Integer.MAX_VALUE : start + length; + int len = numChars(); + int start = (pos > 0) ? pos -1 : ((pos < 0) ? len + pos : 0); + int end = (length == Integer.MAX_VALUE) ? len : start + length; return substring(start, end); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org