Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with websocket send #2341

Merged
merged 2 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ object WebSocketAdvanced extends ZIOAppDefault {

// Echo the same message 10 times if it's not "foo" or "bar"
case Read(WebSocketFrame.Text(text)) =>
channel.send(Read(WebSocketFrame.text(text))).repeatN(10)
channel
.send(Read(WebSocketFrame.text(s"echo $text")))
.repeatN(10)
.catchSomeCause { case cause =>
ZIO.logErrorCause(s"failed sending", cause)
}

// Send a "greeting" message to the server once the connection is established
case UserEventTriggered(UserEvent.HandshakeComplete) =>
Expand Down
14 changes: 10 additions & 4 deletions zio-http/src/main/scala/zio/http/WebSocketChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,24 @@ private[http] object WebSocketChannel {
queue: Queue[WebSocketChannelEvent],
): WebSocketChannel =
new WebSocketChannel {
def awaitShutdown: UIO[Unit] =
def awaitShutdown: UIO[Unit] =
nettyChannel.awaitClose
def receive: Task[WebSocketChannelEvent] =

def receive: Task[WebSocketChannelEvent] =
queue.take
def send(in: WebSocketChannelEvent): Task[Unit] =

def send(in: WebSocketChannelEvent): Task[Unit] = {
in match {
case Read(message) => nettyChannel.writeAndFlush(frameToNetty(message))
case _ => ZIO.unit
}
}

def sendAll(in: Iterable[WebSocketChannelEvent]): Task[Unit] =
ZIO.suspendSucceed {
val iterator = in.iterator.collect { case Read(message) => message }

println(s"sendAll")
ZIO.whileLoop(iterator.hasNext) {
val message = iterator.next()
if (iterator.hasNext) nettyChannel.write(frameToNetty(message))
Expand All @@ -54,7 +59,7 @@ private[http] object WebSocketChannel {
nettyChannel.close(false).orDie
}

private def frameToNetty(frame: WebSocketFrame): JWebSocketFrame =
private def frameToNetty(frame: WebSocketFrame): JWebSocketFrame = {
frame match {
case b: WebSocketFrame.Binary =>
new BinaryWebSocketFrame(b.isFinal, 0, Unpooled.wrappedBuffer(b.bytes.toArray))
Expand All @@ -71,4 +76,5 @@ private[http] object WebSocketChannel {
case c: WebSocketFrame.Continuation =>
new ContinuationWebSocketFrame(c.isFinal, 0, Unpooled.wrappedBuffer(c.buffer.toArray))
}
}
}
28 changes: 16 additions & 12 deletions zio-http/src/main/scala/zio/http/netty/NettyChannel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@ package zio.http.netty
import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Task, Trace, UIO, ZIO}

import zio.http.Channel

import io.netty.channel.{Channel => JChannel, ChannelFuture => JChannelFuture}
final case class NettyChannel[-A](
private val channel: JChannel,
private val convert: A => Any,
) {
self =>

private def foreach[S](await: Boolean)(run: JChannel => JChannelFuture)(implicit trace: Trace): Task[Unit] = {
if (await) NettyFutureExecutor.executed(run(channel))
else ZIO.attempt(run(channel): Unit)
private def run(await: Boolean)(run: JChannel => JChannelFuture)(implicit trace: Trace): Task[Unit] = {
if (await) {
NettyFutureExecutor.executed(run(channel))
} else {
ZIO.attempt(run(channel): Unit)
}
}

def autoRead(flag: Boolean)(implicit trace: Trace): UIO[Unit] =
Expand All @@ -41,7 +42,8 @@ final case class NettyChannel[-A](
()
}

def close(await: Boolean = false)(implicit trace: Trace): Task[Unit] = foreach(await) { _.close() }
def close(await: Boolean = false)(implicit trace: Trace): Task[Unit] =
run(await) { _.close() }

def contramap[A1](f: A1 => A): NettyChannel[A1] = copy(convert = convert.compose(f))

Expand All @@ -53,13 +55,15 @@ final case class NettyChannel[-A](

def read(implicit trace: Trace): UIO[Unit] = ZIO.succeed(channel.read(): Unit)

def write(msg: A, await: Boolean = false)(implicit trace: Trace): Task[Unit] = foreach(await) {
_.write(convert(msg))
}
def write(msg: A, await: Boolean = false)(implicit trace: Trace): Task[Unit] =
run(await) {
_.write(convert(msg))
}

def writeAndFlush(msg: A, await: Boolean = true)(implicit trace: Trace): Task[Unit] = foreach(await) {
_.writeAndFlush(convert(msg))
}
def writeAndFlush(msg: => A, await: Boolean = true)(implicit trace: Trace): Task[Unit] =
run(await) { ch =>
ch.writeAndFlush(convert(msg))
}
}

object NettyChannel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ object NettyFutureExecutor {
def make[A](jFuture: => Future[A])(implicit trace: Trace): Task[NettyFutureExecutor[A]] =
ZIO.succeed(new NettyFutureExecutor(jFuture))

def executed[A](jFuture: => Future[A])(implicit trace: Trace): Task[Unit] = make(jFuture).flatMap(_.execute.unit)
def executed[A](jFuture: => Future[A])(implicit trace: Trace): Task[Unit] =
make(jFuture).flatMap(_.execute.unit)

def scoped[A](jFuture: => Future[A])(implicit trace: Trace): ZIO[Scope, Throwable, Unit] =
make(jFuture).flatMap(_.scoped.unit)
Expand Down
Loading