diff --git a/core/src/main/scala/ox/channels/SourceOps.scala b/core/src/main/scala/ox/channels/SourceOps.scala index 664269e5..3ab9e2ea 100644 --- a/core/src/main/scala/ox/channels/SourceOps.scala +++ b/core/src/main/scala/ox/channels/SourceOps.scala @@ -3,7 +3,7 @@ package ox.channels import ox.* import java.util.concurrent.{CountDownLatch, Semaphore} -import scala.collection.mutable +import scala.collection.{IterableOnce, mutable} import scala.concurrent.duration.FiniteDuration trait SourceOps[+T] { this: Source[T] => @@ -34,6 +34,67 @@ trait SourceOps[+T] { this: Source[T] => } c2 + /** Intersperses this source with provided element and forwards it to the returned channel. + * + * @param inject + * An element to be injected between the stream elements. + * @return + * A source, onto which elements will be injected. + * @example + * {{{ + * import ox.* + * import ox.channels.Source + * + * scoped { + * Source.empty[String].intersperse(", ").toList // List() + * Source.fromValues("foo").intersperse(", ").toList // List(foo) + * Source.fromValues("foo", "bar").intersperse(", ").toList // List(foo, ", ", bar) + * } + * }}} + */ + def intersperse[U >: T](inject: U)(using Ox, StageCapacity): Source[U] = + intersperse(None, inject, None) + + /** Intersperses this source with start, end and provided elements and forwards it to the returned channel. + * + * @param start + * An element to be prepended to the stream. + * @param inject + * An element to be injected between the stream elements. + * @param end + * An element to be appended to the end of the stream. + * @return + * A source, onto which elements will be injected. + * @example + * {{{ + * import ox.* + * import ox.channels.Source + * + * scoped { + * Source.empty[String].intersperse("[", ", ", "]").toList // List([, ]) + * Source.fromValues("foo").intersperse("[", ", ", "]").toList // List([, foo, ]) + * Source.fromValues("foo", "bar").intersperse("[", ", ", "]").toList // List([, foo, ", ", bar, ]) + * } + * }}} + */ + def intersperse[U >: T](start: U, inject: U, end: U)(using Ox, StageCapacity): Source[U] = + intersperse(Some(start), inject, Some(end)) + + private def intersperse[U >: T](start: Option[U], inject: U, end: Option[U])(using Ox, StageCapacity): Source[U] = + val c = StageCapacity.newChannel[U] + forkDaemon { + start.foreach(c.send) + var firstEmitted = false + repeatWhile { + receive() match + case ChannelClosed.Done => end.foreach(c.send); c.done(); false + case ChannelClosed.Error(e) => c.error(e); false + case v: U @unchecked if !firstEmitted => firstEmitted = true; c.send(v); true + case v: U @unchecked => c.send(inject); c.send(v); true + } + } + c + /** Applies the given mapping function `f` to each element received from this source, and sends the results to the returned channel. At * most `parallelism` invocations of `f` are run in parallel. * @@ -160,6 +221,24 @@ trait SourceOps[+T] { this: Source[T] => */ def takeWhile(f: T => Boolean)(using Ox, StageCapacity): Source[T] = transform(_.takeWhile(f)) + /** Drops `n` elements from this source and forwards subsequent elements to the returned channel. + * + * @param n + * Number of elements to be dropped. + * @example + * {{{ + * import ox.* + * import ox.channels.Source + * + * scoped { + * Source.empty[Int].drop(1).toList // List() + * Source.fromValues(1, 2, 3).drop(1).toList // List(2 ,3) + * Source.fromValues(1).drop(2).toList // List() + * } + * }}} + */ + def drop(n: Int)(using Ox, StageCapacity): Source[T] = transform(_.drop(n)) + def filter(f: T => Boolean)(using Ox, StageCapacity): Source[T] = transform(_.filter(f)) def transform[U](f: Iterator[T] => Iterator[U])(using Ox, StageCapacity): Source[U] = @@ -213,6 +292,46 @@ trait SourceOps[+T] { this: Source[T] => } c + /** Combines elements from this and other sources into tuples handling early completion of either source with defaults. + * + * @param other + * A source of elements to be combined with. + * @param thisDefault + * A default element to be used in the result tuple when the other source is longer. + * @param otherDefault + * A default element to be used in the result tuple when the current source is longer. + * @example + * {{{ + * import ox.* + * import ox.channels.Source + * + * scoped { + * Source.empty[Int].zipAll(Source.empty[String], -1, "foo").toList // List() + * Source.empty[Int].zipAll(Source.fromValues("a"), -1, "foo").toList // List((-1, "a")) + * Source.fromValues(1).zipAll(Source.empty[String], -1, "foo").toList // List((1, "foo")) + * Source.fromValues(1).zipAll(Source.fromValues("a"), -1, "foo").toList // List((1, "a")) + * } + * }}} + */ + def zipAll[U >: T, V](other: Source[V], thisDefault: U, otherDefault: V)(using Ox, StageCapacity): Source[(U, V)] = + val c = StageCapacity.newChannel[(U, V)] + + def receiveFromOther(thisElement: U, otherClosedHandler: () => Boolean): Boolean = + other.receive() match + case ChannelClosed.Done => otherClosedHandler() + case ChannelClosed.Error(r) => c.error(r); false + case v: V @unchecked => c.send(thisElement, v); true + + forkDaemon { + repeatWhile { + receive() match + case ChannelClosed.Done => receiveFromOther(thisDefault, () => { c.done(); false }) + case ChannelClosed.Error(r) => c.error(r); false + case t: T @unchecked => receiveFromOther(t, () => { c.send(t, otherDefault); true }) + } + } + c + // /** Sends a given number of elements (determined byc `segmentSize`) from this source to the returned channel, then sends the same number @@ -290,6 +409,113 @@ trait SourceOps[+T] { this: Source[T] => def drain(): Unit = foreach(_ => ()) def applied[U](f: Source[T] => U): U = f(this) + + /** Applies the given mapping function `f`, using additional state, to each element received from this source, and sends the results to + * the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel once this + * source is done. + * + * The `initializeState` function is called once when `statefulMap` is called. + * + * The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the + * returned channel, while an empty value will be ignored. + * + * @param initializeState + * A function that initializes the state. + * @param f + * A function that transforms the element from this source and the state into a pair of the next state and the result which is sent + * sent to the returned channel. + * @param onComplete + * A function that transforms the final state into an optional element sent to the returned channel. By default the final state is + * ignored. + * @return + * A source to which the results of applying `f` to the elements from this source would be sent. + * @example + * {{{ + * scala> + * import ox.* + * import ox.channels.Source + * + * scoped { + * val s = Source.fromValues(1, 2, 3, 4, 5) + * s.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply) + * } + * + * scala> val res0: List[Int] = List(0, 1, 3, 6, 10, 15) + * }}} + */ + def mapStateful[S, U >: T]( + initializeState: () => S + )(f: (S, T) => (S, U), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] = + def resultToSome(s: S, t: T) = + val (newState, result) = f(s, t) + (newState, Some(result)) + + mapStatefulConcat(initializeState)(resultToSome, onComplete) + + /** Applies the given mapping function `f`, using additional state, to each element received from this source, and sends the results one + * by one to the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel + * once this source is done. + * + * The `initializeState` function is called once when `statefulMap` is called. + * + * The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the + * returned channel, while an empty value will be ignored. + * + * @param initializeState + * A function that initializes the state. + * @param f + * A function that transforms the element from this source and the state into a pair of the next state and a + * [[scala.collection.IterableOnce]] of results which are sent one by one to the returned channel. If the result of `f` is empty, + * nothing is sent to the returned channel. + * @param onComplete + * A function that transforms the final state into an optional element sent to the returned channel. By default the final state is + * ignored. + * @return + * A source to which the results of applying `f` to the elements from this source would be sent. + * @example + * {{{ + * scala> + * import ox.* + * import ox.channels.Source + * + * scoped { + * val s = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5) + * // deduplicate the values + * s.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e))) + * } + * + * scala> val res0: List[Int] = List(1, 2, 3, 4, 5) + * }}} + */ + def mapStatefulConcat[S, U >: T]( + initializeState: () => S + )(f: (S, T) => (S, IterableOnce[U]), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] = + val c = StageCapacity.newChannel[U] + forkDaemon { + var state = initializeState() + repeatWhile { + receive() match + case ChannelClosed.Done => + try + onComplete(state).foreach(c.send) + c.done() + catch case t: Throwable => c.error(t) + false + case ChannelClosed.Error(r) => + c.error(r) + false + case t: T @unchecked => + try + val (nextState, result) = f(state, t) + state = nextState + result.iterator.map(c.send).forall(_.isValue) + catch + case t: Throwable => + c.error(t) + false + } + } + c } trait SourceCompanionOps: diff --git a/core/src/test/scala/ox/channels/SourceOpsDropTest.scala b/core/src/test/scala/ox/channels/SourceOpsDropTest.scala new file mode 100644 index 00000000..bcb760ed --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsDropTest.scala @@ -0,0 +1,29 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsDropTest extends AnyFlatSpec with Matchers { + behavior of "Source.drop" + + it should "not drop from the empty source" in supervised { + val s = Source.empty[Int] + s.drop(1).toList shouldBe List.empty + } + + it should "drop elements from the source" in supervised { + val s = Source.fromValues(1, 2, 3) + s.drop(2).toList shouldBe List(3) + } + + it should "return empty source when more elements than source length was dropped" in supervised { + val s = Source.fromValues(1, 2) + s.drop(3).toList shouldBe List.empty + } + + it should "not drop when 'n == 0'" in supervised { + val s = Source.fromValues(1, 2, 3) + s.drop(0).toList shouldBe List(1, 2, 3) + } +} diff --git a/core/src/test/scala/ox/channels/SourceOpsIntersperseTest.scala b/core/src/test/scala/ox/channels/SourceOpsIntersperseTest.scala new file mode 100644 index 00000000..6a8a30cc --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsIntersperseTest.scala @@ -0,0 +1,39 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsIntersperseTest extends AnyFlatSpec with Matchers { + behavior of "Source.intersperse" + + it should "intersperse with inject only over an empty source" in supervised { + val s = Source.empty[String] + s.intersperse(", ").toList shouldBe List.empty + } + + it should "intersperse with inject only over a source with one element" in supervised { + val s = Source.fromValues("foo") + s.intersperse(", ").toList shouldBe List("foo") + } + + it should "intersperse with inject only over a source with multiple elements" in supervised { + val s = Source.fromValues("foo", "bar") + s.intersperse(", ").toList shouldBe List("foo", ", ", "bar") + } + + it should "intersperse with start, inject and end over an empty source" in supervised { + val s = Source.empty[String] + s.intersperse("[", ", ", "]").toList shouldBe List("[", "]") + } + + it should "intersperse with start, inject and end over a source with one element" in supervised { + val s = Source.fromValues("foo") + s.intersperse("[", ", ", "]").toList shouldBe List("[", "foo", "]") + } + + it should "intersperse with start, inject and end over a source with multiple elements" in supervised { + val s = Source.fromValues("foo", "bar") + s.intersperse("[", ", ", "]").toList shouldBe List("[", "foo", ", ", "bar", "]") + } +} diff --git a/core/src/test/scala/ox/channels/SourceOpsMapStatefulConcatTest.scala b/core/src/test/scala/ox/channels/SourceOpsMapStatefulConcatTest.scala new file mode 100644 index 00000000..b5a5d5f1 --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsMapStatefulConcatTest.scala @@ -0,0 +1,79 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsMapStatefulConcatTest extends AnyFlatSpec with Matchers { + + behavior of "Source.mapStatefulConcat" + + it should "deduplicate" in scoped { + // given + val c = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5) + + // when + val s = c.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e))) + + // then + s.toList shouldBe List(1, 2, 3, 4, 5) + } + + it should "count consecutive" in scoped { + // given + val c = Source.fromValues("apple", "apple", "apple", "banana", "orange", "orange", "apple") + + // when + val s = c.mapStatefulConcat(() => (Option.empty[String], 0))( + { case ((previous, count), e) => + previous match + case None => ((Some(e), 1), None) + case Some(`e`) => ((previous, count + 1), None) + case Some(_) => ((Some(e), 1), previous.map((_, count))) + }, + { case (previous, count) => previous.map((_, count)) } + ) + + // then + s.toList shouldBe List( + ("apple", 3), + ("banana", 1), + ("orange", 2), + ("apple", 1) + ) + } + + it should "propagate errors in the mapping function" in scoped { + // given + val c = Source.fromValues("a", "b", "c") + + // when + val s = c.mapStatefulConcat(() => 0) { (index, element) => + if (index < 2) (index + 1, Some(element)) + else throw new RuntimeException("boom") + } + + // then + s.receive() shouldBe "a" + s.receive() shouldBe "b" + s.receive() should matchPattern { + case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => + } + } + + it should "propagate errors in the completion callback" in scoped { + // given + val c = Source.fromValues("a", "b", "c") + + // when + val s = c.mapStatefulConcat(() => 0)((index, element) => (index + 1, Some(element)), _ => throw new RuntimeException("boom")) + + // then + s.receive() shouldBe "a" + s.receive() shouldBe "b" + s.receive() shouldBe "c" + s.receive() should matchPattern { + case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => + } + } +} diff --git a/core/src/test/scala/ox/channels/SourceOpsMapStatefulTest.scala b/core/src/test/scala/ox/channels/SourceOpsMapStatefulTest.scala new file mode 100644 index 00000000..4b080fae --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsMapStatefulTest.scala @@ -0,0 +1,60 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsMapStatefulTest extends AnyFlatSpec with Matchers { + + behavior of "Source.mapStateful" + + it should "zip with index" in scoped { + val c = Source.fromValues("a", "b", "c") + + val s = c.mapStateful(() => 0)((index, element) => (index + 1, (element, index))) + + s.toList shouldBe List(("a", 0), ("b", 1), ("c", 2)) + } + + it should "calculate a running total" in scoped { + val c = Source.fromValues(1, 2, 3, 4, 5) + + val s = c.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply) + + s.toList shouldBe List(0, 1, 3, 6, 10, 15) + } + + it should "propagate errors in the mapping function" in scoped { + // given + val c = Source.fromValues("a", "b", "c") + + // when + val s = c.mapStateful(() => 0) { (index, element) => + if (index < 2) (index + 1, element) + else throw new RuntimeException("boom") + } + + // then + s.receive() shouldBe "a" + s.receive() shouldBe "b" + s.receive() should matchPattern { + case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => + } + } + + it should "propagate errors in the completion callback" in scoped { + // given + val c = Source.fromValues("a", "b", "c") + + // when + val s = c.mapStateful(() => 0)((index, element) => (index + 1, element), _ => throw new RuntimeException("boom")) + + // then + s.receive() shouldBe "a" + s.receive() shouldBe "b" + s.receive() shouldBe "c" + s.receive() should matchPattern { + case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => + } + } +} diff --git a/core/src/test/scala/ox/channels/SourceOpsZipAllTest.scala b/core/src/test/scala/ox/channels/SourceOpsZipAllTest.scala new file mode 100644 index 00000000..6824d8d3 --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsZipAllTest.scala @@ -0,0 +1,51 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsZipAllTest extends AnyFlatSpec with Matchers { + behavior of "Source.zipAll" + + it should "not emit any element when both channels are empty" in scoped { + val s = Source.empty[Int] + val other = Source.empty[String] + + s.zipAll(other, -1, "foo").toList shouldBe List.empty + } + + it should "emit this element when other channel is empty" in scoped { + val s = Source.fromValues(1) + val other = Source.empty[String] + + s.zipAll(other, -1, "foo").toList shouldBe List((1, "foo")) + } + + it should "emit other element when this channel is empty" in scoped { + val s = Source.empty[Int] + val other = Source.fromValues("a") + + s.zipAll(other, -1, "foo").toList shouldBe List((-1, "a")) + } + + it should "emit matching elements when both channels are of the same size" in scoped { + val s = Source.fromValues(1, 2) + val other = Source.fromValues("a", "b") + + s.zipAll(other, -1, "foo").toList shouldBe List((1, "a"), (2, "b")) + } + + it should "emit default for other channel if this channel is longer" in scoped { + val s = Source.fromValues(1, 2, 3) + val other = Source.fromValues("a") + + s.zipAll(other, -1, "foo").toList shouldBe List((1, "a"), (2, "foo"), (3, "foo")) + } + + it should "emit default for this channel if other channel is longer" in scoped { + val s = Source.fromValues(1) + val other = Source.fromValues("a", "b", "c") + + s.zipAll(other, -1, "foo").toList shouldBe List((1, "a"), (-1, "b"), (-1, "c")) + } +}