spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From gurwls...@apache.org
Subject spark git commit: [SPARK-24709][SQL] schema_of_json() - schema inference from an example
Date Wed, 04 Jul 2018 01:38:27 GMT
Repository: spark
Updated Branches:
  refs/heads/master 5585c5765 -> 776f299fc


[SPARK-24709][SQL] schema_of_json() - schema inference from an example

## What changes were proposed in this pull request?

In the PR, I propose to add new function - *schema_of_json()* which infers schema of JSON string literal. The result of the function is a string containing a schema in DDL format.

One of the use cases is using of *schema_of_json()* in the combination with *from_json()*. Currently, _from_json()_ requires a schema as a mandatory argument. The *schema_of_json()* function will allow to point out an JSON string as an example which has the same schema as the first argument of _from_json()_. For instance:

```sql
select from_json(json_column, schema_of_json('{"c1": [0], "c2": [{"c3":0}]}'))
from json_table;
```

## How was this patch tested?

Added new test to `JsonFunctionsSuite`, `JsonExpressionsSuite` and SQL tests to `json-functions.sql`

Author: Maxim Gekk <maxim.gekk@databricks.com>

Closes #21686 from MaxGekk/infer_schema_json.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/776f299f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/776f299f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/776f299f

Branch: refs/heads/master
Commit: 776f299fc8146b400e97185b1577b0fc8f06e14b
Parents: 5585c57
Author: Maxim Gekk <maxim.gekk@databricks.com>
Authored: Wed Jul 4 09:38:18 2018 +0800
Committer: hyukjinkwon <gurwls223@apache.org>
Committed: Wed Jul 4 09:38:18 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  27 ++
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../catalyst/expressions/jsonExpressions.scala  |  52 ++-
 .../sql/catalyst/json/JsonInferSchema.scala     | 348 ++++++++++++++++++
 .../expressions/JsonExpressionsSuite.scala      |   7 +
 .../datasources/json/JsonDataSource.scala       |   2 +-
 .../datasources/json/JsonInferSchema.scala      | 349 -------------------
 .../scala/org/apache/spark/sql/functions.scala  |  42 +++
 .../sql-tests/inputs/json-functions.sql         |   4 +
 .../sql-tests/results/json-functions.sql.out    |  20 +-
 .../apache/spark/sql/JsonFunctionsSuite.scala   |  17 +-
 .../execution/datasources/json/JsonSuite.scala  |   4 +-
 12 files changed, 509 insertions(+), 364 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 9652d3e..4d37197 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2189,11 +2189,16 @@ def from_json(col, schema, options={}):
     >>> df = spark.createDataFrame(data, ("key", "value"))
     >>> df.select(from_json(df.value, schema).alias("json")).collect()
     [Row(json=[Row(a=1)])]
+    >>> schema = schema_of_json(lit('''{"a": 0}'''))
+    >>> df.select(from_json(df.value, schema).alias("json")).collect()
+    [Row(json=Row(a=1))]
     """
 
     sc = SparkContext._active_spark_context
     if isinstance(schema, DataType):
         schema = schema.json()
+    elif isinstance(schema, Column):
+        schema = _to_java_column(schema)
     jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options)
     return Column(jc)
 
@@ -2235,6 +2240,28 @@ def to_json(col, options={}):
     return Column(jc)
 
 
+@ignore_unicode_prefix
+@since(2.4)
+def schema_of_json(col):
+    """
+    Parses a column containing a JSON string and infers its schema in DDL format.
+
+    :param col: string column in json format
+
+    >>> from pyspark.sql.types import *
+    >>> data = [(1, '{"a": 1}')]
+    >>> df = spark.createDataFrame(data, ("key", "value"))
+    >>> df.select(schema_of_json(df.value).alias("json")).collect()
+    [Row(json=u'struct<a:bigint>')]
+    >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect()
+    [Row(json=u'struct<a:bigint>')]
+    """
+
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.schema_of_json(_to_java_column(col))
+    return Column(jc)
+
+
 @since(1.5)
 def size(col):
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index a574d8a..80a0af6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -505,6 +505,7 @@ object FunctionRegistry {
     // json
     expression[StructsToJson]("to_json"),
     expression[JsonToStructs]("from_json"),
+    expression[SchemaOfJson]("schema_of_json"),
 
     // cast
     expression[Cast]("cast"),

http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index f6d74f5..8cd8605 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter}
+import java.io._
 
 import scala.util.parsing.combinator.RegexParsers
 
@@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.json._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField
+import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -525,17 +526,19 @@ case class JsonToStructs(
   override def nullable: Boolean = true
 
   // Used in `FunctionRegistry`
-  def this(child: Expression, schema: Expression) =
+  def this(child: Expression, schema: Expression, options: Map[String, String]) =
     this(
-      schema = JsonExprUtils.validateSchemaLiteral(schema),
-      options = Map.empty[String, String],
+      schema = JsonExprUtils.evalSchemaExpr(schema),
+      options = options,
       child = child,
       timeZoneId = None,
       forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
 
+  def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String])
+
   def this(child: Expression, schema: Expression, options: Expression) =
     this(
-      schema = JsonExprUtils.validateSchemaLiteral(schema),
+      schema = JsonExprUtils.evalSchemaExpr(schema),
       options = JsonExprUtils.convertToMapData(options),
       child = child,
       timeZoneId = None,
@@ -744,11 +747,44 @@ case class StructsToJson(
   override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil
 }
 
+/**
+ * A function infers schema of JSON string.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(json[, options]) - Returns schema in the DDL format of JSON string.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_('[{"col":0}]');
+       array<struct<col:int>>
+  """,
+  since = "2.4.0")
+case class SchemaOfJson(child: Expression)
+  extends UnaryExpression with String2StringExpression with CodegenFallback {
+
+  private val jsonOptions = new JSONOptions(Map.empty, "UTC")
+  private val jsonFactory = new JsonFactory()
+  jsonOptions.setJacksonOptions(jsonFactory)
+
+  override def convert(v: UTF8String): UTF8String = {
+    val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser =>
+      parser.nextToken()
+      inferField(parser, jsonOptions)
+    }
+
+    UTF8String.fromString(dt.catalogString)
+  }
+}
+
 object JsonExprUtils {
 
-  def validateSchemaLiteral(exp: Expression): DataType = exp match {
+  def evalSchemaExpr(exp: Expression): DataType = exp match {
     case Literal(s, StringType) => DataType.fromDDL(s.toString)
-    case e => throw new AnalysisException(s"Expected a string literal instead of $e")
+    case e @ SchemaOfJson(_: Literal) =>
+      val ddlSchema = e.eval().asInstanceOf[UTF8String]
+      DataType.fromDDL(ddlSchema.toString)
+    case e => throw new AnalysisException(
+      "Schema should be specified in DDL format as a string literal" +
+      s" or output of the schema_of_json function instead of ${e.sql}")
   }
 
   def convertToMapData(exp: Expression): Map[String, String] = exp match {

http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
new file mode 100644
index 0000000..491ca00
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
@@ -0,0 +1,348 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.json
+
+import java.util.Comparator
+
+import com.fasterxml.jackson.core._
+
+import org.apache.spark.SparkException
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion
+import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
+import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+private[sql] object JsonInferSchema {
+
+  /**
+   * Infer the type of a collection of json records in three stages:
+   *   1. Infer the type of each record
+   *   2. Merge types by choosing the lowest type necessary to cover equal keys
+   *   3. Replace any remaining null fields with string, the top type
+   */
+  def infer[T](
+      json: RDD[T],
+      configOptions: JSONOptions,
+      createParser: (JsonFactory, T) => JsonParser): StructType = {
+    val parseMode = configOptions.parseMode
+    val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
+
+    // In each RDD partition, perform schema inference on each row and merge afterwards.
+    val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode)
+    val mergedTypesFromPartitions = json.mapPartitions { iter =>
+      val factory = new JsonFactory()
+      configOptions.setJacksonOptions(factory)
+      iter.flatMap { row =>
+        try {
+          Utils.tryWithResource(createParser(factory, row)) { parser =>
+            parser.nextToken()
+            Some(inferField(parser, configOptions))
+          }
+        } catch {
+          case  e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match {
+            case PermissiveMode =>
+              Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType))))
+            case DropMalformedMode =>
+              None
+            case FailFastMode =>
+              throw new SparkException("Malformed records are detected in schema inference. " +
+                s"Parse Mode: ${FailFastMode.name}.", e)
+          }
+        }
+      }.reduceOption(typeMerger).toIterator
+    }
+
+    // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because
+    // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have
+    // active SparkSession and `SQLConf.get` may point to the wrong configs.
+    val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger)
+
+    canonicalizeType(rootType, configOptions) match {
+      case Some(st: StructType) => st
+      case _ =>
+        // canonicalizeType erases all empty structs, including the only one we want to keep
+        StructType(Nil)
+    }
+  }
+
+  private[this] val structFieldComparator = new Comparator[StructField] {
+    override def compare(o1: StructField, o2: StructField): Int = {
+      o1.name.compareTo(o2.name)
+    }
+  }
+
+  private def isSorted(arr: Array[StructField]): Boolean = {
+    var i: Int = 0
+    while (i < arr.length - 1) {
+      if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
+        return false
+      }
+      i += 1
+    }
+    true
+  }
+
+  /**
+   * Infer the type of a json document from the parser's token stream
+   */
+  def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
+    import com.fasterxml.jackson.core.JsonToken._
+    parser.getCurrentToken match {
+      case null | VALUE_NULL => NullType
+
+      case FIELD_NAME =>
+        parser.nextToken()
+        inferField(parser, configOptions)
+
+      case VALUE_STRING if parser.getTextLength < 1 =>
+        // Zero length strings and nulls have special handling to deal
+        // with JSON generators that do not distinguish between the two.
+        // To accurately infer types for empty strings that are really
+        // meant to represent nulls we assume that the two are isomorphic
+        // but will defer treating null fields as strings until all the
+        // record fields' types have been combined.
+        NullType
+
+      case VALUE_STRING => StringType
+      case START_OBJECT =>
+        val builder = Array.newBuilder[StructField]
+        while (nextUntil(parser, END_OBJECT)) {
+          builder += StructField(
+            parser.getCurrentName,
+            inferField(parser, configOptions),
+            nullable = true)
+        }
+        val fields: Array[StructField] = builder.result()
+        // Note: other code relies on this sorting for correctness, so don't remove it!
+        java.util.Arrays.sort(fields, structFieldComparator)
+        StructType(fields)
+
+      case START_ARRAY =>
+        // If this JSON array is empty, we use NullType as a placeholder.
+        // If this array is not empty in other JSON objects, we can resolve
+        // the type as we pass through all JSON objects.
+        var elementType: DataType = NullType
+        while (nextUntil(parser, END_ARRAY)) {
+          elementType = compatibleType(
+            elementType, inferField(parser, configOptions))
+        }
+
+        ArrayType(elementType)
+
+      case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType
+
+      case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType
+
+      case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
+        import JsonParser.NumberType._
+        parser.getNumberType match {
+          // For Integer values, use LongType by default.
+          case INT | LONG => LongType
+          // Since we do not have a data type backed by BigInteger,
+          // when we see a Java BigInteger, we use DecimalType.
+          case BIG_INTEGER | BIG_DECIMAL =>
+            val v = parser.getDecimalValue
+            if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
+              DecimalType(Math.max(v.precision(), v.scale()), v.scale())
+            } else {
+              DoubleType
+            }
+          case FLOAT | DOUBLE if configOptions.prefersDecimal =>
+            val v = parser.getDecimalValue
+            if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
+              DecimalType(Math.max(v.precision(), v.scale()), v.scale())
+            } else {
+              DoubleType
+            }
+          case FLOAT | DOUBLE =>
+            DoubleType
+        }
+
+      case VALUE_TRUE | VALUE_FALSE => BooleanType
+    }
+  }
+
+  /**
+   * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields,
+   * drops NullTypes or converts them to StringType based on provided options.
+   */
+  private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match {
+    case at: ArrayType =>
+      canonicalizeType(at.elementType, options)
+        .map(t => at.copy(elementType = t))
+
+    case StructType(fields) =>
+      val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f =>
+        canonicalizeType(f.dataType, options)
+          .map(t => f.copy(dataType = t))
+      }
+      // SPARK-8093: empty structs should be deleted
+      if (canonicalFields.isEmpty) {
+        None
+      } else {
+        Some(StructType(canonicalFields))
+      }
+
+    case NullType =>
+      if (options.dropFieldIfAllNull) {
+        None
+      } else {
+        Some(StringType)
+      }
+
+    case other => Some(other)
+  }
+
+  private def withCorruptField(
+      struct: StructType,
+      other: DataType,
+      columnNameOfCorruptRecords: String,
+      parseMode: ParseMode) = parseMode match {
+    case PermissiveMode =>
+      // If we see any other data type at the root level, we get records that cannot be
+      // parsed. So, we use the struct as the data type and add the corrupt field to the schema.
+      if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
+        // If this given struct does not have a column used for corrupt records,
+        // add this field.
+        val newFields: Array[StructField] =
+          StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
+        // Note: other code relies on this sorting for correctness, so don't remove it!
+        java.util.Arrays.sort(newFields, structFieldComparator)
+        StructType(newFields)
+      } else {
+        // Otherwise, just return this struct.
+        struct
+      }
+
+    case DropMalformedMode =>
+      // If corrupt record handling is disabled we retain the valid schema and discard the other.
+      struct
+
+    case FailFastMode =>
+      // If `other` is not struct type, consider it as malformed one and throws an exception.
+      throw new SparkException("Malformed records are detected in schema inference. " +
+        s"Parse Mode: ${FailFastMode.name}. Reasons: Failed to infer a common schema. " +
+        s"Struct types are expected, but `${other.catalogString}` was found.")
+  }
+
+  /**
+   * Remove top-level ArrayType wrappers and merge the remaining schemas
+   */
+  private def compatibleRootType(
+      columnNameOfCorruptRecords: String,
+      parseMode: ParseMode): (DataType, DataType) => DataType = {
+    // Since we support array of json objects at the top level,
+    // we need to check the element type and find the root level data type.
+    case (ArrayType(ty1, _), ty2) =>
+      compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
+    case (ty1, ArrayType(ty2, _)) =>
+      compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
+    // Discard null/empty documents
+    case (struct: StructType, NullType) => struct
+    case (NullType, struct: StructType) => struct
+    case (struct: StructType, o) if !o.isInstanceOf[StructType] =>
+      withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode)
+    case (o, struct: StructType) if !o.isInstanceOf[StructType] =>
+      withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode)
+    // If we get anything else, we call compatibleType.
+    // Usually, when we reach here, ty1 and ty2 are two StructTypes.
+    case (ty1, ty2) => compatibleType(ty1, ty2)
+  }
+
+  private[this] val emptyStructFieldArray = Array.empty[StructField]
+
+  /**
+   * Returns the most general data type for two given data types.
+   */
+  def compatibleType(t1: DataType, t2: DataType): DataType = {
+    TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
+      // t1 or t2 is a StructType, ArrayType, or an unexpected type.
+      (t1, t2) match {
+        // Double support larger range than fixed decimal, DecimalType.Maximum should be enough
+        // in most case, also have better precision.
+        case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
+          DoubleType
+
+        case (t1: DecimalType, t2: DecimalType) =>
+          val scale = math.max(t1.scale, t2.scale)
+          val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
+          if (range + scale > 38) {
+            // DecimalType can't support precision > 38
+            DoubleType
+          } else {
+            DecimalType(range + scale, scale)
+          }
+
+        case (StructType(fields1), StructType(fields2)) =>
+          // Both fields1 and fields2 should be sorted by name, since inferField performs sorting.
+          // Therefore, we can take advantage of the fact that we're merging sorted lists and skip
+          // building a hash map or performing additional sorting.
+          assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}")
+          assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}")
+
+          val newFields = new java.util.ArrayList[StructField]()
+
+          var f1Idx = 0
+          var f2Idx = 0
+
+          while (f1Idx < fields1.length && f2Idx < fields2.length) {
+            val f1Name = fields1(f1Idx).name
+            val f2Name = fields2(f2Idx).name
+            val comp = f1Name.compareTo(f2Name)
+            if (comp == 0) {
+              val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
+              newFields.add(StructField(f1Name, dataType, nullable = true))
+              f1Idx += 1
+              f2Idx += 1
+            } else if (comp < 0) { // f1Name < f2Name
+              newFields.add(fields1(f1Idx))
+              f1Idx += 1
+            } else { // f1Name > f2Name
+              newFields.add(fields2(f2Idx))
+              f2Idx += 1
+            }
+          }
+          while (f1Idx < fields1.length) {
+            newFields.add(fields1(f1Idx))
+            f1Idx += 1
+          }
+          while (f2Idx < fields2.length) {
+            newFields.add(fields2(f2Idx))
+            f2Idx += 1
+          }
+          StructType(newFields.toArray(emptyStructFieldArray))
+
+        case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
+          ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
+
+        // The case that given `DecimalType` is capable of given `IntegralType` is handled in
+        // `findTightestCommonType`. Both cases below will be executed only when the given
+        // `DecimalType` is not capable of the given `IntegralType`.
+        case (t1: IntegralType, t2: DecimalType) =>
+          compatibleType(DecimalType.forType(t1), t2)
+        case (t1: DecimalType, t2: IntegralType) =>
+          compatibleType(t1, DecimalType.forType(t2))
+
+        // strings and every string is a Json object.
+        case (_, _) => StringType
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index 00e9763..52203b9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -706,4 +706,11 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
       assert(schemaToCompare == schema)
     }
   }
