spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From andrewor14 <...@git.apache.org>
Subject [GitHub] spark pull request: SPARK-2045 Sort-based shuffle
Date Wed, 30 Jul 2014 01:42:52 GMT
Github user andrewor14 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/1499#discussion_r15563349
  
    --- Diff: core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala ---
    @@ -0,0 +1,667 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.util.collection
    +
    +import java.io._
    +import java.util.Comparator
    +
    +import scala.collection.mutable.ArrayBuffer
    +import scala.collection.mutable
    +
    +import com.google.common.io.ByteStreams
    +
    +import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
    +import org.apache.spark.serializer.Serializer
    +import org.apache.spark.storage.BlockId
    +
    +/**
    + * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce
key-combiner
    + * pairs of type (K, C). Uses a Partitioner to first group the keys into partitions,
and then
    + * optionally sorts keys within each partition using a custom Comparator. Can output
a single
    + * partitioned file with a different byte range for each partition, suitable for shuffle
fetches.
    + *
    + * If combining is disabled, the type C must equal V -- we'll cast the objects at the
end.
    + *
    + * @param aggregator optional Aggregator with combine functions to use for merging data
    + * @param partitioner optional Partitioner; if given, sort by partition ID and then key
    + * @param ordering optional Ordering to sort keys within each partition; should be a
total ordering
    + * @param serializer serializer to use when spilling to disk
    + *
    + * Note that if an Ordering is given, we'll always sort using it, so only provide it
if you really
    + * want the output keys to be sorted. In a map task without map-side combine for example,
you
    + * probably want to pass None as the ordering to avoid extra sorting. On the other hand,
if you do
    + * want to do combining, having an Ordering is more efficient than not having it.
    + *
    + * At a high level, this class works as follows:
    + *
    + * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap
if
    + *   we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these
buffers,
    + *   we sort elements of type ((Int, K), C) where the Int is the partition ID. This is
done to
    + *   avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner).
    + *
    + * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted
first
    + *   by partition ID and possibly second by key or by hash code of the key, if we want
to do
    + *   aggregation. For each file, we track how many objects were in each partition in
memory, so we
    + *   don't have to write out the partition ID for every element.
    + *
    + * - When the user requests an iterator, the spilled files are merged, along with any
remaining
    + *   in-memory data, using the same sort order defined above (unless both sorting and
aggregation
    + *   are disabled). If we need to aggregate by key, we either use a total ordering from
the
    + *   ordering parameter, or read the keys with the same hash code and compare them with
each other
    + *   for equality to merge values.
    + *
    + * - Users are expected to call stop() at the end to delete all the intermediate files.
    + */
    +private[spark] class ExternalSorter[K, V, C](
    +    aggregator: Option[Aggregator[K, V, C]] = None,
    +    partitioner: Option[Partitioner] = None,
    +    ordering: Option[Ordering[K]] = None,
    +    serializer: Option[Serializer] = None) extends Logging {
    +
    +  private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
    +  private val shouldPartition = numPartitions > 1
    +
    +  private val blockManager = SparkEnv.get.blockManager
    +  private val diskBlockManager = blockManager.diskBlockManager
    +  private val ser = Serializer.getSerializer(serializer)
    +  private val serInstance = ser.newInstance()
    +
    +  private val conf = SparkEnv.get.conf
    +  private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
    +  private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
    +  private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
    +
    +  private def getPartition(key: K): Int = {
    +    if (shouldPartition) partitioner.get.getPartition(key) else 0
    +  }
    +
    +  // Data structures to store in-memory objects before we spill. Depending on whether
we have an
    +  // Aggregator set, we either put objects into an AppendOnlyMap where we combine them,
or we
    +  // store them in an array buffer.
    +  var map = new SizeTrackingAppendOnlyMap[(Int, K), C]
    +  var buffer = new SizeTrackingPairBuffer[(Int, K), C]
    +
    +  // Number of pairs read from input since last spill; note that we count them even if
a value is
    +  // merged with a previous key in case we're doing something like groupBy where the
result grows
    +  private var elementsRead = 0L
    +
    +  // What threshold of elementsRead we start estimating map size at.
    +  private val trackMemoryThreshold = 1000
    +
    +  // Spilling statistics
    +  private var spillCount = 0
    +  private var _memoryBytesSpilled = 0L
    +  private var _diskBytesSpilled = 0L
    +
    +  // Collective memory threshold shared across all running tasks
    +  private val maxMemoryThreshold = {
    +    val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
    +    val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
    +    (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
    +  }
    +
    +  // How much of the shared memory pool this collection has claimed
    +  private var myMemoryThreshold = 0L
    +
    +  // A comparator for keys K that orders them within a partition to allow partial aggregation.
    +  // Can be a partial ordering by hash code if a total ordering is not provided through
by the
    +  // user. (A partial ordering means that equal keys have comparator.compare(k, k) =
0, but some
    +  // non-equal keys also have this, so we need to do a later pass to find truly equal
keys).
    +  // Note that we ignore this if no aggregator is given.
    +  private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
    +    override def compare(a: K, b: K): Int = {
    +      val h1 = if (a == null) 0 else a.hashCode()
    +      val h2 = if (b == null) 0 else b.hashCode()
    +      h1 - h2
    +    }
    +  })
    +
    +  // A comparator for (Int, K) elements that orders them by partition and then possibly
