spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ues...@apache.org
Subject spark git commit: [SPARK-23736][SQL] Extending the concat function to support array columns
Date Fri, 20 Apr 2018 05:58:18 GMT
Repository: spark
Updated Branches:
  refs/heads/master b3fde5a41 -> e6b466084


[SPARK-23736][SQL] Extending the concat function to support array columns

## What changes were proposed in this pull request?
The PR adds a logic for easy concatenation of multiple array columns and covers:
- Concat expression has been extended to support array columns
- A Python wrapper

## How was this patch tested?
New tests added into:
- CollectionExpressionsSuite
- DataFrameFunctionsSuite
- typeCoercion/native/concat.sql

## Codegen examples
### Primitive-type elements
```
val df = Seq(
  (Seq(1 ,2), Seq(3, 4)),
  (Seq(1, 2, 3), null)
).toDF("a", "b")
df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen()
```
Result:
```
/* 033 */         boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 034 */         ArrayData inputadapter_value = inputadapter_isNull ?
/* 035 */         null : (inputadapter_row.getArray(0));
/* 036 */
/* 037 */         if (!(!inputadapter_isNull)) continue;
/* 038 */
/* 039 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows
*/).add(1);
/* 040 */
/* 041 */         ArrayData[] project_args = new ArrayData[2];
/* 042 */
/* 043 */         if (!false) {
/* 044 */           project_args[0] = inputadapter_value;
/* 045 */         }
/* 046 */
/* 047 */         boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 048 */         ArrayData inputadapter_value1 = inputadapter_isNull1 ?
/* 049 */         null : (inputadapter_row.getArray(1));
/* 050 */         if (!inputadapter_isNull1) {
/* 051 */           project_args[1] = inputadapter_value1;
/* 052 */         }
/* 053 */
/* 054 */         ArrayData project_value = new Object() {
/* 055 */           public ArrayData concat(ArrayData[] args) {
/* 056 */             for (int z = 0; z < 2; z++) {
/* 057 */               if (args[z] == null) return null;
/* 058 */             }
/* 059 */
/* 060 */             long project_numElements = 0L;
/* 061 */             for (int z = 0; z < 2; z++) {
/* 062 */               project_numElements += args[z].numElements();
/* 063 */             }
/* 064 */             if (project_numElements > 2147483632) {
/* 065 */               throw new RuntimeException("Unsuccessful try to concat arrays with
" + project_numElements +
/* 066 */                 " elements due to exceeding the array size limit 2147483632.");
/* 067 */             }
/* 068 */
/* 069 */             long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
/* 070 */               project_numElements,
/* 071 */               4);
/* 072 */             if (project_size > 2147483632) {
/* 073 */               throw new RuntimeException("Unsuccessful try to concat arrays with
" + project_size +
/* 074 */                 " bytes of data due to exceeding the limit 2147483632 bytes" +
/* 075 */                 " for UnsafeArrayData.");
/* 076 */             }
/* 077 */
/* 078 */             byte[] project_array = new byte[(int)project_size];
/* 079 */             UnsafeArrayData project_arrayData = new UnsafeArrayData();
/* 080 */             Platform.putLong(project_array, 16, project_numElements);
/* 081 */             project_arrayData.pointTo(project_array, 16, (int)project_size);
/* 082 */             int project_counter = 0;
/* 083 */             for (int y = 0; y < 2; y++) {
/* 084 */               for (int z = 0; z < args[y].numElements(); z++) {
/* 085 */                 if (args[y].isNullAt(z)) {
/* 086 */                   project_arrayData.setNullAt(project_counter);
/* 087 */                 } else {
/* 088 */                   project_arrayData.setInt(
/* 089 */                     project_counter,
/* 090 */                     args[y].getInt(z)
/* 091 */                   );
/* 092 */                 }
/* 093 */                 project_counter++;
/* 094 */               }
/* 095 */             }
/* 096 */             return project_arrayData;
/* 097 */           }
/* 098 */         }.concat(project_args);
/* 099 */         boolean project_isNull = project_value == null;
```

