From fef4f067fe5d70ffa739afed0f829b534130d459 Mon Sep 17 00:00:00 2001 From: Jacek Kunicki Date: Tue, 10 Oct 2023 16:08:26 +0200 Subject: [PATCH] Add Source.mapParUnordered combinator --- .../main/scala/ox/channels/SourceOps.scala | 29 +++- .../SourceOpsMapParUnorderedTest.scala | 137 ++++++++++++++++++ 2 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 core/src/test/scala/ox/channels/SourceOpsMapParUnorderedTest.scala diff --git a/core/src/main/scala/ox/channels/SourceOps.scala b/core/src/main/scala/ox/channels/SourceOps.scala index a98ccf45..6494193b 100644 --- a/core/src/main/scala/ox/channels/SourceOps.scala +++ b/core/src/main/scala/ox/channels/SourceOps.scala @@ -110,7 +110,34 @@ trait SourceOps[+T] { this: Source[T] => closeScope.await() } - def mapParUnordered[U](parallelism: Int)(f: T => U)(using Ox, StageCapacity): Source[U] = ??? // TODO + def mapParUnordered[U](parallelism: Int)(f: T => U)(using Ox, StageCapacity): Source[U] = + val c = StageCapacity.newChannel[U] + val s = new Semaphore(parallelism) + forkDaemon { + supervised { + repeatWhile { + s.acquire() + receive() match + case ChannelClosed.Done => false + case e @ ChannelClosed.Error(r) => + c.error(r) + throw e.toThrowable + case t: T @unchecked => + fork { + try + c.send(f(t)) + s.release() + catch + case t: Throwable => + c.error(t) + throw t + } + true + } + } + c.done() + } + c def take(n: Int)(using Ox, StageCapacity): Source[T] = transform(_.take(n)) diff --git a/core/src/test/scala/ox/channels/SourceOpsMapParUnorderedTest.scala b/core/src/test/scala/ox/channels/SourceOpsMapParUnorderedTest.scala new file mode 100644 index 00000000..378c3f9e --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsMapParUnorderedTest.scala @@ -0,0 +1,137 @@ +package ox.channels + +import org.scalatest.concurrent.Eventually +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* +import ox.util.Trail + +import java.util.concurrent.atomic.AtomicInteger + +class SourceOpsMapParUnorderedTest extends AnyFlatSpec with Matchers with Eventually { + + behavior of "Source.mapParUnordered" + + for (parallelism <- 1 to 10) { + it should s"map over a source with parallelism limit $parallelism" in scoped { + // given + val s = Source.fromIterable(1 to 10) + val running = new AtomicInteger(0) + val maxRunning = new AtomicInteger(0) + + def f(i: Int) = + running.incrementAndGet() + try + Thread.sleep(100) + i * 2 + finally running.decrementAndGet() + + // update max running + fork { + var max = 0 + forever { + max = math.max(max, running.get()) + maxRunning.set(max) + Thread.sleep(10) + } + } + + // when + val result = s.mapParUnordered(parallelism)(f).toList + + // then + result should contain theSameElementsAs List(2, 4, 6, 8, 10, 12, 14, 16, 18, 20) + maxRunning.get() shouldBe parallelism + } + } + + it should s"map over a source with parallelism limit 10 (stress test)" in scoped { + for (i <- 1 to 100) { + info(s"iteration $i") + + // given + val s = Source.fromIterable(1 to 10) + + def f(i: Int) = + Thread.sleep(50) + i * 2 + + // when + val result = s.mapParUnordered(10)(f).toList + + // then + result should contain theSameElementsAs List(2, 4, 6, 8, 10, 12, 14, 16, 18, 20) + } + } + + it should "propagate errors" in scoped { + // given + val s = Source.fromIterable(1 to 10) + val started = new AtomicInteger() + + // when + val s2 = s.mapParUnordered(3) { i => + started.incrementAndGet() + if i > 4 then throw new Exception("boom") + i * 2 + } + + // then + try + s2.toList + fail("should have thrown") + catch + case ChannelClosedException.Error(Some(reason)) if reason.getMessage == "boom" => + started.get() should be >= 4 + started.get() should be <= 7 // 4 successful + at most 3 taking up all the permits + } + + it should "cancel other running forks when there's an error" in scoped { + // given + val trail = Trail() + val s = Source.fromIterable(1 to 10) + + // when + val s2 = s.mapParUnordered(2) { i => + if i == 4 then + Thread.sleep(100) + trail.add("exception") + throw new Exception("boom") + else + Thread.sleep(200) + trail.add(s"done") + i * 2 + } + + // then + List(s2.receive(), s2.receive()) should contain only (2, 4) + s2.receive() should matchPattern { case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => } + s2.isError shouldBe true + + // checking if the forks aren't left running + Thread.sleep(200) + trail.get shouldBe Vector("done", "done", "exception") + } + + it should "emit downstream as soon as a value is ready, regardless of the incoming order" in scoped { + // given + val s = Source.fromIterable(1 to 5) + val delays = Map( + 1 -> 100, + 2 -> 10, + 3 -> 50, + 4 -> 500, + 5 -> 200 + ) + val expectedElements = delays.toList.sortBy(_._2).map(_._1) + + // when + val s2 = s.mapParUnordered(5) { i => + Thread.sleep(delays(i)) + i + } + + // then + s2.toList should contain theSameElementsInOrderAs expectedElements + } +}