by key
    +  private val partitionKeyComparator: Comparator[(Int, K)] = {
    +    if (ordering.isDefined || aggregator.isDefined) {
    +      // Sort by partition ID then key comparator
    +      new Comparator[(Int, K)] {
    +        override def compare(a: (Int, K), b: (Int, K)): Int = {
    +          val partitionDiff = a._1 - b._1
    +          if (partitionDiff != 0) {
    +            partitionDiff
    +          } else {
    +            keyComparator.compare(a._2, b._2)
    +          }
    +        }
    +      }
    +    } else {
    +      // Just sort it by partition ID
    +      new Comparator[(Int, K)] {
    +        override def compare(a: (Int, K), b: (Int, K)): Int = {
    +          a._1 - b._1
    +        }
    +      }
    +    }
    +  }
    +
    +  // Information about a spilled file. Includes sizes in bytes of "batches" written by
the
    +  // serializer as we periodically reset its stream, as well as number of elements in
each
    +  // partition, used to efficiently keep track of partitions when merging.
    +  private[this] case class SpilledFile(
    +    file: File,
    +    blockId: BlockId,
    +    serializerBatchSizes: Array[Long],
    +    elementsPerPartition: Array[Long])
    +  private val spills = new ArrayBuffer[SpilledFile]
    +
    +  def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
    +    // TODO: stop combining if we find that the reduction factor isn't high
    +    val shouldCombine = aggregator.isDefined
    +
    +    if (shouldCombine) {
    +      // Combine values in-memory first using our AppendOnlyMap
    +      val mergeValue = aggregator.get.mergeValue
    +      val createCombiner = aggregator.get.createCombiner
    +      var kv: Product2[K, V] = null
    +      val update = (hadValue: Boolean, oldValue: C) => {
    +        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
    +      }
    +      while (records.hasNext) {
    +        elementsRead += 1
    +        kv = records.next()
    +        map.changeValue((getPartition(kv._1), kv._1), update)
    +        maybeSpill(usingMap = true)
    +      }
    +    } else {
    +      // Stick values into our buffer
    +      while (records.hasNext) {
    +        elementsRead += 1
    +        val kv = records.next()
    +        buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
    +        maybeSpill(usingMap = false)
    +      }
    +    }
    +  }
    +
    +  private def maybeSpill(usingMap: Boolean): Unit = {
    +    if (!spillingEnabled) {
    +      return
    +    }
    +
    +    val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else
buffer
    +
    +    if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
    +        collection.estimateSize() >= myMemoryThreshold)
    +    {
    +      // TODO: This logic doesn't work if there are two external collections being used
in the same
    +      // task (e.g. to read shuffle output and write it out into another shuffle).
    +
    +      val currentSize = collection.estimateSize()
    +      var shouldSpill = false
    +      val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
    +
    +      // Atomically check whether there is sufficient memory in the global pool for
    +      // us to double our threshold
    +      shuffleMemoryMap.synchronized {
    +        val threadId = Thread.currentThread().getId
    +        val previouslyClaimedMemory = shuffleMemoryMap.get(threadId)
    +        val availableMemory = maxMemoryThreshold -
    +          (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L))
    +
    +        // Try to allocate at least 2x more memory, otherwise spill
    +        shouldSpill = availableMemory < currentSize * 2
    +        if (!shouldSpill) {
    +          shuffleMemoryMap(threadId) = currentSize * 2
    +          myMemoryThreshold = currentSize * 2
    +        }
    +      }
    +      // Do not hold lock during spills
    +      if (shouldSpill) {
    +        spill(currentSize, usingMap)
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Spill the current in-memory collection to disk, adding a new file to spills, and
clear it.
    +   *
    +   * @param usingMap whether we're using a map or buffer as our current in-memory collection
    +   */
    +  private def spill(memorySize: Long, usingMap: Boolean): Unit = {
    +    val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else
buffer
    +    val memorySize = collection.estimateSize()
    +
    +    spillCount += 1
    +    logWarning("Spilling in-memory batch of %d MB to disk (%d spill%s so far)"
    +      .format(memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else
""))
    +    val (blockId, file) = diskBlockManager.createTempBlock()
    +    var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
    +    var objectsWritten = 0   // Objects written since the last flush
    +
    +    // List of batch sizes (bytes) in the order they are written to disk
    +    val batchSizes = new ArrayBuffer[Long]
    +
    +    // How many elements we have in each partition
    +    val elementsPerPartition = new Array[Long](numPartitions)
    +
    +    // Flush the disk writer's contents to disk, and update relevant variables
    +    def flush() = {
    +      writer.commit()
    +      val bytesWritten = writer.bytesWritten
    +      batchSizes.append(bytesWritten)
    +      _diskBytesSpilled += bytesWritten
    +    }
    +
    +    try {
    +      val it = collection.destructiveSortedIterator(partitionKeyComparator)
    +      while (it.hasNext) {
    +        val elem = it.next()
    +        val partitionId = elem._1._1
    +        val key = elem._1._2
    +        val value = elem._2
    +        writer.write(key)
    +        writer.write(value)
    +        elementsPerPartition(partitionId) += 1
    +        objectsWritten += 1
    +
    +        if (objectsWritten == serializerBatchSize) {
    +          flush()
    +          objectsWritten = 0
    +          writer.close()
    +          writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
    +        }
    +      }
    +      if (objectsWritten > 0) {
    +        flush()
    +      }
    +      writer.close()
    +    } catch {
    +      case e: Exception =>
    +        writer.close()
    +        file.delete()
    +        throw e
    +    }
    +
    +    if (usingMap) {
    +      map = new SizeTrackingAppendOnlyMap[(Int, K), C]
    +    } else {
    +      buffer = new SizeTrackingPairBuffer[(Int, K), C]
    +    }
    +
    +    // Reset the amount of shuffle memory used by this map in the global pool
    +    val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
    +    shuffleMemoryMap.synchronized {
    +      shuffleMemoryMap(Thread.currentThread().getId) = 0
    +    }
    +    myMemoryThreshold = 0
    +
    +    spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
    +    _memoryBytesSpilled += memorySize
    +  }
    +
    +  /**
    +   * Merge a sequence of sorted files, giving an iterator over partitions and then over
elements
    +   * inside each partition. This can be used to either write out a new file or return
data to
    +   * the user.
    +   *
    +   * Returns an iterator over all the data written to this object, grouped by partition.
For each
    +   * partition we then have an iterator over its contents, and these are expected to
be accessed
    +   * in order (you can't "skip ahead" to one partition without reading the previous one).
    +   * Guaranteed to return a key-value pair for each partition, in order of partition
ID.
    +   */
    +  private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
    +      : Iterator[(Int, Iterator[Product2[K, C]])] = {
    +    val readers = spills.map(new SpillReader(_))
    +    val inMemBuffered = inMemory.buffered
    +    (0 until numPartitions).iterator.map { p =>
    +      val inMemIterator = new IteratorForPartition(p, inMemBuffered)
    +      val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
    +      if (aggregator.isDefined) {
    +        // Perform partial aggregation across partitions
    +        (p, mergeWithAggregation(
    +          iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
    +      } else if (ordering.isDefined) {
    +        // No aggregator given, but we have an ordering (e.g. used by reduce tasks in
sortByKey);
    +        // sort the elements without trying to merge them
    +        (p, mergeSort(iterators, ordering.get))
    +      } else {
    +        (p, iterators.iterator.flatten)
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys.
    +   */
    +  private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
    +      : Iterator[Product2[K, C]] =
    +  {
    +    val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
    +    type Iter = BufferedIterator[Product2[K, C]]
    +    val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
    +      override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1)
    +    })
    +    heap.enqueue(bufferedIters: _*)
    +    new Iterator[Product2[K, C]] {
    +      override def hasNext: Boolean = !heap.isEmpty
    +
    +      override def next(): Product2[K, C] = {
    +        if (!hasNext) {
    +          throw new NoSuchElementException
    +        }
    +        val firstBuf = heap.dequeue()
    +        val firstPair = firstBuf.next()
    +        if (firstBuf.hasNext) {
    +          heap.enqueue(firstBuf)
    +        }
    +        firstPair
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming
that each
    +   * iterator is sorted by key with a given comparator. If the comparator is not a total
ordering
    +   * (e.g. when we sort objects by hash code and different keys may compare as equal
although
    +   * they're not), we still merge them by doing equality tests for all keys that compare
as equal.
    +   */
    +  private def mergeWithAggregation(
    +      iterators: Seq[Iterator[Product2[K, C]]],
    +      mergeCombiners: (C, C) => C,
    +      comparator: Comparator[K],
    +      totalOrder: Boolean)
    +      : Iterator[Product2[K, C]] =
    +  {
    +    if (!totalOrder) {
    +      // We only have a partial ordering, e.g. comparing the keys by hash code, which
means that
    +      // multiple distinct keys might be treated as equal by the ordering. To deal with
this, we
    +      // need to read all keys considered equal by the ordering at once and compare them.
    +      new Iterator[Iterator[Product2[K, C]]] {
    +        val sorted = mergeSort(iterators, comparator).buffered
    +
    +        // Buffers reused across elements to decrease memory allocation
    +        val keys = new ArrayBuffer[K]
    +        val combiners = new ArrayBuffer[C]
    +
    +        override def hasNext: Boolean = sorted.hasNext
    +
    +        override def next(): Iterator[Product2[K, C]] = {
    +          if (!hasNext) {
    +            throw new NoSuchElementException
    +          }
    +          keys.clear()
    +          combiners.clear()
    +          val firstPair = sorted.next()
    +          keys += firstPair._1
    +          combiners += firstPair._2
    +          val key = firstPair._1
    +          while (sorted.hasNext && comparator.compare(sorted.head._1, key) ==
0) {
    +            val pair = sorted.next()
    +            var i = 0
    +            var foundKey = false
    +            while (i < keys.size && !foundKey) {
    +              if (keys(i) == pair._1) {
    +                combiners(i) = mergeCombiners(combiners(i), pair._2)
    +                foundKey = true
    +              }
    +              i += 1
    +            }
    +            if (!foundKey) {
    +              keys += pair._1
    +              combiners += pair._2
    +            }
    +          }
    +
    +          // Note that we return an iterator of elements since we could've had many keys
marked
    +          // equal by the partial order; we flatten this below to get a flat iterator
of (K, C).
    +          keys.iterator.zip(combiners.iterator)
    +        }
    +      }.flatMap(i => i)
    +    } else {
    +      // We have a total ordering. This means we can merge objects one by one as we read
them
    +      // from the iterators, without buffering all the ones that are "equal" to a given
key.
    +      // We do so with code similar to mergeSort, except our Iterator.next combines together
all
    +      // the elements with the given key.
    +      val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
    +      type Iter = BufferedIterator[Product2[K, C]]
    +      val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
    +        override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1,
y.head._1)
    +      })
    +      heap.enqueue(bufferedIters: _*)
    +      new Iterator[Product2[K, C]] {
    +        override def hasNext: Boolean = !heap.isEmpty
    +
    +        override def next(): Product2[K, C] = {
    +          if (!hasNext) {
    +            throw new NoSuchElementException
    +          }
    +          val firstBuf = heap.dequeue()
    +          val firstPair = firstBuf.next()
    +          val k = firstPair._1
    +          var c = firstPair._2
    +          if (firstBuf.hasNext) {
    +            heap.enqueue(firstBuf)
    +          }
    +          var shouldStop = false
    +          while (!heap.isEmpty && !shouldStop) {
    +            shouldStop = true  // Stop unless we find another element with the same key
    +            val newBuf = heap.dequeue()
    +            while (newBuf.hasNext && newBuf.head._1 == k) {
    +              val elem = newBuf.next()
    +              c = mergeCombiners(c, elem._2)
    +              shouldStop = false
    +            }
    +            if (newBuf.hasNext) {
    +              heap.enqueue(newBuf)
    +            }
    +          }
    --- End diff --
    
    As you said in the comments, this is very similar to your code in `mergeSort`. It'll probably
be cleaner to abstract this and add a `shoulCombine: Boolean` field in `mergeSort` for the
extra logic here.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message