Skip to content

Commit

Permalink
Implement Channel#sendAll (#2315)
Browse files Browse the repository at this point in the history
* implement channel send all

* fix tests

* format

* write and flush

* ignore flaky test

* cleanup

* fix compilation error
  • Loading branch information
adamgfraser authored Jul 28, 2023
1 parent 77e1c8f commit 18b9a96
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
10 changes: 6 additions & 4 deletions zio-http-testkit/src/main/scala/zio/http/TestChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ case class TestChannel(
out: Queue[WebSocketChannelEvent],
promise: Promise[Nothing, Unit],
) extends WebSocketChannel {
def awaitShutdown: UIO[Unit] =
def awaitShutdown: UIO[Unit] =
promise.await
def receive: Task[WebSocketChannelEvent] =
def receive: Task[WebSocketChannelEvent] =
in.take
def send(in: WebSocketChannelEvent): Task[Unit] =
def send(in: WebSocketChannelEvent): Task[Unit] =
out.offer(in).unit
def shutdown: UIO[Unit] =
def sendAll(in: Iterable[WebSocketChannelEvent]): Task[Unit] =
out.offerAll(in).unit
def shutdown: UIO[Unit] =
in.offer(ChannelEvent.Unregistered) *>
out.offer(ChannelEvent.Unregistered) *>
promise.succeed(()).unit
Expand Down
25 changes: 17 additions & 8 deletions zio-http/src/main/scala/zio/http/Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ trait Channel[-In, +Out] { self =>
*/
def send(in: In): Task[Unit]

/**
* Send all messages to the channel.
*/
def sendAll(in: Iterable[In]): Task[Unit]

/**
* Shut down the channel.
*/
Expand All @@ -51,13 +56,15 @@ trait Channel[-In, +Out] { self =>
*/
final def contramap[In2](f: In2 => In): Channel[In2, Out] =
new Channel[In2, Out] {
def awaitShutdown: UIO[Unit] =
def awaitShutdown: UIO[Unit] =
self.awaitShutdown
def receive: Task[Out] =
def receive: Task[Out] =
self.receive
def send(in: In2): Task[Unit] =
def send(in: In2): Task[Unit] =
self.send(f(in))
def shutdown: UIO[Unit] =
def sendAll(in: Iterable[In2]): Task[Unit] =
self.sendAll(in.map(f))
def shutdown: UIO[Unit] =
self.shutdown
}

Expand All @@ -67,13 +74,15 @@ trait Channel[-In, +Out] { self =>
*/
final def map[Out2](f: Out => Out2): Channel[In, Out2] =
new Channel[In, Out2] {
def awaitShutdown: UIO[Unit] =
def awaitShutdown: UIO[Unit] =
self.awaitShutdown
def receive: Task[Out2] =
def receive: Task[Out2] =
self.receive.map(f)
def send(in: In): Task[Unit] =
def send(in: In): Task[Unit] =
self.send(in)
def shutdown: UIO[Unit] =
def sendAll(in: Iterable[In]): Task[Unit] =
self.sendAll(in)
def shutdown: UIO[Unit] =
self.shutdown
}

Expand Down
18 changes: 14 additions & 4 deletions zio-http/src/main/scala/zio/http/WebSocketChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,26 @@ private[http] object WebSocketChannel {
queue: Queue[WebSocketChannelEvent],
): WebSocketChannel =
new WebSocketChannel {
def awaitShutdown: UIO[Unit] =
def awaitShutdown: UIO[Unit] =
nettyChannel.awaitClose
def receive: Task[WebSocketChannelEvent] =
def receive: Task[WebSocketChannelEvent] =
queue.take
def send(in: WebSocketChannelEvent): Task[Unit] =
def send(in: WebSocketChannelEvent): Task[Unit] =
in match {
case Read(message) => nettyChannel.writeAndFlush(frameToNetty(message))
case _ => ZIO.unit
}
def shutdown: UIO[Unit] =
def sendAll(in: Iterable[WebSocketChannelEvent]): Task[Unit] =
ZIO.suspendSucceed {
val iterator = in.iterator.collect { case Read(message) => message }

ZIO.whileLoop(iterator.hasNext) {
val message = iterator.next()
if (iterator.hasNext) nettyChannel.write(frameToNetty(message))
else nettyChannel.writeAndFlush(frameToNetty(message))
}(_ => ())
}
def shutdown: UIO[Unit] =
nettyChannel.close(false).orDie
}

Expand Down

0 comments on commit 18b9a96

Please sign in to comment.