spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dav...@apache.org
Subject spark git commit: [SPARK-8266] [SQL] add function translate
Date Thu, 06 Aug 2015 16:02:45 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 29ace3bbf -> cab86c4b7


[SPARK-8266] [SQL] add function translate

![translate](http://www.w3resource.com/PostgreSQL/postgresql-translate-function.png)

Author: zhichao.li <zhichao.li@intel.com>

Closes #7709 from zhichao-li/translate and squashes the following commits:

9418088 [zhichao.li] refine checking condition
f2ab77a [zhichao.li] clone string
9d88f2d [zhichao.li] fix indent
6aa2962 [zhichao.li] style
e575ead [zhichao.li] add python api
9d4bab0 [zhichao.li] add special case for fodable and refactor unittest
eda7ad6 [zhichao.li] update to use TernaryExpression
cdfd4be [zhichao.li] add function translate

(cherry picked from commit aead18ffca36830e854fba32a1cac11a0b2e31d5)
Signed-off-by: Davies Liu <davies.liu@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cab86c4b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cab86c4b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cab86c4b

Branch: refs/heads/branch-1.5
Commit: cab86c4b7348219fd2f679560ccb5a673612a09a
Parents: 29ace3b
Author: zhichao.li <zhichao.li@intel.com>
Authored: Thu Aug 6 09:02:30 2015 -0700
Committer: Davies Liu <davies.liu@gmail.com>
Committed: Thu Aug 6 09:02:41 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 16 ++++
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../sql/catalyst/expressions/Expression.scala   |  4 +-
 .../catalyst/expressions/stringOperations.scala | 79 +++++++++++++++++++-
 .../expressions/StringExpressionsSuite.scala    | 14 ++++
 .../scala/org/apache/spark/sql/functions.scala  | 21 ++++--
 .../apache/spark/sql/StringFunctionsSuite.scala |  6 ++
 .../apache/spark/unsafe/types/UTF8String.java   | 16 ++++
 .../spark/unsafe/types/UTF8StringSuite.java     | 31 ++++++++
 9 files changed, 180 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cab86c4b/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 9f0d71d..b5c6a01 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1290,6 +1290,22 @@ def length(col):
     return Column(sc._jvm.functions.length(_to_java_column(col)))
 
 
+@ignore_unicode_prefix
+@since(1.5)
+def translate(srcCol, matching, replace):
+    """A function translate any character in the `srcCol` by a character in `matching`.
+    The characters in `replace` is corresponding to the characters in `matching`.
+    The translate will happen when any character in the string matching with the character
+    in the `matching`.
+
+    >>> sqlContext.createDataFrame([('translate',)], ['a']).select(translate('a',
"rnlt", "123")\
+    .alias('r')).collect()
+    [Row(r=u'1a2s3ae')]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.translate(_to_java_column(srcCol), matching, replace))
+
+
 # ---------------------- Collection functions ------------------------------
 
 @since(1.4)

http://git-wip-us.apache.org/repos/asf/spark/blob/cab86c4b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 94c355f..cd5a90d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -203,6 +203,7 @@ object FunctionRegistry {
     expression[Substring]("substr"),
     expression[Substring]("substring"),
     expression[SubstringIndex]("substring_index"),
+    expression[StringTranslate]("translate"),
     expression[StringTrim]("trim"),
     expression[UnBase64]("unbase64"),
     expression[Upper]("ucase"),

http://git-wip-us.apache.org/repos/asf/spark/blob/cab86c4b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index ef2fc2e..0b98f55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -444,7 +444,7 @@ abstract class TernaryExpression extends Expression {
   override def nullable: Boolean = children.exists(_.nullable)
 
   /**
-   * Default behavior of evaluation according to the default nullability of BinaryExpression.
+   * Default behavior of evaluation according to the default nullability of TernaryExpression.
    * If subclass of BinaryExpression override nullable, probably should also override this.
    */
   override def eval(input: InternalRow): Any = {
@@ -463,7 +463,7 @@ abstract class TernaryExpression extends Expression {
   }
 
   /**
-   * Called by default [[eval]] implementation.  If subclass of BinaryExpression keep the
default
+   * Called by default [[eval]] implementation.  If subclass of TernaryExpression keep the
default
    * nullability, they can override this method to save null-check code.  If we need full
control
    * of evaluation process, we should override [[eval]].
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/cab86c4b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 0cc785d..76666bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -18,7 +18,9 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import java.text.DecimalFormat
-import java.util.{Arrays, Locale}
+import java.util.Arrays
+import java.util.{Map => JMap, HashMap}
+import java.util.Locale
 import java.util.regex.{MatchResult, Pattern}
 
 import org.apache.commons.lang3.StringEscapeUtils
@@ -349,6 +351,81 @@ case class EndsWith(left: Expression, right: Expression)
   }
 }
 
+object StringTranslate {
+
+  def buildDict(matchingString: UTF8String, replaceString: UTF8String)
+    : JMap[Character, Character] = {
+    val matching = matchingString.toString()
+    val replace = replaceString.toString()
+    val dict = new HashMap[Character, Character]()
+    var i = 0
+    while (i < matching.length()) {
+      val rep = if (i < replace.length()) replace.charAt(i) else '\0'
+      if (null == dict.get(matching.charAt(i))) {
+        dict.put(matching.charAt(i), rep)
+      }
+      i += 1
+    }
+    dict
+  }
+}
+
+/**
+ * A function translate any character in the `srcExpr` by a character in `replaceExpr`.
+ * The characters in `replaceExpr` is corresponding to the characters in `matchingExpr`.
+ * The translate will happen when any character in the string matching with the character
+ * in the `matchingExpr`.
+ */
+case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression)
+  extends TernaryExpression with ImplicitCastInputTypes {
+
+  @transient private var lastMatching: UTF8String = _
+  @transient private var lastReplace: UTF8String = _
+  @transient private var dict: JMap[Character, Character] = _
+
+  override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = {
+    if (matchingEval != lastMatching || replaceEval != lastReplace) {
+      lastMatching = matchingEval.asInstanceOf[UTF8String].clone()
+      lastReplace = replaceEval.asInstanceOf[UTF8String].clone()
+      dict = StringTranslate.buildDict(lastMatching, lastReplace)
+    }
+    srcEval.asInstanceOf[UTF8String].translate(dict)
+  }
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val termLastMatching = ctx.freshName("lastMatching")
+    val termLastReplace = ctx.freshName("lastReplace")
+    val termDict = ctx.freshName("dict")
+    val classNameDict = classOf[JMap[Character, Character]].getCanonicalName
+
+    ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;")
+    ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;")
+    ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;")
+
+    nullSafeCodeGen(ctx, ev, (src, matching, replace) => {
+      val check = if (matchingExpr.foldable && replaceExpr.foldable) {
+        s"${termDict} == null"
+      } else {
+        s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})"
+      }
+      s"""if ($check) {
+        // Not all of them is literal or matching or replace value changed
+        ${termLastMatching} = ${matching}.clone();
+        ${termLastReplace} = ${replace}.clone();
+        ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate
+          .buildDict(${termLastMatching}, ${termLastReplace});
+      }
+      ${ev.primitive} = ${src}.translate(${termDict});
+      """
+    })
+  }
+
+  override def dataType: DataType = StringType
+  override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
+  override def children: Seq[Expression] = srcExpr :: matchingExpr :: replaceExpr :: Nil
+  override def prettyName: String = "translate"
+}
+
 /**
  * A function that returns the index (1-based) of the given string (left) in the comma-
  * delimited list (right). Returns 0, if the string wasn't found or if the given

http://git-wip-us.apache.org/repos/asf/spark/blob/cab86c4b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 23f36ca..426dc27 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -431,6 +431,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
{
     checkEvaluation(SoundEx(Literal("!!")), "!!")
   }
 
+  test("translate") {
+    checkEvaluation(
+      StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae")
+    checkEvaluation(StringTranslate(Literal("translate"), Literal(""), Literal("123")), "translate")
+    checkEvaluation(StringTranslate(Literal("translate"), Literal("rnlt"), Literal("")),
"asae")
+    // test for multiple mapping
+    checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("123")), "12cd")
+    checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("12")), "12cd")
+    // scalastyle:off
+    // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+    checkEvaluation(StringTranslate(Literal("花花世界"), Literal("花界"), Literal("ab")),
"aa世b")
+    // scalastyle:on
+  }
+
   test("TRIM/LTRIM/RTRIM") {
     val s = 'a.string.at(0)
     checkEvaluation(StringTrim(Literal(" aa  ")), "aa", create_row(" abdef "))

http://git-wip-us.apache.org/repos/asf/spark/blob/cab86c4b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 39aa905..79c5f59 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1100,11 +1100,11 @@ object functions {
   }
 
   /**
-    * Computes hex value of the given column.
-    *
-    * @group math_funcs
-    * @since 1.5.0
-    */
+   * Computes hex value of the given column.
+   *
+   * @group math_funcs
+   * @since 1.5.0
+   */
   def hex(column: Column): Column = Hex(column.expr)
 
   /**
@@ -1863,6 +1863,17 @@ object functions {
   def substring_index(str: Column, delim: String, count: Int): Column =
     SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
 
+  /* Translate any character in the src by a character in replaceString.
+  * The characters in replaceString is corresponding to the characters in matchingString.
+  * The translate will happen when any character in the string matching with the character
+  * in the matchingString.
+  *
+  * @group string_funcs
+  * @since 1.5.0
+  */
+  def translate(src: Column, matchingString: String, replaceString: String): Column =
+    StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr)
+
   /**
    * Trim the spaces from both ends for the specified string column.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/cab86c4b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index ab5da6e..ca298b2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -128,6 +128,12 @@ class StringFunctionsSuite extends QueryTest {
     // scalastyle:on
   }
 
+  test("string translate") {
+    val df = Seq(("translate", "")).toDF("a", "b")
+    checkAnswer(df.select(translate($"a", "rnlt", "123")), Row("1a2s3ae"))
+    checkAnswer(df.selectExpr("""translate(a, "rnlt", "")"""), Row("asae"))
+  }
+
   test("string trim functions") {
     val df = Seq(("  example  ", "")).toDF("a", "b")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/cab86c4b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index febbe3d..d101442 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -22,6 +22,7 @@ import java.io.Serializable;
 import java.io.UnsupportedEncodingException;
 import java.nio.ByteOrder;
 import java.util.Arrays;
+import java.util.Map;
 
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -795,6 +796,21 @@ public final class UTF8String implements Comparable<UTF8String>,
Serializable {
     return res;
   }
 
+  // TODO: Need to use `Code Point` here instead of Char in case the character longer than
2 bytes
+  public UTF8String translate(Map<Character, Character> dict) {
+    String srcStr = this.toString();
+
+    StringBuilder sb = new StringBuilder();
+    for(int k = 0; k< srcStr.length(); k++) {
+      if (null == dict.get(srcStr.charAt(k))) {
+        sb.append(srcStr.charAt(k));
+      } else if ('\0' != dict.get(srcStr.charAt(k))){
+        sb.append(dict.get(srcStr.charAt(k)));
+      }
+    }
+    return fromString(sb.toString());
+  }
+
   @Override
   public String toString() {
     try {

http://git-wip-us.apache.org/repos/asf/spark/blob/cab86c4b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
----------------------------------------------------------------------
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index b30c94c..98aa8a2 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -19,7 +19,9 @@ package org.apache.spark.unsafe.types;
 
 import java.io.UnsupportedEncodingException;
 import java.util.Arrays;
+import java.util.HashMap;
 
+import com.google.common.collect.ImmutableMap;
 import org.junit.Test;
 
 import static junit.framework.Assert.*;
@@ -392,6 +394,35 @@ public class UTF8StringSuite {
   }
 
   @Test
+  public void translate() {
+    assertEquals(
+      fromString("1a2s3ae"),
+      fromString("translate").translate(ImmutableMap.of(
+        'r', '1',
+        'n', '2',
+        'l', '3',
+        't', '\0'
+      )));
+    assertEquals(
+      fromString("translate"),
+      fromString("translate").translate(new HashMap<Character, Character>()));
+    assertEquals(
+      fromString("asae"),
+      fromString("translate").translate(ImmutableMap.of(
+        'r', '\0',
+        'n', '\0',
+        'l', '\0',
+        't', '\0'
+      )));
+    assertEquals(
+      fromString("aa世b"),
+      fromString("花花世界").translate(ImmutableMap.of(
+        '花', 'a',
+        '界', 'b'
+      )));
+  }
+
+  @Test
   public void createBlankString() {
     assertEquals(fromString(" "), blankString(1));
     assertEquals(fromString("  "), blankString(2));


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message