spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From joshro...@apache.org
Subject spark git commit: [SPARK-9393] [SQL] Fix several error-handling bugs in ScriptTransform operator
Date Tue, 28 Jul 2015 23:04:52 GMT
Repository: spark
Updated Branches:
  refs/heads/master 21825529e -> 59b92add7


[SPARK-9393] [SQL] Fix several error-handling bugs in ScriptTransform operator

SparkSQL's ScriptTransform operator has several serious bugs which make debugging fairly difficult:

- If exceptions are thrown in the writing thread then the child process will not be killed,
leading to a deadlock because the reader thread will block while waiting for input that will
never arrive.
- TaskContext is not propagated to the writer thread, which may cause errors in upstream pipelined
operators.
- Exceptions which occur in the writer thread are not propagated to the main reader thread,
which may cause upstream errors to be silently ignored instead of killing the job.  This can
lead to silently incorrect query results.
- The writer thread is not a daemon thread, but it should be.

In addition, the code in this file is extremely messy:

- Lots of fields are nullable but the nullability isn't clearly explained.
- Many confusing variable names: for instance, there are variables named `ite` and `iterator`
that are defined in the same scope.
- Some code was misindented.
- The `*serdeClass` variables are actually expected to be single-quoted strings, which is
really confusing: I feel that this parsing / extraction should be performed in the analyzer,
not in the operator itself.
- There were no unit tests for the operator itself, only end-to-end tests.

This pull request addresses these issues, borrowing some error-handling techniques from PySpark's
PythonRDD.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #7710 from JoshRosen/script-transform and squashes the following commits:

16c44e2 [Josh Rosen] Update some comments
983f200 [Josh Rosen] Use unescapeSQLString instead of stripQuotes
6a06a8c [Josh Rosen] Clean up handling of quotes in serde class name
494cde0 [Josh Rosen] Propagate TaskContext to writer thread
323bb2b [Josh Rosen] Fix error-swallowing bug
b31258d [Josh Rosen] Rename iterator variables to disambiguate.
88278de [Josh Rosen] Split ScriptTransformation writer thread into own class.
8b162b6 [Josh Rosen] Add failing test which demonstrates exception masking issue
4ee36a2 [Josh Rosen] Kill script transform subprocess when error occurs in input writer.
bd4c948 [Josh Rosen] Skip launching of external command for empty partitions.
b43e4ec [Josh Rosen] Clean up nullability in ScriptTransformation
fa18d26 [Josh Rosen] Add basic unit test for script transform with 'cat' command.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/59b92add
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/59b92add
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/59b92add

Branch: refs/heads/master
Commit: 59b92add7cc9cca1eaf0c558edb7c4add66c284f
Parents: 2182552
Author: Josh Rosen <joshrosen@databricks.com>
Authored: Tue Jul 28 16:04:48 2015 -0700
Committer: Josh Rosen <joshrosen@databricks.com>
Committed: Tue Jul 28 16:04:48 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/SparkPlanTest.scala     |  27 +-
 .../org/apache/spark/sql/hive/HiveQl.scala      |  10 +-
 .../hive/execution/ScriptTransformation.scala   | 280 ++++++++++++-------
 .../execution/ScriptTransformationSuite.scala   | 123 ++++++++
 4 files changed, 317 insertions(+), 123 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/59b92add/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 6a8f394..f46855e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row}
