spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-24991][SQL] use InternalRow in DataSourceWriter
Date Mon, 06 Aug 2018 07:52:10 GMT
Repository: spark
Updated Branches:
  refs/heads/master 327bb3007 -> ac527b520


[SPARK-24991][SQL] use InternalRow in DataSourceWriter

## What changes were proposed in this pull request?

A follow up of #21118

Since we use `InternalRow` in the read API of data source v2, we should do the same thing
for the write API.

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #21948 from cloud-fan/row-write.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ac527b52
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ac527b52
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ac527b52

Branch: refs/heads/master
Commit: ac527b5205ec2826677e2b7ad0d424aa976bce81
Parents: 327bb30
Author: Wenchen Fan <wenchen@databricks.com>
Authored: Mon Aug 6 15:52:01 2018 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Mon Aug 6 15:52:01 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/kafka010/KafkaStreamWriter.scala  |  4 +-
 .../sql/sources/v2/writer/DataSourceWriter.java |  4 +-
 .../spark/sql/sources/v2/writer/DataWriter.java |  4 +-
 .../sources/v2/writer/DataWriterFactory.java    |  5 +-
 .../v2/writer/SupportsWriteInternalRow.java     | 41 -----------
 .../datasources/v2/WriteToDataSourceV2.scala    | 30 +-------
 .../streaming/MicroBatchExecution.scala         | 10 +--
 .../continuous/ContinuousWriteRDD.scala         |  6 +-
 .../WriteToContinuousDataSourceExec.scala       | 12 +---
 .../streaming/sources/ConsoleWriter.scala       | 11 ++-
 .../sources/ForeachWriterProvider.scala         | 10 +--
 .../streaming/sources/MicroBatchWriter.scala    | 21 +-----
 .../sources/PackedRowWriterFactory.scala        | 15 ++--
 .../execution/streaming/sources/memoryV2.scala  | 33 +++++----
 .../execution/streaming/MemorySinkV2Suite.scala | 18 +++--
 .../sql/sources/v2/DataSourceV2Suite.scala      |  7 --
 .../sources/v2/SimpleWritableDataSource.scala   | 72 ++------------------
 17 files changed, 73 insertions(+), 230 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala
b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala
index 32923dc..5f0802b 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala
@@ -42,11 +42,11 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage
  */
 class KafkaStreamWriter(
     topic: Option[String], producerParams: Map[String, String], schema: StructType)
