Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Source.mapStateful and Source.mapStatefulConcat combinators #15

Merged
merged 4 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 109 additions & 1 deletion core/src/main/scala/ox/channels/SourceOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import ox.*
import java.util.concurrent.{ArrayBlockingQueue, ConcurrentLinkedQueue, CountDownLatch, LinkedBlockingQueue, Semaphore}
import scala.collection.mutable
import scala.concurrent.duration.FiniteDuration
import scala.collection.IterableOnce

trait SourceOps[+T] { this: Source[T] =>
// view ops (lazy)
Expand Down Expand Up @@ -244,6 +245,113 @@ trait SourceOps[+T] { this: Source[T] =>
def drain(): Unit = foreach(_ => ())

def applied[U](f: Source[T] => U): U = f(this)

/** Applies the given mapping function `f`, using additional state, to each element received from this source, and sends the results to
* the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel once this
* source is done.
*
* The `initializeState` function is called once when `statefulMap` is called.
*
* The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the
* returned channel, while an empty value will be ignored.
*
* @param initializeState
* A function that initializes the state.
* @param f
* A function that transforms the element from this source and the state into a pair of the next state and the result which is sent
* sent to the returned channel.
* @param onComplete
* A function that transforms the final state into an optional element sent to the returned channel. By default the final state is
* ignored.
* @return
* A source to which the results of applying `f` to the elements from this source would be sent.
* @example
* {{{
* scala>
* import ox.*
* import ox.channels.Source
*
* scoped {
* val s = Source.fromValues(1, 2, 3, 4, 5)
* s.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply)
* }
*
* scala> val res0: List[Int] = List(0, 1, 3, 6, 10, 15)
* }}}
*/
def mapStateful[S, U >: T](
initializeState: () => S
)(f: (S, T) => (S, U), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] =
def resultToSome(s: S, t: T) =
val (newState, result) = f(s, t)
(newState, Some(result))

mapStatefulConcat(initializeState)(resultToSome, onComplete)

/** Applies the given mapping function `f`, using additional state, to each element received from this source, and sends the results one
* by one to the returned channel. Optionally sends an additional element, possibly based on the final state, to the returned channel
* once this source is done.
*
* The `initializeState` function is called once when `statefulMap` is called.
*
* The `onComplete` function is called once when this source is done. If it returns a non-empty value, the value will be sent to the
* returned channel, while an empty value will be ignored.
*
* @param initializeState
* A function that initializes the state.
* @param f
* A function that transforms the element from this source and the state into a pair of the next state and a
* [[scala.collection.IterableOnce]] of results which are sent one by one to the returned channel. If the result of `f` is empty,
* nothing is sent to the returned channel.
* @param onComplete
* A function that transforms the final state into an optional element sent to the returned channel. By default the final state is
* ignored.
* @return
* A source to which the results of applying `f` to the elements from this source would be sent.
* @example
* {{{
* scala>
* import ox.*
* import ox.channels.Source
*
* scoped {
* val s = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5)
* // deduplicate the values
* s.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e)))
* }
*
* scala> val res0: List[Int] = List(1, 2, 3, 4, 5)
* }}}
*/
def mapStatefulConcat[S, U >: T](
initializeState: () => S
)(f: (S, T) => (S, IterableOnce[U]), onComplete: S => Option[U] = (_: S) => None)(using Ox, StageCapacity): Source[U] =
val c = StageCapacity.newChannel[U]
forkDaemon {
var state = initializeState()
repeatWhile {
receive() match
case ChannelClosed.Done =>
try
onComplete(state).foreach(c.send)
c.done()
catch case t: Throwable => c.error(t)
false
case ChannelClosed.Error(r) =>
c.error(r)
false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we already have the desired error handling, so no changes should be necessary :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're still checking the result of c.send below in result.iterator.map(c.send).forall(_.isValue), although we assume that send should not fail since we control c, so we should probably always return true after sending, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, though this at worst is dead code - not that there's a possiblity of incorrect behavior. We can change this though:

  • ignoring the return value probably needs a comment that we know what we are doing - and why
  • in larger code chunks (which are harder to follow), we might opt for the defensive solution anyway

case t: T @unchecked =>
try
val (nextState, result) = f(state, t)
state = nextState
result.iterator.map(c.send).forall(_.isValue)
catch
case t: Throwable =>
c.error(t)
false
}
}
c
}

trait SourceCompanionOps:
Expand Down Expand Up @@ -419,7 +527,7 @@ trait SourceCompanionOps:
StageCapacity
): Source[T] =
sources match
case Nil => Source.empty
case Nil => Source.empty
case single :: Nil => single
case _ =>
val c = StageCapacity.newChannel[T]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package ox.channels

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

