Skip to content

Commit

Permalink
Do Not Continue Reading From Web Socket After Terminal Event (#2441)
Browse files Browse the repository at this point in the history
do not continue reading after terminal event
  • Loading branch information
adamgfraser authored Sep 24, 2023
1 parent e27c016 commit 7489570
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 17 deletions.
12 changes: 12 additions & 0 deletions zio-http-testkit/src/main/scala/zio/http/TestChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ case class TestChannel(
promise.await
def receive(implicit trace: Trace): Task[WebSocketChannelEvent] =
in.take
def receiveAll[Env, Err](f: WebSocketChannelEvent => ZIO[Env, Err, Any])(implicit
trace: Trace,
): ZIO[Env, Err, Unit] = {
lazy val loop: ZIO[Env, Err, Unit] =
in.take.flatMap {
case event @ ChannelEvent.ExceptionCaught(_) => f(event).unit
case event @ ChannelEvent.Unregistered => f(event).unit
case event => f(event) *> ZIO.yieldNow *> loop
}

loop
}
def send(in: WebSocketChannelEvent)(implicit trace: Trace): Task[Unit] =
out.offer(in).unit
def sendAll(in: Iterable[WebSocketChannelEvent])(implicit trace: Trace): Task[Unit] =
Expand Down
37 changes: 20 additions & 17 deletions zio-http/src/main/scala/zio/http/Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ trait Channel[-In, +Out] { self =>
*/
def receive(implicit trace: Trace): Task[Out]

/**
* Reads all messages from the channel, handling them with the specified
* function.
*/
def receiveAll[Env, Err](f: Out => ZIO[Env, Err, Any])(implicit trace: Trace): ZIO[Env, Err, Unit]

/**
* Send a message to the channel.
*/
Expand All @@ -57,15 +63,17 @@ trait Channel[-In, +Out] { self =>
*/
final def contramap[In2](f: In2 => In): Channel[In2, Out] =
new Channel[In2, Out] {
def awaitShutdown(implicit trace: Trace): UIO[Unit] =
def awaitShutdown(implicit trace: Trace): UIO[Unit] =
self.awaitShutdown
def receive(implicit trace: Trace): Task[Out] =
def receive(implicit trace: Trace): Task[Out] =
self.receive
def send(in: In2)(implicit trace: Trace): Task[Unit] =
def receiveAll[Env, Err](g: Out => ZIO[Env, Err, Any])(implicit trace: Trace): ZIO[Env, Err, Unit] =
self.receiveAll(g)
def send(in: In2)(implicit trace: Trace): Task[Unit] =
self.send(f(in))
def sendAll(in: Iterable[In2])(implicit trace: Trace): Task[Unit] =
def sendAll(in: Iterable[In2])(implicit trace: Trace): Task[Unit] =
self.sendAll(in.map(f))
def shutdown(implicit trace: Trace): UIO[Unit] =
def shutdown(implicit trace: Trace): UIO[Unit] =
self.shutdown
}

Expand All @@ -75,22 +83,17 @@ trait Channel[-In, +Out] { self =>
*/
final def map[Out2](f: Out => Out2)(implicit trace: Trace): Channel[In, Out2] =
new Channel[In, Out2] {
def awaitShutdown(implicit trace: Trace): UIO[Unit] =
def awaitShutdown(implicit trace: Trace): UIO[Unit] =
self.awaitShutdown
def receive(implicit trace: Trace): Task[Out2] =
def receive(implicit trace: Trace): Task[Out2] =
self.receive.map(f)
def send(in: In)(implicit trace: Trace): Task[Unit] =
def receiveAll[Env, Err](g: Out2 => ZIO[Env, Err, Any])(implicit trace: Trace): ZIO[Env, Err, Unit] =
self.receiveAll(f andThen g)
def send(in: In)(implicit trace: Trace): Task[Unit] =
self.send(in)
def sendAll(in: Iterable[In])(implicit trace: Trace): Task[Unit] =
def sendAll(in: Iterable[In])(implicit trace: Trace): Task[Unit] =
self.sendAll(in)
def shutdown(implicit trace: Trace): UIO[Unit] =
def shutdown(implicit trace: Trace): UIO[Unit] =
self.shutdown
}

/**
* Reads all messages from the channel, handling them with the specified
* function.
*/
final def receiveAll[Env](f: Out => ZIO[Env, Throwable, Any])(implicit trace: Trace): ZIO[Env, Throwable, Nothing] =
receive.flatMap(f).forever
}
13 changes: 13 additions & 0 deletions zio-http/src/main/scala/zio/http/WebSocketChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ private[http] object WebSocketChannel {
def receive(implicit trace: Trace): Task[WebSocketChannelEvent] =
queue.take

def receiveAll[Env, Err](f: WebSocketChannelEvent => ZIO[Env, Err, Any])(implicit
trace: Trace,
): ZIO[Env, Err, Unit] = {
lazy val loop: ZIO[Env, Err, Unit] =
queue.take.flatMap {
case event @ ChannelEvent.ExceptionCaught(_) => f(event).unit
case event @ ChannelEvent.Unregistered => f(event).unit
case event => f(event) *> ZIO.yieldNow *> loop
}

loop
}

def send(in: WebSocketChannelEvent)(implicit trace: Trace): Task[Unit] = {
in match {
case Read(message) => nettyChannel.writeAndFlush(frameToNetty(message))
Expand Down

0 comments on commit 7489570

Please sign in to comment.