-  extends StreamWriter with SupportsWriteInternalRow {
+  extends StreamWriter {
 
   validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic)
 
-  override def createInternalRowWriterFactory(): KafkaStreamWriterFactory =
+  override def createWriterFactory(): KafkaStreamWriterFactory =
     KafkaStreamWriterFactory(topic, producerParams, schema)
 
   override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
index 7eedc85..385fc29 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
@@ -18,8 +18,8 @@
 package org.apache.spark.sql.sources.v2.writer;
 
 import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SaveMode;
+import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.sources.v2.DataSourceOptions;
 import org.apache.spark.sql.sources.v2.StreamWriteSupport;
 import org.apache.spark.sql.sources.v2.WriteSupport;
@@ -61,7 +61,7 @@ public interface DataSourceWriter {
    * If this method fails (by throwing an exception), the action will fail and no Spark job
will be
    * submitted.
    */
-  DataWriterFactory<Row> createWriterFactory();
+  DataWriterFactory<InternalRow> createWriterFactory();
 
   /**
    * Returns whether Spark should use the commit coordinator to ensure that at most one task
for

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
index 1626c00..27dc5ea 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
@@ -53,9 +53,7 @@ import org.apache.spark.annotation.InterfaceStability;
  * successfully, and have a way to revert committed data writers without the commit message,
because
  * Spark only accepts the commit message that arrives first and ignore others.
  *
- * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal
data
- * source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers
- * that mix in {@link SupportsWriteInternalRow}.
+ * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}.
  */
 @InterfaceStability.Evolving
 public interface DataWriter<T> {

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
index 0932ff8..3d337b6 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
@@ -33,7 +33,10 @@ import org.apache.spark.annotation.InterfaceStability;
 public interface DataWriterFactory<T> extends Serializable {
 
   /**
-   * Returns a data writer to do the actual writing work.
+   * Returns a data writer to do the actual writing work. Note that, Spark will reuse the
same data
+   * object instance when sending data to the data writer, for better performance. Data writers
+   * are responsible for defensive copies if necessary, e.g. copy the data before buffer
it in a
+   * list.
    *
    * If this method fails (by throwing an exception), the action will fail and no Spark job
will be
    * submitted.

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java
deleted file mode 100644
index d2cf7e0..0000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2.writer;
-
-import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.catalyst.InternalRow;
-
-/**
- * A mix-in interface for {@link DataSourceWriter}. Data source writers can implement this
- * interface to write {@link InternalRow} directly and avoid the row conversion at Spark
side.
- * This is an experimental and unstable interface, as {@link InternalRow} is not public and
may get
- * changed in the future Spark versions.
- */
-
-@InterfaceStability.Unstable
-public interface SupportsWriteInternalRow extends DataSourceWriter {
-
-  @Override
-  default DataWriterFactory<Row> createWriterFactory() {
-    throw new IllegalStateException(
-      "createWriterFactory should not be called with SupportsWriteInternalRow.");
-  }
-
-  DataWriterFactory<InternalRow> createInternalRowWriterFactory();
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
index b1148c0..0399970 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
@@ -50,11 +50,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan)
e
   override def output: Seq[Attribute] = Nil
 
   override protected def doExecute(): RDD[InternalRow] = {
-    val writeTask = writer match {
-      case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
-      case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
-    }
-
+    val writeTask = writer.createWriterFactory()
     val useCommitCoordinator = writer.useCommitCoordinator
     val rdd = query.execute()
     val messages = new Array[WriterCommitMessage](rdd.partitions.length)
@@ -155,27 +151,3 @@ object DataWritingSparkTask extends Logging {
     })
   }
 }
-
-class InternalRowDataWriterFactory(
-    rowWriterFactory: DataWriterFactory[Row],
-    schema: StructType) extends DataWriterFactory[InternalRow] {
-
-  override def createDataWriter(
-      partitionId: Int,
-      taskId: Long,
-      epochId: Long): DataWriter[InternalRow] = {
-    new InternalRowDataWriter(
-      rowWriterFactory.createDataWriter(partitionId, taskId, epochId),
-      RowEncoder.apply(schema).resolveAndBind())
-  }
-}
-
-class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row])
-  extends DataWriter[InternalRow] {
-
-  override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record))
-
-  override def commit(): WriterCommitMessage = rowWriter.commit()
-
-  override def abort(): Unit = rowWriter.abort()
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index abb807d..c759f5b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -28,10 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp,
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
 import org.apache.spark.sql.execution.SQLExecution
 import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
-import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter}
+import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter
 import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport,
StreamWriteSupport}
 import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
-import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow
 import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
 import org.apache.spark.util.{Clock, Utils}
 
@@ -498,12 +497,7 @@ class MicroBatchExecution(
           newAttributePlan.schema,
           outputMode,
           new DataSourceOptions(extraOptions.asJava))
-        if (writer.isInstanceOf[SupportsWriteInternalRow]) {
-          WriteToDataSourceV2(
-            new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan)
-        } else {
-          WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan)
-        }
+        WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan)
       case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
index 76f3f5b..967dbe2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
@@ -17,13 +17,10 @@
 
 package org.apache.spark.sql.execution.streaming.continuous
 
