diff --git a/core/src/main/scala/ox/channels/SourceOps.scala b/core/src/main/scala/ox/channels/SourceOps.scala index 6494193b..5872c5cf 100644 --- a/core/src/main/scala/ox/channels/SourceOps.scala +++ b/core/src/main/scala/ox/channels/SourceOps.scala @@ -2,7 +2,7 @@ package ox.channels import ox.* -import java.util.concurrent.{ArrayBlockingQueue, ConcurrentLinkedQueue, CountDownLatch, LinkedBlockingQueue, Semaphore} +import java.util.concurrent.{CountDownLatch, Semaphore} import scala.collection.mutable import scala.concurrent.duration.FiniteDuration @@ -194,6 +194,46 @@ trait SourceOps[+T] { this: Source[T] => } c + /** Combines elements from this and other sources into tuples handling early completion of either source with defaults. + * + * @param other + * A source of elements to be combined with. + * @param thisDefault + * A default element to be used in the result tuple when the other source is longer. + * @param otherDefault + * A default element to be used in the result tuple when the current source is longer. + * @example + * {{{ + * import ox.* + * import ox.channels.Source + * + * scoped { + * Source.empty[Int].zipAll(Source.empty[String], -1, "foo").toList // List() + * Source.empty[Int].zipAll(Source.fromValues("a"), -1, "foo").toList // List((-1, "a")) + * Source.fromValues(1).zipAll(Source.empty[String], -1, "foo").toList // List((1, "foo")) + * Source.fromValues(1).zipAll(Source.fromValues("a"), -1, "foo").toList // List((1, "a")) + * } + * }}} + */ + def zipAll[U >: T, V](other: Source[V], thisDefault: U, otherDefault: V)(using Ox, StageCapacity): Source[(U, V)] = + val c = StageCapacity.newChannel[(U, V)] + + def receiveFromOther(thisElement: U, otherClosedHandler: () => Boolean): Boolean = + other.receive() match + case ChannelClosed.Done => otherClosedHandler() + case ChannelClosed.Error(r) => c.error(r); false + case v: V @unchecked => c.send(thisElement, v); true + + forkDaemon { + repeatWhile { + receive() match + case ChannelClosed.Done => receiveFromOther(thisDefault, () => { c.done(); false }) + case ChannelClosed.Error(r) => c.error(r); false + case t: T @unchecked => receiveFromOther(t, () => { c.send(t, otherDefault); true }) + } + } + c + // /** Sends a given number of elements (determined byc `segmentSize`) from this source to the returned channel, then sends the same number diff --git a/core/src/test/scala/ox/channels/SourceOpsZipAllTest.scala b/core/src/test/scala/ox/channels/SourceOpsZipAllTest.scala new file mode 100644 index 00000000..6824d8d3 --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsZipAllTest.scala @@ -0,0 +1,51 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsZipAllTest extends AnyFlatSpec with Matchers { + behavior of "Source.zipAll" + + it should "not emit any element when both channels are empty" in scoped { + val s = Source.empty[Int] + val other = Source.empty[String] + + s.zipAll(other, -1, "foo").toList shouldBe List.empty + } + + it should "emit this element when other channel is empty" in scoped { + val s = Source.fromValues(1) + val other = Source.empty[String] + + s.zipAll(other, -1, "foo").toList shouldBe List((1, "foo")) + } + + it should "emit other element when this channel is empty" in scoped { + val s = Source.empty[Int] + val other = Source.fromValues("a") + + s.zipAll(other, -1, "foo").toList shouldBe List((-1, "a")) + } + + it should "emit matching elements when both channels are of the same size" in scoped { + val s = Source.fromValues(1, 2) + val other = Source.fromValues("a", "b") + + s.zipAll(other, -1, "foo").toList shouldBe List((1, "a"), (2, "b")) + } + + it should "emit default for other channel if this channel is longer" in scoped { + val s = Source.fromValues(1, 2, 3) + val other = Source.fromValues("a") + + s.zipAll(other, -1, "foo").toList shouldBe List((1, "a"), (2, "foo"), (3, "foo")) + } + + it should "emit default for this channel if other channel is longer" in scoped { + val s = Source.fromValues(1) + val other = Source.fromValues("a", "b", "c") + + s.zipAll(other, -1, "foo").toList shouldBe List((1, "a"), (-1, "b"), (-1, "c")) + } +}