spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yh...@apache.org
Subject spark git commit: [SPARK-10310] [SQL] Fixes script transformation field/line delimiters
Date Wed, 23 Sep 2015 03:09:55 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 26187ab74 -> 73d062184


[SPARK-10310] [SQL] Fixes script transformation field/line delimiters

**Please attribute this PR to `Zhichao Li <zhichao.liintel.com>`.**

This PR is based on PR #8476 authored by zhichao-li. It fixes SPARK-10310 by adding field
delimiter SerDe property to the default `LazySimpleSerDe`, and enabling default record reader/writer
classes.

Currently, we only support `LazySimpleSerDe`, used together with `TextRecordReader` and `TextRecordWriter`,
and don't support customizing record reader/writer using `RECORDREADER`/`RECORDWRITER` clauses.
This should be addressed in separate PR(s).

Author: Cheng Lian <lian@databricks.com>

Closes #8860 from liancheng/spark-10310/fix-script-trans-delimiters.

(cherry picked from commit 84f81e035e1dab1b42c36563041df6ba16e7b287)
Signed-off-by: Yin Huai <yhuai@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/73d06218
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/73d06218
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/73d06218

Branch: refs/heads/branch-1.5
Commit: 73d062184dfcb22e2ae6377ac4a71a9b766bd105
Parents: 26187ab
Author: Zhichao Li <zhichao.li@intel.com>
Authored: Tue Sep 22 19:41:57 2015 -0700
Committer: Yin Huai <yhuai@databricks.com>
Committed: Tue Sep 22 20:09:46 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/hive/HiveQl.scala      | 52 ++++++++++--
 .../hive/execution/ScriptTransformation.scala   | 87 +++++++++++++++-----
 .../resources/data/scripts/test_transform.py    |  6 ++
 .../sql/hive/execution/SQLQuerySuite.scala      | 39 +++++++++
 .../execution/ScriptTransformationSuite.scala   |  2 +
 5 files changed, 158 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/73d06218/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index ad33dee..c2b0055 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -29,6 +29,7 @@ import org.apache.hadoop.hive.ql.lib.Node
 import org.apache.hadoop.hive.ql.parse._
 import org.apache.hadoop.hive.ql.plan.PlanUtils
 import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
 
 import org.apache.spark.Logging
 import org.apache.spark.sql.AnalysisException
@@ -880,16 +881,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
                   AttributeReference("value", StringType)()), true)
             }
 
-            def matchSerDe(clause: Seq[ASTNode])
-              : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match
{
+            type SerDeInfo = (
+              Seq[(String, String)],  // Input row format information
+              Option[String],         // Optional input SerDe class
+              Seq[(String, String)],  // Input SerDe properties
+              Boolean                 // Whether to use default record reader/writer
+            )
+
+            def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match {
               case Token("TOK_SERDEPROPS", propsClause) :: Nil =>
                 val rowFormat = propsClause.map {
                   case Token(name, Token(value, Nil) :: Nil) => (name, value)
                 }
-                (rowFormat, None, Nil)
+                (rowFormat, None, Nil, false)
 
               case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil =>
-                (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil)
+                (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false)
 
               case Token("TOK_SERDENAME", Token(serdeClass, Nil) ::
                 Token("TOK_TABLEPROPERTIES",
@@ -899,20 +906,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
                     (BaseSemanticAnalyzer.unescapeSQLString(name),
                       BaseSemanticAnalyzer.unescapeSQLString(value))
                 }
-                (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps)
 
-              case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil)
+                // SPARK-10310: Special cases LazySimpleSerDe
+                // TODO Fully supports user-defined record reader/writer classes
+                val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass)
+                val useDefaultRecordReaderWriter =
+                  unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName
+                (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter)
+
+              case Nil =>
+                // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here
+                val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t")
+                (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true)
             }
 
-            val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause)
-            val (outRowFormat, outSerdeClass, outSerdeProps) = matchSerDe(outputSerdeClause)
+            val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) =
+              matchSerDe(inputSerdeClause)
+
+            val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) =
+              matchSerDe(outputSerdeClause)
 
             val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script)
 