+
+  test("SPARK-24709: infer schema of json strings") {
+    checkEvaluation(SchemaOfJson(Literal.create("""{"col":0}""")), "struct<col:bigint>")
+    checkEvaluation(
+      SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")),
+      "struct<col0:array<string>,col1:struct<col2:string>>")
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 3b6df45..2fee212 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -33,7 +33,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
 import org.apache.spark.rdd.{BinaryFileRDD, RDD}
 import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
+import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions}
 import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.text.TextFileFormat

http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
deleted file mode 100644
index 8e1b430..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ /dev/null
@@ -1,349 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.json
-
-import java.util.Comparator
-
-import com.fasterxml.jackson.core._
-
-import org.apache.spark.SparkException
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.TypeCoercion
-import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
-import org.apache.spark.sql.catalyst.json.JSONOptions
-import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode}
-import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
-
-private[sql] object JsonInferSchema {
-
-  /**
-   * Infer the type of a collection of json records in three stages:
-   *   1. Infer the type of each record
-   *   2. Merge types by choosing the lowest type necessary to cover equal keys
-   *   3. Replace any remaining null fields with string, the top type
-   */
-  def infer[T](
-      json: RDD[T],
-      configOptions: JSONOptions,
-      createParser: (JsonFactory, T) => JsonParser): StructType = {
-    val parseMode = configOptions.parseMode
-    val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
-
-    // In each RDD partition, perform schema inference on each row and merge afterwards.
-    val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode)
-    val mergedTypesFromPartitions = json.mapPartitions { iter =>
-      val factory = new JsonFactory()
-      configOptions.setJacksonOptions(factory)
-      iter.flatMap { row =>
-        try {
-          Utils.tryWithResource(createParser(factory, row)) { parser =>
-            parser.nextToken()
-            Some(inferField(parser, configOptions))
-          }
-        } catch {
-          case  e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match {
-            case PermissiveMode =>
-              Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType))))
-            case DropMalformedMode =>
-              None
-            case FailFastMode =>
-              throw new SparkException("Malformed records are detected in schema inference. " +
-                s"Parse Mode: ${FailFastMode.name}.", e)
-          }
-        }
-      }.reduceOption(typeMerger).toIterator
-    }
-
-    // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because
-    // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have
-    // active SparkSession and `SQLConf.get` may point to the wrong configs.
-    val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger)
-
-    canonicalizeType(rootType, configOptions) match {
-      case Some(st: StructType) => st
-      case _ =>
-        // canonicalizeType erases all empty structs, including the only one we want to keep
-        StructType(Nil)
-    }
-  }
-
-  private[this] val structFieldComparator = new Comparator[StructField] {
-    override def compare(o1: StructField, o2: StructField): Int = {
-      o1.name.compareTo(o2.name)
-    }
-  }
-
-  private def isSorted(arr: Array[StructField]): Boolean = {
-    var i: Int = 0
-    while (i < arr.length - 1) {
-      if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
-        return false
-      }
-      i += 1
-    }
-    true
-  }
-
-  /**
-   * Infer the type of a json document from the parser's token stream
-   */
-  private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
-    import com.fasterxml.jackson.core.JsonToken._
-    parser.getCurrentToken match {
-      case null | VALUE_NULL => NullType
-
-      case FIELD_NAME =>
-        parser.nextToken()
-        inferField(parser, configOptions)
-
-      case VALUE_STRING if parser.getTextLength < 1 =>
-        // Zero length strings and nulls have special handling to deal
-        // with JSON generators that do not distinguish between the two.
-        // To accurately infer types for empty strings that are really
-        // meant to represent nulls we assume that the two are isomorphic
-        // but will defer treating null fields as strings until all the
-        // record fields' types have been combined.
-        NullType
-
-      case VALUE_STRING => StringType
-      case START_OBJECT =>
-        val builder = Array.newBuilder[StructField]
-        while (nextUntil(parser, END_OBJECT)) {
-          builder += StructField(
-            parser.getCurrentName,
-            inferField(parser, configOptions),
-            nullable = true)
-        }
-        val fields: Array[StructField] = builder.result()
-        // Note: other code relies on this sorting for correctness, so don't remove it!
-        java.util.Arrays.sort(fields, structFieldComparator)
-        StructType(fields)
-
-      case START_ARRAY =>
-        // If this JSON array is empty, we use NullType as a placeholder.
-        // If this array is not empty in other JSON objects, we can resolve
-        // the type as we pass through all JSON objects.
-        var elementType: DataType = NullType
-        while (nextUntil(parser, END_ARRAY)) {
-          elementType = compatibleType(
-            elementType, inferField(parser, configOptions))
-        }
-
-        ArrayType(elementType)
-
-      case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType
-
-      case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType
-
-      case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
-        import JsonParser.NumberType._
-        parser.getNumberType match {
-          // For Integer values, use LongType by default.
-          case INT | LONG => LongType
-          // Since we do not have a data type backed by BigInteger,
-          // when we see a Java BigInteger, we use DecimalType.
-          case BIG_INTEGER | BIG_DECIMAL =>
-            val v = parser.getDecimalValue
-            if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
-              DecimalType(Math.max(v.precision(), v.scale()), v.scale())
-            } else {
-              DoubleType
-            }
-          case FLOAT | DOUBLE if configOptions.prefersDecimal =>
-            val v = parser.getDecimalValue
-            if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
-              DecimalType(Math.max(v.precision(), v.scale()), v.scale())
-            } else {
-              DoubleType
-            }
-          case FLOAT | DOUBLE =>
-            DoubleType
-        }
-
-      case VALUE_TRUE | VALUE_FALSE => BooleanType
-    }
-  }
-
-  /**
-   * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields,
-   * drops NullTypes or converts them to StringType based on provided options.
-   */
-  private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match {
-    case at: ArrayType =>
-      canonicalizeType(at.elementType, options)
-        .map(t => at.copy(elementType = t))
-
-    case StructType(fields) =>
-      val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f =>
-        canonicalizeType(f.dataType, options)
-          .map(t => f.copy(dataType = t))
-      }
-      // SPARK-8093: empty structs should be deleted
-      if (canonicalFields.isEmpty) {
-        None
-      } else {
-        Some(StructType(canonicalFields))
-      }
-
-    case NullType =>
-      if (options.dropFieldIfAllNull) {
-        None
-      } else {
-        Some(StringType)
-      }
-
-    case other => Some(other)
-  }
-
-  private def withCorruptField(
-      struct: StructType,
-      other: DataType,
-      columnNameOfCorruptRecords: String,
-      parseMode: ParseMode) = parseMode match {
-    case PermissiveMode =>
-      // If we see any other data type at the root level, we get records that cannot be
-      // parsed. So, we use the struct as the data type and add the corrupt field to the schema.
-      if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
-        // If this given struct does not have a column used for corrupt records,
-        // add this field.
-        val newFields: Array[StructField] =
-          StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
-        // Note: other code relies on this sorting for correctness, so don't remove it!
-        java.util.Arrays.sort(newFields, structFieldComparator)
-        StructType(newFields)
-      } else {
-        // Otherwise, just return this struct.
-        struct
-      }
-
-    case DropMalformedMode =>
-      // If corrupt record handling is disabled we retain the valid schema and discard the other.
-      struct
-
-    case FailFastMode =>
-      // If `other` is not struct type, consider it as malformed one and throws an exception.
-      throw new SparkException("Malformed records are detected in schema inference. " +
-        s"Parse Mode: ${FailFastMode.name}. Reasons: Failed to infer a common schema. " +
-        s"Struct types are expected, but `${other.catalogString}` was found.")
-  }
-
-  /**
-   * Remove top-level ArrayType wrappers and merge the remaining schemas
-   */
-  private def compatibleRootType(
-      columnNameOfCorruptRecords: String,
-      parseMode: ParseMode): (DataType, DataType) => DataType = {
-    // Since we support array of json objects at the top level,
-    // we need to check the element type and find the root level data type.
-    case (ArrayType(ty1, _), ty2) =>
-      compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
-    case (ty1, ArrayType(ty2, _)) =>
-      compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
-    // Discard null/empty documents
-    case (struct: StructType, NullType) => struct
-    case (NullType, struct: StructType) => struct
-    case (struct: StructType, o) if !o.isInstanceOf[StructType] =>
-      withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode)
-    case (o, struct: StructType) if !o.isInstanceOf[StructType] =>
-      withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode)
-    // If we get anything else, we call compatibleType.
-    // Usually, when we reach here, ty1 and ty2 are two StructTypes.
-    case (ty1, ty2) => compatibleType(ty1, ty2)
-  }
-
-  private[this] val emptyStructFieldArray = Array.empty[StructField]
-
-  /**
-   * Returns the most general data type for two given data types.
-   */
-  def compatibleType(t1: DataType, t2: DataType): DataType = {
-    TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
-      // t1 or t2 is a StructType, ArrayType, or an unexpected type.
-      (t1, t2) match {
-        // Double support larger range than fixed decimal, DecimalType.Maximum should be enough
-        // in most case, also have better precision.
-        case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
-          DoubleType
-
-        case (t1: DecimalType, t2: DecimalType) =>
-          val scale = math.max(t1.scale, t2.scale)
-          val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
-          if (range + scale > 38) {
-            // DecimalType can't support precision > 38
-            DoubleType
-          } else {
-            DecimalType(range + scale, scale)
-          }
-
-        case (StructType(fields1), StructType(fields2)) =>
-          // Both fields1 and fields2 should be sorted by name, since inferField performs sorting.
-          // Therefore, we can take advantage of the fact that we're merging sorted lists and skip
-          // building a hash map or performing additional sorting.
-          assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}")
-          assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}")
-
-          val newFields = new java.util.ArrayList[StructField]()
-
-          var f1Idx = 0
-          var f2Idx = 0
-
-          while (f1Idx < fields1.length && f2Idx < fields2.length) {
-            val f1Name = fields1(f1Idx).name
-            val f2Name = fields2(f2Idx).name
-            val comp = f1Name.compareTo(f2Name)
-            if (comp == 0) {
-              val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
-              newFields.add(StructField(f1Name, dataType, nullable = true))
-              f1Idx += 1
-              f2Idx += 1
-            } else if (comp < 0) { // f1Name < f2Name
-              newFields.add(fields1(f1Idx))
-              f1Idx += 1
-            } else { // f1Name > f2Name
-              newFields.add(fields2(f2Idx))
-              f2Idx += 1
-            }
-          }
-          while (f1Idx < fields1.length) {
-            newFields.add(fields1(f1Idx))
-            f1Idx += 1
-          }
-          while (f2Idx < fields2.length) {
-            newFields.add(fields2(f2Idx))
-            f2Idx += 1
-          }
-          StructType(newFields.toArray(emptyStructFieldArray))
-
-        case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
-          ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
-
-        // The case that given `DecimalType` is capable of given `IntegralType` is handled in
-        // `findTightestCommonType`. Both cases below will be executed only when the given
-        // `DecimalType` is not capable of the given `IntegralType`.
-        case (t1: IntegralType, t2: DecimalType) =>
-          compatibleType(DecimalType.forType(t1), t2)
-        case (t1: DecimalType, t2: IntegralType) =>
-          compatibleType(t1, DecimalType.forType(t2))
-
-        // strings and every string is a Json object.
-        case (_, _) => StringType
-      }
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index acca957..614f65f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3382,6 +3382,48 @@ object functions {
   }
 
   /**
+   * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+   * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+   * Returns `null`, in the case of an unparseable string.
+   *
+   * @param e a string column containing JSON data.
+   * @param schema the schema to use when parsing the json string
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def from_json(e: Column, schema: Column): Column = {
+    from_json(e, schema, Map.empty[String, String].asJava)
+  }
+
+  /**
+   * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType`
+   * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema.
+   * Returns `null`, in the case of an unparseable string.
+   *
+   * @param e a string column containing JSON data.
+   * @param schema the schema to use when parsing the json string
+   * @param options options to control how the json is parsed. accepts the same options and the
+   *                json data source.
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def from_json(e: Column, schema: Column, options: java.util.Map[String, String]): Column = {
+    withExpr(new JsonToStructs(e.expr, schema.expr, options.asScala.toMap))
+  }
+
+  /**
+   * Parses a column containing a JSON string and infers its schema.
+   *
+   * @param e a string column containing JSON data.
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr))
+
+  /**
    * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s,
    * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema.
    * Throws an exception, in the case of an unsupported type.

http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
index dc15d13..79fdd58 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
@@ -35,3 +35,7 @@ DROP VIEW IF EXISTS jsonTable;
 -- from_json - complex types
 select from_json('{"a":1, "b":2}', 'map<string, int>');
 select from_json('{"a":1, "b":"2"}', 'struct<a:int,b:string>');
+
+-- infer schema of json literal
+select schema_of_json('{"c1":0, "c2":[1]}');
+select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}'));

http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
index 2b3288d..3d49323 100644
--- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 28
+-- Number of queries: 30
 
 
 -- !query 0
@@ -183,7 +183,7 @@ select from_json('{"a":1}', 1)
 struct<>
 -- !query 17 output
 org.apache.spark.sql.AnalysisException
-Expected a string literal instead of 1;; line 1 pos 7
+Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7
 
 
 -- !query 18
@@ -274,3 +274,19 @@ select from_json('{"a":1, "b":"2"}', 'struct<a:int,b:string>')
 struct<jsontostructs({"a":1, "b":"2"}):struct<a:int,b:string>>
 -- !query 27 output
 {"a":1,"b":"2"}
+
+
+-- !query 28
+select schema_of_json('{"c1":0, "c2":[1]}')
+-- !query 28 schema
+struct<schemaofjson({"c1":0, "c2":[1]}):string>
+-- !query 28 output
+struct<c1:bigint,c2:array<bigint>>
+
+
+-- !query 29
+select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}'))
+-- !query 29 schema
+struct<jsontostructs({"c1":[1, 2, 3]}):struct<c1:array<bigint>>>
+-- !query 29 output
+{"c1":[1,2,3]}

http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 7bf17cb..d3b2701 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json}
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
@@ -311,7 +311,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
     val errMsg1 = intercept[AnalysisException] {
       df3.selectExpr("from_json(value, 1)")
     }
-    assert(errMsg1.getMessage.startsWith("Expected a string literal instead of"))
+    assert(errMsg1.getMessage.startsWith("Schema should be specified in DDL format as a string"))
     val errMsg2 = intercept[AnalysisException] {
       df3.selectExpr("""from_json(value, 'time InvalidType')""")
     }
@@ -392,4 +392,17 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
     checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)),
       Row(null))
   }
+
+  test("SPARK-24709: infers schemas of json strings and pass them to from_json") {
+    val in = Seq("""{"a": [1, 2, 3]}""").toDS()
+    val out = in.select(from_json('value, schema_of_json(lit("""{"a": [1]}"""))) as "parsed")
+    val expected = StructType(StructField(
+      "parsed",
+      StructType(StructField(
+        "a",
+        ArrayType(LongType, true), true) :: Nil),
+      true) :: Nil)
+
+    assert(out.schema == expected)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 897424d..eab15b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -31,11 +31,11 @@ import org.apache.hadoop.io.compress.GzipCodec
 import org.apache.spark.{SparkException, TestUtils}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{functions => F, _}
-import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
+import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions}
+import org.apache.spark.sql.catalyst.json.JsonInferSchema.compatibleType
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.ExternalRDD
 import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message