+import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row}
 
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.TypeTag
@@ -33,11 +33,13 @@ import scala.util.control.NonFatal
  */
 class SparkPlanTest extends SparkFunSuite {
 
+  protected def sqlContext: SQLContext = TestSQLContext
+
   /**
    * Creates a DataFrame from a local Seq of Product.
    */
   implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder
= {
-    TestSQLContext.implicits.localSeqToDataFrameHolder(data)
+    sqlContext.implicits.localSeqToDataFrameHolder(data)
   }
 
   /**
@@ -98,7 +100,7 @@ class SparkPlanTest extends SparkFunSuite {
       planFunction: Seq[SparkPlan] => SparkPlan,
       expectedAnswer: Seq[Row],
       sortAnswers: Boolean = true): Unit = {
-    SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match {
+    SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext)
match {
       case Some(errorMessage) => fail(errorMessage)
       case None =>
     }
@@ -121,7 +123,8 @@ class SparkPlanTest extends SparkFunSuite {
       planFunction: SparkPlan => SparkPlan,
       expectedPlanFunction: SparkPlan => SparkPlan,
       sortAnswers: Boolean = true): Unit = {
-    SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match
{
+    SparkPlanTest.checkAnswer(
+        input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
       case Some(errorMessage) => fail(errorMessage)
       case None =>
     }
@@ -147,13 +150,14 @@ object SparkPlanTest {
       input: DataFrame,
       planFunction: SparkPlan => SparkPlan,
       expectedPlanFunction: SparkPlan => SparkPlan,
-      sortAnswers: Boolean): Option[String] = {
+      sortAnswers: Boolean,
+      sqlContext: SQLContext): Option[String] = {
 
     val outputPlan = planFunction(input.queryExecution.sparkPlan)
     val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
 
     val expectedAnswer: Seq[Row] = try {
-      executePlan(expectedOutputPlan)
+      executePlan(expectedOutputPlan, sqlContext)
     } catch {
       case NonFatal(e) =>
         val errorMessage =
@@ -168,7 +172,7 @@ object SparkPlanTest {
     }
 
     val actualAnswer: Seq[Row] = try {
-      executePlan(outputPlan)
+      executePlan(outputPlan, sqlContext)
     } catch {
       case NonFatal(e) =>
         val errorMessage =
@@ -207,12 +211,13 @@ object SparkPlanTest {
       input: Seq[DataFrame],
       planFunction: Seq[SparkPlan] => SparkPlan,
       expectedAnswer: Seq[Row],
-      sortAnswers: Boolean): Option[String] = {
+      sortAnswers: Boolean,
+      sqlContext: SQLContext): Option[String] = {
 
     val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
 
     val sparkAnswer: Seq[Row] = try {
-      executePlan(outputPlan)
+      executePlan(outputPlan, sqlContext)
     } catch {
       case NonFatal(e) =>
         val errorMessage =
@@ -275,10 +280,10 @@ object SparkPlanTest {
     }
   }
 
-  private def executePlan(outputPlan: SparkPlan): Seq[Row] = {
+  private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
     // A very simple resolver to make writing tests easier. In contrast to the real resolver
     // this is always case sensitive and does not try to handle scoping or complex type resolution.
-    val resolvedPlan = TestSQLContext.prepareForExecution.execute(
+    val resolvedPlan = sqlContext.prepareForExecution.execute(
       outputPlan transform {
         case plan: SparkPlan =>
           val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap

http://git-wip-us.apache.org/repos/asf/spark/blob/59b92add/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 2f79b0a..e6df64d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -874,15 +874,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
             }
 
             def matchSerDe(clause: Seq[ASTNode])
-              : (Seq[(String, String)], String, Seq[(String, String)]) = clause match {
+              : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match
{
               case Token("TOK_SERDEPROPS", propsClause) :: Nil =>
                 val rowFormat = propsClause.map {
                   case Token(name, Token(value, Nil) :: Nil) => (name, value)
                 }
-                (rowFormat, "", Nil)
+                (rowFormat, None, Nil)
 
               case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil =>
-                (Nil, serdeClass, Nil)
+                (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil)
 
               case Token("TOK_SERDENAME", Token(serdeClass, Nil) ::
                 Token("TOK_TABLEPROPERTIES",
@@ -891,9 +891,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
                   case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) ::
Nil) =>
                     (name, value)
                 }
-                (Nil, serdeClass, serdeProps)
+                (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps)
 
-              case Nil => (Nil, "", Nil)
+              case Nil => (Nil, None, Nil)
             }
 
             val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause)

http://git-wip-us.apache.org/repos/asf/spark/blob/59b92add/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 205e622..741c705 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -17,15 +17,18 @@
 
 package org.apache.spark.sql.hive.execution
 
-import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader}
+import java.io._
 import java.util.Properties
+import javax.annotation.Nullable
 
 import scala.collection.JavaConversions._
+import scala.util.control.NonFatal
 
 import org.apache.hadoop.hive.serde.serdeConstants
 import org.apache.hadoop.hive.serde2.AbstractSerDe
 import org.apache.hadoop.hive.serde2.objectinspector._
 
+import org.apache.spark.{TaskContext, Logging}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
@@ -56,21 +59,53 @@ case class ScriptTransformation(
   override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
 
   protected override def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitions { iter =>
+    def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
       val cmd = List("/bin/bash", "-c", script)
       val builder = new ProcessBuilder(cmd)
-      // We need to start threads connected to the process pipeline:
-      // 1) The error msg generated by the script process would be hidden.
-      // 2) If the error msg is too big to chock up the buffer, the input logic would be
hung
+
       val proc = builder.start()
       val inputStream = proc.getInputStream
       val outputStream = proc.getOutputStream
       val errorStream = proc.getErrorStream
-      val reader = new BufferedReader(new InputStreamReader(inputStream))
 
-      val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output)
+      // In order to avoid deadlocks, we need to consume the error output of the child process.
+      // To avoid issues caused by large error output, we use a circular buffer to limit
the amount
+      // of error output that we retain. See SPARK-7862 for more discussion of the deadlock
/ hang
+      // that motivates this.
+      val stderrBuffer = new CircularBuffer(2048)
+      new RedirectThread(
+        errorStream,
+        stderrBuffer,
+        "Thread-ScriptTransformation-STDERR-Consumer").start()
+
+      val outputProjection = new InterpretedProjection(input, child.output)
+
+      // This nullability is a performance optimization in order to avoid an Option.foreach()
call
+      // inside of a loop
+      @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null,
null))
+
+      // This new thread will consume the ScriptTransformation's input rows and write them
to the
+      // external process. That process's output will be read by this current thread.
+      val writerThread = new ScriptTransformationWriterThread(
+        inputIterator,
+        outputProjection,
+        inputSerde,
+        inputSoi,
+        ioschema,
+        outputStream,
+        proc,
+        stderrBuffer,
+        TaskContext.get()
+      )
+
+      // This nullability is a performance optimization in order to avoid an Option.foreach()
call
+      // inside of a loop
+      @Nullable val (outputSerde, outputSoi) = {
+        ioschema.initOutputSerDe(output).getOrElse((null, null))
+      }
 
-      val iterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors
{
+      val reader = new BufferedReader(new InputStreamReader(inputStream))
+      val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors
{
         var cacheRow: InternalRow = null
         var curLine: String = null
         var eof: Boolean = false
@@ -79,12 +114,26 @@ case class ScriptTransformation(
           if (outputSerde == null) {
             if (curLine == null) {
               curLine = reader.readLine()
-              curLine != null
+              if (curLine == null) {
+                if (writerThread.exception.isDefined) {
+                  throw writerThread.exception.get
+                }
+                false
+              } else {
+                true
+              }
             } else {
               true
             }
           } else {
-            !eof
+            if (eof) {
+              if (writerThread.exception.isDefined) {
+                throw writerThread.exception.get
+              }
+              false
+            } else {
+              true
+            }
           }
         }
 
@@ -110,11 +159,11 @@ case class ScriptTransformation(
               }
               i += 1
             })
-            return mutableRow
+            mutableRow
           } catch {
             case e: EOFException =>
               eof = true
-              return null
+              null
           }
         }
 
@@ -146,49 +195,83 @@ case class ScriptTransformation(
         }
       }
 
-      val (inputSerde, inputSoi) = ioschema.initInputSerDe(input)
-      val dataOutputStream = new DataOutputStream(outputStream)
-      val outputProjection = new InterpretedProjection(input, child.output)
+      writerThread.start()
 
-      // TODO make the 2048 configurable?
-      val stderrBuffer = new CircularBuffer(2048)
-      // Consume the error stream from the pipeline, otherwise it will be blocked if
-      // the pipeline is full.
-      new RedirectThread(errorStream, // input stream from the pipeline
-        stderrBuffer,                 // output to a circular buffer
-        "Thread-ScriptTransformation-STDERR-Consumer").start()
+      outputIterator
+    }
 
-      // Put the write(output to the pipeline) into a single thread
-      // and keep the collector as remain in the main thread.
-      // otherwise it will causes deadlock if the data size greater than
-      // the pipeline / buffer capacity.
-      new Thread(new Runnable() {
-        override def run(): Unit = {
-          Utils.tryWithSafeFinally {
-            iter
-              .map(outputProjection)
-              .foreach { row =>
-              if (inputSerde == null) {
-                val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"),
-                  ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8")
-
-                outputStream.write(data)
-              } else {
-                val writable = inputSerde.serialize(
-                  row.asInstanceOf[GenericInternalRow].values, inputSoi)
-                prepareWritable(writable).write(dataOutputStream)
-              }
-            }
-            outputStream.close()
-          } {
-            if (proc.waitFor() != 0) {
-              logError(stderrBuffer.toString) // log the stderr circular buffer
-            }
-          }
-        }
-      }, "Thread-ScriptTransformation-Feed").start()
+    child.execute().mapPartitions { iter =>
+      if (iter.hasNext) {
+        processIterator(iter)
+      } else {
+        // If the input iterator has no rows then do not launch the external script.
+        Iterator.empty
+      }
+    }
+  }
+}
 
-      iterator
+private class ScriptTransformationWriterThread(
+    iter: Iterator[InternalRow],
+    outputProjection: Projection,
+    @Nullable inputSerde: AbstractSerDe,
+    @Nullable inputSoi: ObjectInspector,
+    ioschema: HiveScriptIOSchema,
+    outputStream: OutputStream,
+    proc: Process,
+    stderrBuffer: CircularBuffer,
+    taskContext: TaskContext
+  ) extends Thread("Thread-ScriptTransformation-Feed") with Logging {
+
+  setDaemon(true)
+
+  @volatile private var _exception: Throwable = null
+
+  /** Contains the exception thrown while writing the parent iterator to the external process.
*/
+  def exception: Option[Throwable] = Option(_exception)
+
+  override def run(): Unit = Utils.logUncaughtExceptions {
+    TaskContext.setTaskContext(taskContext)
+
+    val dataOutputStream = new DataOutputStream(outputStream)
+
+    // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so
+    // let's use a variable to record whether the `finally` block was hit due to an exception
+    var threwException: Boolean = true
+    try {
+      iter.map(outputProjection).foreach { row =>
+        if (inputSerde == null) {
+          val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"),
+            ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8")
+          outputStream.write(data)
+        } else {
+          val writable = inputSerde.serialize(
+            row.asInstanceOf[GenericInternalRow].values, inputSoi)
+          prepareWritable(writable).write(dataOutputStream)
+        }
+      }
+      outputStream.close()
+      threwException = false
+    } catch {
+      case NonFatal(e) =>
+        // An error occurred while writing input, so kill the child process. According to
the
+        // Javadoc this call will not throw an exception:
+        _exception = e
+        proc.destroy()
+        throw e
+    } finally {
+      try {
+        if (proc.waitFor() != 0) {
+          logError(stderrBuffer.toString) // log the stderr circular buffer
+        }
+      } catch {
+        case NonFatal(exceptionFromFinallyBlock) =>
+          if (!threwException) {
+            throw exceptionFromFinallyBlock
+          } else {
+            log.error("Exception in finally block", exceptionFromFinallyBlock)
+          }
+      }
     }
   }
 }
