diff --git a/core/src/main/scala/ox/channels/SourceOps.scala b/core/src/main/scala/ox/channels/SourceOps.scala index 831c61ce..3ab9e2ea 100644 --- a/core/src/main/scala/ox/channels/SourceOps.scala +++ b/core/src/main/scala/ox/channels/SourceOps.scala @@ -3,9 +3,8 @@ package ox.channels import ox.* import java.util.concurrent.{CountDownLatch, Semaphore} -import scala.collection.mutable +import scala.collection.{IterableOnce, mutable} import scala.concurrent.duration.FiniteDuration -import scala.collection.IterableOnce trait SourceOps[+T] { this: Source[T] => // view ops (lazy) @@ -203,6 +202,25 @@ trait SourceOps[+T] { this: Source[T] => def take(n: Int)(using Ox, StageCapacity): Source[T] = transform(_.take(n)) + /** Sends elements to the returned channel until predicate `f` is satisfied (returns `true`). Note that when the predicate `f` is not + * satisfied (returns `false`), subsequent elements are dropped even if they could still satisfy it. + * + * @param f + * A predicate function. + * @example + * {{{ + * import ox.* + * import ox.channels.Source + * + * scoped { + * Source.empty[Int].takeWhile(_ > 3).toList // List() + * Source.fromValues(1, 2, 3).takeWhile(_ < 3).toList // List(1, 2) + * Source.fromValues(3, 2, 1).takeWhile(_ < 3).toList // List() + * } + * }}} + */ + def takeWhile(f: T => Boolean)(using Ox, StageCapacity): Source[T] = transform(_.takeWhile(f)) + /** Drops `n` elements from this source and forwards subsequent elements to the returned channel. * * @param n diff --git a/core/src/test/scala/ox/channels/SourceOpsTakeWhileTest.scala b/core/src/test/scala/ox/channels/SourceOpsTakeWhileTest.scala new file mode 100644 index 00000000..5bdbb5f5 --- /dev/null +++ b/core/src/test/scala/ox/channels/SourceOpsTakeWhileTest.scala @@ -0,0 +1,24 @@ +package ox.channels + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* + +class SourceOpsTakeWhileTest extends AnyFlatSpec with Matchers { + behavior of "Source.takeWhile" + + it should "not take from the empty source" in supervised { + val s = Source.empty[Int] + s.takeWhile(_ < 3).toList shouldBe List.empty + } + + it should "take as long as predicate is satisfied" in supervised { + val s = Source.fromValues(1, 2, 3) + s.takeWhile(_ < 3).toList shouldBe List(1, 2) + } + + it should "not take if predicate fails for first or more elements" in supervised { + val s = Source.fromValues(3, 2, 1) + s.takeWhile(_ < 3).toList shouldBe List() + } +}