Github user zsxwing commented on a diff in the pull request:
https://github.com/apache/spark/pull/9373#discussion_r44601858
--- Diff: streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
---
@@ -198,6 +197,45 @@ class FileBasedWriteAheadLogSuite
import WriteAheadLogSuite._
+ test("FileBasedWriteAheadLog - seqToParIterator") {
+ /*
+ If the setting `closeFileAfterWrite` is enabled, we start generating a very large
number of
+ files. This causes recovery to take a very long time. In order to make it quicker,
we
+ parallelized the reading of these files. This test makes sure that we limit the
number of
+ open files to the size of the number of threads in our thread pool rather than the
size of
+ the list of files.
+ */
+ val numThreads = 8
+ val tpool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "wal-test-thread-pool")
+ class GetMaxCounter {
+ private val value = new AtomicInteger()
+ @volatile private var max: Int = 0
+ def increment(): Unit = synchronized {
+ val atInstant = value.incrementAndGet()
+ if (atInstant > max) max = atInstant
+ }
+ def decrement(): Unit = synchronized { value.decrementAndGet() }
+ def get(): Int = synchronized { value.get() }
+ def getMax(): Int = synchronized { max }
+ }
+ try {
+ val testSeq = 1 to 64
+ val counter = new GetMaxCounter()
+ def handle(value: Int): Iterator[Int] = {
+ new CompletionIterator[Int, Iterator[Int]](Iterator(value)) {
+ counter.increment()
+ override def completion() { counter.decrement() }
+ }
+ }
+ val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](tpool, testSeq,
handle)
+ assert(iterator.toSeq === testSeq)
+ assert(counter.getMax() > 1) // make sure we are doing a parallel computation!
--- End diff --
Here is the code we discussed to fix:
```Scala
try {
val latch = new CountDownLatch(1)
val testSeq = 1 to 1000
val counter = new GetMaxCounter()
def handle(value: Int): Iterator[Int] = {
new CompletionIterator[Int, Iterator[Int]](Iterator(value)) {
counter.increment()
latch.await(10, TimeUnit.SECONDS)
override def completion() { counter.decrement() }
}
}
@volatile var collected: Seq[Int] = Nil
val t = new Thread() {
override def run() {
val iterator = FileBasedWriteAheadLog.seqToParIterator[Int, Int](tpool, testSeq,
handle)
collected = iterator.toSeq
}
}
t.start()
eventually(Eventually.timeout(10.seconds)) {
// make sure we are doing a parallel computation!
assert(counter.getMax() > 1)
}
latch.countDown()
t.join(10000)
assert(collected === testSeq)
// make sure we didn't open too many Iterators
assert(counter.getMax() <= numThreads)
} finally {
tpool.shutdownNow()
}
```
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org
|