diff --git a/core/src/main/scala/ox/channels/SourceOps.scala b/core/src/main/scala/ox/channels/SourceOps.scala index 590aadee..96a98961 100644 --- a/core/src/main/scala/ox/channels/SourceOps.scala +++ b/core/src/main/scala/ox/channels/SourceOps.scala @@ -5,6 +5,7 @@ import ox.* import java.util.concurrent.{ArrayBlockingQueue, ConcurrentLinkedQueue, CountDownLatch, LinkedBlockingQueue, Semaphore} import scala.collection.mutable import scala.concurrent.duration.FiniteDuration +import scala.collection.IterableOnce trait SourceOps[+T] { this: Source[T] => // view ops (lazy) @@ -244,6 +245,113 @@ trait SourceOps[+T] { this: Source[T] => def drain(): Unit = foreach(_ => ()) def applied[U](f: Source[T] => U): U = f(this) + + /** Applies the given mapping function `f`, using additional state, to each element received from this source, and sends the results to + * the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel once this + * source is done. + * + * The `initializeState` function is called once when `statefulMap` is called. + * + * The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the + * returned channel, while an empty value will be ignored. + * + * @param initializeState + * A function that initializes the state. + * @param f + * A function that transforms the element from this source and the state into a pair of the next state and the result which is sent + * sent to the returned channel. + * @param onComplete + * A function that transforms the final state into an optional element sent to the returned channel. By default the final state is + * ignored. + * @return + * A source to which the results of applying `f` to the elements from this source would be sent. + * @example + * {{{ + * scala> + * import ox.* + * import ox.channels.Source + * + * scoped { + * val s = Source.fromValues(1, 2, 3, 4, 5) + * s.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply) + * } + * + * scala> val res0: List[Int] = List(0, 1, 3, 6, 10, 15) + * }}} + */ + def mapStateful[S, U >: T]( + initializeState: () => S + )(f: (S, T) => (S, U), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] = + def resultToSome(s: S, t: T) = + val (newState, result) = f(s, t) + (newState, Some(result)) + + mapStatefulConcat(initializeState)(resultToSome, onComplete) + + /** Applies the given mapping function `f`, using additional state, to each element received from this source, and sends the results one + * by one to the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel + * once this source is done. + * + * The `initializeState` function is called once when `statefulMap` is called. + * + * The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the + * returned channel, while an empty value will be ignored. + * + * @param initializeState + * A function that initializes the state. + * @param f + * A function that transforms the element from this source and the state into a pair of the next state and a + * [[scala.collection.IterableOnce]] of results which are sent one by one to the returned channel. If the result of `f` is empty, + * nothing is sent to the returned channel. + * @param onComplete + * A function that transforms the final state into an optional element sent to the returned channel. By default the final state is + * ignored. + * @return + * A source to which the results of applying `f` to the elements from this source would be sent. + * @example + * {{{ + * scala> + * import ox.* + * import ox.channels.Source + * + * scoped { + * val s = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5) + * // deduplicate the values + * s.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e))) + * } + * + * scala> val res0: List[Int] = List(1, 2, 3, 4, 5) + * }}} + */ + def mapStatefulConcat[S, U >: T]( + initializeState: () => S + )(f: (S, T) => (S, IterableOnce[U]), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] = + val c = StageCapacity.newChannel[U] + forkDaemon { + var state = initializeState() + repeatWhile { + receive() match + case ChannelClosed.Done => + try + onComplete(state).foreach(c.send) + c.done() + catch case t: Throwable => c.error(t) + false + case ChannelClosed.Error(r) => + c.error(r) + false + case t: T @unchecked => + try + val (nextState, result) = f(state, t) + state = nextState + result.iterator.map(c.send).forall(_.isValue) + catch + case t: Throwable => + c.error(t) + false + } + } + c } trait SourceCompanionOps: @@ -419,7 +527,7 @@ trait SourceCompanionOps: StageCapacity ): Source[T] = sources match - case Nil => Source.empty + case Nil => Source.empty case single :: Nil => single case _ => val c = StageCapacity.newChannel[T] diff --git a/core/src/test/scala/ox/channels/SourceOpsMapStatefulConcatTest.scala b/core/src/test/scala/ox/channels/SourceOpsMapStatefulConcatTest.scala new file mode 100644 index 00000000..b5a5d5f1 --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsMapStatefulConcatTest.scala @@ -0,0 +1,79 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsMapStatefulConcatTest extends AnyFlatSpec with Matchers { + + behavior of "Source.mapStatefulConcat" + + it should "deduplicate" in scoped { + // given + val c = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5) + + // when + val s = c.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e))) + + // then + s.toList shouldBe List(1, 2, 3, 4, 5) + } + + it should "count consecutive" in scoped { + // given + val c = Source.fromValues("apple", "apple", "apple", "banana", "orange", "orange", "apple") + + // when + val s = c.mapStatefulConcat(() => (Option.empty[String], 0))( + { case ((previous, count), e) => + previous match + case None => ((Some(e), 1), None) + case Some(`e`) => ((previous, count + 1), None) + case Some(_) => ((Some(e), 1), previous.map((_, count))) + }, + { case (previous, count) => previous.map((_, count)) } + ) + + // then + s.toList shouldBe List( + ("apple", 3), + ("banana", 1), + ("orange", 2), + ("apple", 1) + ) + } + + it should "propagate errors in the mapping function" in scoped { + // given + val c = Source.fromValues("a", "b", "c") + + // when + val s = c.mapStatefulConcat(() => 0) { (index, element) => + if (index < 2) (index + 1, Some(element)) + else throw new RuntimeException("boom") + } + + // then + s.receive() shouldBe "a" + s.receive() shouldBe "b" + s.receive() should matchPattern { + case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => + } + } + + it should "propagate errors in the completion callback" in scoped { + // given + val c = Source.fromValues("a", "b", "c") + + // when + val s = c.mapStatefulConcat(() => 0)((index, element) => (index + 1, Some(element)), _ => throw new RuntimeException("boom")) + + // then + s.receive() shouldBe "a" + s.receive() shouldBe "b" + s.receive() shouldBe "c" + s.receive() should matchPattern { + case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => + } + } +} diff --git a/core/src/test/scala/ox/channels/SourceOpsMapStatefulTest.scala b/core/src/test/scala/ox/channels/SourceOpsMapStatefulTest.scala new file mode 100644 index 00000000..4b080fae --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsMapStatefulTest.scala @@ -0,0 +1,60 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsMapStatefulTest extends AnyFlatSpec with Matchers { + + behavior of "Source.mapStateful" + + it should "zip with index" in scoped { + val c = Source.fromValues("a", "b", "c") + + val s = c.mapStateful(() => 0)((index, element) => (index + 1, (element, index))) + + s.toList shouldBe List(("a", 0), ("b", 1), ("c", 2)) + } + + it should "calculate a running total" in scoped { + val c = Source.fromValues(1, 2, 3, 4, 5) + + val s = c.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply) + + s.toList shouldBe List(0, 1, 3, 6, 10, 15) + } + + it should "propagate errors in the mapping function" in scoped { + // given + val c = Source.fromValues("a", "b", "c") + + // when + val s = c.mapStateful(() => 0) { (index, element) => + if (index < 2) (index + 1, element) + else throw new RuntimeException("boom") + } + + // then + s.receive() shouldBe "a" + s.receive() shouldBe "b" + s.receive() should matchPattern { + case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => + } + } + + it should "propagate errors in the completion callback" in scoped { + // given + val c = Source.fromValues("a", "b", "c") + + // when + val s = c.mapStateful(() => 0)((index, element) => (index + 1, element), _ => throw new RuntimeException("boom")) + + // then + s.receive() shouldBe "a" + s.receive() shouldBe "b" + s.receive() shouldBe "c" + s.receive() should matchPattern { + case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => + } + } +}