spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From lix...@apache.org
Subject spark git commit: [SPARK-23786][SQL] Checking column names of csv headers
Date Mon, 04 Jun 2018 05:02:23 GMT
Repository: spark
Updated Branches:
  refs/heads/master 416cd1fd9 -> 1d9338bb1


[SPARK-23786][SQL] Checking column names of csv headers

## What changes were proposed in this pull request?

Currently column names of headers in CSV files are not checked against provided schema of
CSV data. It could cause errors like showed in the [SPARK-23786](https://issues.apache.org/jira/browse/SPARK-23786)
and https://github.com/apache/spark/pull/20894#issuecomment-375957777. I introduced new CSV
option - `enforceSchema`. If it is enabled (by default `true`), Spark forcibly applies provided
or inferred schema to CSV files. In that case, CSV headers are ignored and not checked against
the schema. If `enforceSchema` is set to `false`, additional checks can be performed. For
example, if column in CSV header and in the schema have different ordering, the following
exception is thrown:

```
java.lang.IllegalArgumentException: CSV file header does not contain the expected fields
 Header: depth, temperature
 Schema: temperature, depth
CSV file: marina.csv
```

## How was this patch tested?

The changes were tested by existing tests of CSVSuite and by 2 new tests.

Author: Maxim Gekk <maxim.gekk@databricks.com>
Author: Maxim Gekk <max.gekk@gmail.com>

Closes #20894 from MaxGekk/check-column-names.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1d9338bb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1d9338bb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1d9338bb

Branch: refs/heads/master
Commit: 1d9338bb10b953daddb23b8879ff99aa5c57dbea
Parents: 416cd1f
Author: Maxim Gekk <maxim.gekk@databricks.com>
Authored: Sun Jun 3 22:02:21 2018 -0700
Committer: Xiao Li <gatorsmile@gmail.com>
Committed: Sun Jun 3 22:02:21 2018 -0700

----------------------------------------------------------------------
 python/pyspark/sql/readwriter.py                |  15 +-
 python/pyspark/sql/streaming.py                 |  15 +-
 python/pyspark/sql/tests.py                     |  18 ++
 .../org/apache/spark/sql/DataFrameReader.scala  |  19 ++
 .../datasources/csv/CSVDataSource.scala         | 126 +++++++++++-
 .../datasources/csv/CSVFileFormat.scala         |   9 +-
 .../execution/datasources/csv/CSVOptions.scala  |   6 +
 .../execution/datasources/csv/CSVUtils.scala    |  22 ++-
 .../datasources/csv/UnivocityParser.scala       |  26 +--
 .../execution/datasources/csv/CSVSuite.scala    | 192 +++++++++++++++++++
 10 files changed, 411 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1d9338bb/python/pyspark/sql/readwriter.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 448a473..a0e20d3 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -346,7 +346,7 @@ class DataFrameReader(OptionUtils):
             negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
             maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
             columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
-            samplingRatio=None):
+            samplingRatio=None, enforceSchema=None):
         """Loads a CSV file and returns the result as a  :class:`DataFrame`.
 
         This function will go through the input once to determine the input schema if
@@ -373,6 +373,16 @@ class DataFrameReader(OptionUtils):
                        default value, ``false``.
         :param inferSchema: infers the input schema automatically from data. It requires
one extra
                        pass over the data. If None is set, it uses the default value, ``false``.
+        :param enforceSchema: If it is set to ``true``, the specified or inferred schema
will be
+                              forcibly applied to datasource files, and headers in CSV files
will be
+                              ignored. If the option is set to ``false``, the schema will
be
+                              validated against all headers in CSV files or the first header
in RDD
+                              if the ``header`` option is set to ``true``. Field names in
the schema
+                              and column names in CSV headers are checked by their positions
+                              taking into account ``spark.sql.caseSensitive``. If None is
set,
+                              ``true`` is used by default. Though the default value is ``true``,
+                              it is recommended to disable the ``enforceSchema`` option
+                              to avoid incorrect results.
         :param ignoreLeadingWhiteSpace: A flag indicating whether or not leading whitespaces
from
                                         values being read should be skipped. If None is set,
it
                                         uses the default value, ``false``.
@@ -449,7 +459,8 @@ class DataFrameReader(OptionUtils):
             maxCharsPerColumn=maxCharsPerColumn,
             maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
             columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
-            charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio)
+            charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, samplingRatio=samplingRatio,
+            enforceSchema=enforceSchema)
         if isinstance(path, basestring):
             path = [path]
         if type(path) == list:

http://git-wip-us.apache.org/repos/asf/spark/blob/1d9338bb/python/pyspark/sql/streaming.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 15f9407..fae50b3 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -564,7 +564,8 @@ class DataStreamReader(OptionUtils):
             ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
             negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None,
             maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None,
-            columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None):
+            columnNameOfCorruptRecord=None, multiLine=None, charToEscapeQuoteEscaping=None,
+            enforceSchema=None):
         """Loads a CSV file stream and returns the result as a  :class:`DataFrame`.
 
         This function will go through the input once to determine the input schema if
