Skip to content

Commit

Permalink
Merge pull request #13 from softwaremill/source-interleave-all
Browse files Browse the repository at this point in the history
Add Source.interleaveAll combinator
  • Loading branch information
adamw authored Oct 6, 2023
2 parents 4c419ae + a92e6ce commit 8efe299
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 37 deletions.
127 changes: 90 additions & 37 deletions core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ox.channels
import ox.*

import java.util.concurrent.{ArrayBlockingQueue, ConcurrentLinkedQueue, CountDownLatch, LinkedBlockingQueue, Semaphore}
import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration

trait SourceOps[+T] { this: Source[T] =>
Expand Down Expand Up @@ -171,8 +172,8 @@ trait SourceOps[+T] { this: Source[T] =>
/** Sends a given number of elements (determined byc `segmentSize`) from this source to the returned channel, then sends the same number
* of elements from the `other` source and repeats. The order of elements in both sources is preserved.
*
* If one of the sources is closed before the other, the behavior depends on the `eagerCancel` flag. When set to `true`, the other source
* is cancelled immediately, otherwise the remaining elements from the other source are sent to the returned channel.
* If one of the sources is done before the other, the behavior depends on the `eagerCancel` flag. When set to `true`, the returned
* channel is completed immediately, otherwise the remaining elements from the other source are sent to the returned channel.
*
* Must be run within a scope, since a child fork is created which receives from both sources and sends to the resulting channel.
*
Expand Down Expand Up @@ -201,41 +202,7 @@ trait SourceOps[+T] { this: Source[T] =>
* }}}
*/
def interleave[U >: T](other: Source[U], segmentSize: Int = 1, eagerComplete: Boolean = false)(using Ox, StageCapacity): Source[U] =
val c = StageCapacity.newChannel[U]

forkDaemon {
var source: Source[U] = this
var counter = 0
var neitherCompleted = true

def switchSource(): Unit = {
if (source == this) source = other else source = this
counter = 0
}

repeatWhile {
source.receive() match
case ChannelClosed.Done =>
// if one source has completed, either complete the resulting source immediately if eagerComplete is set, or:
// - continue with the other source if it hasn't completed yet, or
// - complete the resulting source if both input sources have completed
if (neitherCompleted && !eagerComplete) {
neitherCompleted = false
switchSource()
true
} else {
c.done()
false
}
case ChannelClosed.Error(r) => c.error(r); false
case value: U @unchecked =>
counter += 1
// after reaching segmentSize, only switch to the other source if it hasn't completed yet
if (counter == segmentSize && neitherCompleted) switchSource()
c.send(value).isValue
}
}
c
Source.interleaveAll(List(this, other), segmentSize, eagerComplete)

/** Invokes the given function for each received element. Blocks until the channel is done.
* @throws ChannelClosedException
Expand Down Expand Up @@ -406,3 +373,89 @@ trait SourceCompanionOps:
catch case t: Throwable => c.error(t)
}
c

def empty[T]: Source[T] =
val c = DirectChannel()
c.done()
c

/** Sends a given number of elements (determined byc `segmentSize`) from each source in `sources` to the returned channel and repeats. The
* order of elements in all sources is preserved.
*
* If any of the sources is done before the others, the behavior depends on the `eagerCancel` flag. When set to `true`, the returned
* channel is completed immediately, otherwise the interleaving continues with the remaining non-completed sources. Once all but one
* sources are complete, the elements of the remaining non-complete source are sent to the returned channel.
*
* Must be run within a scope, since a child fork is created which receives from the subsequent sources and sends to the resulting
* channel.
*
* @param sources
* The sources whose elements will be interleaved.
* @param segmentSize
* The number of elements sent from each source before switching to the next one. Default is 1.
* @param eagerComplete
* If `true`, the returned channel is completed as soon as any of the sources completes. If 'false`, the interleaving continues with
* the remaining non-completed sources.
* @return
* A source to which the interleaved elements from both sources would be sent.
* @example
* {{{
* scala>
* import ox.*
* import ox.channels.Source
*
* scoped {
* val s1 = Source.fromValues(1, 2, 3, 4, 5, 6, 7, 8)
* val s2 = Source.fromValues(10, 20, 30)
* val s3 = Source.fromValues(100, 200, 300, 400, 500)
* Source.interleaveAll(List(s1, s2, s3), segmentSize = 2, eagerComplete = true).toList
* }
*
* scala> val res0: List[Int] = List(1, 2, 10, 20, 100, 200, 3, 4, 30)
* }}}
*/
def interleaveAll[T](sources: Seq[Source[T]], segmentSize: Int = 1, eagerComplete: Boolean = false)(using
Ox,
StageCapacity
): Source[T] =
sources match
case Nil => Source.empty
case single :: Nil => single
case _ =>
val c = StageCapacity.newChannel[T]

forkDaemon {
val availableSources = mutable.ArrayBuffer.from(sources)
var currentSourceIndex = 0
var elementsRead = 0

def completeCurrentSource(): Unit =
availableSources.remove(currentSourceIndex)
currentSourceIndex = if (currentSourceIndex == 0) availableSources.size - 1 else currentSourceIndex - 1

def switchToNextSource(): Unit =
currentSourceIndex = (currentSourceIndex + 1) % availableSources.size
elementsRead = 0

repeatWhile {
availableSources(currentSourceIndex).receive() match
case ChannelClosed.Done =>
completeCurrentSource()

if (eagerComplete || availableSources.isEmpty)
c.done()
false
else
switchToNextSource()
true
case ChannelClosed.Error(r) =>
c.error(r)
false
case value: T @unchecked =>
elementsRead += 1
// after reaching segmentSize, only switch to next source if there's any other available
if (elementsRead == segmentSize && availableSources.size > 1) switchToNextSource()
c.send(value).isValue
}
}
c
18 changes: 18 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsEmptyTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package ox.channels

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

class SourceOpsEmptyTest extends AnyFlatSpec with Matchers {

behavior of "Source.empty"

it should "be done" in scoped {
Source.empty.isDone shouldBe true
}

it should "be empty" in scoped {
Source.empty.toList shouldBe empty
}
}
54 changes: 54 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsInterleaveAllTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package ox.channels

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

class SourceOpsInterleaveAllTest extends AnyFlatSpec with Matchers {

behavior of "Source.interleaveAll"

it should "interleave no sources" in scoped {
val s = Source.interleaveAll(List.empty)

s.toList shouldBe empty
}

it should "interleave a single source" in scoped {
val c = Source.fromValues(1, 2, 3)

val s = Source.interleaveAll(List(c))

s.toList shouldBe List(1, 2, 3)
}

it should "interleave multiple sources" in scoped {
val c1 = Source.fromValues(1, 2, 3, 4, 5, 6, 7, 8)
val c2 = Source.fromValues(10, 20, 30)
val c3 = Source.fromValues(100, 200, 300, 400, 500)

val s = Source.interleaveAll(List(c1, c2, c3))

s.toList shouldBe List(1, 10, 100, 2, 20, 200, 3, 30, 300, 4, 400, 5, 500, 6, 7, 8)
}

it should "interleave multiple sources using custom segment size" in scoped {
val c1 = Source.fromValues(1, 2, 3, 4, 5, 6, 7, 8)
val c2 = Source.fromValues(10, 20, 30)
val c3 = Source.fromValues(100, 200, 300, 400, 500)

val s = Source.interleaveAll(List(c1, c2, c3), segmentSize = 2)

s.toList shouldBe List(1, 2, 10, 20, 100, 200, 3, 4, 30, 300, 400, 5, 6, 500, 7, 8)
}

it should "interleave multiple sources using custom segment size and complete eagerly" in scoped {
val c1 = Source.fromValues(1, 2, 3, 4, 5, 6, 7, 8)
val c2 = Source.fromValues(10, 20, 30)
val c3 = Source.fromValues(100, 200, 300, 400, 500)

val s = Source.interleaveAll(List(c1, c2, c3), segmentSize = 2, eagerComplete = true)

s.toList shouldBe List(1, 2, 10, 20, 100, 200, 3, 4, 30)
}
}

0 comments on commit 8efe299

Please sign in to comment.