spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-8244] [SQL] string function: find in set
Date Tue, 04 Aug 2015 16:02:06 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 f44b27a2b -> 945da3534


[SPARK-8244] [SQL] string function: find in set

This PR is based on #7186 (just fix the conflict), thanks to tarekauel .

find_in_set(string str, string strList): int

Returns the first occurance of str in strList where strList is a comma-delimited string. Returns
null if either argument is null. Returns 0 if the first argument contains any commas. For
example, find_in_set('ab', 'abc,b,ab,c,def') returns 3.

Only add this to SQL, not DataFrame.

Closes #7186

Author: Tarek Auel <tarek.auel@googlemail.com>
Author: Davies Liu <davies@databricks.com>

Closes #7900 from davies/find_in_set and squashes the following commits:

4334209 [Davies Liu] Merge branch 'master' of github.com:apache/spark into find_in_set
8f00572 [Davies Liu] Merge branch 'master' of github.com:apache/spark into find_in_set
243ede4 [Tarek Auel] [SPARK-8244][SQL] hive compatibility
1aaf64e [Tarek Auel] [SPARK-8244][SQL] unit test fix
e4093a4 [Tarek Auel] [SPARK-8244][SQL] final modifier for COMMA_UTF8
0d05df5 [Tarek Auel] Merge branch 'master' into SPARK-8244
208d710 [Tarek Auel] [SPARK-8244] address comments & bug fix
71b2e69 [Tarek Auel] [SPARK-8244] find_in_set
66c7fda [Tarek Auel] Merge branch 'master' into SPARK-8244
61b8ca2 [Tarek Auel] [SPARK-8224] removed loop and split; use unsafe String comparison
4f75a65 [Tarek Auel] Merge branch 'master' into SPARK-8244
e3b20c8 [Tarek Auel] [SPARK-8244] added type check
1c2bbb7 [Tarek Auel] [SPARK-8244] findInSet


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/945da353
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/945da353
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/945da353

Branch: refs/heads/branch-1.5
Commit: 945da3534762a73fe7ffc52c868ff07a0783502b
Parents: f44b27a
Author: Tarek Auel <tarek.auel@googlemail.com>
Authored: Tue Aug 4 08:59:42 2015 -0700
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Tue Aug 4 09:01:58 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../catalyst/expressions/stringOperations.scala | 25 ++++++++++++--
 .../expressions/StringExpressionsSuite.scala    | 10 ++++++
 .../spark/sql/DataFrameFunctionsSuite.scala     |  8 +++++
 .../apache/spark/unsafe/types/UTF8String.java   | 35 ++++++++++++++++++--
 .../spark/unsafe/types/UTF8StringSuite.java     | 12 +++++++
 6 files changed, 87 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/945da353/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index bc08466..6140d1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -177,6 +177,7 @@ object FunctionRegistry {
     expression[ConcatWs]("concat_ws"),
     expression[Encode]("encode"),
     expression[Decode]("decode"),
+    expression[FindInSet]("find_in_set"),
     expression[FormatNumber]("format_number"),
     expression[InitCap]("initcap"),
     expression[Lower]("lcase"),

http://git-wip-us.apache.org/repos/asf/spark/blob/945da353/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 5622529..0cc785d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -18,8 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import java.text.DecimalFormat
-import java.util.Arrays
-import java.util.Locale
+import java.util.{Arrays, Locale}
 import java.util.regex.{MatchResult, Pattern}
 
 import org.apache.commons.lang3.StringEscapeUtils
@@ -351,6 +350,28 @@ case class EndsWith(left: Expression, right: Expression)
 }
 
 /**
+ * A function that returns the index (1-based) of the given string (left) in the comma-
+ * delimited list (right). Returns 0, if the string wasn't found or if the given
+ * string (left) contains a comma.
+ */
+case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
+    with ImplicitCastInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
+
+  override protected def nullSafeEval(word: Any, set: Any): Any =
+    set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String])
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    nullSafeCodeGen(ctx, ev, (word, set) =>
+      s"${ev.primitive} = $set.findInSet($word);"
+    )
+  }
+
+  override def dataType: DataType = IntegerType
+}
+
+/**
  * A function that trim the spaces from both ends for the specified string.
  */
 case class StringTrim(child: Expression)

http://git-wip-us.apache.org/repos/asf/spark/blob/945da353/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 906be70..23f36ca 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -675,4 +675,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
{
     checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null)
     checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null)
   }