@@ -592,6 +593,16 @@ class DataStreamReader(OptionUtils):
                        default value, ``false``.
         :param inferSchema: infers the input schema automatically from data. It requires
one extra
                        pass over the data. If None is set, it uses the default value, ``false``.
+        :param enforceSchema: If it is set to ``true``, the specified or inferred schema
will be
+                              forcibly applied to datasource files, and headers in CSV files
will be
+                              ignored. If the option is set to ``false``, the schema will
be
+                              validated against all headers in CSV files or the first header
in RDD
+                              if the ``header`` option is set to ``true``. Field names in
the schema
+                              and column names in CSV headers are checked by their positions
+                              taking into account ``spark.sql.caseSensitive``. If None is
set,
+                              ``true`` is used by default. Though the default value is ``true``,
+                              it is recommended to disable the ``enforceSchema`` option
+                              to avoid incorrect results.
         :param ignoreLeadingWhiteSpace: a flag indicating whether or not leading whitespaces
from
                                         values being read should be skipped. If None is set,
it
                                         uses the default value, ``false``.
@@ -664,7 +675,7 @@ class DataStreamReader(OptionUtils):
             maxCharsPerColumn=maxCharsPerColumn,
             maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
             columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine,
-            charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
+            charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, enforceSchema=enforceSchema)
         if isinstance(path, basestring):
             return self._df(self._jreader.csv(path))
         else:

http://git-wip-us.apache.org/repos/asf/spark/blob/1d9338bb/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a245093..ea2dd76 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3056,6 +3056,24 @@ class SQLTests(ReusedSQLTestCase):
             .csv(rdd, samplingRatio=0.5).schema
         self.assertEquals(schema, StructType([StructField("_c0", IntegerType(), True)]))
 
+    def test_checking_csv_header(self):
+        path = tempfile.mkdtemp()
+        shutil.rmtree(path)
+        try:
+            self.spark.createDataFrame([[1, 1000], [2000, 2]])\
+                .toDF('f1', 'f2').write.option("header", "true").csv(path)
+            schema = StructType([
+                StructField('f2', IntegerType(), nullable=True),
+                StructField('f1', IntegerType(), nullable=True)])
+            df = self.spark.read.option('header', 'true').schema(schema)\
+                .csv(path, enforceSchema=False)
+            self.assertRaisesRegexp(
+                Exception,
+                "CSV header does not conform to the schema",
+                lambda: df.collect())
+        finally:
+            shutil.rmtree(path)
+
 
 class HiveSparkSubmitTests(SparkSubmitTests):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1d9338bb/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index ac4580a..de6be5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -22,6 +22,7 @@ import java.util.{Locale, Properties}
 import scala.collection.JavaConverters._
 
 import com.fasterxml.jackson.databind.ObjectMapper
+import com.univocity.parsers.csv.CsvParser
 
 import org.apache.spark.Partition
 import org.apache.spark.annotation.InterfaceStability
