Skip to content

Commit

Permalink
Merge pull request #426 from typelevel/wip-fix-stream-ingest
Browse files Browse the repository at this point in the history
runtime: request(n) in stream ingest fix
  • Loading branch information
ahjohannessen authored Oct 26, 2021
2 parents f32bb7f + 034058a commit 51d3496
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 27 deletions.
4 changes: 2 additions & 2 deletions runtime/src/main/scala/client/Fs2ClientCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (

private def mkStreamListenerR(md: Metadata): Resource[F, Fs2StreamClientCallListener[F, Response]] = {

val acquire = start(Fs2StreamClientCallListener.create[F, Response](request, dispatcher, options.prefetchN), md) <*
request(options.prefetchN)
val acquire =
start(Fs2StreamClientCallListener.create[F, Response](request, dispatcher, options.prefetchN), md) <* request(1)

val release = handleExitCase(cancelSucceed = true)

Expand Down
30 changes: 12 additions & 18 deletions runtime/src/main/scala/client/StreamIngest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package grpc
package client

import cats.implicits._
import cats.effect.kernel.{Concurrent, Ref}
import cats.effect.Concurrent
import cats.effect.std.Queue

private[client] trait StreamIngest[F[_], T] {
Expand All @@ -39,39 +39,33 @@ private[client] object StreamIngest {
request: Int => F[Unit],
prefetchN: Int
): F[StreamIngest[F, T]] =
(Concurrent[F].ref(prefetchN), Queue.unbounded[F, Either[GrpcStatus, T]])
.mapN((d, q) => create[F, T](request, prefetchN, d, q))
Queue
.unbounded[F, Either[GrpcStatus, T]]
.map(q => create[F, T](request, prefetchN, q))

def create[F[_], T](
request: Int => F[Unit],
prefetchN: Int,
demand: Ref[F, Int],
queue: Queue[F, Either[GrpcStatus, T]]
)(implicit F: Concurrent[F]): StreamIngest[F, T] = new StreamIngest[F, T] {

val limit: Int =
math.max(1, prefetchN)

val ensureMessages: F[Unit] =
queue.size.flatMap(qs => request(1).whenA(qs < limit))

def onMessage(msg: T): F[Unit] =
decreaseDemandBy(1) *> queue.offer(msg.asRight)
queue.offer(msg.asRight) *> ensureMessages

def onClose(status: GrpcStatus): F[Unit] =
queue.offer(status.asLeft)

def ensureMessages(nextWhenEmpty: Int): F[Unit] =
(demand.get, queue.size).mapN((cd, qs) => fetch(nextWhenEmpty).whenA((cd + qs) < 1)).flatten

def decreaseDemandBy(n: Int): F[Unit] =
demand.update(d => math.max(d - n, 0))

def increaseDemandBy(n: Int): F[Unit] =
demand.update(_ + n)

def fetch(n: Int): F[Unit] =
request(n) *> increaseDemandBy(n)

val messages: Stream[F, T] = {

val run: F[Option[T]] =
queue.take.flatMap {
case Right(v) => v.some.pure[F] <* ensureMessages(prefetchN)
case Right(v) => ensureMessages *> v.some.pure[F]
case Left(GrpcStatus(status, trailers)) =>
if (!status.isOk) F.raiseError(status.asRuntimeException(trailers))
else none[T].pure[F]
Expand Down
8 changes: 4 additions & 4 deletions runtime/src/test/scala/client/ClientSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class ClientSuite extends Fs2GrpcSuite {
tc.tick()
assertEquals(result.value, Some(Success(List(1, 2, 3))))
assertEquals(dummy.messagesSent.size, 1)
assertEquals(dummy.requested, 2)
assertEquals(dummy.requested, 3)

}

Expand All @@ -215,7 +215,7 @@ class ClientSuite extends Fs2GrpcSuite {

assertEquals(result.value, Some(Success(List(1, 2))))
assertEquals(dummy.messagesSent.size, 1)
assertEquals(dummy.requested, 1)
assertEquals(dummy.requested, 2)
}

runTest0("stream to streamingToStreaming") { (tc, io, d) =>
Expand Down Expand Up @@ -243,7 +243,7 @@ class ClientSuite extends Fs2GrpcSuite {
tc.tick()
assertEquals(result.value, Some(Success(List(1, 2, 3))))
assertEquals(dummy.messagesSent.size, 5)
assertEquals(dummy.requested, 2)
assertEquals(dummy.requested, 3)

}

Expand Down Expand Up @@ -308,7 +308,7 @@ class ClientSuite extends Fs2GrpcSuite {
Status.INTERNAL
)
assertEquals(dummy.messagesSent.size, 5)
assertEquals(dummy.requested, 2)
assertEquals(dummy.requested, 3)

}

Expand Down
7 changes: 4 additions & 3 deletions runtime/src/test/scala/client/StreamIngestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ class StreamIngestSuite extends CatsEffectSuite with CatsEffectFunFixtures {
}

run(prefetchN = 1, takeN = 1, expectedReq = 1, expectedCount = 1) *>
run(prefetchN = 2, takeN = 1, expectedReq = 0, expectedCount = 1) *>
run(prefetchN = 1024, takeN = 1024, expectedReq = 1024, expectedCount = 1024) *>
run(prefetchN = 1024, takeN = 1023, expectedReq = 0, expectedCount = 1023)
run(prefetchN = 2, takeN = 1, expectedReq = 2, expectedCount = 1) *>
run(prefetchN = 2, takeN = 2, expectedReq = 3, expectedCount = 2) *>
run(prefetchN = 1024, takeN = 1024, expectedReq = 2047, expectedCount = 1024) *>
run(prefetchN = 1024, takeN = 1023, expectedReq = 2046, expectedCount = 1023)

}

Expand Down

0 comments on commit 51d3496

Please sign in to comment.