### Non-primitive-type elements
```
val df = Seq(
  (Seq("aa" ,"bb"), Seq("ccc", "ddd")),
  (Seq("x", "y"), null)
).toDF("a", "b")
df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen()
```
Result:
```
/* 033 */         boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 034 */         ArrayData inputadapter_value = inputadapter_isNull ?
/* 035 */         null : (inputadapter_row.getArray(0));
/* 036 */
/* 037 */         if (!(!inputadapter_isNull)) continue;
/* 038 */
/* 039 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows
*/).add(1);
/* 040 */
/* 041 */         ArrayData[] project_args = new ArrayData[2];
/* 042 */
/* 043 */         if (!false) {
/* 044 */           project_args[0] = inputadapter_value;
/* 045 */         }
/* 046 */
/* 047 */         boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 048 */         ArrayData inputadapter_value1 = inputadapter_isNull1 ?
/* 049 */         null : (inputadapter_row.getArray(1));
/* 050 */         if (!inputadapter_isNull1) {
/* 051 */           project_args[1] = inputadapter_value1;
/* 052 */         }
/* 053 */
/* 054 */         ArrayData project_value = new Object() {
/* 055 */           public ArrayData concat(ArrayData[] args) {
/* 056 */             for (int z = 0; z < 2; z++) {
/* 057 */               if (args[z] == null) return null;
/* 058 */             }
/* 059 */
/* 060 */             long project_numElements = 0L;
/* 061 */             for (int z = 0; z < 2; z++) {
/* 062 */               project_numElements += args[z].numElements();
/* 063 */             }
/* 064 */             if (project_numElements > 2147483632) {
/* 065 */               throw new RuntimeException("Unsuccessful try to concat arrays with
" + project_numElements +
/* 066 */                 " elements due to exceeding the array size limit 2147483632.");
/* 067 */             }
/* 068 */
/* 069 */             Object[] project_arrayObjects = new Object[(int)project_numElements];
/* 070 */             int project_counter = 0;
/* 071 */             for (int y = 0; y < 2; y++) {
/* 072 */               for (int z = 0; z < args[y].numElements(); z++) {
/* 073 */                 project_arrayObjects[project_counter] = args[y].getUTF8String(z);
/* 074 */                 project_counter++;
/* 075 */               }
/* 076 */             }
/* 077 */             return new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObjects);
/* 078 */           }
/* 079 */         }.concat(project_args);
/* 080 */         boolean project_isNull = project_value == null;
```

Author: mn-mikke <mrkAha12346github>

Closes #20858 from mn-mikke/feature/array-api-concat_arrays-to-master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e6b46608
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e6b46608
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e6b46608

Branch: refs/heads/master
Commit: e6b466084c26fbb9b9e50dd5cc8b25da7533ac72
Parents: b3fde5a
Author: mn-mikke <mrkAha12346github>
Authored: Fri Apr 20 14:58:11 2018 +0900
Committer: Takuya UESHIN <ueshin@databricks.com>
Committed: Fri Apr 20 14:58:11 2018 +0900

----------------------------------------------------------------------
 .../spark/unsafe/array/ByteArrayMethods.java    |   6 +-
 python/pyspark/sql/functions.py                 |  34 +--
 .../catalyst/expressions/UnsafeArrayData.java   |  10 +
 .../catalyst/analysis/FunctionRegistry.scala    |   2 +-
 .../sql/catalyst/analysis/TypeCoercion.scala    |   8 +
 .../expressions/collectionOperations.scala      | 220 ++++++++++++++++++-
 .../expressions/stringExpressions.scala         |  81 -------
 .../CollectionExpressionsSuite.scala            |  41 ++++
 .../scala/org/apache/spark/sql/functions.scala  |  20 +-
 .../inputs/typeCoercion/native/concat.sql       |  62 ++++++
 .../results/typeCoercion/native/concat.sql.out  |  78 +++++++
 .../spark/sql/DataFrameFunctionsSuite.scala     |  74 +++++++
 .../spark/sql/execution/command/DDLSuite.scala  |   4 +-
 13 files changed, 529 insertions(+), 111 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