@@ -474,6 +475,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends
Logging {
    * it determines the columns as string types and it reads only the first line to determine
the
    * names and the number of fields.
    *
+   * If the enforceSchema is set to `false`, only the CSV header in the first line is checked
+   * to conform specified or inferred schema.
+   *
    * @param csvDataset input Dataset with one CSV row per record
    * @since 2.2.0
    */
@@ -499,6 +503,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends
Logging {
       StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
 
     val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
+      CSVDataSource.checkHeader(
+        firstLine,
+        new CsvParser(parsedOptions.asParserSettings),
+        actualSchema,
+        csvDataset.getClass.getCanonicalName,
+        parsedOptions.enforceSchema,
+        sparkSession.sessionState.conf.caseSensitiveAnalysis)
       filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
     }.getOrElse(filteredLines.rdd)
 
@@ -539,6 +550,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends
Logging {
    * <li>`comment` (default empty string): sets a single character used for skipping
lines
    * beginning with this character. By default, it is disabled.</li>
    * <li>`header` (default `false`): uses the first line as names of columns.</li>
+   * <li>`enforceSchema` (default `true`): If it is set to `true`, the specified or
inferred schema
+   * will be forcibly applied to datasource files, and headers in CSV files will be ignored.
+   * If the option is set to `false`, the schema will be validated against all headers in
CSV files
+   * in the case when the `header` option is set to `true`. Field names in the schema
+   * and column names in CSV headers are checked by their positions taking into account
+   * `spark.sql.caseSensitive`. Though the default value is true, it is recommended to disable
+   * the `enforceSchema` option to avoid incorrect results.</li>
    * <li>`inferSchema` (default `false`): infers the input schema automatically from
data. It
    * requires one extra pass over the data.</li>
    * <li>`samplingRatio` (default is 1.0): defines fraction of rows used for schema
inferring.</li>
@@ -583,6 +601,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends
Logging {
    * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.</li>
    * <li>`multiLine` (default `false`): parse one record, which may span multiple lines.</li>
    * </ul>
+   *
    * @since 2.0.0
    */
   @scala.annotation.varargs

http://git-wip-us.apache.org/repos/asf/spark/blob/1d9338bb/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index dc54d18..82322df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
 
 import org.apache.spark.TaskContext
 import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
+import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.{BinaryFileRDD, RDD}
 import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
@@ -50,7 +51,10 @@ abstract class CSVDataSource extends Serializable {
       conf: Configuration,
       file: PartitionedFile,
       parser: UnivocityParser,
-      schema: StructType): Iterator[InternalRow]
+      requiredSchema: StructType,
+      // Actual schema of data in the csv file
+      dataSchema: StructType,
+      caseSensitive: Boolean): Iterator[InternalRow]
 
   /**
    * Infers the schema from `inputPaths` files.
@@ -110,7 +114,7 @@ abstract class CSVDataSource extends Serializable {
   }
 }
 
-object CSVDataSource {
+object CSVDataSource extends Logging {
   def apply(options: CSVOptions): CSVDataSource = {
     if (options.multiLine) {
       MultiLineCSVDataSource
@@ -118,6 +122,84 @@ object CSVDataSource {
       TextInputCSVDataSource
     }
   }
+
+  /**
+   * Checks that column names in a CSV header and field names in the schema are the same
+   * by taking into account case sensitivity.
+   *
+   * @param schema - provided (or inferred) schema to which CSV must conform.
+   * @param columnNames - names of CSV columns that must be checked against to the schema.
+   * @param fileName - name of CSV file that are currently checked. It is used in error messages.
+   * @param enforceSchema - if it is `true`, column names are ignored otherwise the CSV column
+   *                        names are checked for conformance to the schema. In the case
if
+   *                        the column name don't conform to the schema, an exception is
thrown.
+   * @param caseSensitive - if it is set to `false`, comparison of column names and schema
field
+   *                        names is not case sensitive.
+   */
+  def checkHeaderColumnNames(
+      schema: StructType,
+      columnNames: Array[String],
+      fileName: String,
+      enforceSchema: Boolean,
+      caseSensitive: Boolean): Unit = {
+    if (columnNames != null) {
+      val fieldNames = schema.map(_.name).toIndexedSeq
+      val (headerLen, schemaSize) = (columnNames.size, fieldNames.length)
+      var errorMessage: Option[String] = None
+
+      if (headerLen == schemaSize) {
+        var i = 0
+        while (errorMessage.isEmpty && i < headerLen) {
+          var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i))
+          if (!caseSensitive) {
+            nameInSchema = nameInSchema.toLowerCase
+            nameInHeader = nameInHeader.toLowerCase
+          }
+          if (nameInHeader != nameInSchema) {
+            errorMessage = Some(
+              s"""|CSV header does not conform to the schema.
+                  | Header: ${columnNames.mkString(", ")}
+                  | Schema: ${fieldNames.mkString(", ")}
+                  |Expected: ${fieldNames(i)} but found: ${columnNames(i)}
+                  |CSV file: $fileName""".stripMargin)
+          }
+          i += 1
+        }
+      } else {
+        errorMessage = Some(
+          s"""|Number of column in CSV header is not equal to number of fields in the schema:
+              | Header length: $headerLen, schema size: $schemaSize
+              |CSV file: $fileName""".stripMargin)
+      }
+
+      errorMessage.foreach { msg =>
+        if (enforceSchema) {
+          logWarning(msg)
+        } else {
+          throw new IllegalArgumentException(msg)
+        }
+      }
+    }
+  }
+
+  /**
+   * Checks that CSV header contains the same column names as fields names in the given schema
+   * by taking into account case sensitivity.
+   */
+  def checkHeader(
+      header: String,
+      parser: CsvParser,
+      schema: StructType,
+      fileName: String,
+      enforceSchema: Boolean,
+      caseSensitive: Boolean): Unit = {
+    checkHeaderColumnNames(
+        schema,
+        parser.parseLine(header),
+        fileName,
+        enforceSchema,
+        caseSensitive)
+  }
 }
 
 object TextInputCSVDataSource extends CSVDataSource {
@@ -127,7 +209,9 @@ object TextInputCSVDataSource extends CSVDataSource {
       conf: Configuration,
       file: PartitionedFile,
       parser: UnivocityParser,
-      schema: StructType): Iterator[InternalRow] = {
+      requiredSchema: StructType,
+      dataSchema: StructType,
+      caseSensitive: Boolean): Iterator[InternalRow] = {
     val lines = {
       val linesReader = new HadoopFileLinesReader(file, conf)
       Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
@@ -136,8 +220,24 @@ object TextInputCSVDataSource extends CSVDataSource {
       }
     }
 
-    val shouldDropHeader = parser.options.headerFlag && file.start == 0
-    UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema)
+    val hasHeader = parser.options.headerFlag && file.start == 0
+    if (hasHeader) {
+      // Checking that column names in the header are matched to field names of the schema.
+      // The header will be removed from lines.
+      // Note: if there are only comments in the first block, the header would probably
+      // be not extracted.
+      CSVUtils.extractHeader(lines, parser.options).foreach { header =>
+        CSVDataSource.checkHeader(
+          header,
+          parser.tokenizer,
+          dataSchema,
+          file.filePath,
+          parser.options.enforceSchema,
+          caseSensitive)
+      }
+    }
+
+    UnivocityParser.parseIterator(lines, parser, requiredSchema)
   }
 
   override def infer(
@@ -206,12 +306,24 @@ object MultiLineCSVDataSource extends CSVDataSource {
       conf: Configuration,
       file: PartitionedFile,
       parser: UnivocityParser,
-      schema: StructType): Iterator[InternalRow] = {
+      requiredSchema: StructType,
+      dataSchema: StructType,
+      caseSensitive: Boolean): Iterator[InternalRow] = {
+    def checkHeader(header: Array[String]): Unit = {
+      CSVDataSource.checkHeaderColumnNames(
+        dataSchema,
+        header,
+        file.filePath,
+        parser.options.enforceSchema,
+        caseSensitive)
+    }
+
     UnivocityParser.parseStream(
       CodecStreams.createInputStreamWithCloseResource(conf, new Path(new URI(file.filePath))),
       parser.options.headerFlag,
       parser,
-      schema)
+      requiredSchema,
+      checkHeader)
   }
 
   override def infer(

http://git-wip-us.apache.org/repos/asf/spark/blob/1d9338bb/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 21279d6..b90275d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -130,6 +130,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister
{
           "df.filter($\"_corrupt_record\".isNotNull).count()."
       )
     }
