diff --git a/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala b/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala index 8627d8b1d2..48d4911e5e 100644 --- a/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala +++ b/zio-http-testkit/src/main/scala/zio/http/TestChannel.scala @@ -8,13 +8,15 @@ case class TestChannel( out: Queue[WebSocketChannelEvent], promise: Promise[Nothing, Unit], ) extends WebSocketChannel { - def awaitShutdown: UIO[Unit] = + def awaitShutdown: UIO[Unit] = promise.await - def receive: Task[WebSocketChannelEvent] = + def receive: Task[WebSocketChannelEvent] = in.take - def send(in: WebSocketChannelEvent): Task[Unit] = + def send(in: WebSocketChannelEvent): Task[Unit] = out.offer(in).unit - def shutdown: UIO[Unit] = + def sendAll(in: Iterable[WebSocketChannelEvent]): Task[Unit] = + out.offerAll(in).unit + def shutdown: UIO[Unit] = in.offer(ChannelEvent.Unregistered) *> out.offer(ChannelEvent.Unregistered) *> promise.succeed(()).unit diff --git a/zio-http/src/main/scala/zio/http/Channel.scala b/zio-http/src/main/scala/zio/http/Channel.scala index 9a4e405d0c..e729629feb 100644 --- a/zio-http/src/main/scala/zio/http/Channel.scala +++ b/zio-http/src/main/scala/zio/http/Channel.scala @@ -40,6 +40,11 @@ trait Channel[-In, +Out] { self => */ def send(in: In): Task[Unit] + /** + * Send all messages to the channel. + */ + def sendAll(in: Iterable[In]): Task[Unit] + /** * Shut down the channel. */ @@ -51,13 +56,15 @@ trait Channel[-In, +Out] { self => */ final def contramap[In2](f: In2 => In): Channel[In2, Out] = new Channel[In2, Out] { - def awaitShutdown: UIO[Unit] = + def awaitShutdown: UIO[Unit] = self.awaitShutdown - def receive: Task[Out] = + def receive: Task[Out] = self.receive - def send(in: In2): Task[Unit] = + def send(in: In2): Task[Unit] = self.send(f(in)) - def shutdown: UIO[Unit] = + def sendAll(in: Iterable[In2]): Task[Unit] = + self.sendAll(in.map(f)) + def shutdown: UIO[Unit] = self.shutdown } @@ -67,13 +74,15 @@ trait Channel[-In, +Out] { self => */ final def map[Out2](f: Out => Out2): Channel[In, Out2] = new Channel[In, Out2] { - def awaitShutdown: UIO[Unit] = + def awaitShutdown: UIO[Unit] = self.awaitShutdown - def receive: Task[Out2] = + def receive: Task[Out2] = self.receive.map(f) - def send(in: In): Task[Unit] = + def send(in: In): Task[Unit] = self.send(in) - def shutdown: UIO[Unit] = + def sendAll(in: Iterable[In]): Task[Unit] = + self.sendAll(in) + def shutdown: UIO[Unit] = self.shutdown } diff --git a/zio-http/src/main/scala/zio/http/WebSocketChannel.scala b/zio-http/src/main/scala/zio/http/WebSocketChannel.scala index 8823e3a319..8f072ac7f4 100644 --- a/zio-http/src/main/scala/zio/http/WebSocketChannel.scala +++ b/zio-http/src/main/scala/zio/http/WebSocketChannel.scala @@ -31,16 +31,26 @@ private[http] object WebSocketChannel { queue: Queue[WebSocketChannelEvent], ): WebSocketChannel = new WebSocketChannel { - def awaitShutdown: UIO[Unit] = + def awaitShutdown: UIO[Unit] = nettyChannel.awaitClose - def receive: Task[WebSocketChannelEvent] = + def receive: Task[WebSocketChannelEvent] = queue.take - def send(in: WebSocketChannelEvent): Task[Unit] = + def send(in: WebSocketChannelEvent): Task[Unit] = in match { case Read(message) => nettyChannel.writeAndFlush(frameToNetty(message)) case _ => ZIO.unit } - def shutdown: UIO[Unit] = + def sendAll(in: Iterable[WebSocketChannelEvent]): Task[Unit] = + ZIO.suspendSucceed { + val iterator = in.iterator.collect { case Read(message) => message } + + ZIO.whileLoop(iterator.hasNext) { + val message = iterator.next() + if (iterator.hasNext) nettyChannel.write(frameToNetty(message)) + else nettyChannel.writeAndFlush(frameToNetty(message)) + }(_ => ()) + } + def shutdown: UIO[Unit] = nettyChannel.close(false).orDie }