-import java.util.concurrent.atomic.AtomicLong
-
 import org.apache.spark.{Partition, SparkEnv, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
-import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory}
 import org.apache.spark.util.Utils
 
 /**
@@ -47,7 +44,6 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
       SparkEnv.get)
     EpochTracker.initializeCurrentEpoch(
       context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
-
     while (!context.isInterrupted() && !context.isCompleted()) {
       var dataWriter: DataWriter[InternalRow] = null
       // write the data and commit this writer.

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
index e0af3a2..927d3a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
@@ -19,18 +19,14 @@ package org.apache.spark.sql.execution.streaming.continuous
 
 import scala.util.control.NonFatal
 
-import org.apache.spark.{SparkEnv, SparkException, TaskContext}
+import org.apache.spark.SparkException
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory}
-import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
 import org.apache.spark.sql.execution.streaming.StreamExecution
-import org.apache.spark.sql.sources.v2.writer._
 import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
-import org.apache.spark.util.Utils
 
 /**
  * The physical plan for writing data into a continuous processing [[StreamWriter]].
@@ -41,11 +37,7 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query:
SparkPla
   override def output: Seq[Attribute] = Nil
 
   override protected def doExecute(): RDD[InternalRow] = {
-    val writerFactory = writer match {
-      case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
-      case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
-    }
-
+    val writerFactory = writer.createWriterFactory()
     val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)
 
     logInfo(s"Start processing data source writer: $writer. " +

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala
index d276403..fd45ba5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala
@@ -17,10 +17,10 @@
 
 package org.apache.spark.sql.execution.streaming.sources
 
-import scala.collection.JavaConverters._
-
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.{Dataset, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.sources.v2.DataSourceOptions
 import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage}
 import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
@@ -39,7 +39,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
   assert(SparkSession.getActiveSession.isDefined)
   protected val spark = SparkSession.getActiveSession.get
 
-  def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory
+  def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory
 
   override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
     // We have to print a "Batch" label for the epoch for compatibility with the pre-data
source V2
@@ -62,8 +62,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
     println(printMessage)
     println("-------------------------------------------")
     // scalastyle:off println
-    spark
-      .createDataFrame(rows.toList.asJava, schema)
+    Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows))
       .show(numRowsToShow, isTruncated)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
index bc9b6d9..e8ce21c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala
@@ -17,13 +17,13 @@
 
 package org.apache.spark.sql.execution.streaming.sources
 
-import org.apache.spark.sql.{Encoder, ForeachWriter, SparkSession}
+import org.apache.spark.sql.{ForeachWriter, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.python.PythonForeachWriter
 import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
-import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, SupportsWriteInternalRow,
WriterCommitMessage}
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
 import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
@@ -46,11 +46,11 @@ case class ForeachWriterProvider[T](
       schema: StructType,
       mode: OutputMode,
       options: DataSourceOptions): StreamWriter = {
-    new StreamWriter with SupportsWriteInternalRow {
+    new StreamWriter {
       override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
       override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
 
-      override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = {
+      override def createWriterFactory(): DataWriterFactory[InternalRow] = {
         val rowConverter: InternalRow => T = converter match {
           case Left(enc) =>
             val boundEnc = enc.resolveAndBind(

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala
index 56f7ff2..d023a35 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala
@@ -17,9 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming.sources
 
-import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, SupportsWriteInternalRow,
WriterCommitMessage}
+import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage}
 import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
 
 /**
@@ -34,21 +33,5 @@ class MicroBatchWriter(batchId: Long, writer: StreamWriter) extends DataSourceWr
 
   override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId,
messages)
 
-  override def createWriterFactory(): DataWriterFactory[Row] = writer.createWriterFactory()
-}
-
-class InternalRowMicroBatchWriter(batchId: Long, writer: StreamWriter)
-  extends DataSourceWriter with SupportsWriteInternalRow {
-  override def commit(messages: Array[WriterCommitMessage]): Unit = {
-    writer.commit(batchId, messages)
-  }
-
-  override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId,
messages)
-
-  override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] =
-    writer match {
-      case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
-      case _ => throw new IllegalStateException(
-        "InternalRowMicroBatchWriter should only be created with base writer support")
-    }
+  override def createWriterFactory(): DataWriterFactory[InternalRow] = writer.createWriterFactory()
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala
index b501d90..f26e11d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources
 import scala.collection.mutable
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory,
WriterCommitMessage}
 
 /**
@@ -30,11 +30,11 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter,
Dat
  * Note that, because it sends all rows to the driver, this factory will generally be unsuitable
  * for production-quality sinks. It's intended for use in tests.
  */