+    val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
 
     (file: PartitionedFile) => {
       val conf = broadcastedHadoopConf.value.value
@@ -137,7 +138,13 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister
{
         StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
         StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
         parsedOptions)
-      CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema)
+      CSVDataSource(parsedOptions).readFile(
+        conf,
+        file,
+        parser,
+        requiredSchema,
+        dataSchema,
+        caseSensitive)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1d9338bb/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 7119189..fab8d62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -156,6 +156,12 @@ class CSVOptions(
   val samplingRatio =
     parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
 
+  /**
+   * Forcibly apply the specified or inferred schema to datasource files.
+   * If the option is enabled, headers of CSV files will be ignored.
+   */
+  val enforceSchema = getBool("enforceSchema", default = true)
+
   def asWriterSettings: CsvWriterSettings = {
     val writerSettings = new CsvWriterSettings()
     val format = writerSettings.getFormat

http://git-wip-us.apache.org/repos/asf/spark/blob/1d9338bb/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
index 9dae41b..1012e77 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
@@ -68,12 +68,8 @@ object CSVUtils {
     }
   }
 
-  /**
-   * Drop header line so that only data can remain.
-   * This is similar with `filterHeaderLine` above and currently being used in CSV reading
path.
-   */
-  def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
-    val nonEmptyLines = if (options.isCommentSet) {
+  def skipComments(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
+    if (options.isCommentSet) {
       val commentPrefix = options.comment.toString
       iter.dropWhile { line =>
         line.trim.isEmpty || line.trim.startsWith(commentPrefix)
@@ -81,12 +77,20 @@ object CSVUtils {
     } else {
       iter.dropWhile(_.trim.isEmpty)
     }
-
-    if (nonEmptyLines.hasNext) nonEmptyLines.drop(1)
-    iter
   }
 
   /**
+   * Extracts header and moves iterator forward so that only data remains in it
+   */
+  def extractHeader(iter: Iterator[String], options: CSVOptions): Option[String] = {
+    val nonEmptyLines = skipComments(iter, options)
+    if (nonEmptyLines.hasNext) {
+      Some(nonEmptyLines.next())
+    } else {
+      None
+    }
+  }
+  /**
    * Helper method that converts string representation of a character to actual character.
    * It handles some Java escaped strings and throws exception if given string is longer
than one
    * character.

http://git-wip-us.apache.org/repos/asf/spark/blob/1d9338bb/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 4f00cc5..5f7d569 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -45,7 +45,7 @@ class UnivocityParser(
   // A `ValueConverter` is responsible for converting the given value to a desired type.
   private type ValueConverter = String => Any
 
-  private val tokenizer = {
+  val tokenizer = {
     val parserSetting = options.asParserSettings
     if (options.columnPruning && requiredSchema.length < dataSchema.length) {
       val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f)))
@@ -250,14 +250,15 @@ private[csv] object UnivocityParser {
       inputStream: InputStream,
       shouldDropHeader: Boolean,
       parser: UnivocityParser,
-      schema: StructType): Iterator[InternalRow] = {
+      schema: StructType,
+      checkHeader: Array[String] => Unit): Iterator[InternalRow] = {
     val tokenizer = parser.tokenizer
     val safeParser = new FailureSafeParser[Array[String]](
       input => Seq(parser.convert(input)),
       parser.options.parseMode,
       schema,
       parser.options.columnNameOfCorruptRecord)
-    convertStream(inputStream, shouldDropHeader, tokenizer) { tokens =>
+    convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) { tokens =>
       safeParser.parse(tokens)
     }.flatten
   }
@@ -265,11 +266,14 @@ private[csv] object UnivocityParser {
   private def convertStream[T](
       inputStream: InputStream,
       shouldDropHeader: Boolean,
-      tokenizer: CsvParser)(convert: Array[String] => T) = new Iterator[T] {
+      tokenizer: CsvParser,
+      checkHeader: Array[String] => Unit = _ => ())(
+      convert: Array[String] => T) = new Iterator[T] {
     tokenizer.beginParsing(inputStream)
     private var nextRecord = {
       if (shouldDropHeader) {
-        tokenizer.parseNext()
+        val firstRecord = tokenizer.parseNext()
+        checkHeader(firstRecord)
       }
       tokenizer.parseNext()
     }
@@ -291,21 +295,11 @@ private[csv] object UnivocityParser {
    */
   def parseIterator(
       lines: Iterator[String],
-      shouldDropHeader: Boolean,
       parser: UnivocityParser,
       schema: StructType): Iterator[InternalRow] = {
     val options = parser.options
 
-    val linesWithoutHeader = if (shouldDropHeader) {
-      // Note that if there are only comments in the first block, the header would probably
-      // be not dropped.
-      CSVUtils.dropHeaderLine(lines, options)
-    } else {
-      lines
-    }
-
-    val filteredLines: Iterator[String] =
-      CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options)
+    val filteredLines: Iterator[String] = CSVUtils.filterCommentAndEmpty(lines, options)
 
     val safeParser = new FailureSafeParser[String](
       input => Seq(parser.parse(input)),

http://git-wip-us.apache.org/repos/asf/spark/blob/1d9338bb/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index afe10bd..d2f166c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -23,9 +23,13 @@ import java.sql.{Date, Timestamp}
 import java.text.SimpleDateFormat
 import java.util.Locale
 
+import scala.collection.JavaConverters._
+
 import org.apache.commons.lang3.time.FastDateFormat
 import org.apache.hadoop.io.SequenceFile.CompressionType
 import org.apache.hadoop.io.compress.GzipCodec
+import org.apache.log4j.{AppenderSkeleton, LogManager}
+import org.apache.log4j.spi.LoggingEvent
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT}
@@ -1410,4 +1414,192 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils
with Te
       checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5)))
     }
   }
