diff --git a/core/src/main/scala/ox/channels/SourceOps.scala b/core/src/main/scala/ox/channels/SourceOps.scala index 590aadee..508cd21a 100644 --- a/core/src/main/scala/ox/channels/SourceOps.scala +++ b/core/src/main/scala/ox/channels/SourceOps.scala @@ -244,6 +244,62 @@ trait SourceOps[+T] { this: Source[T] => def drain(): Unit = foreach(_ => ()) def applied[U](f: Source[T] => U): U = f(this) + + /** Applies the given mapping function `f`, using additional mutable state, to each element received from this source, and sends the + * results to the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel + * once this source is done. + * + * The `initializeState` function is called once when `statefulMap` is called. + * + * The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the + * returned channel, while an empty value will be ignored. + * + * @param initializeState + * A function that initializes the state. + * @param f + * A function that transforms the element from this source and the state into an optional pair of the next state and the result. If `f` + * returns a nonempty value, th result will be sent to the returned channel, otherwise it will be ignored. + * @param onComplete + * A function that transforms the final state into an optional element sent to the returned channel. By default the final state is + * ignored. + * @return + * A source to which the results of applying `f` to the elements from this source would be sent. + * @example + * {{{ + * scala> + * import ox.* + * import ox.channels.Source + * + * scoped { + * val s = Source.fromValues(1, 2, 3, 4, 5) + * s.statefulMap(() => 0)((sum, element) => (sum + element, Some(sum)), Some.apply) + * } + * + * scala> val res0: List[Int] = List(0, 1, 3, 6, 10, 15) + * }}} + */ + def statefulMap[S, U >: T]( + initializeState: () => S + )(f: (S, T) => (S, Option[U]), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] = + val c = StageCapacity.newChannel[U] + forkDaemon { + var state = initializeState() + repeatWhile { + receive() match + case ChannelClosed.Done => + onComplete(state).foreach(c.send) + c.done() + false + case ChannelClosed.Error(r) => + c.error(r) + false + case t: T @unchecked => + val (nextState, result) = f(state, t) + state = nextState + result.map(c.send(_).isValue).getOrElse(true) + } + } + c } trait SourceCompanionOps: diff --git a/core/src/test/scala/ox/channels/SourceOpsStatefulMapTest.scala b/core/src/test/scala/ox/channels/SourceOpsStatefulMapTest.scala new file mode 100644 index 00000000..87a2e25d --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsStatefulMapTest.scala @@ -0,0 +1,37 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsStatefulMapTest extends AnyFlatSpec with Matchers { + + behavior of "Source.statefulMap" + + it should "zip with index" in scoped { + val c = Source.fromValues("a", "b", "c") + + val s = c.statefulMap(() => 0)((index, element) => (index + 1, Some((element, index)))) + + s.toList shouldBe List(("a", 0), ("b", 1), ("c", 2)) + } + + it should "calculate a running total" in scoped { + val c = Source.fromValues(1, 2, 3, 4, 5) + + val s = c.statefulMap(() => 0)((sum, element) => (sum + element, Some(sum)), Some.apply) + + s.toList shouldBe List(0, 1, 3, 6, 10, 15) + } + + it should "deduplicate" in scoped { + val c = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5) + + val s = c.statefulMap(() => Set.empty[Int])((alreadySeen, element) => + val result = Option.unless(alreadySeen.contains(element))(element) + (alreadySeen + element, result) + ) + + s.toList shouldBe List(1, 2, 3, 4, 5) + } +}