+            // TODO Adds support for user-defined record reader/writer classes
+            val recordReaderClass = if (useDefaultRecordReader) {
+              Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER))
+            } else {
+              None
+            }
+
+            val recordWriterClass = if (useDefaultRecordWriter) {
+              Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER))
+            } else {
+              None
+            }
+
             val schema = HiveScriptIOSchema(
               inRowFormat, outRowFormat,
               inSerdeClass, outSerdeClass,
-              inSerdeProps, outSerdeProps, schemaLess)
+              inSerdeProps, outSerdeProps,
+              recordReaderClass, recordWriterClass,
+              schemaLess)
 
             Some(
               logical.ScriptTransformation(

http://git-wip-us.apache.org/repos/asf/spark/blob/73d06218/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index ade2745..8eaadd8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -21,23 +21,25 @@ import java.io._
 import java.util.Properties
 import javax.annotation.Nullable
 
-import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
 import scala.util.control.NonFatal
 
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter}
 import org.apache.hadoop.hive.serde.serdeConstants
 import org.apache.hadoop.hive.serde2.AbstractSerDe
 import org.apache.hadoop.hive.serde2.objectinspector._
 import org.apache.hadoop.io.Writable
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.hive.HiveShim._
 import org.apache.spark.sql.hive.{HiveContext, HiveInspectors}
 import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils}
