diff --git a/zio-http-example/src/main/scala/example/WebSocketAdvanced.scala b/zio-http-example/src/main/scala/example/WebSocketAdvanced.scala index 9d41d307b0..c2c6946e29 100644 --- a/zio-http-example/src/main/scala/example/WebSocketAdvanced.scala +++ b/zio-http-example/src/main/scala/example/WebSocketAdvanced.scala @@ -24,7 +24,12 @@ object WebSocketAdvanced extends ZIOAppDefault { // Echo the same message 10 times if it's not "foo" or "bar" case Read(WebSocketFrame.Text(text)) => - channel.send(Read(WebSocketFrame.text(text))).repeatN(10) + channel + .send(Read(WebSocketFrame.text(s"echo $text"))) + .repeatN(10) + .catchSomeCause { case cause => + ZIO.logErrorCause(s"failed sending", cause) + } // Send a "greeting" message to the server once the connection is established case UserEventTriggered(UserEvent.HandshakeComplete) => diff --git a/zio-http/src/main/scala/zio/http/WebSocketChannel.scala b/zio-http/src/main/scala/zio/http/WebSocketChannel.scala index 8f072ac7f4..6f726c9fa9 100644 --- a/zio-http/src/main/scala/zio/http/WebSocketChannel.scala +++ b/zio-http/src/main/scala/zio/http/WebSocketChannel.scala @@ -31,19 +31,24 @@ private[http] object WebSocketChannel { queue: Queue[WebSocketChannelEvent], ): WebSocketChannel = new WebSocketChannel { - def awaitShutdown: UIO[Unit] = + def awaitShutdown: UIO[Unit] = nettyChannel.awaitClose - def receive: Task[WebSocketChannelEvent] = + + def receive: Task[WebSocketChannelEvent] = queue.take - def send(in: WebSocketChannelEvent): Task[Unit] = + + def send(in: WebSocketChannelEvent): Task[Unit] = { in match { case Read(message) => nettyChannel.writeAndFlush(frameToNetty(message)) case _ => ZIO.unit } + } + def sendAll(in: Iterable[WebSocketChannelEvent]): Task[Unit] = ZIO.suspendSucceed { val iterator = in.iterator.collect { case Read(message) => message } + println(s"sendAll") ZIO.whileLoop(iterator.hasNext) { val message = iterator.next() if (iterator.hasNext) nettyChannel.write(frameToNetty(message)) @@ -54,7 +59,7 @@ private[http] object WebSocketChannel { nettyChannel.close(false).orDie } - private def frameToNetty(frame: WebSocketFrame): JWebSocketFrame = + private def frameToNetty(frame: WebSocketFrame): JWebSocketFrame = { frame match { case b: WebSocketFrame.Binary => new BinaryWebSocketFrame(b.isFinal, 0, Unpooled.wrappedBuffer(b.bytes.toArray)) @@ -71,4 +76,5 @@ private[http] object WebSocketChannel { case c: WebSocketFrame.Continuation => new ContinuationWebSocketFrame(c.isFinal, 0, Unpooled.wrappedBuffer(c.buffer.toArray)) } + } } diff --git a/zio-http/src/main/scala/zio/http/netty/NettyChannel.scala b/zio-http/src/main/scala/zio/http/netty/NettyChannel.scala index 411215a855..a978775877 100644 --- a/zio-http/src/main/scala/zio/http/netty/NettyChannel.scala +++ b/zio-http/src/main/scala/zio/http/netty/NettyChannel.scala @@ -19,8 +19,6 @@ package zio.http.netty import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.{Task, Trace, UIO, ZIO} -import zio.http.Channel - import io.netty.channel.{Channel => JChannel, ChannelFuture => JChannelFuture} final case class NettyChannel[-A]( private val channel: JChannel, @@ -28,9 +26,12 @@ final case class NettyChannel[-A]( ) { self => - private def foreach[S](await: Boolean)(run: JChannel => JChannelFuture)(implicit trace: Trace): Task[Unit] = { - if (await) NettyFutureExecutor.executed(run(channel)) - else ZIO.attempt(run(channel): Unit) + private def run(await: Boolean)(run: JChannel => JChannelFuture)(implicit trace: Trace): Task[Unit] = { + if (await) { + NettyFutureExecutor.executed(run(channel)) + } else { + ZIO.attempt(run(channel): Unit) + } } def autoRead(flag: Boolean)(implicit trace: Trace): UIO[Unit] = @@ -41,7 +42,8 @@ final case class NettyChannel[-A]( () } - def close(await: Boolean = false)(implicit trace: Trace): Task[Unit] = foreach(await) { _.close() } + def close(await: Boolean = false)(implicit trace: Trace): Task[Unit] = + run(await) { _.close() } def contramap[A1](f: A1 => A): NettyChannel[A1] = copy(convert = convert.compose(f)) @@ -53,13 +55,15 @@ final case class NettyChannel[-A]( def read(implicit trace: Trace): UIO[Unit] = ZIO.succeed(channel.read(): Unit) - def write(msg: A, await: Boolean = false)(implicit trace: Trace): Task[Unit] = foreach(await) { - _.write(convert(msg)) - } + def write(msg: A, await: Boolean = false)(implicit trace: Trace): Task[Unit] = + run(await) { + _.write(convert(msg)) + } - def writeAndFlush(msg: A, await: Boolean = true)(implicit trace: Trace): Task[Unit] = foreach(await) { - _.writeAndFlush(convert(msg)) - } + def writeAndFlush(msg: => A, await: Boolean = true)(implicit trace: Trace): Task[Unit] = + run(await) { ch => + ch.writeAndFlush(convert(msg)) + } } object NettyChannel { diff --git a/zio-http/src/main/scala/zio/http/netty/NettyFutureExecutor.scala b/zio-http/src/main/scala/zio/http/netty/NettyFutureExecutor.scala index b57676c3eb..d42d9cf7e7 100644 --- a/zio-http/src/main/scala/zio/http/netty/NettyFutureExecutor.scala +++ b/zio-http/src/main/scala/zio/http/netty/NettyFutureExecutor.scala @@ -66,7 +66,8 @@ object NettyFutureExecutor { def make[A](jFuture: => Future[A])(implicit trace: Trace): Task[NettyFutureExecutor[A]] = ZIO.succeed(new NettyFutureExecutor(jFuture)) - def executed[A](jFuture: => Future[A])(implicit trace: Trace): Task[Unit] = make(jFuture).flatMap(_.execute.unit) + def executed[A](jFuture: => Future[A])(implicit trace: Trace): Task[Unit] = + make(jFuture).flatMap(_.execute.unit) def scoped[A](jFuture: => Future[A])(implicit trace: Trace): ZIO[Scope, Throwable, Unit] = make(jFuture).flatMap(_.scoped.unit)