Skip to content

Commit

Permalink
Add Source.statefulMap combinator
Browse files Browse the repository at this point in the history
  • Loading branch information
rucek committed Oct 6, 2023
1 parent 8efe299 commit 2990190
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
56 changes: 56 additions & 0 deletions core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,62 @@ trait SourceOps[+T] { this: Source[T] =>
def drain(): Unit = foreach(_ => ())

def applied[U](f: Source[T] => U): U = f(this)

/** Applies the given mapping function `f`, using additional mutable state, to each element received from this source, and sends the
* results to the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel
* once this source is done.
*
* The `initializeState` function is called once when `statefulMap` is called.
*
* The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the
* returned channel, while an empty value will be ignored.
*
* @param initializeState
* A function that initializes the state.
* @param f
* A function that transforms the element from this source and the state into an optional pair of the next state and the result. If `f`
* returns a nonempty value, th result will be sent to the returned channel, otherwise it will be ignored.
* @param onComplete
* A function that transforms the final state into an optional element sent to the returned channel. By default the final state is
* ignored.
* @return
* A source to which the results of applying `f` to the elements from this source would be sent.
* @example
* {{{
* scala>
* import ox.*
* import ox.channels.Source
*
* scoped {
* val s = Source.fromValues(1, 2, 3, 4, 5)
* s.statefulMap(() => 0)((sum, element) => (sum + element, Some(sum)), Some.apply)
* }
*
* scala> val res0: List[Int] = List(0, 1, 3, 6, 10, 15)
* }}}
*/
def statefulMap[S, U >: T](
initializeState: () => S
)(f: (S, T) => (S, Option[U]), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] =
val c = StageCapacity.newChannel[U]
forkDaemon {
var state = initializeState()
repeatWhile {
receive() match
case ChannelClosed.Done =>
onComplete(state).foreach(c.send)
c.done()
false
case ChannelClosed.Error(r) =>
c.error(r)
false
case t: T @unchecked =>
val (nextState, result) = f(state, t)
state = nextState
result.map(c.send(_).isValue).getOrElse(true)
}
}
c
}

trait SourceCompanionOps:
Expand Down
37 changes: 37 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsStatefulMapTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package ox.channels

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

class SourceOpsStatefulMapTest extends AnyFlatSpec with Matchers {

behavior of "Source.statefulMap"

it should "zip with index" in scoped {
val c = Source.fromValues("a", "b", "c")

val s = c.statefulMap(() => 0)((index, element) => (index + 1, Some((element, index))))

s.toList shouldBe List(("a", 0), ("b", 1), ("c", 2))
}

it should "calculate a running total" in scoped {
val c = Source.fromValues(1, 2, 3, 4, 5)

val s = c.statefulMap(() => 0)((sum, element) => (sum + element, Some(sum)), Some.apply)

s.toList shouldBe List(0, 1, 3, 6, 10, 15)
}

it should "deduplicate" in scoped {
val c = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5)

val s = c.statefulMap(() => Set.empty[Int])((alreadySeen, element) =>
val result = Option.unless(alreadySeen.contains(element))(element)
(alreadySeen + element, result)
)

s.toList shouldBe List(1, 2, 3, 4, 5)
}
}

0 comments on commit 2990190

Please sign in to comment.