+import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration,
Utils}
 import org.apache.spark.{Logging, TaskContext}
 
 /**
@@ -58,15 +60,18 @@ case class ScriptTransformation(
 
   override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
 
+  private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf)
+
   protected override def doExecute(): RDD[InternalRow] = {
     def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
       val cmd = List("/bin/bash", "-c", script)
-      val builder = new ProcessBuilder(cmd)
+      val builder = new ProcessBuilder(cmd.asJava)
 
       val proc = builder.start()
       val inputStream = proc.getInputStream
       val outputStream = proc.getOutputStream
       val errorStream = proc.getErrorStream
+      val localHiveConf = serializedHiveConf.value
 
       // In order to avoid deadlocks, we need to consume the error output of the child process.
       // To avoid issues caused by large error output, we use a circular buffer to limit
the amount
@@ -96,7 +101,8 @@ case class ScriptTransformation(
         outputStream,
         proc,
         stderrBuffer,
-        TaskContext.get()
+        TaskContext.get(),
+        localHiveConf
       )
 
       // This nullability is a performance optimization in order to avoid an Option.foreach()
call
@@ -109,6 +115,10 @@ case class ScriptTransformation(
       val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors
{
         var curLine: String = null
         val scriptOutputStream = new DataInputStream(inputStream)
+
+        @Nullable val scriptOutputReader =
+          ioschema.recordReader(scriptOutputStream, localHiveConf).orNull
+
         var scriptOutputWritable: Writable = null
         val reusedWritableObject: Writable = if (null != outputSerde) {
           outputSerde.getSerializedClass().newInstance
@@ -134,15 +144,25 @@ case class ScriptTransformation(
             }
           } else if (scriptOutputWritable == null) {
             scriptOutputWritable = reusedWritableObject
-            try {
-              scriptOutputWritable.readFields(scriptOutputStream)
-              true
-            } catch {
-              case _: EOFException =>
-                if (writerThread.exception.isDefined) {
-                  throw writerThread.exception.get
-                }
+
+            if (scriptOutputReader != null) {
+              if (scriptOutputReader.next(scriptOutputWritable) <= 0) {
+                writerThread.exception.foreach(throw _)
                 false
+              } else {
+                true
+              }
+            } else {
+              try {
+                scriptOutputWritable.readFields(scriptOutputStream)
+                true
+              } catch {
+                case _: EOFException =>
+                  if (writerThread.exception.isDefined) {
+                    throw writerThread.exception.get
+                  }
+                  false
+              }
             }
           } else {
             true
@@ -172,10 +192,10 @@ case class ScriptTransformation(
             val fieldList = outputSoi.getAllStructFieldRefs()
             var i = 0
             while (i < dataList.size()) {
-              if (dataList(i) == null) {
+              if (dataList.get(i) == null) {
                 mutableRow.setNullAt(i)
               } else {
-                mutableRow(i) = unwrap(dataList(i), fieldList(i).getFieldObjectInspector)
+                mutableRow(i) = unwrap(dataList.get(i), fieldList.get(i).getFieldObjectInspector)
               }
               i += 1
             }
@@ -210,7 +230,8 @@ private class ScriptTransformationWriterThread(
     outputStream: OutputStream,
     proc: Process,
     stderrBuffer: CircularBuffer,
-    taskContext: TaskContext
+    taskContext: TaskContext,
+    conf: Configuration
   ) extends Thread("Thread-ScriptTransformation-Feed") with Logging {
 
   setDaemon(true)
@@ -224,6 +245,7 @@ private class ScriptTransformationWriterThread(
     TaskContext.setTaskContext(taskContext)
 
     val dataOutputStream = new DataOutputStream(outputStream)
+    @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull
 
     // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so
     // let's use a variable to record whether the `finally` block was hit due to an exception
@@ -250,7 +272,12 @@ private class ScriptTransformationWriterThread(
         } else {
           val writable = inputSerde.serialize(
             row.asInstanceOf[GenericInternalRow].values, inputSoi)
-          prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream)
+
+          if (scriptInputWriter != null) {
+            scriptInputWriter.write(writable)
+          } else {
+            prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream)
+          }
         }
       }
       outputStream.close()
@@ -290,6 +317,8 @@ case class HiveScriptIOSchema (
     outputSerdeClass: Option[String],
     inputSerdeProps: Seq[(String, String)],
     outputSerdeProps: Seq[(String, String)],
+    recordReaderClass: Option[String],
+    recordWriterClass: Option[String],
     schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors {
 
   private val defaultFormat = Map(
@@ -307,7 +336,7 @@ case class HiveScriptIOSchema (
       val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps)
       val fieldObjectInspectors = columnTypes.map(toInspector)
       val objectInspector = ObjectInspectorFactory
-        .getStandardStructObjectInspector(columns, fieldObjectInspectors)
+        .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava)
         .asInstanceOf[ObjectInspector]
       (serde, objectInspector)
     }
@@ -342,9 +371,29 @@ case class HiveScriptIOSchema (
     propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
 
     val properties = new Properties()
-    properties.putAll(propsMap)
+    properties.putAll(propsMap.asJava)
     serde.initialize(null, properties)
 
     serde
   }
+
+  def recordReader(
+      inputStream: InputStream,
+      conf: Configuration): Option[RecordReader] = {
+    recordReaderClass.map { klass =>
+      val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader]
+      val props = new Properties()
+      props.putAll(outputSerdeProps.toMap.asJava)
+      instance.initialize(inputStream, conf, props)
+      instance
+    }
+  }
+
+  def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter]
= {
+    recordWriterClass.map { klass =>
+      val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter]
+      instance.initialize(outputStream, conf)
+      instance
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/73d06218/sql/hive/src/test/resources/data/scripts/test_transform.py
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/data/scripts/test_transform.py b/sql/hive/src/test/resources/data/scripts/test_transform.py
new file mode 100755
index 0000000..ac6d11d
--- /dev/null
+++ b/sql/hive/src/test/resources/data/scripts/test_transform.py
@@ -0,0 +1,6 @@
+import sys
+
+delim = sys.argv[1]
+
+for row in sys.stdin:
+    print(delim.join([w + '#' for w in row[:-1].split(delim)]))

http://git-wip-us.apache.org/repos/asf/spark/blob/73d06218/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 3eab66e..5c5e3c5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -1187,4 +1187,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils {
 
     checkAnswer(df, Row("text inside layer 2") :: Nil)
   }
+
+  test("SPARK-10310: " +
+    "script transformation using default input/output SerDe and record reader/writer") {
+    sqlContext
+      .range(5)
+      .selectExpr("id AS a", "id AS b")
+      .registerTempTable("test")
+
+    checkAnswer(
+      sql(
+        """FROM(
+          |  FROM test SELECT TRANSFORM(a, b)
+          |  USING 'python src/test/resources/data/scripts/test_transform.py "\t"'
+          |  AS (c STRING, d STRING)
+          |) t
+          |SELECT c
+        """.stripMargin),
+      (0 until 5).map(i => Row(i + "#")))
+  }
+
+  test("SPARK-10310: script transformation using LazySimpleSerDe") {
+    sqlContext
+      .range(5)
+      .selectExpr("id AS a", "id AS b")
+      .registerTempTable("test")
+
+    val df = sql(
+      """FROM test
+        |SELECT TRANSFORM(a, b)
+        |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
+        |WITH SERDEPROPERTIES('field.delim' = '|')
+        |USING 'python src/test/resources/data/scripts/test_transform.py "|"'
+        |AS (c STRING, d STRING)
+        |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
+        |WITH SERDEPROPERTIES('field.delim' = '|')
+      """.stripMargin)
+
+    checkAnswer(df, (0 until 5).map(i => Row(i + "#", i + "#")))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/73d06218/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
index 9aca40f..c7edcff 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
@@ -41,6 +41,8 @@ class ScriptTransformationSuite extends SparkPlanTest {
     outputSerdeClass = None,
     inputSerdeProps = Seq.empty,
     outputSerdeProps = Seq.empty,
+    recordReaderClass = None,
+    recordWriterClass = None,
     schemaLess = false
   )
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message