Skip to content

Commit

Permalink
Merge pull request #15 from softwaremill/source-stateful-map
Browse files Browse the repository at this point in the history
Add Source.mapStateful and Source.mapStatefulConcat combinators
  • Loading branch information
adamw authored Oct 17, 2023
2 parents 12e80dd + 3db1ded commit 02f5198
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 0 deletions.
108 changes: 108 additions & 0 deletions core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import ox.*
import java.util.concurrent.{CountDownLatch, Semaphore}
import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
import scala.collection.IterableOnce

trait SourceOps[+T] { this: Source[T] =>
// view ops (lazy)
Expand Down Expand Up @@ -311,6 +312,113 @@ trait SourceOps[+T] { this: Source[T] =>
def drain(): Unit = foreach(_ => ())

def applied[U](f: Source[T] => U): U = f(this)

/** Applies the given mapping function `f`, using additional state, to each element received from this source, and sends the results to
* the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel once this
* source is done.
*
* The `initializeState` function is called once when `statefulMap` is called.
*
* The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the
* returned channel, while an empty value will be ignored.
*
* @param initializeState
* A function that initializes the state.
* @param f
* A function that transforms the element from this source and the state into a pair of the next state and the result which is sent
* sent to the returned channel.
* @param onComplete
* A function that transforms the final state into an optional element sent to the returned channel. By default the final state is
* ignored.
* @return
* A source to which the results of applying `f` to the elements from this source would be sent.
* @example
* {{{
* scala>
* import ox.*
* import ox.channels.Source
*
* scoped {
* val s = Source.fromValues(1, 2, 3, 4, 5)
* s.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply)
* }
*
* scala> val res0: List[Int] = List(0, 1, 3, 6, 10, 15)
* }}}
*/
def mapStateful[S, U >: T](
initializeState: () => S
)(f: (S, T) => (S, U), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] =
def resultToSome(s: S, t: T) =
val (newState, result) = f(s, t)
(newState, Some(result))

mapStatefulConcat(initializeState)(resultToSome, onComplete)

/** Applies the given mapping function `f`, using additional state, to each element received from this source, and sends the results one
* by one to the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel
* once this source is done.
*
* The `initializeState` function is called once when `statefulMap` is called.
*
* The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the
* returned channel, while an empty value will be ignored.
*
* @param initializeState
* A function that initializes the state.
* @param f
* A function that transforms the element from this source and the state into a pair of the next state and a
* [[scala.collection.IterableOnce]] of results which are sent one by one to the returned channel. If the result of `f` is empty,
* nothing is sent to the returned channel.
* @param onComplete
* A function that transforms the final state into an optional element sent to the returned channel. By default the final state is
* ignored.
* @return
* A source to which the results of applying `f` to the elements from this source would be sent.
* @example
* {{{
* scala>
* import ox.*
* import ox.channels.Source
*
* scoped {
* val s = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5)
* // deduplicate the values
* s.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e)))
* }
*
* scala> val res0: List[Int] = List(1, 2, 3, 4, 5)
* }}}
*/
def mapStatefulConcat[S, U >: T](
initializeState: () => S
)(f: (S, T) => (S, IterableOnce[U]), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] =
val c = StageCapacity.newChannel[U]
forkDaemon {
var state = initializeState()
repeatWhile {
receive() match
case ChannelClosed.Done =>
try
onComplete(state).foreach(c.send)
c.done()
catch case t: Throwable => c.error(t)
false
case ChannelClosed.Error(r) =>
c.error(r)
false
case t: T @unchecked =>
try
val (nextState, result) = f(state, t)
state = nextState
result.iterator.map(c.send).forall(_.isValue)
catch
case t: Throwable =>
c.error(t)
false
}
}
c
}

trait SourceCompanionOps:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package ox.channels

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

class SourceOpsMapStatefulConcatTest extends AnyFlatSpec with Matchers {

behavior of "Source.mapStatefulConcat"

it should "deduplicate" in scoped {
// given
val c = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5)

// when
val s = c.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e)))

// then
s.toList shouldBe List(1, 2, 3, 4, 5)
}

it should "count consecutive" in scoped {
// given
val c = Source.fromValues("apple", "apple", "apple", "banana", "orange", "orange", "apple")

// when
val s = c.mapStatefulConcat(() => (Option.empty[String], 0))(
{ case ((previous, count), e) =>
previous match
case None => ((Some(e), 1), None)
case Some(`e`) => ((previous, count + 1), None)
case Some(_) => ((Some(e), 1), previous.map((_, count)))
},
{ case (previous, count) => previous.map((_, count)) }
)

// then
s.toList shouldBe List(
("apple", 3),
("banana", 1),
("orange", 2),
("apple", 1)
)
}

it should "propagate errors in the mapping function" in scoped {
// given
val c = Source.fromValues("a", "b", "c")

// when
val s = c.mapStatefulConcat(() => 0) { (index, element) =>
if (index < 2) (index + 1, Some(element))
else throw new RuntimeException("boom")
}

// then
s.receive() shouldBe "a"
s.receive() shouldBe "b"
s.receive() should matchPattern {
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
}
}

it should "propagate errors in the completion callback" in scoped {
// given
val c = Source.fromValues("a", "b", "c")

// when
val s = c.mapStatefulConcat(() => 0)((index, element) => (index + 1, Some(element)), _ => throw new RuntimeException("boom"))

// then
s.receive() shouldBe "a"
s.receive() shouldBe "b"
s.receive() shouldBe "c"
s.receive() should matchPattern {
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
}
}
}
60 changes: 60 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsMapStatefulTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package ox.channels

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

class SourceOpsMapStatefulTest extends AnyFlatSpec with Matchers {

behavior of "Source.mapStateful"

it should "zip with index" in scoped {
val c = Source.fromValues("a", "b", "c")

val s = c.mapStateful(() => 0)((index, element) => (index + 1, (element, index)))

s.toList shouldBe List(("a", 0), ("b", 1), ("c", 2))
}

it should "calculate a running total" in scoped {
val c = Source.fromValues(1, 2, 3, 4, 5)

val s = c.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply)

s.toList shouldBe List(0, 1, 3, 6, 10, 15)
}

it should "propagate errors in the mapping function" in scoped {
// given
val c = Source.fromValues("a", "b", "c")

// when
val s = c.mapStateful(() => 0) { (index, element) =>
if (index < 2) (index + 1, element)
else throw new RuntimeException("boom")
}

// then
s.receive() shouldBe "a"
s.receive() shouldBe "b"
s.receive() should matchPattern {
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
}
}

it should "propagate errors in the completion callback" in scoped {
// given
val c = Source.fromValues("a", "b", "c")

// when
val s = c.mapStateful(() => 0)((index, element) => (index + 1, element), _ => throw new RuntimeException("boom"))

// then
s.receive() shouldBe "a"
s.receive() shouldBe "b"
s.receive() shouldBe "c"
s.receive() should matchPattern {
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
}
}
}

0 comments on commit 02f5198

Please sign in to comment.