class SourceOpsMapStatefulConcatTest extends AnyFlatSpec with Matchers {

behavior of "Source.mapStatefulConcat"

it should "deduplicate" in scoped {
// given
val c = Source.fromValues(1, 2, 2, 3, 2, 4, 3, 1, 5)

// when
val s = c.mapStatefulConcat(() => Set.empty[Int])((s, e) => (s + e, Option.unless(s.contains(e))(e)))

// then
s.toList shouldBe List(1, 2, 3, 4, 5)
}

it should "count consecutive" in scoped {
// given
val c = Source.fromValues("apple", "apple", "apple", "banana", "orange", "orange", "apple")

// when
val s = c.mapStatefulConcat(() => (Option.empty[String], 0))(
{ case ((previous, count), e) =>
previous match
case None => ((Some(e), 1), None)
case Some(`e`) => ((previous, count + 1), None)
case Some(_) => ((Some(e), 1), previous.map((_, count)))
},
{ case (previous, count) => previous.map((_, count)) }
)

// then
s.toList shouldBe List(
("apple", 3),
("banana", 1),
("orange", 2),
("apple", 1)
)
}

it should "propagate errors in the mapping function" in scoped {
// given
val c = Source.fromValues("a", "b", "c")

// when
val s = c.mapStatefulConcat(() => 0) { (index, element) =>
if (index < 2) (index + 1, Some(element))
else throw new RuntimeException("boom")
}

// then
s.receive() shouldBe "a"
s.receive() shouldBe "b"
s.receive() should matchPattern {
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
}
}

it should "propagate errors in the completion callback" in scoped {
// given
val c = Source.fromValues("a", "b", "c")

// when
val s = c.mapStatefulConcat(() => 0)((index, element) => (index + 1, Some(element)), _ => throw new RuntimeException("boom"))

// then
s.receive() shouldBe "a"
s.receive() shouldBe "b"
s.receive() shouldBe "c"
s.receive() should matchPattern {
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
}
}
}
60 changes: 60 additions & 0 deletions core/src/test/scala/ox/channels/SourceOpsMapStatefulTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package ox.channels

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.*

class SourceOpsMapStatefulTest extends AnyFlatSpec with Matchers {

behavior of "Source.mapStateful"

it should "zip with index" in scoped {
val c = Source.fromValues("a", "b", "c")

val s = c.mapStateful(() => 0)((index, element) => (index + 1, (element, index)))

s.toList shouldBe List(("a", 0), ("b", 1), ("c", 2))
}

it should "calculate a running total" in scoped {
val c = Source.fromValues(1, 2, 3, 4, 5)

val s = c.mapStateful(() => 0)((sum, element) => (sum + element, sum), Some.apply)

s.toList shouldBe List(0, 1, 3, 6, 10, 15)
}

it should "propagate errors in the mapping function" in scoped {
// given
val c = Source.fromValues("a", "b", "c")

// when
val s = c.mapStateful(() => 0) { (index, element) =>
if (index < 2) (index + 1, element)
else throw new RuntimeException("boom")
}

// then
s.receive() shouldBe "a"
s.receive() shouldBe "b"
s.receive() should matchPattern {
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
}
}

it should "propagate errors in the completion callback" in scoped {
// given
val c = Source.fromValues("a", "b", "c")

// when
val s = c.mapStateful(() => 0)((index, element) => (index + 1, element), _ => throw new RuntimeException("boom"))

// then
s.receive() shouldBe "a"
s.receive() shouldBe "b"
s.receive() shouldBe "c"
s.receive() should matchPattern {
case ChannelClosed.Error(Some(reason)) if reason.getMessage == "boom" =>
}
}
}