spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-8241][SQL] string function: concat_ws.
Date Sun, 19 Jul 2015 23:48:50 GMT
Repository: spark
Updated Branches:
  refs/heads/master 7a8124534 -> 163e3f1df


[SPARK-8241][SQL] string function: concat_ws.

I also changed the semantics of concat w.r.t. null back to the same behavior as Hive.
That is to say, concat now returns null if any input is null.

Author: Reynold Xin <rxin@databricks.com>

Closes #7504 from rxin/concat_ws and squashes the following commits:

83fd950 [Reynold Xin] Fixed type casting.
3ae85f7 [Reynold Xin] Write null better.
cdc7be6 [Reynold Xin] Added code generation for pure string mode.
a61c4e4 [Reynold Xin] Updated comments.
2d51406 [Reynold Xin] [SPARK-8241][SQL] string function: concat_ws.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/163e3f1d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/163e3f1d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/163e3f1d

Branch: refs/heads/master
Commit: 163e3f1df94f6b7d3dadb46a87dbb3a2bade3f95
Parents: 7a81245
Author: Reynold Xin <rxin@databricks.com>
Authored: Sun Jul 19 16:48:47 2015 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Sun Jul 19 16:48:47 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    | 11 ++-
 .../catalyst/expressions/stringOperations.scala | 72 +++++++++++++++++---
 .../org/apache/spark/sql/types/DataType.scala   |  2 +-
 .../analysis/HiveTypeCoercionSuite.scala        | 11 ++-
 .../expressions/StringExpressionsSuite.scala    | 31 ++++++++-
 .../scala/org/apache/spark/sql/functions.scala  | 24 +++++++
 .../apache/spark/sql/StringFunctionsSuite.scala | 19 ++++--
 .../hive/execution/HiveCompatibilitySuite.scala |  4 +-
 .../apache/spark/unsafe/types/UTF8String.java   | 58 ++++++++++++++--
 .../spark/unsafe/types/UTF8StringSuite.java     | 62 +++++++++++++----
 10 files changed, 256 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/163e3f1d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 4b256ad..71e87b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -153,6 +153,7 @@ object FunctionRegistry {
     expression[Ascii]("ascii"),
     expression[Base64]("base64"),
     expression[Concat]("concat"),
+    expression[ConcatWs]("concat_ws"),
     expression[Encode]("encode"),
     expression[Decode]("decode"),
     expression[FormatNumber]("format_number"),
@@ -211,7 +212,10 @@ object FunctionRegistry {
     val builder = (expressions: Seq[Expression]) => {
       if (varargCtor.isDefined) {
         // If there is an apply method that accepts Seq[Expression], use that one.
-        varargCtor.get.newInstance(expressions).asInstanceOf[Expression]
+        Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match {
+          case Success(e) => e
+          case Failure(e) => throw new AnalysisException(e.getMessage)
+        }
       } else {
         // Otherwise, find an ctor method that matches the number of arguments, and use that.
         val params = Seq.fill(expressions.size)(classOf[Expression])
@@ -221,7 +225,10 @@ object FunctionRegistry {
           case Failure(e) =>
             throw new AnalysisException(s"Invalid number of arguments for function $name")
         }
-        f.newInstance(expressions : _*).asInstanceOf[Expression]
+        Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
+          case Success(e) => e
+          case Failure(e) => throw new AnalysisException(e.getMessage)
+        }
       }
     }
     (name, builder)

http://git-wip-us.apache.org/repos/asf/spark/blob/163e3f1d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 560b1bc..5f8ac71 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -34,19 +34,14 @@ import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * An expression that concatenates multiple input strings into a single string.
- * Input expressions that are evaluated to nulls are skipped.
- *
- * For example, `concat("a", null, "b")` is evaluated to `"ab"`.
- *
- * Note that this is different from Hive since Hive outputs null if any input is null.
- * We never output null.
+ * If any input is null, concat returns null.
  */
 case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes
{
 
   override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
   override def dataType: DataType = StringType
 
-  override def nullable: Boolean = false
+  override def nullable: Boolean = children.exists(_.nullable)
   override def foldable: Boolean = children.forall(_.foldable)
 
   override def eval(input: InternalRow): Any = {
@@ -56,15 +51,76 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
 
   override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String
= {
     val evals = children.map(_.gen(ctx))
-    val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(",
")
+    val inputs = evals.map { eval =>
+      s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
+    }.mkString(", ")
     evals.map(_.code).mkString("\n") + s"""
       boolean ${ev.isNull} = false;
       UTF8String ${ev.primitive} = UTF8String.concat($inputs);
+      if (${ev.primitive} == null) {
+        ${ev.isNull} = true;
+      }
     """
   }
 }
 
 
+/**
+ * An expression that concatenates multiple input strings or array of strings into a single
string,
+ * using a given separator (the first child).
+ *
+ * Returns null if the separator is null. Otherwise, concat_ws skips all null values.
+ */
+case class ConcatWs(children: Seq[Expression])
+  extends Expression with ImplicitCastInputTypes with CodegenFallback {
+
+  require(children.nonEmpty, s"$prettyName requires at least one argument.")
+
+  override def prettyName: String = "concat_ws"
+
+  /** The 1st child (separator) is str, and rest are either str or array of str. */
+  override def inputTypes: Seq[AbstractDataType] = {
+    val arrayOrStr = TypeCollection(ArrayType(StringType), StringType)
+    StringType +: Seq.fill(children.size - 1)(arrayOrStr)
+  }
+
+  override def dataType: DataType = StringType
+
+  override def nullable: Boolean = children.head.nullable
+  override def foldable: Boolean = children.forall(_.foldable)
+
+  override def eval(input: InternalRow): Any = {
+    val flatInputs = children.flatMap { child =>
+      child.eval(input) match {
+        case s: UTF8String => Iterator(s)
+        case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]]
+        case null => Iterator(null.asInstanceOf[UTF8String])
+      }
+    }
+    UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*)
+  }
+
+  override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String
= {
+    if (children.forall(_.dataType == StringType)) {
+      // All children are strings. In that case we can construct a fixed size array.
+      val evals = children.map(_.gen(ctx))
+
+      val inputs = evals.map { eval =>
+        s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}"
+      }.mkString(", ")
+
+      evals.map(_.code).mkString("\n") + s"""
+        UTF8String ${ev.primitive} = UTF8String.concatWs($inputs);
+        boolean ${ev.isNull} = ${ev.primitive} == null;
+      """
+    } else {
+      // Contains a mix of strings and array<string>s. Fall back to interpreted mode
for now.
+      super.genCode(ctx, ev)
+    }
+  }
+}
+
+
 trait StringRegexExpression extends ImplicitCastInputTypes {
   self: BinaryExpression =>
 

http://git-wip-us.apache.org/repos/asf/spark/blob/163e3f1d/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 2d133ee..e98fd25 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {
 
   override private[sql] def defaultConcreteType: DataType = this
 
-  override private[sql] def acceptsType(other: DataType): Boolean = this == other
+  override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/163e3f1d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index f9442bc..7ee2333 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -37,7 +37,6 @@ class HiveTypeCoercionSuite extends PlanTest {
     shouldCast(NullType, IntegerType, IntegerType)
     shouldCast(NullType, DecimalType, DecimalType.Unlimited)
 
-    // TODO: write the entire implicit cast table out for test cases.
     shouldCast(ByteType, IntegerType, IntegerType)
     shouldCast(IntegerType, IntegerType, IntegerType)
     shouldCast(IntegerType, LongType, LongType)
@@ -86,6 +85,16 @@ class HiveTypeCoercionSuite extends PlanTest {
       DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe =>
       shouldCast(tpe, NumericType, tpe)
     }
+
+    shouldCast(
+      ArrayType(StringType, false),
+      TypeCollection(ArrayType(StringType), StringType),
+      ArrayType(StringType, false))
+
+    shouldCast(
+      ArrayType(StringType, true),
+      TypeCollection(ArrayType(StringType), StringType),
+      ArrayType(StringType, true))
   }
 
   test("ineligible implicit type cast") {

http://git-wip-us.apache.org/repos/asf/spark/blob/163e3f1d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 0ed567a..96f433b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -26,7 +26,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
{
 
   test("concat") {
     def testConcat(inputs: String*): Unit = {
-      val expected = inputs.filter(_ != null).mkString
+      val expected = if (inputs.contains(null)) null else inputs.mkString
       checkEvaluation(Concat(inputs.map(Literal.create(_, StringType))), expected, EmptyRow)
     }
 
@@ -46,6 +46,35 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
{
     // scalastyle:on
   }
 
+  test("concat_ws") {
+    def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = {
+      val inputExprs = inputs.map {
+        case s: Seq[_] => Literal.create(s, ArrayType(StringType))
+        case null => Literal.create(null, StringType)
+        case s: String => Literal.create(s, StringType)
+      }
+      val sepExpr = Literal.create(sep, StringType)
+      checkEvaluation(ConcatWs(sepExpr +: inputExprs), expected, EmptyRow)
+    }
+
+    // scalastyle:off
+    // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+    testConcatWs(null, null)
+    testConcatWs(null, null, "a", "b")
+    testConcatWs("", "")
+    testConcatWs("ab", "哈哈", "ab")
+    testConcatWs("a哈哈b", "哈哈", "a", "b")
+    testConcatWs("a哈哈b", "哈哈", "a", null, "b")
+    testConcatWs("a哈哈b哈哈c", "哈哈", null, "a", null, "b", "c")
+
+    testConcatWs("ab", "哈哈", Seq("ab"))
+    testConcatWs("a哈哈b", "哈哈", Seq("a", "b"))
+    testConcatWs("a哈哈b哈哈c哈哈d", "哈哈", Seq("a", null, "b"), null, "c", Seq(null,
"d"))
+    testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq.empty[String])
+    testConcatWs("a哈哈b哈哈c", "哈哈", Seq("a", null, "b"), null, "c", Seq[String](null))
+    // scalastyle:on
+  }
+
   test("StringComparison") {
     val row = create_row("abc", null)
     val c1 = 'a.string.at(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/163e3f1d/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index f67c894..b5140dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1733,6 +1733,30 @@ object functions {
   }
 
   /**
+   * Concatenates input strings together into a single string, using the given separator.
+   *
+   * @group string_funcs
+   * @since 1.5.0
+   */
+  @scala.annotation.varargs
+  def concat_ws(sep: String, exprs: Column*): Column = {
+    ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr))
+  }
+
+  /**
+   * Concatenates input strings together into a single string, using the given separator.
+   *
+   * This is the variant of concat_ws that takes in the column names.
+   *
+   * @group string_funcs
+   * @since 1.5.0
+   */
+  @scala.annotation.varargs
+  def concat_ws(sep: String, columnName: String, columnNames: String*): Column = {
+    concat_ws(sep, (columnName +: columnNames).map(Column.apply) : _*)
+  }
+
+  /**
    * Computes the length of a given string / binary value.
    *
    * @group string_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/163e3f1d/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 4eff33e..fe4de8d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -30,14 +30,25 @@ class StringFunctionsSuite extends QueryTest {
     val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")
 
     checkAnswer(
-      df.select(concat($"a", $"b", $"c")),
-      Row("ab"))
+      df.select(concat($"a", $"b"), concat($"a", $"b", $"c")),
+      Row("ab", null))
 
     checkAnswer(
-      df.selectExpr("concat(a, b, c)"),
-      Row("ab"))
+      df.selectExpr("concat(a, b)", "concat(a, b, c)"),
+      Row("ab", null))
   }
 
+  test("string concat_ws") {
+    val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")
+
+    checkAnswer(
+      df.select(concat_ws("||", $"a", $"b", $"c")),
+      Row("a||b"))
+
+    checkAnswer(
+      df.selectExpr("concat_ws('||', a, b, c)"),
+      Row("a||b"))
+  }
 
   test("string Levenshtein distance") {
     val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")

http://git-wip-us.apache.org/repos/asf/spark/blob/163e3f1d/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 2689d90..b12b383 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -263,9 +263,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter
{
     "timestamp_2",
     "timestamp_udf",
 
-    // Hive outputs NULL if any concat input has null. We never output null for concat.
-    "udf_concat",
-
     // Unlike Hive, we do support log base in (0, 1.0], therefore disable this
     "udf7"
   )
@@ -856,6 +853,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter
{
     "udf_case",
     "udf_ceil",
     "udf_ceiling",
+    "udf_concat",
     "udf_concat_insert1",
     "udf_concat_insert2",
     "udf_concat_ws",

http://git-wip-us.apache.org/repos/asf/spark/blob/163e3f1d/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 9723b6e..3eecd65 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -397,19 +397,16 @@ public final class UTF8String implements Comparable<UTF8String>,
Serializable {
   }
 
   /**
-   * Concatenates input strings together into a single string. A null input is skipped.
-   * For example, concat("a", null, "c") would yield "ac".
+   * Concatenates input strings together into a single string. Returns null if any input
is null.
    */
   public static UTF8String concat(UTF8String... inputs) {
-    if (inputs == null) {
-      return fromBytes(new byte[0]);
-    }
-
     // Compute the total length of the result.
     int totalLength = 0;
     for (int i = 0; i < inputs.length; i++) {
       if (inputs[i] != null) {
         totalLength += inputs[i].numBytes;
+      } else {
+        return null;
       }
     }
 
@@ -417,6 +414,45 @@ public final class UTF8String implements Comparable<UTF8String>,
Serializable {
     final byte[] result = new byte[totalLength];
     int offset = 0;
     for (int i = 0; i < inputs.length; i++) {
+      int len = inputs[i].numBytes;
+      PlatformDependent.copyMemory(
+        inputs[i].base, inputs[i].offset,
+        result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
+        len);
+      offset += len;
+    }
+    return fromBytes(result);
+  }
+
+  /**
+   * Concatenates input strings together into a single string using the separator.
+   * A null input is skipped. For example, concat(",", "a", null, "c") would yield "a,c".
+   */
+  public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) {
+    if (separator == null) {
+      return null;
+    }
+
+    int numInputBytes = 0;  // total number of bytes from the inputs
+    int numInputs = 0;      // number of non-null inputs
+    for (int i = 0; i < inputs.length; i++) {
+      if (inputs[i] != null) {
+        numInputBytes += inputs[i].numBytes;
+        numInputs++;
+      }
+    }
+
+    if (numInputs == 0) {
+      // Return an empty string if there is no input, or all the inputs are null.
+      return fromBytes(new byte[0]);
+    }
+
+    // Allocate a new byte array, and copy the inputs one by one into it.
+    // The size of the new array is the size of all inputs, plus the separators.
+    final byte[] result = new byte[numInputBytes + (numInputs - 1) * separator.numBytes];
+    int offset = 0;
+
+    for (int i = 0, j = 0; i < inputs.length; i++) {
       if (inputs[i] != null) {
         int len = inputs[i].numBytes;
         PlatformDependent.copyMemory(
@@ -424,6 +460,16 @@ public final class UTF8String implements Comparable<UTF8String>,
Serializable {
           result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
           len);
         offset += len;
+
+        j++;
+        // Add separator if this is not the last input.
+        if (j < numInputs) {
+          PlatformDependent.copyMemory(
+            separator.base, separator.offset,
+            result, PlatformDependent.BYTE_ARRAY_OFFSET + offset,
+            separator.numBytes);
+          offset += separator.numBytes;
+        }
       }
     }
     return fromBytes(result);

http://git-wip-us.apache.org/repos/asf/spark/blob/163e3f1d/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
----------------------------------------------------------------------
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 0db7522..7d0c49e 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -88,16 +88,50 @@ public class UTF8StringSuite {
 
   @Test
   public void concatTest() {
-    assertEquals(concat(), fromString(""));
-    assertEquals(concat(null), fromString(""));
-    assertEquals(concat(fromString("")), fromString(""));
-    assertEquals(concat(fromString("ab")), fromString("ab"));
-    assertEquals(concat(fromString("a"), fromString("b")), fromString("ab"));
-    assertEquals(concat(fromString("a"), fromString("b"), fromString("c")), fromString("abc"));
-    assertEquals(concat(fromString("a"), null, fromString("c")), fromString("ac"));
-    assertEquals(concat(fromString("a"), null, null), fromString("a"));
-    assertEquals(concat(null, null, null), fromString(""));
-    assertEquals(concat(fromString("数据"), fromString("砖头")), fromString("数据砖头"));
+    assertEquals(fromString(""), concat());
+    assertEquals(null, concat((UTF8String) null));
+    assertEquals(fromString(""), concat(fromString("")));
+    assertEquals(fromString("ab"), concat(fromString("ab")));
+    assertEquals(fromString("ab"), concat(fromString("a"), fromString("b")));
+    assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c")));
+    assertEquals(null, concat(fromString("a"), null, fromString("c")));
+    assertEquals(null, concat(fromString("a"), null, null));
+    assertEquals(null, concat(null, null, null));
+    assertEquals(fromString("数据砖头"), concat(fromString("数据"), fromString("砖头")));
+  }
+
+  @Test
+  public void concatWsTest() {
+    // Returns null if the separator is null
+    assertEquals(null, concatWs(null, (UTF8String)null));
+    assertEquals(null, concatWs(null, fromString("a")));
+
+    // If separator is null, concatWs should skip all null inputs and never return null.
+    UTF8String sep = fromString("哈哈");
+    assertEquals(
+      fromString(""),
+      concatWs(sep, fromString("")));
+    assertEquals(
+      fromString("ab"),
+      concatWs(sep, fromString("ab")));
+    assertEquals(
+      fromString("a哈哈b"),
+      concatWs(sep, fromString("a"), fromString("b")));
+    assertEquals(
+      fromString("a哈哈b哈哈c"),
+      concatWs(sep, fromString("a"), fromString("b"), fromString("c")));
+    assertEquals(
+      fromString("a哈哈c"),
+      concatWs(sep, fromString("a"), null, fromString("c")));
+    assertEquals(
+      fromString("a"),
+      concatWs(sep, fromString("a"), null, null));
+    assertEquals(
+      fromString(""),
+      concatWs(sep, null, null, null));
+    assertEquals(
+      fromString("数据哈哈砖头"),
+      concatWs(sep, fromString("数据"), fromString("砖头")));
   }
 
   @Test
@@ -215,14 +249,18 @@ public class UTF8StringSuite {
     assertEquals(fromString("??数据砖头"), fromString("数据砖头").lpad(6, fromString("????")));
     assertEquals(fromString("孙行数据砖头"), fromString("数据砖头").lpad(6, fromString("孙行者")));
     assertEquals(fromString("孙行者数据砖头"), fromString("数据砖头").lpad(7,
fromString("孙行者")));
-    assertEquals(fromString("孙行者孙行者孙行数据砖头"), fromString("数据砖头").lpad(12,
fromString("孙行者")));
+    assertEquals(
+      fromString("孙行者孙行者孙行数据砖头"),
+      fromString("数据砖头").lpad(12, fromString("孙行者")));
 
     assertEquals(fromString("数据砖"), fromString("数据砖头").rpad(3, fromString("????")));
     assertEquals(fromString("数据砖头?"), fromString("数据砖头").rpad(5, fromString("????")));
     assertEquals(fromString("数据砖头??"), fromString("数据砖头").rpad(6, fromString("????")));
     assertEquals(fromString("数据砖头孙行"), fromString("数据砖头").rpad(6, fromString("孙行者")));
     assertEquals(fromString("数据砖头孙行者"), fromString("数据砖头").rpad(7,
fromString("孙行者")));
-    assertEquals(fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12,
fromString("孙行者")));
+    assertEquals(
+      fromString("数据砖头孙行者孙行者孙行"),
+      fromString("数据砖头").rpad(12, fromString("孙行者")));
   }
   
   @Test


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message