spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ues...@apache.org
Subject spark git commit: [SPARK-23934][SQL] Adding map_from_entries function
Date Fri, 22 Jun 2018 07:18:30 GMT
Repository: spark
Updated Branches:
  refs/heads/master dc8a6befa -> 92c2f00bd


[SPARK-23934][SQL] Adding map_from_entries function

## What changes were proposed in this pull request?
The PR adds the `map_from_entries` function that returns a map created from the given array
of entries.

## How was this patch tested?
New tests added into:
- `CollectionExpressionSuite`
- `DataFrameFunctionSuite`

## CodeGen Examples
### Primitive-type Keys and Values
```
val idf = Seq(
  Seq((1, 10), (2, 20), (3, 10)),
  Seq((1, 10), null, (2, 20))
).toDF("a")
idf.filter('a.isNotNull).select(map_from_entries('a)).debugCodegen
```
Result:
```
/* 042 */         boolean project_isNull_0 = false;
/* 043 */         MapData project_value_0 = null;
/* 044 */
/* 045 */         for (int project_idx_2 = 0; !project_isNull_0 && project_idx_2 <
inputadapter_value_0.numElements(); project_idx_2++) {
/* 046 */           project_isNull_0 |= inputadapter_value_0.isNullAt(project_idx_2);
/* 047 */         }
/* 048 */         if (!project_isNull_0) {
/* 049 */           final int project_numEntries_0 = inputadapter_value_0.numElements();
/* 050 */
/* 051 */           final long project_keySectionSize_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(project_numEntries_0,
4);
/* 052 */           final long project_valueSectionSize_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(project_numEntries_0,
4);
/* 053 */           final long project_byteArraySize_0 = 8 + project_keySectionSize_0 + project_valueSectionSize_0;
/* 054 */           if (project_byteArraySize_0 > 2147483632) {
/* 055 */             final Object[] project_keys_0 = new Object[project_numEntries_0];
/* 056 */             final Object[] project_values_0 = new Object[project_numEntries_0];
/* 057 */
/* 058 */             for (int project_idx_1 = 0; project_idx_1 < project_numEntries_0;
project_idx_1++) {
/* 059 */               InternalRow project_entry_1 = inputadapter_value_0.getStruct(project_idx_1,
2);
/* 060 */
/* 061 */               project_keys_0[project_idx_1] = project_entry_1.getInt(0);
/* 062 */               project_values_0[project_idx_1] = project_entry_1.getInt(1);
/* 063 */             }
/* 064 */
/* 065 */             project_value_0 = org.apache.spark.sql.catalyst.util.ArrayBasedMapData.apply(project_keys_0,
project_values_0);
/* 066 */
/* 067 */           } else {
/* 068 */             final byte[] project_byteArray_0 = new byte[(int)project_byteArraySize_0];
/* 069 */             UnsafeMapData project_unsafeMapData_0 = new UnsafeMapData();
/* 070 */             Platform.putLong(project_byteArray_0, 16, project_keySectionSize_0);
/* 071 */             Platform.putLong(project_byteArray_0, 24, project_numEntries_0);
/* 072 */             Platform.putLong(project_byteArray_0, 24 + project_keySectionSize_0,
project_numEntries_0);
/* 073 */             project_unsafeMapData_0.pointTo(project_byteArray_0, 16, (int)project_byteArraySize_0);
/* 074 */             ArrayData project_keyArrayData_0 = project_unsafeMapData_0.keyArray();
/* 075 */             ArrayData project_valueArrayData_0 = project_unsafeMapData_0.valueArray();
/* 076 */
/* 077 */             for (int project_idx_0 = 0; project_idx_0 < project_numEntries_0;
project_idx_0++) {
/* 078 */               InternalRow project_entry_0 = inputadapter_value_0.getStruct(project_idx_0,
2);
/* 079 */
/* 080 */               project_keyArrayData_0.setInt(project_idx_0, project_entry_0.getInt(0));
/* 081 */               project_valueArrayData_0.setInt(project_idx_0, project_entry_0.getInt(1));
/* 082 */             }
/* 083 */
/* 084 */             project_value_0 = project_unsafeMapData_0;
/* 085 */           }
/* 086 */
/* 087 */         }
```
### Non-primitive-type Keys and Values
```
val sdf = Seq(
  Seq(("a", null), ("b", "bb"), ("c", "aa")),
  Seq(("a", "aa"), null, (null, "bb"))
).toDF("a")
sdf.filter('a.isNotNull).select(map_from_entries('a)).debugCodegen
```
Result:
```
/* 042 */         boolean project_isNull_0 = false;
/* 043 */         MapData project_value_0 = null;
/* 044 */
/* 045 */         for (int project_idx_1 = 0; !project_isNull_0 && project_idx_1 <
inputadapter_value_0.numElements(); project_idx_1++) {
/* 046 */           project_isNull_0 |= inputadapter_value_0.isNullAt(project_idx_1);
/* 047 */         }
/* 048 */         if (!project_isNull_0) {
/* 049 */           final int project_numEntries_0 = inputadapter_value_0.numElements();
/* 050 */
/* 051 */           final Object[] project_keys_0 = new Object[project_numEntries_0];
/* 052 */           final Object[] project_values_0 = new Object[project_numEntries_0];
/* 053 */
/* 054 */           for (int project_idx_0 = 0; project_idx_0 < project_numEntries_0; project_idx_0++)
{
/* 055 */             InternalRow project_entry_0 = inputadapter_value_0.getStruct(project_idx_0,
2);
/* 056 */
/* 057 */             if (project_entry_0.isNullAt(0)) {
/* 058 */               throw new RuntimeException("The first field from a struct (key) can't
be null.");
/* 059 */             }
/* 060 */
/* 061 */             project_keys_0[project_idx_0] = project_entry_0.getUTF8String(0);
/* 062 */             project_values_0[project_idx_0] = project_entry_0.getUTF8String(1);
/* 063 */           }
/* 064 */
/* 065 */           project_value_0 = org.apache.spark.sql.catalyst.util.ArrayBasedMapData.apply(project_keys_0,
project_values_0);
/* 066 */
/* 067 */         }
```