@@ -200,33 +283,43 @@ private[hive]
 case class HiveScriptIOSchema (
     inputRowFormat: Seq[(String, String)],
     outputRowFormat: Seq[(String, String)],
-    inputSerdeClass: String,
-    outputSerdeClass: String,
+    inputSerdeClass: Option[String],
+    outputSerdeClass: Option[String],
     inputSerdeProps: Seq[(String, String)],
     outputSerdeProps: Seq[(String, String)],
     schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors {
 
-  val defaultFormat = Map(("TOK_TABLEROWFORMATFIELD", "\t"),
-                          ("TOK_TABLEROWFORMATLINES", "\n"))
+  private val defaultFormat = Map(
+    ("TOK_TABLEROWFORMATFIELD", "\t"),
+    ("TOK_TABLEROWFORMATLINES", "\n")
+  )
 
   val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
   val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
 
 
-  def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = {
-    val (columns, columnTypes) = parseAttrs(input)
-    val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps)
-    (serde, initInputSoi(serde, columns, columnTypes))
+  def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] =
{
+    inputSerdeClass.map { serdeClass =>
+      val (columns, columnTypes) = parseAttrs(input)
+      val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps)
+      val fieldObjectInspectors = columnTypes.map(toInspector)
+      val objectInspector = ObjectInspectorFactory
+        .getStandardStructObjectInspector(columns, fieldObjectInspectors)
+        .asInstanceOf[ObjectInspector]
+      (serde, objectInspector)
+    }
   }
 
