Skip to content

Commit

Permalink
Fix issue with websocket send
Browse files Browse the repository at this point in the history
  • Loading branch information
vigoo committed Jul 28, 2023
1 parent 500cbb7 commit 02c8e15
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ object WebSocketAdvanced extends ZIOAppDefault {

// Echo the same message 10 times if it's not "foo" or "bar"
case Read(WebSocketFrame.Text(text)) =>
channel.send(Read(WebSocketFrame.text(text))).repeatN(10)
channel
.send(Read(WebSocketFrame.text(s"echo $text")))
.repeatN(10)
.catchSomeCause { case cause =>
ZIO.logErrorCause(s"failed sending", cause)
}

// Send a "greeting" message to the server once the connection is established
case UserEventTriggered(UserEvent.HandshakeComplete) =>
Expand Down
14 changes: 10 additions & 4 deletions zio-http/src/main/scala/zio/http/WebSocketChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,24 @@ private[http] object WebSocketChannel {
queue: Queue[WebSocketChannelEvent],
): WebSocketChannel =
new WebSocketChannel {
def awaitShutdown: UIO[Unit] =
def awaitShutdown: UIO[Unit] =
nettyChannel.awaitClose
def receive: Task[WebSocketChannelEvent] =

def receive: Task[WebSocketChannelEvent] =
queue.take
def send(in: WebSocketChannelEvent): Task[Unit] =

def send(in: WebSocketChannelEvent): Task[Unit] = {
in match {
case Read(message) => nettyChannel.writeAndFlush(frameToNetty(message))
case _ => ZIO.unit
}
}

def sendAll(in: Iterable[WebSocketChannelEvent]): Task[Unit] =
ZIO.suspendSucceed {
val iterator = in.iterator.collect { case Read(message) => message }

println(s"sendAll")
ZIO.whileLoop(iterator.hasNext) {
val message = iterator.next()
if (iterator.hasNext) nettyChannel.write(frameToNetty(message))
Expand All @@ -54,7 +59,7 @@ private[http] object WebSocketChannel {
nettyChannel.close(false).orDie
}

private def frameToNetty(frame: WebSocketFrame): JWebSocketFrame =
private def frameToNetty(frame: WebSocketFrame): JWebSocketFrame = {
frame match {
case b: WebSocketFrame.Binary =>
new BinaryWebSocketFrame(b.isFinal, 0, Unpooled.wrappedBuffer(b.bytes.toArray))
Expand All @@ -71,4 +76,5 @@ private[http] object WebSocketChannel {
case c: WebSocketFrame.Continuation =>
new ContinuationWebSocketFrame(c.isFinal, 0, Unpooled.wrappedBuffer(c.buffer.toArray))
}
}
}
28 changes: 16 additions & 12 deletions zio-http/src/main/scala/zio/http/netty/NettyChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ package zio.http.netty
import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Task, Trace, UIO, ZIO}

import zio.http.Channel

import io.netty.channel.{Channel => JChannel, ChannelFuture => JChannelFuture}
final case class NettyChannel[-A](
private val channel: JChannel,
private val convert: A => Any,
) {
self =>

private def foreach[S](await: Boolean)(run: JChannel => JChannelFuture)(implicit trace: Trace): Task[Unit] = {
if (await) NettyFutureExecutor.executed(run(channel))
else ZIO.attempt(run(channel): Unit)
private def run(await: Boolean)(run: JChannel => JChannelFuture)(implicit trace: Trace): Task[Unit] = {
if (await) {
NettyFutureExecutor.executed(run(channel))
} else {
ZIO.attempt(run(channel): Unit)
}
}

def autoRead(flag: Boolean)(implicit trace: Trace): UIO[Unit] =
Expand All @@ -41,7 +42,8 @@ final case class NettyChannel[-A](
()
}

def close(await: Boolean = false)(implicit trace: Trace): Task[Unit] = foreach(await) { _.close() }
def close(await: Boolean = false)(implicit trace: Trace): Task[Unit] =
run(await) { _.close() }

def contramap[A1](f: A1 => A): NettyChannel[A1] = copy(convert = convert.compose(f))

Expand All @@ -53,13 +55,15 @@ final case class NettyChannel[-A](

def read(implicit trace: Trace): UIO[Unit] = ZIO.succeed(channel.read(): Unit)

def write(msg: A, await: Boolean = false)(implicit trace: Trace): Task[Unit] = foreach(await) {
_.write(convert(msg))
}
def write(msg: A, await: Boolean = false)(implicit trace: Trace): Task[Unit] =
run(await) {
_.write(convert(msg))
}

def writeAndFlush(msg: A, await: Boolean = true)(implicit trace: Trace): Task[Unit] = foreach(await) {
_.writeAndFlush(convert(msg))
}
def writeAndFlush(msg: => A, await: Boolean = true)(implicit trace: Trace): Task[Unit] =
run(await) { ch =>
ch.writeAndFlush(convert(msg))
}
}

object NettyChannel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ object NettyFutureExecutor {
def make[A](jFuture: => Future[A])(implicit trace: Trace): Task[NettyFutureExecutor[A]] =
ZIO.succeed(new NettyFutureExecutor(jFuture))

def executed[A](jFuture: => Future[A])(implicit trace: Trace): Task[Unit] = make(jFuture).flatMap(_.execute.unit)
def executed[A](jFuture: => Future[A])(implicit trace: Trace): Task[Unit] =
make(jFuture).flatMap(_.execute.unit)

def scoped[A](jFuture: => Future[A])(implicit trace: Trace): ZIO[Scope, Throwable, Unit] =
make(jFuture).flatMap(_.scoped.unit)
Expand Down

0 comments on commit 02c8e15

Please sign in to comment.