Author: Marek Novotny <mn.mikke@gmail.com>

Closes #21282 from mn-mikke/feature/array-api-map_from_entries-to-master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/92c2f00b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/92c2f00b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/92c2f00b

Branch: refs/heads/master
Commit: 92c2f00bd275a90b6912fb8c8cf542002923629c
Parents: dc8a6be
Author: Marek Novotny <mn.mikke@gmail.com>
Authored: Fri Jun 22 16:18:22 2018 +0900
Committer: Takuya UESHIN <ueshin@databricks.com>
Committed: Fri Jun 22 16:18:22 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  20 ++
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/codegen/CodeGenerator.scala     |  30 +++
 .../expressions/collectionOperations.scala      | 235 +++++++++++++++++--
 .../CollectionExpressionsSuite.scala            |  51 ++++
 .../scala/org/apache/spark/sql/functions.scala  |   7 +
 .../spark/sql/DataFrameFunctionsSuite.scala     |  50 ++++
 7 files changed, 378 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 11b179f..5f5d733 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2412,6 +2412,26 @@ def map_entries(col):
     return Column(sc._jvm.functions.map_entries(_to_java_column(col)))
 
 
+@since(2.4)
+def map_from_entries(col):
+    """
+    Collection function: Returns a map created from the given array of entries.
+
+    :param col: name of column or expression
+
+    >>> from pyspark.sql.functions import map_from_entries
+    >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data")
+    >>> df.select(map_from_entries("data").alias("map")).show()
+    +----------------+
+    |             map|
+    +----------------+
+    |[1 -> a, 2 -> b]|
+    +----------------+
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.map_from_entries(_to_java_column(col)))
+
+
 @ignore_unicode_prefix
 @since(2.4)
 def array_repeat(col, count):

http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 4b09b9a..8abc616 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -421,6 +421,7 @@ object FunctionRegistry {
     expression[MapKeys]("map_keys"),
     expression[MapValues]("map_values"),
     expression[MapEntries]("map_entries"),
+    expression[MapFromEntries]("map_from_entries"),
     expression[Size]("size"),
     expression[Slice]("slice"),
     expression[Size]("cardinality"),

http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 66315e5..4cc0968 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -820,6 +820,36 @@ class CodegenContext {
   }
 
   /**
+   * Generates code to do null safe execution when accessing properties of complex
+   * ArrayData elements.
+   *
+   * @param nullElements used to decide whether the ArrayData might contain null or not.
+   * @param isNull a variable indicating whether the result will be evaluated to null or
not.
+   * @param arrayData a variable name representing the ArrayData.
+   * @param execute the code that should be executed only if the ArrayData doesn't contain
+   *                any null.
+   */
+  def nullArrayElementsSaveExec(
+      nullElements: Boolean,
+      isNull: String,
+      arrayData: String)(
+      execute: String): String = {
+    val i = freshName("idx")
+    if (nullElements) {
+      s"""
+         |for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) {
+         |  $isNull |= $arrayData.isNullAt($i);
+         |}
+         |if (!$isNull) {
+         |  $execute
+         |}
+       """.stripMargin
+    } else {
+      execute
+    }
+  }
+
+  /**
    * Splits the generated code of expressions into multiple functions, because function has
    * 64kb code size limit in JVM. If the class to which the function would be inlined would
grow
    * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to
it

http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 7c064a1..3afabe1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
-import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
+import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.Platform
@@ -476,6 +476,223 @@ case class MapEntries(child: Expression) extends UnaryExpression with
ExpectsInp
 }
 
 /**
+ * Returns a map created from the given array of entries.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));
+       {1:"a",2:"b"}
+  """,
+  since = "2.4.0")
+case class MapFromEntries(child: Expression) extends UnaryExpression {
+
+  @transient
+  private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType
match {
+    case ArrayType(
+      StructType(Array(
+        StructField(_, keyType, keyNullable, _),
+        StructField(_, valueType, valueNullable, _))),
+      containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable,
containsNull))
+    case _ => None
+  }
+
+  private def nullEntries: Boolean = dataTypeDetails.get._3
+
+  override def nullable: Boolean = child.nullable || nullEntries
+
+  override def dataType: MapType = dataTypeDetails.get._1
+
+  override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
+    case Some(_) => TypeCheckResult.TypeCheckSuccess
+    case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " +
+      s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.")
+  }
+
+  override protected def nullSafeEval(input: Any): Any = {
+    val arrayData = input.asInstanceOf[ArrayData]
+    val numEntries = arrayData.numElements()
+    var i = 0
+    if(nullEntries) {
+      while (i < numEntries) {
+        if (arrayData.isNullAt(i)) return null
+        i += 1
+      }
+    }
+    val keyArray = new Array[AnyRef](numEntries)
+    val valueArray = new Array[AnyRef](numEntries)
+    i = 0
+    while (i < numEntries) {
+      val entry = arrayData.getStruct(i, 2)
+      val key = entry.get(0, dataType.keyType)
+      if (key == null) {
+        throw new RuntimeException("The first field from a struct (key) can't be null.")
+      }
+      keyArray.update(i, key)
+      val value = entry.get(1, dataType.valueType)
+      valueArray.update(i, value)
+      i += 1
+    }
+    ArrayBasedMapData(keyArray, valueArray)
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, c => {
+      val numEntries = ctx.freshName("numEntries")
+      val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType)
+      val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
+      val code = if (isKeyPrimitive && isValuePrimitive) {
+        genCodeForPrimitiveElements(ctx, c, ev.value, numEntries)
+      } else {
+        genCodeForAnyElements(ctx, c, ev.value, numEntries)
+      }
+      ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) {
+        s"""
+           |final int $numEntries = $c.numElements();
+           |$code
+         """.stripMargin
+      }
+    })
+  }
+
+  private def genCodeForAssignmentLoop(
+      ctx: CodegenContext,
+      childVariable: String,
+      mapData: String,
+      numEntries: String,
+      keyAssignment: (String, String) => String,
+      valueAssignment: (String, String) => String): String = {
+    val entry = ctx.freshName("entry")
+    val i = ctx.freshName("idx")
+
+    val nullKeyCheck = if (dataTypeDetails.get._2) {
+      s"""
+         |if ($entry.isNullAt(0)) {
+         |  throw new RuntimeException("The first field from a struct (key) can't be null.");
+         |}
+       """.stripMargin
+    } else {
+      ""
+    }
+
+    s"""
+       |for (int $i = 0; $i < $numEntries; $i++) {
+       |  InternalRow $entry = $childVariable.getStruct($i, 2);
+       |  $nullKeyCheck
+       |  ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)}
+       |  ${valueAssignment(entry, i)}
+       |}
+     """.stripMargin
+  }
+
+  private def genCodeForPrimitiveElements(
+      ctx: CodegenContext,
+      childVariable: String,
+      mapData: String,
+      numEntries: String): String = {
+    val byteArraySize = ctx.freshName("byteArraySize")
+    val keySectionSize = ctx.freshName("keySectionSize")
+    val valueSectionSize = ctx.freshName("valueSectionSize")
+    val data = ctx.freshName("byteArray")
+    val unsafeMapData = ctx.freshName("unsafeMapData")
+    val keyArrayData = ctx.freshName("keyArrayData")
+    val valueArrayData = ctx.freshName("valueArrayData")
+
+    val baseOffset = Platform.BYTE_ARRAY_OFFSET
+    val keySize = dataType.keyType.defaultSize
+    val valueSize = dataType.valueType.defaultSize
+    val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)"
+    val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)"
+    val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType)
+    val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType)
+
+    val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx,
$key);"
+    val valueAssignment = (entry: String, idx: String) => {
+      val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
+      val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);"
+      if (dataType.valueContainsNull) {
+        s"""
+           |if ($entry.isNullAt(1)) {
+           |  $valueArrayData.setNullAt($idx);
+           |} else {
+           |  $valueNullUnsafeAssignment
+           |}
+         """.stripMargin
+      } else {
+        valueNullUnsafeAssignment
+      }
+    }
+    val assignmentLoop = genCodeForAssignmentLoop(
+      ctx,
+      childVariable,
+      mapData,
+      numEntries,
+      keyAssignment,
+      valueAssignment
+    )
+
+    s"""
+       |final long $keySectionSize = $kByteSize;
+       |final long $valueSectionSize = $vByteSize;
+       |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize;
+       |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+       |  ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)}
+       |} else {
+       |  final byte[] $data = new byte[(int)$byteArraySize];
+       |  UnsafeMapData $unsafeMapData = new UnsafeMapData();
+       |  Platform.putLong($data, $baseOffset, $keySectionSize);
+       |  Platform.putLong($data, ${baseOffset + 8}, $numEntries);
+       |  Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries);
+       |  $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize);
+       |  ArrayData $keyArrayData = $unsafeMapData.keyArray();
+       |  ArrayData $valueArrayData = $unsafeMapData.valueArray();
+       |  $assignmentLoop
+       |  $mapData = $unsafeMapData;
+       |}
+     """.stripMargin
+  }
+
+  private def genCodeForAnyElements(
+      ctx: CodegenContext,
+      childVariable: String,
+      mapData: String,
+      numEntries: String): String = {
+    val keys = ctx.freshName("keys")
+    val values = ctx.freshName("values")
+    val mapDataClass = classOf[ArrayBasedMapData].getName()
+
+    val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
+    val valueAssignment = (entry: String, idx: String) => {
+      val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
+      if (dataType.valueContainsNull && isValuePrimitive) {
+        s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;"
+      } else {
+        s"$values[$idx] = $value;"
+      }
+    }
+    val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;"
+    val assignmentLoop = genCodeForAssignmentLoop(
+      ctx,
+      childVariable,
+      mapData,
+      numEntries,
+      keyAssignment,
+      valueAssignment)
+
+    s"""
+       |final Object[] $keys = new Object[$numEntries];
+       |final Object[] $values = new Object[$numEntries];
+       |$assignmentLoop
+       |$mapData = $mapDataClass.apply($keys, $values);
+     """.stripMargin
+  }
+
+  override def prettyName: String = "map_from_entries"
+}
+
+
+/**
  * Common base class for [[SortArray]] and [[ArraySort]].
  */
 trait ArraySortLike extends ExpectsInputTypes {
@@ -1990,24 +2207,10 @@ case class Flatten(child: Expression) extends UnaryExpression {
       } else {
         genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value)
       }