-  def initOutputSerDe(output: Seq[Attribute]): (AbstractSerDe, StructObjectInspector) = {
-    val (columns, columnTypes) = parseAttrs(output)
-    val serde = initSerDe(outputSerdeClass, columns, columnTypes, outputSerdeProps)
-    (serde, initOutputputSoi(serde))
+  def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)]
= {
+    outputSerdeClass.map { serdeClass =>
+      val (columns, columnTypes) = parseAttrs(output)
+      val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps)
+      val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector]
+      (serde, structObjectInspector)
+    }
   }
 
-  def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
-
+  private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = {
     val columns = attrs.map {
       case aref: AttributeReference => aref.name
       case e: NamedExpression => e.name
@@ -242,52 +335,25 @@ case class HiveScriptIOSchema (
     (columns, columnTypes)
   }
 
-  def initSerDe(serdeClassName: String, columns: Seq[String],
-    columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = {
+  private def initSerDe(
+      serdeClassName: String,
+      columns: Seq[String],
+      columnTypes: Seq[DataType],
+      serdeProps: Seq[(String, String)]): AbstractSerDe = {
 
-    val serde: AbstractSerDe = if (serdeClassName != "") {
-      val trimed_class = serdeClassName.split("'")(1)
-      Utils.classForName(trimed_class)
-        .newInstance.asInstanceOf[AbstractSerDe]
-    } else {
-      null
-    }
+    val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe]
 
-    if (serde != null) {
-      val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",")
+    val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",")
 
-      var propsMap = serdeProps.map(kv => {
-        (kv._1.split("'")(1), kv._2.split("'")(1))
-      }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
-      propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
+    var propsMap = serdeProps.map(kv => {
+      (kv._1.split("'")(1), kv._2.split("'")(1))
+    }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
+    propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)
 
-      val properties = new Properties()
-      properties.putAll(propsMap)
-      serde.initialize(null, properties)
-    }
+    val properties = new Properties()
+    properties.putAll(propsMap)
+    serde.initialize(null, properties)
 
     serde
   }
-
-  def initInputSoi(inputSerde: AbstractSerDe, columns: Seq[String], columnTypes: Seq[DataType])
-    : ObjectInspector = {
-
-    if (inputSerde != null) {
-      val fieldObjectInspectors = columnTypes.map(toInspector(_))
-      ObjectInspectorFactory
-        .getStandardStructObjectInspector(columns, fieldObjectInspectors)
-        .asInstanceOf[ObjectInspector]
-    } else {
-      null
-    }
-  }
-
-  def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = {
-    if (outputSerde != null) {
-      outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector]
-    } else {
-      null
-    }
-  }
 }
