spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-22973][SQL] Fix incorrect results of Casting Map to String
Date Sun, 07 Jan 2018 05:42:10 GMT
Repository: spark
Updated Branches:
  refs/heads/master 9a7048b28 -> 18e941499


[SPARK-22973][SQL] Fix incorrect results of Casting Map to String

## What changes were proposed in this pull request?
This pr fixed the issue when casting maps into strings;
```
scala> Seq(Map(1 -> "a", 2 -> "b")).toDF("a").write.saveAsTable("t")
scala> sql("SELECT cast(a as String) FROM t").show(false)
+----------------------------------------------------------------+
|a                                                               |
+----------------------------------------------------------------+
|org.apache.spark.sql.catalyst.expressions.UnsafeMapData38bdd75d|
+----------------------------------------------------------------+
```
This pr modified the result into;
```
+----------------+
|a               |
+----------------+
|[1 -> a, 2 -> b]|
+----------------+
```

## How was this patch tested?
Added tests in `CastSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #20166 from maropu/SPARK-22973.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/18e94149
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/18e94149
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/18e94149

Branch: refs/heads/master
Commit: 18e94149992618a2b4e6f0fd3b3f4594e1745224
Parents: 9a7048b
Author: Takeshi Yamamuro <yamamuro@apache.org>
Authored: Sun Jan 7 13:42:01 2018 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Sun Jan 7 13:42:01 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   | 89 ++++++++++++++++++++
 .../sql/catalyst/expressions/CastSuite.scala    | 28 ++++++
 2 files changed, 117 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/18e94149/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d4fc5e0..f2de4c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -228,6 +228,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
         builder.append("]")
         builder.build()
       })
+    case MapType(kt, vt, _) =>
+      buildCast[MapData](_, map => {
+        val builder = new UTF8StringBuilder
+        builder.append("[")
+        if (map.numElements > 0) {
+          val keyArray = map.keyArray()
+          val valueArray = map.valueArray()
+          val keyToUTF8String = castToString(kt)
+          val valueToUTF8String = castToString(vt)
+          builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String])
+          builder.append(" ->")
+          if (!valueArray.isNullAt(0)) {
+            builder.append(" ")
+            builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String])
+          }
+          var i = 1
+          while (i < map.numElements) {
+            builder.append(", ")
+            builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String])
+            builder.append(" ->")
+            if (!valueArray.isNullAt(i)) {
+              builder.append(" ")
+              builder.append(valueToUTF8String(valueArray.get(i, vt))
+                .asInstanceOf[UTF8String])
+            }
+            i += 1
+          }
+        }
+        builder.append("]")
+        builder.build()
+      })
     case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
   }
 
@@ -654,6 +685,53 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
      """.stripMargin
   }
 
+  private def writeMapToStringBuilder(
+      kt: DataType,
+      vt: DataType,
+      map: String,
+      buffer: String,
+      ctx: CodegenContext): String = {
+
+    def dataToStringFunc(func: String, dataType: DataType) = {
+      val funcName = ctx.freshName(func)
+      val dataToStringCode = castToStringCode(dataType, ctx)
+      ctx.addNewFunction(funcName,
+        s"""
+           |private UTF8String $funcName(${ctx.javaType(dataType)} data) {
+           |  UTF8String dataStr = null;
+           |  ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)}
+           |  return dataStr;
+           |}
+         """.stripMargin)
+    }
+
+    val keyToStringFunc = dataToStringFunc("keyToString", kt)
+    val valueToStringFunc = dataToStringFunc("valueToString", vt)
+    val loopIndex = ctx.freshName("loopIndex")
+    s"""
+       |$buffer.append("[");
+       |if ($map.numElements() > 0) {
+       |  $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")}));
+       |  $buffer.append(" ->");
+       |  if (!$map.valueArray().isNullAt(0)) {
+       |    $buffer.append(" ");
+       |    $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")}));
+       |  }
+       |  for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) {
+       |    $buffer.append(", ");
+       |    $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)}));
+       |    $buffer.append(" ->");
+       |    if (!$map.valueArray().isNullAt($loopIndex)) {
+       |      $buffer.append(" ");
+       |      $buffer.append($valueToStringFunc(
+       |        ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)}));
+       |    }
+       |  }
+       |}
+       |$buffer.append("]");
+     """.stripMargin
+  }
+
   private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction =
{
     from match {
       case BinaryType =>
@@ -676,6 +754,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
              |$evPrim = $buffer.build();
            """.stripMargin
         }
+      case MapType(kt, vt, _) =>
+        (c, evPrim, evNull) => {
+          val buffer = ctx.freshName("buffer")
+          val bufferClass = classOf[UTF8StringBuilder].getName
+          val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx)
+          s"""
+             |$bufferClass $buffer = new $bufferClass();
+             |$writeMapElemCode;
+             |$evPrim = $buffer.build();
+           """.stripMargin
+        }
       case _ =>
         (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/18e94149/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index e3ed717..1445bb8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -878,4 +878,32 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
       StringType)
     checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]")
   }
+
+  test("SPARK-22973 Cast map to string") {
+    val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType)
+    checkEvaluation(ret1, "[1 -> a, 2 -> b, 3 -> c]")
+    val ret2 = cast(
+      Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)),
+      StringType)
+    checkEvaluation(ret2, "[1 -> a, 2 ->, 3 -> c]")
+    val ret3 = cast(
+      Literal.create(Map(
+        1 -> Date.valueOf("2014-12-03"),
+        2 -> Date.valueOf("2014-12-04"),
+        3 -> Date.valueOf("2014-12-05"))),
+      StringType)
+    checkEvaluation(ret3, "[1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05]")
+    val ret4 = cast(
+      Literal.create(Map(
+        1 -> Timestamp.valueOf("2014-12-03 13:01:00"),
+        2 -> Timestamp.valueOf("2014-12-04 15:05:00"))),
+      StringType)
+    checkEvaluation(ret4, "[1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00]")
+    val ret5 = cast(
+      Literal.create(Map(
+        1 -> Array(1, 2, 3),
+        2 -> Array(4, 5, 6))),
+      StringType)
+    checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message