-      if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code
+      ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code)
     })
   }
 
-  private def nullElementsProtection(
-      ev: ExprCode,
-      childVariableName: String,
-      coreLogic: String): String = {
-    s"""
-    |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++)
{
-    |  ${ev.isNull} |= $childVariableName.isNullAt(z);
-    |}
-    |if (!${ev.isNull}) {
-    |  $coreLogic
-    |}
-    """.stripMargin
-  }
-
   private def genCodeForNumberOfElements(
       ctx: CodegenContext,
       childVariableName: String) : (String, String) = {

http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index f377f9c..5b8cf51 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -80,6 +80,57 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
     checkEvaluation(MapEntries(ms2), null)
   }
 
+  test("MapFromEntries") {
+    def arrayType(keyType: DataType, valueType: DataType) : DataType = {
+      ArrayType(
+        StructType(Seq(
+          StructField("a", keyType),
+          StructField("b", valueType))),
+        true)
+    }
+    def r(values: Any*): InternalRow = create_row(values: _*)
+
+    // Primitive-type keys and values
+    val aiType = arrayType(IntegerType, IntegerType)
+    val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType)
+    val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType)
+    val ai2 = Literal.create(Seq.empty, aiType)
+    val ai3 = Literal.create(null, aiType)
+    val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType)
+    val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType)
+    val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType)
+
+    checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20))
+    checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null))
+    checkEvaluation(MapFromEntries(ai2), Map.empty)
+    checkEvaluation(MapFromEntries(ai3), null)
+    checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1))
+    checkExceptionInExpression[RuntimeException](
+      MapFromEntries(ai5),
+      "The first field from a struct (key) can't be null.")
+    checkEvaluation(MapFromEntries(ai6), null)
+
+    // Non-primitive-type keys and values
+    val asType = arrayType(StringType, StringType)
+    val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType)
+    val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType)
+    val as2 = Literal.create(Seq.empty, asType)
+    val as3 = Literal.create(null, asType)
+    val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType)
+    val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType)
+    val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType)
+
+    checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb"))
+    checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null))
+    checkEvaluation(MapFromEntries(as2), Map.empty)
+    checkEvaluation(MapFromEntries(as3), null)
+    checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a"))
+    checkExceptionInExpression[RuntimeException](
+      MapFromEntries(as5),
+      "The first field from a struct (key) can't be null.")
+    checkEvaluation(MapFromEntries(as6), null)
+  }
+
   test("Sort Array") {
     val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
     val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))

http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index c296a1b..f792a6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3527,6 +3527,13 @@ object functions {
   def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) }
 
   /**
+   * Returns a map created from the given array of entries.
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) }
+
+  /**
    * Returns a merged array of structs in which the N-th struct contains all N-th values
of input
    * arrays.
    * @group collection_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index fcdd33f..25fdbab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -633,6 +633,56 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext
{
     checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected)
   }
 
+  test("map_from_entries function") {
+    def dummyFilter(c: Column): Column = c.isNull || c.isNotNull
+    val oneRowDF = Seq(3215).toDF("i")
+
+    // Test cases with primitive-type keys and values
+    val idf = Seq(
+      Seq((1, 10), (2, 20), (3, 10)),
+      Seq((1, 10), null, (2, 20)),
+      Seq.empty,
+      null
+    ).toDF("a")
+    val iExpected = Seq(
+      Row(Map(1 -> 10, 2 -> 20, 3 -> 10)),
+      Row(null),
+      Row(Map.empty),
+      Row(null))
+
+    checkAnswer(idf.select(map_from_entries('a)), iExpected)
+    checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected)
+    checkAnswer(idf.filter(dummyFilter('a)).select(map_from_entries('a)), iExpected)
+    checkAnswer(
+      oneRowDF.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"),
+      Seq(Row(Map(1 -> null, 2 -> null)))
+    )
+    checkAnswer(
+      oneRowDF.filter(dummyFilter('i))
+        .selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"),
+      Seq(Row(Map(1 -> null, 2 -> null)))
+    )
+
+    // Test cases with non-primitive-type keys and values
+    val sdf = Seq(
+      Seq(("a", "aa"), ("b", "bb"), ("c", "aa")),
+      Seq(("a", "aa"), null, ("b", "bb")),
+      Seq(("a", null), ("b", null)),
+      Seq.empty,
+      null
+    ).toDF("a")
+    val sExpected = Seq(
+      Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")),
+      Row(null),
+      Row(Map("a" -> null, "b" -> null)),
+      Row(Map.empty),
+      Row(null))
+
+    checkAnswer(sdf.select(map_from_entries('a)), sExpected)
+    checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected)
+    checkAnswer(sdf.filter(dummyFilter('a)).select(map_from_entries('a)), sExpected)
+  }
+
   test("array contains function") {
     val df = Seq(
       (Seq[Int](1, 2), "x", 1),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message