index 4bc9955..ef0f78d 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
@@ -33,7 +33,11 @@ public class ByteArrayMethods {
   }
 
   public static int roundNumberOfBytesToNearestWord(int numBytes) {
-    int remainder = numBytes & 0x07;  // This is equivalent to `numBytes % 8`
+    return (int)roundNumberOfBytesToNearestWord((long)numBytes);
+  }
+
+  public static long roundNumberOfBytesToNearestWord(long numBytes) {
+    long remainder = numBytes & 0x07;  // This is equivalent to `numBytes % 8`
     if (remainder == 0) {
       return numBytes;
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 1be68f2..da32ab2 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1427,21 +1427,6 @@ del _name, _doc
 
 @since(1.5)
 @ignore_unicode_prefix
-def concat(*cols):
-    """
-    Concatenates multiple input columns together into a single column.
-    If all inputs are binary, concat returns an output as binary. Otherwise, it returns as
string.
-
-    >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
-    >>> df.select(concat(df.s, df.d).alias('s')).collect()
-    [Row(s=u'abcd123')]
-    """
-    sc = SparkContext._active_spark_context
-    return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
-
-
-@since(1.5)
-@ignore_unicode_prefix
 def concat_ws(sep, *cols):
     """
     Concatenates multiple input string columns together into a single string column,
@@ -1845,6 +1830,25 @@ def array_contains(col, value):
     return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
 
 
+@since(1.5)
+@ignore_unicode_prefix
+def concat(*cols):
+    """
+    Concatenates multiple input columns together into a single column.
+    The function works with strings, binary and compatible array columns.
+
+    >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
+    >>> df.select(concat(df.s, df.d).alias('s')).collect()
+    [Row(s=u'abcd123')]
+
+    >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])],
['a', 'b', 'c'])
+    >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()
+    [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
+
+
 @since(2.4)
 def array_position(col, value):
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 8546c28..d5d934b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -56,9 +56,19 @@ import org.apache.spark.unsafe.types.UTF8String;
 public final class UnsafeArrayData extends ArrayData {
 
   public static int calculateHeaderPortionInBytes(int numFields) {
+    return (int)calculateHeaderPortionInBytes((long)numFields);
+  }
+
+  public static long calculateHeaderPortionInBytes(long numFields) {
     return 8 + ((numFields + 63)/ 64) * 8;
   }
 
+  public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize)
{
+    long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) +
+      ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize);
+    return size;
+  }
+
   private Object baseObject;
   private long baseOffset;
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index a44f2d5..c41f16c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -308,7 +308,6 @@ object FunctionRegistry {
     expression[BitLength]("bit_length"),
     expression[Length]("char_length"),
     expression[Length]("character_length"),
-    expression[Concat]("concat"),
     expression[ConcatWs]("concat_ws"),
     expression[Decode]("decode"),
     expression[Elt]("elt"),
@@ -413,6 +412,7 @@ object FunctionRegistry {
     expression[ArrayMin]("array_min"),
     expression[ArrayMax]("array_max"),
     expression[Reverse]("reverse"),
+    expression[Concat]("concat"),
     CreateStruct.registryEntry,
 
     // misc functions

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 281f206..cfcbd8d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -520,6 +520,14 @@ object TypeCoercion {
           case None => a
         }
 
+      case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType))
&&
+        !haveSameType(children) =>
+        val types = children.map(_.dataType)
+        findWiderCommonType(types) match {
+          case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
+          case None => c
+        }
+
       case m @ CreateMap(children) if m.keys.length == m.values.length &&
         (!haveSameType(m.keys) || !haveSameType(m.values)) =>
         val newKeys = if (haveSameType(m.keys)) {

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index dba426e..c16793b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
 
 /**
  * Given an array or map, returns its size. Returns -1 if null.
@@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
 
   override def prettyName: String = "element_at"
 }
+
+/**
+ * Concatenates multiple input columns together into a single column.
+ * The function works with strings, binary and compatible array columns.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ...,
colN.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_('Spark', 'SQL');
+       SparkSQL
+      > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
+ |     [1,2,3,4,5,6]
+  """)
+case class Concat(children: Seq[Expression]) extends Expression {
+
+  private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+  val allowedTypes = Seq(StringType, BinaryType, ArrayType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.isEmpty) {
+      TypeCheckResult.TypeCheckSuccess
+    } else {
+      val childTypes = children.map(_.dataType)
+      if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) {
+        return TypeCheckResult.TypeCheckFailure(
+          s"input to function $prettyName should have been StringType, BinaryType or ArrayType,"
+
+            s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]"))
+      }
+      TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
+    }
+  }
+
+  override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
+
+  lazy val javaType: String = CodeGenerator.javaType(dataType)
+
+  override def nullable: Boolean = children.exists(_.nullable)
+
+  override def foldable: Boolean = children.forall(_.foldable)
+
+  override def eval(input: InternalRow): Any = dataType match {
+    case BinaryType =>
+      val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
+      ByteArray.concat(inputs: _*)
+    case StringType =>
+      val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
+      UTF8String.concat(inputs : _*)
+    case ArrayType(elementType, _) =>
+      val inputs = children.toStream.map(_.eval(input))
+      if (inputs.contains(null)) {
+        null
+      } else {
+        val arrayData = inputs.map(_.asInstanceOf[ArrayData])
+        val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
+        if (numberOfElements > MAX_ARRAY_LENGTH) {
+          throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements"
+
+            s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
+        }
+        val finalData = new Array[AnyRef](numberOfElements.toInt)
+        var position = 0
+        for(ad <- arrayData) {
+          val arr = ad.toObjectArray(elementType)
+          Array.copy(arr, 0, finalData, position, arr.length)
+          position += arr.length
+        }
+        new GenericArrayData(finalData)
+      }
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val evals = children.map(_.genCode(ctx))
+    val args = ctx.freshName("args")
+
+    val inputs = evals.zipWithIndex.map { case (eval, index) =>
+      s"""
+        ${eval.code}
+        if (!${eval.isNull}) {
+          $args[$index] = ${eval.value};
+        }
+      """
+    }
+
+    val (concatenator, initCode) = dataType match {
+      case BinaryType =>
+        (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
+      case StringType =>
+        ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
+      case ArrayType(elementType, _) =>
+        val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
+          genCodeForPrimitiveArrays(ctx, elementType)
+        } else {
+          genCodeForNonPrimitiveArrays(ctx, elementType)
+        }
+        (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
+    }
+    val codes = ctx.splitExpressionsWithCurrentInputs(
+      expressions = inputs,
+      funcName = "valueConcat",
+      extraArguments = (s"$javaType[]", args) :: Nil)
+    ev.copy(s"""
+      $initCode
+      $codes
+      $javaType ${ev.value} = $concatenator.concat($args);
+      boolean ${ev.isNull} = ${ev.value} == null;
+    """)
+  }
+
+  private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
+    val numElements = ctx.freshName("numElements")
+    val code = s"""
+        |long $numElements = 0L;
+        |for (int z = 0; z < ${children.length}; z++) {
+        |  $numElements += args[z].numElements();
+        |}
+        |if ($numElements > $MAX_ARRAY_LENGTH) {
+        |  throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements
+
+        |    " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+        |}
+      """.stripMargin
+
+    (code, numElements)
+  }
+
+  private def nullArgumentProtection() : String = {
+    if (nullable) {
+      s"""
+         |for (int z = 0; z < ${children.length}; z++) {
+         |  if (args[z] == null) return null;
+         |}
+       """.stripMargin
+    } else {
+      ""
+    }
+  }
+
+  private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String
= {
+    val arrayName = ctx.freshName("array")
+    val arraySizeName = ctx.freshName("size")
+    val counter = ctx.freshName("counter")
+    val arrayData = ctx.freshName("arrayData")
+
+    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
+
+    val unsafeArraySizeInBytes = s"""
+      |long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
+      |  $numElemName,
+      |  ${elementType.defaultSize});
+      |if ($arraySizeName > $MAX_ARRAY_LENGTH) {
+      |  throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName
+
+      |    " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" +
+      |    " for UnsafeArrayData.");
+      |}
+      """.stripMargin
+    val baseOffset = Platform.BYTE_ARRAY_OFFSET
+    val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+
+    s"""
+       |new Object() {
+       |  public ArrayData concat($javaType[] args) {
+       |    ${nullArgumentProtection()}
+       |    $numElemCode
+       |    $unsafeArraySizeInBytes
+       |    byte[] $arrayName = new byte[(int)$arraySizeName];
+       |    UnsafeArrayData $arrayData = new UnsafeArrayData();
+       |    Platform.putLong($arrayName, $baseOffset, $numElemName);
+       |    $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
+       |    int $counter = 0;
+       |    for (int y = 0; y < ${children.length}; y++) {
+       |      for (int z = 0; z < args[y].numElements(); z++) {
+       |        if (args[y].isNullAt(z)) {
+       |          $arrayData.setNullAt($counter);
+       |        } else {
+       |          $arrayData.set$primitiveValueTypeName(
+       |            $counter,
+       |            ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
+       |          );
+       |        }
+       |        $counter++;
+       |      }
+       |    }
+       |    return $arrayData;
+       |  }
+       |}""".stripMargin.stripPrefix("\n")
+  }
+
+  private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String
= {
+    val genericArrayClass = classOf[GenericArrayData].getName
+    val arrayData = ctx.freshName("arrayObjects")
+    val counter = ctx.freshName("counter")
+
+    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
+
+    s"""
+       |new Object() {
+       |  public ArrayData concat($javaType[] args) {
+       |    ${nullArgumentProtection()}
+       |    $numElemCode
+       |    Object[] $arrayData = new Object[(int)$numElemName];
+       |    int $counter = 0;
+       |    for (int y = 0; y < ${children.length}; y++) {
+       |      for (int z = 0; z < args[y].numElements(); z++) {
+       |        $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType,
"z")};
+       |        $counter++;
+       |      }
+       |    }
+       |    return new $genericArrayClass($arrayData);
+       |  }
+       |}""".stripMargin.stripPrefix("\n")
+  }
+
+  override def toString: String = s"concat(${children.mkString(", ")})"
+
+  override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 5a02ca0..ea005a2 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -37,87 +37,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
 
 
 /**
- * An expression that concatenates multiple inputs into a single output.
- * If all inputs are binary, concat returns an output as binary. Otherwise, it returns as
string.
- * If any input is null, concat returns null.
- */
-@ExpressionDescription(
-  usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ...,
strN.",
-  examples = """
-    Examples:
-      > SELECT _FUNC_('Spark', 'SQL');
-       SparkSQL
-  """)
-case class Concat(children: Seq[Expression]) extends Expression {
-
-  private lazy val isBinaryMode: Boolean = dataType == BinaryType
-
-  override def checkInputDataTypes(): TypeCheckResult = {
-    if (children.isEmpty) {
-      TypeCheckResult.TypeCheckSuccess
-    } else {
-      val childTypes = children.map(_.dataType)
-      if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) {
-        return TypeCheckResult.TypeCheckFailure(
-          s"input to function $prettyName should have StringType or BinaryType, but it's
" +
-            childTypes.map(_.simpleString).mkString("[", ", ", "]"))
-      }
-      TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
-    }
-  }
-
-  override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
-
-  override def nullable: Boolean = children.exists(_.nullable)
-  override def foldable: Boolean = children.forall(_.foldable)
-
-  override def eval(input: InternalRow): Any = {
-    if (isBinaryMode) {
-      val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
-      ByteArray.concat(inputs: _*)
-    } else {
-      val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
-      UTF8String.concat(inputs : _*)
-    }
-  }
-
-  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val evals = children.map(_.genCode(ctx))
-    val args = ctx.freshName("args")
-
-    val inputs = evals.zipWithIndex.map { case (eval, index) =>
-      s"""
-        ${eval.code}
-        if (!${eval.isNull}) {
-          $args[$index] = ${eval.value};
-        }
-      """
-    }
-
-    val (concatenator, initCode) = if (isBinaryMode) {
-      (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
-    } else {
-      ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
-    }
-    val codes = ctx.splitExpressionsWithCurrentInputs(
-      expressions = inputs,
-      funcName = "valueConcat",
-      extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil)
-    ev.copy(s"""
-      $initCode
-      $codes
-      ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
-      boolean ${ev.isNull} = ${ev.value} == null;
-    """)
-  }
-
-  override def toString: String = s"concat(${children.mkString(", ")})"
-
-  override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
-}
-
-
-/**
  * An expression that concatenates multiple input strings or array of strings into a single
string,
  * using a given separator (the first child).
  *

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 7d8fe21..43c5dda 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -239,4 +239,45 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
 
     checkEvaluation(ElementAt(m2, Literal("a")), null)
   }
+
+  test("Concat") {
+    // Primitive-type elements
+    val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+    val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
+    val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType))
+    val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType))
+    val ai4 = Literal.create(null, ArrayType(IntegerType))
+
+    checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3))
+    checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3))
+    checkEvaluation(Concat(Seq(ai1, ai0)), Seq(1, 2, 3))
+    checkEvaluation(Concat(Seq(ai0, ai0)), Seq(1, 2, 3, 1, 2, 3))
+    checkEvaluation(Concat(Seq(ai0, ai2)), Seq(1, 2, 3, 4, null, 5))
+    checkEvaluation(Concat(Seq(ai0, ai3, ai2)), Seq(1, 2, 3, null, null, 4, null, 5))
+    checkEvaluation(Concat(Seq(ai4)), null)
+    checkEvaluation(Concat(Seq(ai0, ai4)), null)
+    checkEvaluation(Concat(Seq(ai4, ai0)), null)
+
+    // Non-primitive-type elements
+    val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
+    val as1 = Literal.create(Seq.empty[String], ArrayType(StringType))
+    val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType))
+    val as3 = Literal.create(Seq(null, null), ArrayType(StringType))
+    val as4 = Literal.create(null, ArrayType(StringType))
+
+    val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType)))
+    val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType)))
+
+    checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c"))
+    checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c"))
+    checkEvaluation(Concat(Seq(as1, as0)), Seq("a", "b", "c"))
+    checkEvaluation(Concat(Seq(as0, as0)), Seq("a", "b", "c", "a", "b", "c"))
+    checkEvaluation(Concat(Seq(as0, as2)), Seq("a", "b", "c", "d", null, "e"))
+    checkEvaluation(Concat(Seq(as0, as3, as2)), Seq("a", "b", "c", null, null, "d", null,
"e"))
+    checkEvaluation(Concat(Seq(as4)), null)
+    checkEvaluation(Concat(Seq(as0, as4)), null)
+    checkEvaluation(Concat(Seq(as4, as0)), null)
+
+    checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e",
"f")))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 9c85803..bea8c0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2229,16 +2229,6 @@ object functions {
   def base64(e: Column): Column = withExpr { Base64(e.expr) }
 
   /**
-   * Concatenates multiple input columns together into a single column.
-   * If all inputs are binary, concat returns an output as binary. Otherwise, it returns
as string.
-   *
-   * @group string_funcs
-   * @since 1.5.0
-   */
-  @scala.annotation.varargs
-  def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) }
-
-  /**
    * Concatenates multiple input string columns together into a single string column,
    * using the given separator.
    *
@@ -3039,6 +3029,16 @@ object functions {
   }
 
   /**
+   * Concatenates multiple input columns together into a single column.
+   * The function works with strings, binary and compatible array columns.
+   *
+   * @group collection_funcs
+   * @since 1.5.0
+   */
+  @scala.annotation.varargs
+  def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) }
+
+  /**
    * Locates the position of the first occurrence of the value in the given array as long.
    * Returns null if either of the arguments are null.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
index 0beebec..db00a18 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
@@ -91,3 +91,65 @@ FROM (
     encode(string(id + 3), 'utf-8') col4
   FROM range(10)
 );
+
+CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES (
+  array(true, false), array(true),
+  array(2Y, 1Y), array(3Y, 4Y),
+  array(2S, 1S), array(3S, 4S),
+  array(2, 1), array(3, 4),
+  array(2L, 1L), array(3L, 4L),
+  array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809),
+  array(2.0D, 1.0D), array(3.0D, 4.0D),
+  array(float(2.0), float(1.0)), array(float(3.0), float(4.0)),
+  array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'),
+  array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'),
+  array(timestamp '2016-11-11 20:54:00.000'),
+  array('a', 'b'), array('c', 'd'),
+  array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')),
+  array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)),
+  array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4))
+) AS various_arrays(
+  boolean_array1, boolean_array2,
+  tinyint_array1, tinyint_array2,
+  smallint_array1, smallint_array2,
+  int_array1, int_array2,
+  bigint_array1, bigint_array2,
+  decimal_array1, decimal_array2,
+  double_array1, double_array2,
+  float_array1, float_array2,
+  date_array1, data_array2,
+  timestamp_array1, timestamp_array2,
+  string_array1, string_array2,
+  array_array1, array_array2,
+  struct_array1, struct_array2,
+  map_array1, map_array2
+);
+
+-- Concatenate arrays of the same type
+SELECT
+    (boolean_array1 || boolean_array2) boolean_array,
+    (tinyint_array1 || tinyint_array2) tinyint_array,
+    (smallint_array1 || smallint_array2) smallint_array,
+    (int_array1 || int_array2) int_array,
+    (bigint_array1 || bigint_array2) bigint_array,
+    (decimal_array1 || decimal_array2) decimal_array,
+    (double_array1 || double_array2) double_array,
+    (float_array1 || float_array2) float_array,
+    (date_array1 || data_array2) data_array,
+    (timestamp_array1 || timestamp_array2) timestamp_array,
+    (string_array1 || string_array2) string_array,
+    (array_array1 || array_array2) array_array,
+    (struct_array1 || struct_array2) struct_array,
+    (map_array1 || map_array2) map_array
+FROM various_arrays;
+
+-- Concatenate arrays of different types
+SELECT
+    (tinyint_array1 || smallint_array2) ts_array,
+    (smallint_array1 || int_array2) si_array,
+    (int_array1 || bigint_array2) ib_array,
+    (double_array1 || float_array2) df_array,
+    (string_array1 || data_array2) std_array,
+    (timestamp_array1 || string_array2) tst_array,
+    (string_array1 || int_array2) sti_array
+FROM various_arrays;

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
index 09729fd..62befc5 100644
--- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
@@ -237,3 +237,81 @@ struct<col:binary>
 78910
 891011
 9101112
+
+
+-- !query 11
+CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES (
+  array(true, false), array(true),
+  array(2Y, 1Y), array(3Y, 4Y),
+  array(2S, 1S), array(3S, 4S),
+  array(2, 1), array(3, 4),
+  array(2L, 1L), array(3L, 4L),
+  array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809),
+  array(2.0D, 1.0D), array(3.0D, 4.0D),
+  array(float(2.0), float(1.0)), array(float(3.0), float(4.0)),
+  array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'),
+  array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'),
+  array(timestamp '2016-11-11 20:54:00.000'),
+  array('a', 'b'), array('c', 'd'),
+  array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')),
+  array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)),
+  array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4))
+) AS various_arrays(
+  boolean_array1, boolean_array2,
+  tinyint_array1, tinyint_array2,
+  smallint_array1, smallint_array2,
+  int_array1, int_array2,
+  bigint_array1, bigint_array2,
+  decimal_array1, decimal_array2,
+  double_array1, double_array2,
+  float_array1, float_array2,
+  date_array1, data_array2,
+  timestamp_array1, timestamp_array2,
+  string_array1, string_array2,
+  array_array1, array_array2,
+  struct_array1, struct_array2,
+  map_array1, map_array2
+)
+-- !query 11 schema
+struct<>
+-- !query 11 output
+
+
+
+-- !query 12
+SELECT
+    (boolean_array1 || boolean_array2) boolean_array,
+    (tinyint_array1 || tinyint_array2) tinyint_array,
+    (smallint_array1 || smallint_array2) smallint_array,
+    (int_array1 || int_array2) int_array,
+    (bigint_array1 || bigint_array2) bigint_array,
+    (decimal_array1 || decimal_array2) decimal_array,
+    (double_array1 || double_array2) double_array,
+    (float_array1 || float_array2) float_array,
+    (date_array1 || data_array2) data_array,
+    (timestamp_array1 || timestamp_array2) timestamp_array,
+    (string_array1 || string_array2) string_array,
+    (array_array1 || array_array2) array_array,
+    (struct_array1 || struct_array2) struct_array,
+    (map_array1 || map_array2) map_array
+FROM various_arrays
+-- !query 12 schema
+struct<boolean_array:array<boolean>,tinyint_array:array<tinyint>,smallint_array:array<smallint>,int_array:array<int>,bigint_array:array<bigint>,decimal_array:array<decimal(19,0)>,double_array:array<double>,float_array:array<float>,data_array:array<date>,timestamp_array:array<timestamp>,string_array:array<string>,array_array:array<array<string>>,struct_array:array<struct<col1:string,col2:int>>,map_array:array<map<string,int>>>
+-- !query 12 output
+[true,false,true]	[2,1,3,4]	[2,1,3,4]	[2,1,3,4]	[2,1,3,4]	[9223372036854775809,9223372036854775808,9223372036854775808,9223372036854775809]
[2.0,1.0,3.0,4.0]	[2.0,1.0,3.0,4.0]	[2016-03-14,2016-03-13,2016-03-12,2016-03-11]	[2016-11-15
20:54:00.0,2016-11-12 20:54:00.0,2016-11-11 20:54:00.0]	["a","b","c","d"]	[["a","b"],["c","d"],["e"],["f"]]
[{"col1":"a","col2":1},{"col1":"b","col2":2},{"col1":"c","col2":3},{"col1":"d","col2":4}]
[{"a":1},{"b":2},{"c":3},{"d":4}]
+
+
+-- !query 13
+SELECT
+    (tinyint_array1 || smallint_array2) ts_array,
+    (smallint_array1 || int_array2) si_array,
+    (int_array1 || bigint_array2) ib_array,
+    (double_array1 || float_array2) df_array,
+    (string_array1 || data_array2) std_array,
+    (timestamp_array1 || string_array2) tst_array,
+    (string_array1 || int_array2) sti_array
+FROM various_arrays
+-- !query 13 schema
+struct<ts_array:array<smallint>,si_array:array<int>,ib_array:array<bigint>,df_array:array<double>,std_array:array<string>,tst_array:array<string>,sti_array:array<string>>
+-- !query 13 output
+[2,1,3,4]	[2,1,3,4]	[2,1,3,4]	[2.0,1.0,3.0,4.0]	["a","b","2016-03-12","2016-03-11"]	["2016-11-15
20:54:00","2016-11-12 20:54:00","c","d"]	["a","b","3","4"]

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 7c976c1..25e5cd6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -617,6 +617,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
     )
   }
 
+  test("concat function - arrays") {
+    val nseqi : Seq[Int] = null
+    val nseqs : Seq[String] = null
+    val df = Seq(
+
+      (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"),
nseqs),
+      (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null),
nseqs)
+    ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn")
+
+    val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on
+
+    // Simple test cases
+    checkAnswer(
+      df.selectExpr("array(1, 2, 3L)"),
+      Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L)))
+    )
+
+    checkAnswer (
+      df.select(concat($"i1", $"s1")),
+      Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a")))
+    )
+    checkAnswer(
+      df.select(concat($"i1", $"i2", $"i3")),
+      Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2)))
+    )
+    checkAnswer(
+      df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")),
+      Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2)))
+    )
+    checkAnswer(
+      df.selectExpr("concat(array(1, null), i2, i3)"),
+      Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2)))
+    )
+    checkAnswer(
+      df.select(concat($"s1", $"s2", $"s3")),
+      Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
+    )
+    checkAnswer(
+      df.selectExpr("concat(s1, s2, s3)"),
+      Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
+    )
+    checkAnswer(
+      df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")),
+      Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
+    )
+
+    // Null test cases
+    checkAnswer(
+      df.select(concat($"i1", $"in")),
+      Seq(Row(null), Row(null))
+    )
+    checkAnswer(
+      df.select(concat($"in", $"i1")),
+      Seq(Row(null), Row(null))
+    )
+    checkAnswer(
+      df.select(concat($"s1", $"sn")),
+      Seq(Row(null), Row(null))
+    )
+    checkAnswer(
+      df.select(concat($"sn", $"s1")),
+      Seq(Row(null), Row(null))
+    )
+
+    // Type error test cases
+    intercept[AnalysisException] {
+      df.selectExpr("concat(i1, i2, null)")
+    }
+
+    intercept[AnalysisException] {
+      df.selectExpr("concat(i1, array(i1, i2))")
+    }
+  }
+
   private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
     import DataFrameFunctionsSuite.CodegenFallbackExpr
     for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true)))
{

http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index cbd7f9d..3998cec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -1742,8 +1742,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
       sql("DESCRIBE FUNCTION 'concat'"),
       Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") ::
         Row("Function: concat") ::
-        Row("Usage: concat(str1, str2, ..., strN) - " +
-            "Returns the concatenation of str1, str2, ..., strN.") :: Nil
+        Row("Usage: concat(col1, col2, ..., colN) - " +
+            "Returns the concatenation of col1, col2, ..., colN.") :: Nil
     )
     // extended mode
     checkAnswer(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message