Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 9C29318328 for ; Tue, 14 Jul 2015 19:56:21 +0000 (UTC) Received: (qmail 13934 invoked by uid 500); 14 Jul 2015 19:56:21 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 13905 invoked by uid 500); 14 Jul 2015 19:56:21 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 13896 invoked by uid 99); 14 Jul 2015 19:56:21 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 14 Jul 2015 19:56:21 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 5AAABE025B; Tue, 14 Jul 2015 19:56:21 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: rxin@apache.org To: commits@spark.apache.org Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-9031] Merge BlockObjectWriter and DiskBlockObject writer to remove abstract class Date: Tue, 14 Jul 2015 19:56:21 +0000 (UTC) Repository: spark Updated Branches: refs/heads/master 8fb3a65cb -> d267c2834 [SPARK-9031] Merge BlockObjectWriter and DiskBlockObject writer to remove abstract class BlockObjectWriter has only one concrete non-test class, DiskBlockObjectWriter. In order to simplify the code in preparation for other refactorings, I think that we should remove this base class and have only DiskBlockObjectWriter. While at one time we may have planned to have multiple BlockObjectWriter implementations, that doesn't seem to have happened, so the extra abstraction seems unnecessary. Author: Josh Rosen Closes #7391 from JoshRosen/shuffle-write-interface-refactoring and squashes the following commits: c418e33 [Josh Rosen] Fix compilation 5047995 [Josh Rosen] Fix comments d5dc548 [Josh Rosen] Update references in comments 89dc797 [Josh Rosen] Rename test suite. 5755918 [Josh Rosen] Remove unnecessary val in case class 1607c91 [Josh Rosen] Merge BlockObjectWriter and DiskBlockObjectWriter Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d267c283 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d267c283 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d267c283 Branch: refs/heads/master Commit: d267c2834a639aaebd0559355c6a82613abb689b Parents: 8fb3a65 Author: Josh Rosen Authored: Tue Jul 14 12:56:17 2015 -0700 Committer: Reynold Xin Committed: Tue Jul 14 12:56:17 2015 -0700 ---------------------------------------------------------------------- .../sort/BypassMergeSortShuffleWriter.java | 8 +- .../unsafe/UnsafeShuffleExternalSorter.java | 2 +- .../unsafe/sort/UnsafeSorterSpillWriter.java | 4 +- .../shuffle/FileShuffleBlockResolver.scala | 8 +- .../shuffle/IndexShuffleBlockResolver.scala | 2 +- .../spark/shuffle/hash/HashShuffleWriter.scala | 4 +- .../org/apache/spark/storage/BlockManager.scala | 2 +- .../spark/storage/BlockObjectWriter.scala | 256 ------------------- .../spark/storage/DiskBlockObjectWriter.scala | 234 +++++++++++++++++ .../spark/util/collection/ChainedBuffer.scala | 2 +- .../spark/util/collection/ExternalSorter.scala | 4 +- .../util/collection/PartitionedPairBuffer.scala | 1 - .../PartitionedSerializedPairBuffer.scala | 5 +- .../WritablePartitionedPairCollection.scala | 8 +- .../BypassMergeSortShuffleWriterSuite.scala | 4 +- .../spark/storage/BlockObjectWriterSuite.scala | 173 ------------- .../storage/DiskBlockObjectWriterSuite.scala | 173 +++++++++++++ .../PartitionedSerializedPairBufferSuite.scala | 52 ++-- 18 files changed, 459 insertions(+), 483 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index d3d6280..0b8b604 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -75,7 +75,7 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< private final Serializer serializer; /** Array of file writers, one for each partition */ - private BlockObjectWriter[] partitionWriters; + private DiskBlockObjectWriter[] partitionWriters; public BypassMergeSortShuffleWriter( SparkConf conf, @@ -101,7 +101,7 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< } final SerializerInstance serInstance = serializer.newInstance(); final long openStartTime = System.nanoTime(); - partitionWriters = new BlockObjectWriter[numPartitions]; + partitionWriters = new DiskBlockObjectWriter[numPartitions]; for (int i = 0; i < numPartitions; i++) { final Tuple2 tempShuffleBlockIdPlusFile = blockManager.diskBlockManager().createTempShuffleBlock(); @@ -121,7 +121,7 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< partitionWriters[partitioner.getPartition(key)].write(key, record._2()); } - for (BlockObjectWriter writer : partitionWriters) { + for (DiskBlockObjectWriter writer : partitionWriters) { writer.commitAndClose(); } } @@ -169,7 +169,7 @@ final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter< if (partitionWriters != null) { try { final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); - for (BlockObjectWriter writer : partitionWriters) { + for (DiskBlockObjectWriter writer : partitionWriters) { // This method explicitly does _not_ throw exceptions: writer.revertPartialWritesAndClose(); if (!diskBlockManager.getFile(writer.blockId()).delete()) { http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 5628957..1d46043 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -157,7 +157,7 @@ final class UnsafeShuffleExternalSorter { // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this // after SPARK-5581 is fixed. - BlockObjectWriter writer; + DiskBlockObjectWriter writer; // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index b8d6665..71eed29 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -26,7 +26,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.BlockObjectWriter; +import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempLocalBlockId; import org.apache.spark.unsafe.PlatformDependent; @@ -47,7 +47,7 @@ final class UnsafeSorterSpillWriter { private final File file; private final BlockId blockId; private final int numRecordsToWrite; - private BlockObjectWriter writer; + private DiskBlockObjectWriter writer; private int numRecordsSpilled = 0; public UnsafeSorterSpillWriter( http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6c3b308..f6a96d8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVecto /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { - val writers: Array[BlockObjectWriter] + val writers: Array[DiskBlockObjectWriter] /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ def releaseWriters(success: Boolean) @@ -113,15 +113,15 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() - val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { + val writers: Array[DiskBlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, writeMetrics) } } else { - Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + Array.tabulate[DiskBlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) val blockFile = blockManager.diskBlockManager.getFile(blockId) // Because of previous failures, the shuffle file may already exist on this machine. http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d9c63b6..fae6955 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -114,7 +114,7 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB } private[spark] object IndexShuffleBlockResolver { - // No-op reduce ID used in interactions with disk store and BlockObjectWriter. + // No-op reduce ID used in interactions with disk store and DiskBlockObjectWriter. // The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort // shuffle outputs for several reduces are glommed into a single file. // TODO: Avoid this entirely by having the DiskBlockObjectWriter not require a BlockId. http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index eb87cee..41df70c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -22,7 +22,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle._ -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter private[spark] class HashShuffleWriter[K, V]( shuffleBlockResolver: FileShuffleBlockResolver, @@ -102,7 +102,7 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter => + val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter => writer.commitAndClose() writer.fileSegment().length } http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/storage/BlockManager.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1beafa1..8649367 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -648,7 +648,7 @@ private[spark] class BlockManager( file: File, serializerInstance: SerializerInstance, bufferSize: Int, - writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { + writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala deleted file mode 100644 index 7eeabd1..0000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} -import java.nio.channels.FileChannel - -import org.apache.spark.Logging -import org.apache.spark.serializer.{SerializerInstance, SerializationStream} -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.util.Utils - -/** - * An interface for writing JVM objects to some underlying storage. This interface allows - * appending data to an existing block, and can guarantee atomicity in the case of faults - * as it allows the caller to revert partial writes. - * - * This interface does not support concurrent writes. Also, once the writer has - * been opened, it cannot be reopened again. - */ -private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream { - - def open(): BlockObjectWriter - - def close() - - def isOpen: Boolean - - /** - * Flush the partial writes and commit them as a single atomic block. - */ - def commitAndClose(): Unit - - /** - * Reverts writes that haven't been flushed yet. Callers should invoke this function - * when there are runtime exceptions. This method will not throw, though it may be - * unsuccessful in truncating written data. - */ - def revertPartialWritesAndClose() - - /** - * Writes a key-value pair. - */ - def write(key: Any, value: Any) - - /** - * Notify the writer that a record worth of bytes has been written with OutputStream#write. - */ - def recordWritten() - - /** - * Returns the file segment of committed data that this Writer has written. - * This is only valid after commitAndClose() has been called. - */ - def fileSegment(): FileSegment -} - -/** - * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. - */ -private[spark] class DiskBlockObjectWriter( - blockId: BlockId, - file: File, - serializerInstance: SerializerInstance, - bufferSize: Int, - compressStream: OutputStream => OutputStream, - syncWrites: Boolean, - // These write metrics concurrently shared with other active BlockObjectWriter's who - // are themselves performing writes. All updates must be relative. - writeMetrics: ShuffleWriteMetrics) - extends BlockObjectWriter(blockId) - with Logging -{ - - /** The file channel, used for repositioning / truncating the file. */ - private var channel: FileChannel = null - private var bs: OutputStream = null - private var fos: FileOutputStream = null - private var ts: TimeTrackingOutputStream = null - private var objOut: SerializationStream = null - private var initialized = false - private var hasBeenClosed = false - private var commitAndCloseHasBeenCalled = false - - /** - * Cursors used to represent positions in the file. - * - * xxxxxxxx|--------|--- | - * ^ ^ ^ - * | | finalPosition - * | reportedPosition - * initialPosition - * - * initialPosition: Offset in the file where we start writing. Immutable. - * reportedPosition: Position at the time of the last update to the write metrics. - * finalPosition: Offset where we stopped writing. Set on closeAndCommit() then never changed. - * -----: Current writes to the underlying file. - * xxxxx: Existing contents of the file. - */ - private val initialPosition = file.length() - private var finalPosition: Long = -1 - private var reportedPosition = initialPosition - - /** - * Keep track of number of records written and also use this to periodically - * output bytes written since the latter is expensive to do for each record. - */ - private var numRecordsWritten = 0 - - override def open(): BlockObjectWriter = { - if (hasBeenClosed) { - throw new IllegalStateException("Writer already closed. Cannot be reopened.") - } - fos = new FileOutputStream(file, true) - ts = new TimeTrackingOutputStream(writeMetrics, fos) - channel = fos.getChannel() - bs = compressStream(new BufferedOutputStream(ts, bufferSize)) - objOut = serializerInstance.serializeStream(bs) - initialized = true - this - } - - override def close() { - if (initialized) { - Utils.tryWithSafeFinally { - if (syncWrites) { - // Force outstanding writes to disk and track how long it takes - objOut.flush() - val start = System.nanoTime() - fos.getFD.sync() - writeMetrics.incShuffleWriteTime(System.nanoTime() - start) - } - } { - objOut.close() - } - - channel = null - bs = null - fos = null - ts = null - objOut = null - initialized = false - hasBeenClosed = true - } - } - - override def isOpen: Boolean = objOut != null - - override def commitAndClose(): Unit = { - if (initialized) { - // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the - // serializer stream and the lower level stream. - objOut.flush() - bs.flush() - close() - finalPosition = file.length() - // In certain compression codecs, more bytes are written after close() is called - writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) - } else { - finalPosition = file.length() - } - commitAndCloseHasBeenCalled = true - } - - // Discard current writes. We do this by flushing the outstanding writes and then - // truncating the file to its initial position. - override def revertPartialWritesAndClose() { - try { - if (initialized) { - writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) - writeMetrics.decShuffleRecordsWritten(numRecordsWritten) - objOut.flush() - bs.flush() - close() - } - - val truncateStream = new FileOutputStream(file, true) - try { - truncateStream.getChannel.truncate(initialPosition) - } finally { - truncateStream.close() - } - } catch { - case e: Exception => - logError("Uncaught exception while reverting partial writes to file " + file, e) - } - } - - override def write(key: Any, value: Any) { - if (!initialized) { - open() - } - - objOut.writeKey(key) - objOut.writeValue(value) - recordWritten() - } - - override def write(b: Int): Unit = throw new UnsupportedOperationException() - - override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { - if (!initialized) { - open() - } - - bs.write(kvBytes, offs, len) - } - - override def recordWritten(): Unit = { - numRecordsWritten += 1 - writeMetrics.incShuffleRecordsWritten(1) - - if (numRecordsWritten % 32 == 0) { - updateBytesWritten() - } - } - - override def fileSegment(): FileSegment = { - if (!commitAndCloseHasBeenCalled) { - throw new IllegalStateException( - "fileSegment() is only valid after commitAndClose() has been called") - } - new FileSegment(file, initialPosition, finalPosition - initialPosition) - } - - /** - * Report the number of bytes written in this writer's shuffle write metrics. - * Note that this is only valid before the underlying streams are closed. - */ - private def updateBytesWritten() { - val pos = channel.position() - writeMetrics.incShuffleBytesWritten(pos - reportedPosition) - reportedPosition = pos - } - - // For testing - private[spark] override def flush() { - objOut.flush() - bs.flush() - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala new file mode 100644 index 0000000..49d9154 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} +import java.nio.channels.FileChannel + +import org.apache.spark.Logging +import org.apache.spark.serializer.{SerializerInstance, SerializationStream} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.util.Utils + +/** + * A class for writing JVM objects directly to a file on disk. This class allows data to be appended + * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to + * revert partial writes. + * + * This class does not support concurrent writes. Also, once the writer has been opened it cannot be + * reopened again. + */ +private[spark] class DiskBlockObjectWriter( + val blockId: BlockId, + file: File, + serializerInstance: SerializerInstance, + bufferSize: Int, + compressStream: OutputStream => OutputStream, + syncWrites: Boolean, + // These write metrics concurrently shared with other active DiskBlockObjectWriters who + // are themselves performing writes. All updates must be relative. + writeMetrics: ShuffleWriteMetrics) + extends OutputStream + with Logging { + + /** The file channel, used for repositioning / truncating the file. */ + private var channel: FileChannel = null + private var bs: OutputStream = null + private var fos: FileOutputStream = null + private var ts: TimeTrackingOutputStream = null + private var objOut: SerializationStream = null + private var initialized = false + private var hasBeenClosed = false + private var commitAndCloseHasBeenCalled = false + + /** + * Cursors used to represent positions in the file. + * + * xxxxxxxx|--------|--- | + * ^ ^ ^ + * | | finalPosition + * | reportedPosition + * initialPosition + * + * initialPosition: Offset in the file where we start writing. Immutable. + * reportedPosition: Position at the time of the last update to the write metrics. + * finalPosition: Offset where we stopped writing. Set on closeAndCommit() then never changed. + * -----: Current writes to the underlying file. + * xxxxx: Existing contents of the file. + */ + private val initialPosition = file.length() + private var finalPosition: Long = -1 + private var reportedPosition = initialPosition + + /** + * Keep track of number of records written and also use this to periodically + * output bytes written since the latter is expensive to do for each record. + */ + private var numRecordsWritten = 0 + + def open(): DiskBlockObjectWriter = { + if (hasBeenClosed) { + throw new IllegalStateException("Writer already closed. Cannot be reopened.") + } + fos = new FileOutputStream(file, true) + ts = new TimeTrackingOutputStream(writeMetrics, fos) + channel = fos.getChannel() + bs = compressStream(new BufferedOutputStream(ts, bufferSize)) + objOut = serializerInstance.serializeStream(bs) + initialized = true + this + } + + override def close() { + if (initialized) { + Utils.tryWithSafeFinally { + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + objOut.flush() + val start = System.nanoTime() + fos.getFD.sync() + writeMetrics.incShuffleWriteTime(System.nanoTime() - start) + } + } { + objOut.close() + } + + channel = null + bs = null + fos = null + ts = null + objOut = null + initialized = false + hasBeenClosed = true + } + } + + def isOpen: Boolean = objOut != null + + /** + * Flush the partial writes and commit them as a single atomic block. + */ + def commitAndClose(): Unit = { + if (initialized) { + // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the + // serializer stream and the lower level stream. + objOut.flush() + bs.flush() + close() + finalPosition = file.length() + // In certain compression codecs, more bytes are written after close() is called + writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + } else { + finalPosition = file.length() + } + commitAndCloseHasBeenCalled = true + } + + + /** + * Reverts writes that haven't been flushed yet. Callers should invoke this function + * when there are runtime exceptions. This method will not throw, though it may be + * unsuccessful in truncating written data. + */ + def revertPartialWritesAndClose() { + // Discard current writes. We do this by flushing the outstanding writes and then + // truncating the file to its initial position. + try { + if (initialized) { + writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) + writeMetrics.decShuffleRecordsWritten(numRecordsWritten) + objOut.flush() + bs.flush() + close() + } + + val truncateStream = new FileOutputStream(file, true) + try { + truncateStream.getChannel.truncate(initialPosition) + } finally { + truncateStream.close() + } + } catch { + case e: Exception => + logError("Uncaught exception while reverting partial writes to file " + file, e) + } + } + + /** + * Writes a key-value pair. + */ + def write(key: Any, value: Any) { + if (!initialized) { + open() + } + + objOut.writeKey(key) + objOut.writeValue(value) + recordWritten() + } + + override def write(b: Int): Unit = throw new UnsupportedOperationException() + + override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { + if (!initialized) { + open() + } + + bs.write(kvBytes, offs, len) + } + + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + def recordWritten(): Unit = { + numRecordsWritten += 1 + writeMetrics.incShuffleRecordsWritten(1) + + if (numRecordsWritten % 32 == 0) { + updateBytesWritten() + } + } + + /** + * Returns the file segment of committed data that this Writer has written. + * This is only valid after commitAndClose() has been called. + */ + def fileSegment(): FileSegment = { + if (!commitAndCloseHasBeenCalled) { + throw new IllegalStateException( + "fileSegment() is only valid after commitAndClose() has been called") + } + new FileSegment(file, initialPosition, finalPosition - initialPosition) + } + + /** + * Report the number of bytes written in this writer's shuffle write metrics. + * Note that this is only valid before the underlying streams are closed. + */ + private def updateBytesWritten() { + val pos = channel.position() + writeMetrics.incShuffleBytesWritten(pos - reportedPosition) + reportedPosition = pos + } + + // For testing + private[spark] override def flush() { + objOut.flush() + bs.flush() + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala index 516aaa4..ae60f3b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala @@ -37,7 +37,7 @@ private[spark] class ChainedBuffer(chunkSize: Int) { private var _size: Long = 0 /** - * Feed bytes from this buffer into a BlockObjectWriter. + * Feed bytes from this buffer into a DiskBlockObjectWriter. * * @param pos Offset in the buffer to read from. * @param os OutputStream to read into. http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 757dec6..ba7ec83 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -30,7 +30,7 @@ import org.apache.spark._ import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter} -import org.apache.spark.storage.{BlockId, BlockObjectWriter} +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -250,7 +250,7 @@ private[spark] class ExternalSorter[K, V, C]( // These variables are reset after each flush var objectsWritten: Long = 0 var spillMetrics: ShuffleWriteMetrics = null - var writer: BlockObjectWriter = null + var writer: DiskBlockObjectWriter = null def openWriter(): Unit = { assert (writer == null && spillMetrics == null) spillMetrics = new ShuffleWriteMetrics http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index 04bb7fc..f5844d5 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -19,7 +19,6 @@ package org.apache.spark.util.collection import java.util.Comparator -import org.apache.spark.storage.BlockObjectWriter import org.apache.spark.util.collection.WritablePartitionedPairCollection._ /** http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala index ae9a487..87a786b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala @@ -21,9 +21,8 @@ import java.io.InputStream import java.nio.IntBuffer import java.util.Comparator -import org.apache.spark.SparkEnv import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance} -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._ /** @@ -136,7 +135,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V]( // current position in the meta buffer in ints var pos = 0 - def writeNext(writer: BlockObjectWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { val keyStart = getKeyStartPos(metaBuffer, pos) val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN) pos += RECORD_SIZE http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index 7bc5989..38848e9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -19,7 +19,7 @@ package org.apache.spark.util.collection import java.util.Comparator -import org.apache.spark.storage.BlockObjectWriter +import org.apache.spark.storage.DiskBlockObjectWriter /** * A common interface for size-tracking collections of key-value pairs that @@ -51,7 +51,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { new WritablePartitionedIterator { private[this] var cur = if (it.hasNext) it.next() else null - def writeNext(writer: BlockObjectWriter): Unit = { + def writeNext(writer: DiskBlockObjectWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (it.hasNext) it.next() else null } @@ -91,11 +91,11 @@ private[spark] object WritablePartitionedPairCollection { } /** - * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element + * Iterator that writes elements to a DiskBlockObjectWriter instead of returning them. Each element * has an associated partition. */ private[spark] trait WritablePartitionedIterator { - def writeNext(writer: BlockObjectWriter): Unit + def writeNext(writer: DiskBlockObjectWriter): Unit def hasNext(): Boolean http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 542f8f4..cc7342f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -68,8 +68,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte any[SerializerInstance], anyInt(), any[ShuffleWriteMetrics] - )).thenAnswer(new Answer[BlockObjectWriter] { - override def answer(invocation: InvocationOnMock): BlockObjectWriter = { + )).thenAnswer(new Answer[DiskBlockObjectWriter] { + override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments new DiskBlockObjectWriter( args(0).asInstanceOf[BlockId], http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala deleted file mode 100644 index 7bdea72..0000000 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.storage - -import java.io.File - -import org.scalatest.BeforeAndAfterEach - -import org.apache.spark.SparkConf -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.util.Utils - -class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { - - var tempDir: File = _ - - override def beforeEach(): Unit = { - tempDir = Utils.createTempDir() - } - - override def afterEach(): Unit = { - Utils.deleteRecursively(tempDir) - } - - test("verify write metrics") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - - writer.write(Long.box(20), Long.box(30)) - // Record metrics update on every write - assert(writeMetrics.shuffleRecordsWritten === 1) - // Metrics don't update on every write - assert(writeMetrics.shuffleBytesWritten == 0) - // After 32 writes, metrics should update - for (i <- 0 until 32) { - writer.flush() - writer.write(Long.box(i), Long.box(i)) - } - assert(writeMetrics.shuffleBytesWritten > 0) - assert(writeMetrics.shuffleRecordsWritten === 33) - writer.commitAndClose() - assert(file.length() == writeMetrics.shuffleBytesWritten) - } - - test("verify write metrics on revert") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - - writer.write(Long.box(20), Long.box(30)) - // Record metrics update on every write - assert(writeMetrics.shuffleRecordsWritten === 1) - // Metrics don't update on every write - assert(writeMetrics.shuffleBytesWritten == 0) - // After 32 writes, metrics should update - for (i <- 0 until 32) { - writer.flush() - writer.write(Long.box(i), Long.box(i)) - } - assert(writeMetrics.shuffleBytesWritten > 0) - assert(writeMetrics.shuffleRecordsWritten === 33) - writer.revertPartialWritesAndClose() - assert(writeMetrics.shuffleBytesWritten == 0) - assert(writeMetrics.shuffleRecordsWritten == 0) - } - - test("Reopening a closed block writer") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - - writer.open() - writer.close() - intercept[IllegalStateException] { - writer.open() - } - } - - test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - for (i <- 1 to 1000) { - writer.write(i, i) - } - writer.commitAndClose() - val bytesWritten = writeMetrics.shuffleBytesWritten - assert(writeMetrics.shuffleRecordsWritten === 1000) - writer.revertPartialWritesAndClose() - assert(writeMetrics.shuffleRecordsWritten === 1000) - assert(writeMetrics.shuffleBytesWritten === bytesWritten) - } - - test("commitAndClose() should be idempotent") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - for (i <- 1 to 1000) { - writer.write(i, i) - } - writer.commitAndClose() - val bytesWritten = writeMetrics.shuffleBytesWritten - val writeTime = writeMetrics.shuffleWriteTime - assert(writeMetrics.shuffleRecordsWritten === 1000) - writer.commitAndClose() - assert(writeMetrics.shuffleRecordsWritten === 1000) - assert(writeMetrics.shuffleBytesWritten === bytesWritten) - assert(writeMetrics.shuffleWriteTime === writeTime) - } - - test("revertPartialWritesAndClose() should be idempotent") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - for (i <- 1 to 1000) { - writer.write(i, i) - } - writer.revertPartialWritesAndClose() - val bytesWritten = writeMetrics.shuffleBytesWritten - val writeTime = writeMetrics.shuffleWriteTime - assert(writeMetrics.shuffleRecordsWritten === 0) - writer.revertPartialWritesAndClose() - assert(writeMetrics.shuffleRecordsWritten === 0) - assert(writeMetrics.shuffleBytesWritten === bytesWritten) - assert(writeMetrics.shuffleWriteTime === writeTime) - } - - test("fileSegment() can only be called after commitAndClose() has been called") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - for (i <- 1 to 1000) { - writer.write(i, i) - } - intercept[IllegalStateException] { - writer.fileSegment() - } - writer.close() - } - - test("commitAndClose() without ever opening or writing") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - writer.commitAndClose() - assert(writer.fileSegment().length === 0) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala new file mode 100644 index 0000000..66af6e1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import java.io.File + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.util.Utils + +class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + var tempDir: File = _ + + override def beforeEach(): Unit = { + tempDir = Utils.createTempDir() + } + + override def afterEach(): Unit = { + Utils.deleteRecursively(tempDir) + } + + test("verify write metrics") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + + writer.write(Long.box(20), Long.box(30)) + // Record metrics update on every write + assert(writeMetrics.shuffleRecordsWritten === 1) + // Metrics don't update on every write + assert(writeMetrics.shuffleBytesWritten == 0) + // After 32 writes, metrics should update + for (i <- 0 until 32) { + writer.flush() + writer.write(Long.box(i), Long.box(i)) + } + assert(writeMetrics.shuffleBytesWritten > 0) + assert(writeMetrics.shuffleRecordsWritten === 33) + writer.commitAndClose() + assert(file.length() == writeMetrics.shuffleBytesWritten) + } + + test("verify write metrics on revert") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + + writer.write(Long.box(20), Long.box(30)) + // Record metrics update on every write + assert(writeMetrics.shuffleRecordsWritten === 1) + // Metrics don't update on every write + assert(writeMetrics.shuffleBytesWritten == 0) + // After 32 writes, metrics should update + for (i <- 0 until 32) { + writer.flush() + writer.write(Long.box(i), Long.box(i)) + } + assert(writeMetrics.shuffleBytesWritten > 0) + assert(writeMetrics.shuffleRecordsWritten === 33) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleBytesWritten == 0) + assert(writeMetrics.shuffleRecordsWritten == 0) + } + + test("Reopening a closed block writer") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + + writer.open() + writer.close() + intercept[IllegalStateException] { + writer.open() + } + } + + test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.commitAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + assert(writeMetrics.shuffleRecordsWritten === 1000) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleRecordsWritten === 1000) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + } + + test("commitAndClose() should be idempotent") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.commitAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + val writeTime = writeMetrics.shuffleWriteTime + assert(writeMetrics.shuffleRecordsWritten === 1000) + writer.commitAndClose() + assert(writeMetrics.shuffleRecordsWritten === 1000) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + assert(writeMetrics.shuffleWriteTime === writeTime) + } + + test("revertPartialWritesAndClose() should be idempotent") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + writer.revertPartialWritesAndClose() + val bytesWritten = writeMetrics.shuffleBytesWritten + val writeTime = writeMetrics.shuffleWriteTime + assert(writeMetrics.shuffleRecordsWritten === 0) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleRecordsWritten === 0) + assert(writeMetrics.shuffleBytesWritten === bytesWritten) + assert(writeMetrics.shuffleWriteTime === writeTime) + } + + test("fileSegment() can only be called after commitAndClose() has been called") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + for (i <- 1 to 1000) { + writer.write(i, i) + } + intercept[IllegalStateException] { + writer.fileSegment() + } + writer.close() + } + + test("commitAndClose() without ever opening or writing") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + writer.commitAndClose() + assert(writer.fileSegment().length === 0) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/d267c283/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala index 6d2459d..3b67f62 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala @@ -17,15 +17,20 @@ package org.apache.spark.util.collection -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import com.google.common.io.ByteStreams +import org.mockito.Matchers.any +import org.mockito.Mockito._ +import org.mockito.Mockito.RETURNS_SMART_NULLS +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.Matchers._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.storage.{FileSegment, BlockObjectWriter} +import org.apache.spark.storage.DiskBlockObjectWriter class PartitionedSerializedPairBufferSuite extends SparkFunSuite { test("OrderedInputStream single record") { @@ -79,13 +84,13 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { val struct = SomeStruct("something", 5) buffer.insert(4, 10, struct) val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val writer = new SimpleBlockObjectWriter + val (writer, baos) = createMockWriter() assert(it.hasNext) it.nextPartition should be (4) it.writeNext(writer) assert(!it.hasNext) - val stream = serializerInstance.deserializeStream(writer.getInputStream) + val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) stream.readObject[AnyRef]() should be (10) stream.readObject[AnyRef]() should be (struct) } @@ -101,7 +106,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { buffer.insert(5, 3, struct3) val it = buffer.destructiveSortedWritablePartitionedIterator(None) - val writer = new SimpleBlockObjectWriter + val (writer, baos) = createMockWriter() assert(it.hasNext) it.nextPartition should be (4) it.writeNext(writer) @@ -113,7 +118,7 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { it.writeNext(writer) assert(!it.hasNext) - val stream = serializerInstance.deserializeStream(writer.getInputStream) + val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray)) val iter = stream.asIterator iter.next() should be (2) iter.next() should be (struct2) @@ -123,26 +128,21 @@ class PartitionedSerializedPairBufferSuite extends SparkFunSuite { iter.next() should be (struct1) assert(!iter.hasNext) } -} - -case class SomeStruct(val str: String, val num: Int) - -class SimpleBlockObjectWriter extends BlockObjectWriter(null) { - val baos = new ByteArrayOutputStream() - override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = { - baos.write(bytes, offs, len) + def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = { + val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS) + val baos = new ByteArrayOutputStream() + when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + val args = invocationOnMock.getArguments + val bytes = args(0).asInstanceOf[Array[Byte]] + val offset = args(1).asInstanceOf[Int] + val length = args(2).asInstanceOf[Int] + baos.write(bytes, offset, length) + } + }) + (writer, baos) } - - def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray) - - override def open(): BlockObjectWriter = this - override def close(): Unit = { } - override def isOpen: Boolean = true - override def commitAndClose(): Unit = { } - override def revertPartialWritesAndClose(): Unit = { } - override def fileSegment(): FileSegment = null - override def write(key: Any, value: Any): Unit = { } - override def recordWritten(): Unit = { } - override def write(b: Int): Unit = { } } + +case class SomeStruct(str: String, num: Int) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org