diff --git a/core/src/main/scala/ox/channels/SourceOps.scala b/core/src/main/scala/ox/channels/SourceOps.scala index db6710b7..878e8c4e 100644 --- a/core/src/main/scala/ox/channels/SourceOps.scala +++ b/core/src/main/scala/ox/channels/SourceOps.scala @@ -5,6 +5,7 @@ import ox.* import java.util.concurrent.{ArrayBlockingQueue, ConcurrentLinkedQueue, CountDownLatch, LinkedBlockingQueue, Semaphore} import scala.collection.mutable import scala.concurrent.duration.FiniteDuration +import scala.collection.IterableOnce trait SourceOps[+T] { this: Source[T] => // view ops (lazy) @@ -257,8 +258,8 @@ trait SourceOps[+T] { this: Source[T] => * @param initializeState * A function that initializes the state. * @param f - * A function that transforms the element from this source and the state into an optional pair of the next state and the result. If `f` - * returns a nonempty value, th result will be sent to the returned channel, otherwise it will be ignored. + * A function that transforms the element from this source and the state into a pair of the next state and the result which is sent + * sent to the returned channel. * @param onComplete * A function that transforms the final state into an optional element sent to the returned channel. By default the final state is * ignored. @@ -272,15 +273,59 @@ trait SourceOps[+T] { this: Source[T] => * * scoped { * val s = Source.fromValues(1, 2, 3, 4, 5) - * s.statefulMap(() => 0)((sum, element) => (sum + element, Some(sum)), Some.apply) + * s.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply) * } * * scala> val res0: List[Int] = List(0, 1, 3, 6, 10, 15) * }}} */ - def statefulMap[S, U >: T]( + def mapStateful[S, U >: T]( initializeState: () => S - )(f: (S, T) => (S, Option[U]), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] = + )(f: (S, T) => (S, U), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] = + def resultToSome(s: S, t: T) = + val (newState, result) = f(s, t) + (newState, Some(result)) + + mapStatefulConcat(initializeState)(resultToSome, onComplete) + + /** Applies the given mapping function `f`, using additional mutable state, to each element received from this source, and sends the + * results one by one to the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned + * channel once this source is done. + * + * The `initializeState` function is called once when `statefulMap` is called. + * + * The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the + * returned channel, while an empty value will be ignored. + * + * @param initializeState + * A function that initializes the state. + * @param f + * A function that transforms the element from this source and the state into a pair of the next state and a + * [[scala.collection.IterableOnce]] of results which are sent one by one to the returned channel. If the result of `f` is empty, + * nothing is sent to the returned channel. + * @param onComplete + * A function that transforms the final state into an optional element sent to the returned channel. By default the final state is + * ignored. + * @return + * A source to which the results of applying `f` to the elements from this source would be sent. + * @example + * {{{ + * scala> + * import ox.* + * import ox.channels.Source + * + * scoped { + * val s = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5) + * // deduplicate the values + * s.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e))) + * } + * + * scala> val res0: List[Int] = List(1, 2, 3, 4, 5) + * }}} + */ + def mapStatefulConcat[S, U >: T]( + initializeState: () => S + )(f: (S, T) => (S, IterableOnce[U]), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] = val c = StageCapacity.newChannel[U] forkDaemon { var state = initializeState() @@ -290,8 +335,7 @@ trait SourceOps[+T] { this: Source[T] => try onComplete(state).foreach(c.send) c.done() - catch - case t: Throwable => c.error(t) + catch case t: Throwable => c.error(t) false case ChannelClosed.Error(r) => c.error(r) @@ -300,9 +344,9 @@ trait SourceOps[+T] { this: Source[T] => try val (nextState, result) = f(state, t) state = nextState - result.map(c.send(_).isValue).getOrElse(true) + result.iterator.map(c.send).forall(_.isValue) catch - case t: Throwable => + case t: Throwable => c.error(t) false } @@ -483,7 +527,7 @@ trait SourceCompanionOps: StageCapacity ): Source[T] = sources match - case Nil => Source.empty + case Nil => Source.empty case single :: Nil => single case _ => val c = StageCapacity.newChannel[T] diff --git a/core/src/test/scala/ox/channels/SourceOpsMapStatefulConcatTest.scala b/core/src/test/scala/ox/channels/SourceOpsMapStatefulConcatTest.scala new file mode 100644 index 00000000..b5a5d5f1 --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsMapStatefulConcatTest.scala @@ -0,0 +1,79 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsMapStatefulConcatTest extends AnyFlatSpec with Matchers { + + behavior of "Source.mapStatefulConcat" + + it should "deduplicate" in scoped { + // given + val c = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5) + + // when + val s = c.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e))) + + // then + s.toList shouldBe List(1, 2, 3, 4, 5) + } + + it should "count consecutive" in scoped { + // given + val c = Source.fromValues("apple", "apple", "apple", "banana", "orange", "orange", "apple") + + // when + val s = c.mapStatefulConcat(() => (Option.empty[String], 0))( + { case ((previous, count), e) => + previous match + case None => ((Some(e), 1), None) + case Some(`e`) => ((previous, count + 1), None) + case Some(_) => ((Some(e), 1), previous.map((_, count))) + }, + { case (previous, count) => previous.map((_, count)) } + ) + + // then + s.toList shouldBe List( + ("apple", 3), + ("banana", 1), + ("orange", 2), + ("apple", 1) + ) + } + + it should "propagate errors in the mapping function" in scoped { + // given + val c = Source.fromValues("a", "b", "c") + + // when + val s = c.mapStatefulConcat(() => 0) { (index, element) => + if (index < 2) (index + 1, Some(element)) + else throw new RuntimeException("boom") + } + + // then + s.receive() shouldBe "a" + s.receive() shouldBe "b" + s.receive() should matchPattern { + case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => + } + } + + it should "propagate errors in the completion callback" in scoped { + // given + val c = Source.fromValues("a", "b", "c") + + // when + val s = c.mapStatefulConcat(() => 0)((index, element) => (index + 1, Some(element)), _ => throw new RuntimeException("boom")) + + // then + s.receive() shouldBe "a" + s.receive() shouldBe "b" + s.receive() shouldBe "c" + s.receive() should matchPattern { + case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" => + } + } +} diff --git a/core/src/test/scala/ox/channels/SourceOpsStatefulMapTest.scala b/core/src/test/scala/ox/channels/SourceOpsMapStatefulTest.scala similarity index 58% rename from core/src/test/scala/ox/channels/SourceOpsStatefulMapTest.scala rename to core/src/test/scala/ox/channels/SourceOpsMapStatefulTest.scala index 54d49363..4b080fae 100644 --- a/core/src/test/scala/ox/channels/SourceOpsStatefulMapTest.scala +++ b/core/src/test/scala/ox/channels/SourceOpsMapStatefulTest.scala @@ -4,14 +4,14 @@ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import ox.* -class SourceOpsStatefulMapTest extends AnyFlatSpec with Matchers { +class SourceOpsMapStatefulTest extends AnyFlatSpec with Matchers { - behavior of "Source.statefulMap" + behavior of "Source.mapStateful" it should "zip with index" in scoped { val c = Source.fromValues("a", "b", "c") - val s = c.statefulMap(() => 0)((index, element) => (index + 1, Some((element, index)))) + val s = c.mapStateful(() => 0)((index, element) => (index + 1, (element, index))) s.toList shouldBe List(("a", 0), ("b", 1), ("c", 2)) } @@ -19,29 +19,18 @@ class SourceOpsStatefulMapTest extends AnyFlatSpec with Matchers { it should "calculate a running total" in scoped { val c = Source.fromValues(1, 2, 3, 4, 5) - val s = c.statefulMap(() => 0)((sum, element) => (sum + element, Some(sum)), Some.apply) + val s = c.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply) s.toList shouldBe List(0, 1, 3, 6, 10, 15) } - it should "deduplicate" in scoped { - val c = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5) - - val s = c.statefulMap(() => Set.empty[Int])((alreadySeen, element) => - val result = Option.unless(alreadySeen.contains(element))(element) - (alreadySeen + element, result) - ) - - s.toList shouldBe List(1, 2, 3, 4, 5) - } - it should "propagate errors in the mapping function" in scoped { // given val c = Source.fromValues("a", "b", "c") // when - val s = c.statefulMap(() => 0) { (index, element) => - if (index < 2) (index + 1, Some(element)) + val s = c.mapStateful(() => 0) { (index, element) => + if (index < 2) (index + 1, element) else throw new RuntimeException("boom") } @@ -58,7 +47,7 @@ class SourceOpsStatefulMapTest extends AnyFlatSpec with Matchers { val c = Source.fromValues("a", "b", "c") // when - val s = c.statefulMap(() => 0)((index, element) => (index + 1, Some(element)), _ => throw new RuntimeException("boom")) + val s = c.mapStateful(() => 0)((index, element) => (index + 1, element), _ => throw new RuntimeException("boom")) // then s.receive() shouldBe "a"