-case object PackedRowWriterFactory extends DataWriterFactory[Row] {
+case object PackedRowWriterFactory extends DataWriterFactory[InternalRow] {
   override def createDataWriter(
       partitionId: Int,
       taskId: Long,
-      epochId: Long): DataWriter[Row] = {
+      epochId: Long): DataWriter[InternalRow] = {
     new PackedRowDataWriter()
   }
 }
@@ -43,15 +43,16 @@ case object PackedRowWriterFactory extends DataWriterFactory[Row] {
  * Commit message for a [[PackedRowDataWriter]], containing all the rows written in the most
  * recent interval.
  */
-case class PackedRowCommitMessage(rows: Array[Row]) extends WriterCommitMessage
+case class PackedRowCommitMessage(rows: Array[InternalRow]) extends WriterCommitMessage
 
 /**
  * A simple [[DataWriter]] that just sends all the rows it's received as a commit message.
  */
-class PackedRowDataWriter() extends DataWriter[Row] with Logging {
-  private val data = mutable.Buffer[Row]()
+class PackedRowDataWriter() extends DataWriter[InternalRow] with Logging {
+  private val data = mutable.Buffer[InternalRow]()
 
-  override def write(row: Row): Unit = data.append(row)
+  // Spark reuses the same `InternalRow` instance, here we copy it before buffer it.
+  override def write(row: InternalRow): Unit = data.append(row.copy())
 
   override def commit(): PackedRowCommitMessage = {
     val msg = PackedRowCommitMessage(data.toArray)

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
index f2a35a9..afacb2f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
@@ -25,6 +25,8 @@ import scala.util.control.NonFatal
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
 import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
@@ -46,7 +48,7 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkB
       schema: StructType,
       mode: OutputMode,
       options: DataSourceOptions): StreamWriter = {
-    new MemoryStreamWriter(this, mode)
+    new MemoryStreamWriter(this, mode, schema)
   }
 
   private case class AddedData(batchId: Long, data: Array[Row])
@@ -115,12 +117,13 @@ class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with
MemorySinkB
   override def toString(): String = "MemorySinkV2"
 }
 
-case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage
{}
+case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row])
+  extends WriterCommitMessage {}
 
-class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode)
+class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, schema: StructType)
   extends DataSourceWriter with Logging {
 
-  override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
+  override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode,
schema)
 
   def commit(messages: Array[WriterCommitMessage]): Unit = {
     val newRows = messages.flatMap {
@@ -134,10 +137,10 @@ class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode)
   }
 }
 
-class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode)
+class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType)
   extends StreamWriter {
 
-  override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode)
+  override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode,
schema)
 
   override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
     val newRows = messages.flatMap {
@@ -151,22 +154,26 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode)
   }
 }
 
-case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] {
+case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType)
+  extends DataWriterFactory[InternalRow] {
+
   override def createDataWriter(
       partitionId: Int,
       taskId: Long,
-      epochId: Long): DataWriter[Row] = {
-    new MemoryDataWriter(partitionId, outputMode)
+      epochId: Long): DataWriter[InternalRow] = {
+    new MemoryDataWriter(partitionId, outputMode, schema)
   }
 }
 