+
+  test("find in set") {
+    checkEvaluation(
+      FindInSet(Literal.create(null, StringType), Literal.create(null, StringType)), null)
+    checkEvaluation(FindInSet(Literal("ab"), Literal.create(null, StringType)), null)
+    checkEvaluation(FindInSet(Literal.create(null, StringType), Literal("abc,b,ab,c,def")),
null)
+    checkEvaluation(FindInSet(Literal("ab"), Literal("abc,b,ab,c,def")), 3)
+    checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0)
+    checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/945da353/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 431dcf7..6137527 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -208,6 +208,14 @@ class DataFrameFunctionsSuite extends QueryTest {
       Row(2743272264L, 2180413220L))
   }
 
+  test("string function find_in_set") {
+    val df = Seq(("abc,b,ab,c,def", "abc,b,ab,c,def")).toDF("a", "b")
+
+    checkAnswer(
+      df.selectExpr("find_in_set('ab', a)", "find_in_set('x', b)"),
+      Row(3, 0))
+  }
+
   test("conditional function: least") {
     checkAnswer(
       testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1),

http://git-wip-us.apache.org/repos/asf/spark/blob/945da353/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index d80bd57..febbe3d 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -54,8 +54,9 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable
{
     5, 5, 5, 5,
     6, 6};
 
-  private static ByteOrder byteOrder = ByteOrder.nativeOrder();
+  private static boolean isLittleEndian = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN;
 
+  private static final UTF8String COMMA_UTF8 = UTF8String.fromString(",");
   public static final UTF8String EMPTY_UTF8 = UTF8String.fromString("");
 
   /**
@@ -179,7 +180,7 @@ public final class UTF8String implements Comparable<UTF8String>,
Serializable {
     // After getting the data, we use a mask to mask out data that is not part of the string.
     long p;
     long mask = 0;
-    if (byteOrder == ByteOrder.LITTLE_ENDIAN) {
+    if (isLittleEndian) {
       if (numBytes >= 8) {
         p = PlatformDependent.UNSAFE.getLong(base, offset);
       } else if (numBytes > 4) {
@@ -411,6 +412,36 @@ public final class UTF8String implements Comparable<UTF8String>,
Serializable {
     return fromString(sb.toString());
   }
 
+  /*
+   * Returns the index of the string `match` in this String. This string has to be a comma
separated
+   * list. If `match` contains a comma 0 will be returned. If the `match` isn't part of this
String,
+   * 0 will be returned, else the index of match (1-based index)
+   */
+  public int findInSet(UTF8String match) {
+    if (match.contains(COMMA_UTF8)) {
+      return 0;
+    }
+
+    int n = 1, lastComma = -1;
+    for (int i = 0; i < numBytes; i++) {
+      if (getByte(i) == (byte) ',') {
+        if (i - (lastComma + 1) == match.numBytes &&
+          ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset,
+            match.numBytes)) {
+          return n;
+        }
+        lastComma = i;
+        n++;
+      }
+    }
+    if (numBytes - (lastComma + 1) == match.numBytes &&
+      ByteArrayMethods.arrayEquals(base, offset + (lastComma + 1), match.base, match.offset,
+        match.numBytes)) {
+      return n;
+    }
+    return 0;
+  }
+
   /**
    * Copy the bytes from the current UTF8String, and make a new UTF8String.
    * @param start the start position of the current UTF8String in bytes.

http://git-wip-us.apache.org/repos/asf/spark/blob/945da353/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
----------------------------------------------------------------------
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 9b3190f..b30c94c 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -400,6 +400,18 @@ public class UTF8StringSuite {
   }
 
   @Test
+  public void findInSet() {
+    assertEquals(fromString("ab").findInSet(fromString("ab")), 1);
+    assertEquals(fromString("a,b").findInSet(fromString("b")), 2);
+    assertEquals(fromString("abc,b,ab,c,def").findInSet(fromString("ab")), 3);
+    assertEquals(fromString("ab,abc,b,ab,c,def").findInSet(fromString("ab")), 1);
+    assertEquals(fromString(",,,ab,abc,b,ab,c,def").findInSet(fromString("ab")), 4);
+    assertEquals(fromString(",ab,abc,b,ab,c,def").findInSet(fromString("")), 1);
+    assertEquals(fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("ab")), 4);
+    assertEquals(fromString("数据砖头,abc,b,ab,c,def").findInSet(fromString("def")),
6);
+  }
+
+  @Test
   public void soundex() {
     assertEquals(fromString("Robert").soundex(), fromString("R163"));
     assertEquals(fromString("Rupert").soundex(), fromString("R163"));


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message