+
+  def checkHeader(multiLine: Boolean): Unit = {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+      withTempPath { path =>
+        val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType)
+        val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema)
+        odf.write.option("header", true).csv(path.getCanonicalPath)
+        val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType)
+        val exception = intercept[SparkException] {
+          spark.read
+            .schema(ischema)
+            .option("multiLine", multiLine)
+            .option("header", true)
+            .option("enforceSchema", false)
+            .csv(path.getCanonicalPath)
+            .collect()
+        }
+        assert(exception.getMessage.contains("CSV header does not conform to the schema"))
+
+        val shortSchema = new StructType().add("f1", DoubleType)
+        val exceptionForShortSchema = intercept[SparkException] {
+          spark.read
+            .schema(shortSchema)
+            .option("multiLine", multiLine)
+            .option("header", true)
+            .option("enforceSchema", false)
+            .csv(path.getCanonicalPath)
+            .collect()
+        }
+        assert(exceptionForShortSchema.getMessage.contains(
+          "Number of column in CSV header is not equal to number of fields in the schema"))
+
+        val longSchema = new StructType()
+          .add("f1", DoubleType)
+          .add("f2", DoubleType)
+          .add("f3", DoubleType)
+
+        val exceptionForLongSchema = intercept[SparkException] {
+          spark.read
+            .schema(longSchema)
+            .option("multiLine", multiLine)
+            .option("header", true)
+            .option("enforceSchema", false)
+            .csv(path.getCanonicalPath)
+            .collect()
+        }
+        assert(exceptionForLongSchema.getMessage.contains("Header length: 2, schema size:
3"))
+
+        val caseSensitiveSchema = new StructType().add("F1", DoubleType).add("f2", DoubleType)
+        val caseSensitiveException = intercept[SparkException] {
+          spark.read
+            .schema(caseSensitiveSchema)
+            .option("multiLine", multiLine)
+            .option("header", true)
+            .option("enforceSchema", false)
+            .csv(path.getCanonicalPath)
+            .collect()
+        }
+        assert(caseSensitiveException.getMessage.contains(
+          "CSV header does not conform to the schema"))
+      }
+    }
+  }
+
+  test(s"SPARK-23786: Checking column names against schema in the multiline mode") {
+    checkHeader(multiLine = true)
+  }
+
+  test(s"SPARK-23786: Checking column names against schema in the per-line mode") {
+    checkHeader(multiLine = false)
+  }
+
+  test("SPARK-23786: CSV header must not be checked if it doesn't exist") {
+    withTempPath { path =>
+      val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType)
+      val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema)
+      odf.write.option("header", false).csv(path.getCanonicalPath)
+      val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType)
+      val idf = spark.read
+          .schema(ischema)
+          .option("header", false)
+          .option("enforceSchema", false)
+          .csv(path.getCanonicalPath)
+
+      checkAnswer(idf, odf)
+    }
+  }
+
+  test("SPARK-23786: Ignore column name case if spark.sql.caseSensitive is false") {
+    withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+      withTempPath { path =>
+        val oschema = new StructType().add("A", StringType)
+        val odf = spark.createDataFrame(List(Row("0")).asJava, oschema)
+        odf.write.option("header", true).csv(path.getCanonicalPath)
+        val ischema = new StructType().add("a", StringType)
+        val idf = spark.read.schema(ischema)
+          .option("header", true)
+          .option("enforceSchema", false)
+          .csv(path.getCanonicalPath)
+        checkAnswer(idf, odf)
+      }
+    }
+  }
+
+  test("SPARK-23786: check header on parsing of dataset of strings") {
+    val ds = Seq("columnA,columnB", "1.0,1000.0").toDS()
+    val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType)
+    val exception = intercept[IllegalArgumentException] {
+      spark.read.schema(ischema).option("header", true).option("enforceSchema", false).csv(ds)
+    }
+
+    assert(exception.getMessage.contains("CSV header does not conform to the schema"))
+  }
+
+  test("SPARK-23786: enforce inferred schema") {
+    val expectedSchema = new StructType().add("_c0", DoubleType).add("_c1", StringType)
+    val withHeader = spark.read
+      .option("inferSchema", true)
+      .option("enforceSchema", false)
+      .option("header", true)
+      .csv(Seq("_c0,_c1", "1.0,a").toDS())
+    assert(withHeader.schema == expectedSchema)
+    checkAnswer(withHeader, Seq(Row(1.0, "a")))
+
+    // Ignore the inferSchema flag if an user sets a schema
+    val schema = new StructType().add("colA", DoubleType).add("colB", StringType)
+    val ds = spark.read
+      .option("inferSchema", true)
+      .option("enforceSchema", false)
+      .option("header", true)
+      .schema(schema)
+      .csv(Seq("colA,colB", "1.0,a").toDS())
+    assert(ds.schema == schema)
+    checkAnswer(ds, Seq(Row(1.0, "a")))
+
+    val exception = intercept[IllegalArgumentException] {
+      spark.read
+        .option("inferSchema", true)
+        .option("enforceSchema", false)
+        .option("header", true)
+        .schema(schema)
+        .csv(Seq("col1,col2", "1.0,a").toDS())
+    }
+    assert(exception.getMessage.contains("CSV header does not conform to the schema"))
+  }
+
+  test("SPARK-23786: warning should be printed if CSV header doesn't conform to schema")
{
+    class TestAppender extends AppenderSkeleton {
+      var events = new java.util.ArrayList[LoggingEvent]
+      override def close(): Unit = {}
+      override def requiresLayout: Boolean = false
+      protected def append(event: LoggingEvent): Unit = events.add(event)
+    }
+
+    val testAppender1 = new TestAppender
+    LogManager.getRootLogger.addAppender(testAppender1)
+    try {
+      val ds = Seq("columnA,columnB", "1.0,1000.0").toDS()
+      val ischema = new StructType().add("columnB", DoubleType).add("columnA", DoubleType)
+
+      spark.read.schema(ischema).option("header", true).option("enforceSchema", true).csv(ds)
+    } finally {
+      LogManager.getRootLogger.removeAppender(testAppender1)
+    }
+    assert(testAppender1.events.asScala
+      .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the
schema")))
+
+    val testAppender2 = new TestAppender
+    LogManager.getRootLogger.addAppender(testAppender2)
+    try {
+      withTempPath { path =>
+        val oschema = new StructType().add("f1", DoubleType).add("f2", DoubleType)
+        val odf = spark.createDataFrame(List(Row(1.0, 1234.5)).asJava, oschema)
+        odf.write.option("header", true).csv(path.getCanonicalPath)
+        val ischema = new StructType().add("f2", DoubleType).add("f1", DoubleType)
+        spark.read
+          .schema(ischema)
+          .option("header", true)
+          .option("enforceSchema", true)
+          .csv(path.getCanonicalPath)
+          .collect()
+      }
+    } finally {
+      LogManager.getRootLogger.removeAppender(testAppender2)
+    }
+    assert(testAppender2.events.asScala
+      .exists(msg => msg.getRenderedMessage.contains("CSV header does not conform to the
schema")))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message