-class MemoryDataWriter(partition: Int, outputMode: OutputMode)
-  extends DataWriter[Row] with Logging {
+class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructType)
+  extends DataWriter[InternalRow] with Logging {
 
   private val data = mutable.Buffer[Row]()
 
-  override def write(row: Row): Unit = {
-    data.append(row)
+  private val encoder = RowEncoder(schema).resolveAndBind()
+
+  override def write(row: InternalRow): Unit = {
+    data.append(encoder.fromRow(row))
   }
 
   override def commit(): MemoryWriterCommitMessage = {

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
index 9be22d9..b4d9b68 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
@@ -20,16 +20,19 @@ package org.apache.spark.sql.execution.streaming
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.streaming.sources._
 import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
+import org.apache.spark.sql.types.StructType
 
 class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
   test("data writer") {
     val partition = 1234
-    val writer = new MemoryDataWriter(partition, OutputMode.Append())
-    writer.write(Row(1))
-    writer.write(Row(2))
-    writer.write(Row(44))
+    val writer = new MemoryDataWriter(
+      partition, OutputMode.Append(), new StructType().add("i", "int"))
+    writer.write(InternalRow(1))
+    writer.write(InternalRow(2))
+    writer.write(InternalRow(44))
     val msg = writer.commit()
     assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44))
     assert(msg.partition == partition)
@@ -40,7 +43,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
 
   test("continuous writer") {
     val sink = new MemorySinkV2
-    val writer = new MemoryStreamWriter(sink, OutputMode.Append())
+    val writer = new MemoryStreamWriter(sink, OutputMode.Append(), new StructType().add("i",
"int"))
     writer.commit(0,
       Array(
         MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
@@ -62,7 +65,8 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
 
   test("microbatch writer") {
     val sink = new MemorySinkV2
-    new MemoryWriter(sink, 0, OutputMode.Append()).commit(
+    val schema = new StructType().add("i", "int")
+    new MemoryWriter(sink, 0, OutputMode.Append(), schema).commit(
       Array(
         MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
         MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
@@ -70,7 +74,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
       ))
     assert(sink.latestBatchId.contains(0))
     assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
-    new MemoryWriter(sink, 19, OutputMode.Append()).commit(
+    new MemoryWriter(sink, 19, OutputMode.Append(), schema).commit(
       Array(
         MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
         MemoryWriterCommitMessage(0, Seq(Row(33)))

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index b6e594d..fef53e6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -242,13 +242,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
         assert(e2.getMessage.contains("Writing job aborted"))
         // make sure we don't have partial data.
         assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty)
-
-        // test internal row writer
-        spark.range(5).select('id, -'id).write.format(cls.getName)
-          .option("path", path).option("internal", "true").mode("overwrite").save()
-        checkAnswer(
-          spark.read.format(cls.getName).option("path", path).load(),
-          spark.range(5).select('id, -'id))
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ac527b52/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
index 183d039..e1b8e9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path}
 
 import org.apache.spark.SparkContext
-import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.SaveMode
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader}
 import org.apache.spark.sql.sources.v2.writer._
@@ -65,9 +65,9 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with
WriteS
   }
 
   class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter
{
-    override def createWriterFactory(): DataWriterFactory[Row] = {
+    override def createWriterFactory(): DataWriterFactory[InternalRow] = {
       SimpleCounter.resetCounter
-      new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf))
+      new CSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf))
     }
 
     override def onDataWriterCommit(message: WriterCommitMessage): Unit = {
@@ -97,18 +97,6 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with
WriteS
     }
   }
 
-  class InternalRowWriter(jobId: String, path: String, conf: Configuration)
-    extends Writer(jobId, path, conf) with SupportsWriteInternalRow {
-
-    override def createWriterFactory(): DataWriterFactory[Row] = {
-      throw new IllegalArgumentException("not expected!")
-    }
-
-    override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = {
-      new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf))
-    }
-  }
-
   override def createReader(options: DataSourceOptions): DataSourceReader = {
     val path = new Path(options.get("path").get())
     val conf = SparkContext.getActive.get.hadoopConfiguration
@@ -124,7 +112,6 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with
WriteS
     assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false))
 
     val path = new Path(options.get("path").get())
-    val internal = options.get("internal").isPresent
     val conf = SparkContext.getActive.get.hadoopConfiguration
     val fs = path.getFileSystem(conf)
 
@@ -142,17 +129,8 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport
with WriteS
       fs.delete(path, true)
     }
 
-    Optional.of(createWriter(jobId, path, conf, internal))
-  }
-
-  private def createWriter(
-      jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceWriter
= {
     val pathStr = path.toUri.toString
-    if (internal) {
-      new InternalRowWriter(jobId, pathStr, conf)
-    } else {
-      new Writer(jobId, pathStr, conf)
-    }
+    Optional.of(new Writer(jobId, pathStr, conf))
   }
 }
 
@@ -204,43 +182,7 @@ private[v2] object SimpleCounter {
   }
 }
 
-class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
-  extends DataWriterFactory[Row] {
-
-  override def createDataWriter(
-      partitionId: Int,
-      taskId: Long,
-      epochId: Long): DataWriter[Row] = {
-    val jobPath = new Path(new Path(path, "_temporary"), jobId)
-    val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId")
-    val fs = filePath.getFileSystem(conf.value)
-    new SimpleCSVDataWriter(fs, filePath)
-  }
-}
-
-class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] {
-
-  private val out = fs.create(file)
-
-  override def write(record: Row): Unit = {
-    out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n")
-  }
-
-  override def commit(): WriterCommitMessage = {
-    out.close()
-    null
-  }
-
-  override def abort(): Unit = {
-    try {
-      out.close()
-    } finally {
-      fs.delete(file, false)
-    }
-  }
-}
-
-class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
+class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration)
   extends DataWriterFactory[InternalRow] {
 
   override def createDataWriter(
@@ -250,11 +192,11 @@ class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf:
Seriali
     val jobPath = new Path(new Path(path, "_temporary"), jobId)
     val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId")
     val fs = filePath.getFileSystem(conf.value)
-    new InternalRowCSVDataWriter(fs, filePath)
+    new CSVDataWriter(fs, filePath)
   }
 }
 
-class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow]
{
+class CSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] {
 
   private val out = fs.create(file)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message