-

http://git-wip-us.apache.org/repos/asf/spark/blob/59b92add/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
new file mode 100644
index 0000000..0875232
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+import org.scalatest.exceptions.TestFailedException
+
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.types.StringType
+
+class ScriptTransformationSuite extends SparkPlanTest {
+
+  override def sqlContext: SQLContext = TestHive
+
+  private val noSerdeIOSchema = HiveScriptIOSchema(
+    inputRowFormat = Seq.empty,
+    outputRowFormat = Seq.empty,
+    inputSerdeClass = None,
+    outputSerdeClass = None,
+    inputSerdeProps = Seq.empty,
+    outputSerdeProps = Seq.empty,
+    schemaLess = false
+  )
+
+  private val serdeIOSchema = noSerdeIOSchema.copy(
+    inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName),
+    outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName)
+  )
+
+  test("cat without SerDe") {
+    val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+    checkAnswer(
+      rowsDf,
+      (child: SparkPlan) => new ScriptTransformation(
+        input = Seq(rowsDf.col("a").expr),
+        script = "cat",
+        output = Seq(AttributeReference("a", StringType)()),
+        child = child,
+        ioschema = noSerdeIOSchema
+      )(TestHive),
+      rowsDf.collect())
+  }
+
+  test("cat with LazySimpleSerDe") {
+    val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+    checkAnswer(
+      rowsDf,
+      (child: SparkPlan) => new ScriptTransformation(
+        input = Seq(rowsDf.col("a").expr),
+        script = "cat",
+        output = Seq(AttributeReference("a", StringType)()),
+        child = child,
+        ioschema = serdeIOSchema
+      )(TestHive),
+      rowsDf.collect())
+  }
+
+  test("script transformation should not swallow errors from upstream operators (no serde)")
{
+    val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+    val e = intercept[TestFailedException] {
+      checkAnswer(
+        rowsDf,
+        (child: SparkPlan) => new ScriptTransformation(
+          input = Seq(rowsDf.col("a").expr),
+          script = "cat",
+          output = Seq(AttributeReference("a", StringType)()),
+          child = ExceptionInjectingOperator(child),
+          ioschema = noSerdeIOSchema
+        )(TestHive),
+        rowsDf.collect())
+    }
+    assert(e.getMessage().contains("intentional exception"))
+  }
+
+  test("script transformation should not swallow errors from upstream operators (with serde)")
{
+    val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a")
+    val e = intercept[TestFailedException] {
+      checkAnswer(
+        rowsDf,
+        (child: SparkPlan) => new ScriptTransformation(
+          input = Seq(rowsDf.col("a").expr),
+          script = "cat",
+          output = Seq(AttributeReference("a", StringType)()),
+          child = ExceptionInjectingOperator(child),
+          ioschema = serdeIOSchema
+        )(TestHive),
+        rowsDf.collect())
+    }
+    assert(e.getMessage().contains("intentional exception"))
+  }
+}
+
+private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryNode {
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().map { x =>
+      assert(TaskContext.get() != null) // Make sure that TaskContext is defined.
+      Thread.sleep(1000) // This sleep gives the external process time to start.
+      throw new IllegalArgumentException("intentional exception")
+    }
+  }
+  override def output: Seq[Attribute] = child.output
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message