diff --git a/core/src/main/scala/ox/channels/SourceOps.scala b/core/src/main/scala/ox/channels/SourceOps.scala index 7d113f64..590aadee 100644 --- a/core/src/main/scala/ox/channels/SourceOps.scala +++ b/core/src/main/scala/ox/channels/SourceOps.scala @@ -3,6 +3,7 @@ package ox.channels import ox.* import java.util.concurrent.{ArrayBlockingQueue, ConcurrentLinkedQueue, CountDownLatch, LinkedBlockingQueue, Semaphore} +import scala.collection.mutable import scala.concurrent.duration.FiniteDuration trait SourceOps[+T] { this: Source[T] => @@ -171,8 +172,8 @@ trait SourceOps[+T] { this: Source[T] => /** Sends a given number of elements (determined byc `segmentSize`) from this source to the returned channel, then sends the same number * of elements from the `other` source and repeats. The order of elements in both sources is preserved. * - * If one of the sources is closed before the other, the behavior depends on the `eagerCancel` flag. When set to `true`, the other source - * is cancelled immediately, otherwise the remaining elements from the other source are sent to the returned channel. + * If one of the sources is done before the other, the behavior depends on the `eagerCancel` flag. When set to `true`, the returned + * channel is completed immediately, otherwise the remaining elements from the other source are sent to the returned channel. * * Must be run within a scope, since a child fork is created which receives from both sources and sends to the resulting channel. * @@ -201,41 +202,7 @@ trait SourceOps[+T] { this: Source[T] => * }}} */ def interleave[U >: T](other: Source[U], segmentSize: Int = 1, eagerComplete: Boolean = false)(using Ox, StageCapacity): Source[U] = - val c = StageCapacity.newChannel[U] - - forkDaemon { - var source: Source[U] = this - var counter = 0 - var neitherCompleted = true - - def switchSource(): Unit = { - if (source == this) source = other else source = this - counter = 0 - } - - repeatWhile { - source.receive() match - case ChannelClosed.Done => - // if one source has completed, either complete the resulting source immediately if eagerComplete is set, or: - // - continue with the other source if it hasn't completed yet, or - // - complete the resulting source if both input sources have completed - if (neitherCompleted && !eagerComplete) { - neitherCompleted = false - switchSource() - true - } else { - c.done() - false - } - case ChannelClosed.Error(r) => c.error(r); false - case value: U @unchecked => - counter += 1 - // after reaching segmentSize, only switch to the other source if it hasn't completed yet - if (counter == segmentSize && neitherCompleted) switchSource() - c.send(value).isValue - } - } - c + Source.interleaveAll(List(this, other), segmentSize, eagerComplete) /** Invokes the given function for each received element. Blocks until the channel is done. * @throws ChannelClosedException @@ -406,3 +373,89 @@ trait SourceCompanionOps: catch case t: Throwable => c.error(t) } c + + def empty[T]: Source[T] = + val c = DirectChannel() + c.done() + c + + /** Sends a given number of elements (determined byc `segmentSize`) from each source in `sources` to the returned channel and repeats. The + * order of elements in all sources is preserved. + * + * If any of the sources is done before the others, the behavior depends on the `eagerCancel` flag. When set to `true`, the returned + * channel is completed immediately, otherwise the interleaving continues with the remaining non-completed sources. Once all but one + * sources are complete, the elements of the remaining non-complete source are sent to the returned channel. + * + * Must be run within a scope, since a child fork is created which receives from the subsequent sources and sends to the resulting + * channel. + * + * @param sources + * The sources whose elements will be interleaved. + * @param segmentSize + * The number of elements sent from each source before switching to the next one. Default is 1. + * @param eagerComplete + * If `true`, the returned channel is completed as soon as any of the sources completes. If 'false`, the interleaving continues with + * the remaining non-completed sources. + * @return + * A source to which the interleaved elements from both sources would be sent. + * @example + * {{{ + * scala> + * import ox.* + * import ox.channels.Source + * + * scoped { + * val s1 = Source.fromValues(1, 2, 3, 4, 5, 6, 7, 8) + * val s2 = Source.fromValues(10, 20, 30) + * val s3 = Source.fromValues(100, 200, 300, 400, 500) + * Source.interleaveAll(List(s1, s2, s3), segmentSize = 2, eagerComplete = true).toList + * } + * + * scala> val res0: List[Int] = List(1, 2, 10, 20, 100, 200, 3, 4, 30) + * }}} + */ + def interleaveAll[T](sources: Seq[Source[T]], segmentSize: Int = 1, eagerComplete: Boolean = false)(using + Ox, + StageCapacity + ): Source[T] = + sources match + case Nil => Source.empty + case single :: Nil => single + case _ => + val c = StageCapacity.newChannel[T] + + forkDaemon { + val availableSources = mutable.ArrayBuffer.from(sources) + var currentSourceIndex = 0 + var elementsRead = 0 + + def completeCurrentSource(): Unit = + availableSources.remove(currentSourceIndex) + currentSourceIndex = if (currentSourceIndex == 0) availableSources.size - 1 else currentSourceIndex - 1 + + def switchToNextSource(): Unit = + currentSourceIndex = (currentSourceIndex + 1) % availableSources.size + elementsRead = 0 + + repeatWhile { + availableSources(currentSourceIndex).receive() match + case ChannelClosed.Done => + completeCurrentSource() + + if (eagerComplete || availableSources.isEmpty) + c.done() + false + else + switchToNextSource() + true + case ChannelClosed.Error(r) => + c.error(r) + false + case value: T @unchecked => + elementsRead += 1 + // after reaching segmentSize, only switch to next source if there's any other available + if (elementsRead == segmentSize && availableSources.size > 1) switchToNextSource() + c.send(value).isValue + } + } + c diff --git a/core/src/test/scala/ox/channels/SourceOpsEmptyTest.scala b/core/src/test/scala/ox/channels/SourceOpsEmptyTest.scala new file mode 100644 index 00000000..52448a43 --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsEmptyTest.scala @@ -0,0 +1,18 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsEmptyTest extends AnyFlatSpec with Matchers { + + behavior of "Source.empty" + + it should "be done" in scoped { + Source.empty.isDone shouldBe true + } + + it should "be empty" in scoped { + Source.empty.toList shouldBe empty + } +} diff --git a/core/src/test/scala/ox/channels/SourceOpsInterleaveAllTest.scala b/core/src/test/scala/ox/channels/SourceOpsInterleaveAllTest.scala new file mode 100644 index 00000000..0c18288f --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsInterleaveAllTest.scala @@ -0,0 +1,54 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsInterleaveAllTest extends AnyFlatSpec with Matchers { + + behavior of "Source.interleaveAll" + + it should "interleave no sources" in scoped { + val s = Source.interleaveAll(List.empty) + + s.toList shouldBe empty + } + + it should "interleave a single source" in scoped { + val c = Source.fromValues(1, 2, 3) + + val s = Source.interleaveAll(List(c)) + + s.toList shouldBe List(1, 2, 3) + } + + it should "interleave multiple sources" in scoped { + val c1 = Source.fromValues(1, 2, 3, 4, 5, 6, 7, 8) + val c2 = Source.fromValues(10, 20, 30) + val c3 = Source.fromValues(100, 200, 300, 400, 500) + + val s = Source.interleaveAll(List(c1, c2, c3)) + + s.toList shouldBe List(1, 10, 100, 2, 20, 200, 3, 30, 300, 4, 400, 5, 500, 6, 7, 8) + } + + it should "interleave multiple sources using custom segment size" in scoped { + val c1 = Source.fromValues(1, 2, 3, 4, 5, 6, 7, 8) + val c2 = Source.fromValues(10, 20, 30) + val c3 = Source.fromValues(100, 200, 300, 400, 500) + + val s = Source.interleaveAll(List(c1, c2, c3), segmentSize = 2) + + s.toList shouldBe List(1, 2, 10, 20, 100, 200, 3, 4, 30, 300, 400, 5, 6, 500, 7, 8) + } + + it should "interleave multiple sources using custom segment size and complete eagerly" in scoped { + val c1 = Source.fromValues(1, 2, 3, 4, 5, 6, 7, 8) + val c2 = Source.fromValues(10, 20, 30) + val c3 = Source.fromValues(100, 200, 300, 400, 500) + + val s = Source.interleaveAll(List(c1, c2, c3), segmentSize = 2, eagerComplete = true) + + s.toList shouldBe List(1, 2, 10, 20, 100, 200, 3, 4, 30) + } +}