Skip to content

Commit

Permalink
Merge pull request #17 from softwaremill/source-map-par-unordered
Browse files Browse the repository at this point in the history
Add Source.mapParUnordered combinator
  • Loading branch information
adamw authored Oct 10, 2023
2 parents f9adb7d + fef4f06 commit 161557a
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 1 deletion.
29 changes: 28 additions & 1 deletion core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,34 @@ trait SourceOps[+T] { this: Source[T] =>
closeScope.await()
}

def mapParUnordered[U](parallelism: Int)(f: T => U)(using Ox, StageCapacity): Source[U] = ??? // TODO
def mapParUnordered[U](parallelism: Int)(f: T => U)(using Ox, StageCapacity): Source[U] =
val c = StageCapacity.newChannel[U]
val s = new Semaphore(parallelism)
forkDaemon {
supervised {
repeatWhile {
s.acquire()
receive() match
case ChannelClosed.Done => false
case e @ ChannelClosed.Error(r) =>
c.error(r)
throw e.toThrowable
case t: T @unchecked =>
fork {
try
c.send(f(t))
s.release()
catch
case t: Throwable =>
c.error(t)
throw t
}
true
}
}
c.done()
}
c

def take(n: Int)(using Ox, StageCapacity): Source[T] = transform(_.take(n))

Expand Down
137 changes: 137 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsMapParUnorderedTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package ox.channels

import org.scalatest.concurrent.Eventually
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*
import ox.util.Trail

import java.util.concurrent.atomic.AtomicInteger

class SourceOpsMapParUnorderedTest extends AnyFlatSpec with Matchers with Eventually {

behavior of "Source.mapParUnordered"

for (parallelism <- 1 to 10) {
it should s"map over a source with parallelism limit $parallelism" in scoped {
// given
val s = Source.fromIterable(1 to 10)
val running = new AtomicInteger(0)
val maxRunning = new AtomicInteger(0)

def f(i: Int) =
running.incrementAndGet()
try
Thread.sleep(100)
i * 2
finally running.decrementAndGet()

// update max running
fork {
var max = 0
forever {
max = math.max(max, running.get())
maxRunning.set(max)
Thread.sleep(10)
}
}

// when
val result = s.mapParUnordered(parallelism)(f).toList

// then
result should contain theSameElementsAs List(2, 4, 6, 8, 10, 12, 14, 16, 18, 20)
maxRunning.get() shouldBe parallelism
}
}

it should s"map over a source with parallelism limit 10 (stress test)" in scoped {
for (i <- 1 to 100) {
info(s"iteration $i")

// given
val s = Source.fromIterable(1 to 10)

def f(i: Int) =
Thread.sleep(50)
i * 2

// when
val result = s.mapParUnordered(10)(f).toList

// then
result should contain theSameElementsAs List(2, 4, 6, 8, 10, 12, 14, 16, 18, 20)
}
}

it should "propagate errors" in scoped {
// given
val s = Source.fromIterable(1 to 10)
val started = new AtomicInteger()

// when
val s2 = s.mapParUnordered(3) { i =>
started.incrementAndGet()
if i > 4 then throw new Exception("boom")
i * 2
}

// then
try
s2.toList
fail("should have thrown")
catch
case ChannelClosedException.Error(Some(reason)) if reason.getMessage == "boom" =>
started.get() should be >= 4
started.get() should be <= 7 // 4 successful + at most 3 taking up all the permits
}

it should "cancel other running forks when there's an error" in scoped {
// given
val trail = Trail()
val s = Source.fromIterable(1 to 10)

// when
val s2 = s.mapParUnordered(2) { i =>
if i == 4 then
Thread.sleep(100)
trail.add("exception")
throw new Exception("boom")
else
Thread.sleep(200)
trail.add(s"done")
i * 2
}

// then
List(s2.receive(), s2.receive()) should contain only (2, 4)
s2.receive() should matchPattern { case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => }
s2.isError shouldBe true

// checking if the forks aren't left running
Thread.sleep(200)
trail.get shouldBe Vector("done", "done", "exception")
}

it should "emit downstream as soon as a value is ready, regardless of the incoming order" in scoped {
// given
val s = Source.fromIterable(1 to 5)
val delays = Map(
1 -> 100,
2 -> 10,
3 -> 50,
4 -> 500,
5 -> 200
)
val expectedElements = delays.toList.sortBy(_._2).map(_._1)

// when
val s2 = s.mapParUnordered(5) { i =>
Thread.sleep(delays(i))
i
}

// then
s2.toList should contain theSameElementsInOrderAs expectedElements
}
}

0 comments on commit 161557a

Please sign in to comment.