diff --git a/docs/examples/advanced/websocket-server.md b/docs/examples/advanced/websocket-server.md index c72d0a311b..2534d5e9de 100644 --- a/docs/examples/advanced/websocket-server.md +++ b/docs/examples/advanced/websocket-server.md @@ -13,7 +13,7 @@ import zio.http.codec.PathCodec.string object WebSocketAdvanced extends ZIOAppDefault { - val socketApp: SocketApp[Any] = + val socketApp: WebSocketApp[Any] = Handler.webSocket { channel => channel.receiveAll { case Read(WebSocketFrame.Text("end")) => diff --git a/docs/examples/basic/websocket.md b/docs/examples/basic/websocket.md index 21f6cdae62..045f3e4f6b 100644 --- a/docs/examples/basic/websocket.md +++ b/docs/examples/basic/websocket.md @@ -12,7 +12,7 @@ import zio.http._ import zio.http.codec.PathCodec.string object WebSocketEcho extends ZIOAppDefault { - private val socketApp: SocketApp[Any] = + private val socketApp: WebSocketApp[Any] = Handler.webSocket { channel => channel.receiveAll { case Read(WebSocketFrame.Text("FOO")) => diff --git a/zio-http-example/src/main/scala/example/WebSocketAdvanced.scala b/zio-http-example/src/main/scala/example/WebSocketAdvanced.scala index c2c6946e29..382fcdbc7a 100644 --- a/zio-http-example/src/main/scala/example/WebSocketAdvanced.scala +++ b/zio-http-example/src/main/scala/example/WebSocketAdvanced.scala @@ -8,7 +8,7 @@ import zio.http.codec.PathCodec.string object WebSocketAdvanced extends ZIOAppDefault { - val socketApp: SocketApp[Any] = + val socketApp: WebSocketApp[Any] = Handler.webSocket { channel => channel.receiveAll { case Read(WebSocketFrame.Text("end")) => diff --git a/zio-http-example/src/main/scala/example/WebSocketEcho.scala b/zio-http-example/src/main/scala/example/WebSocketEcho.scala index 075c6321f8..df3e34a379 100644 --- a/zio-http-example/src/main/scala/example/WebSocketEcho.scala +++ b/zio-http-example/src/main/scala/example/WebSocketEcho.scala @@ -7,7 +7,7 @@ import zio.http._ import zio.http.codec.PathCodec.string object WebSocketEcho extends ZIOAppDefault { - private val socketApp: SocketApp[Any] = + private val socketApp: WebSocketApp[Any] = Handler.webSocket { channel => channel.receiveAll { case Read(WebSocketFrame.Text("FOO")) => diff --git a/zio-http-example/src/main/scala/example/WebSocketReconnectingClient.scala b/zio-http-example/src/main/scala/example/WebSocketReconnectingClient.scala index 60c93380ae..a33ef8499a 100644 --- a/zio-http-example/src/main/scala/example/WebSocketReconnectingClient.scala +++ b/zio-http-example/src/main/scala/example/WebSocketReconnectingClient.scala @@ -10,7 +10,7 @@ object WebSocketReconnectingClient extends ZIOAppDefault { val url = "ws://ws.vi-server.org/mirror" // A promise is used to be able to notify application about websocket errors - def makeSocketApp(p: Promise[Nothing, Throwable]): SocketApp[Any] = + def makeSocketApp(p: Promise[Nothing, Throwable]): WebSocketApp[Any] = Handler // Listen for all websocket channel events diff --git a/zio-http-example/src/main/scala/example/WebSocketSimpleClient.scala b/zio-http-example/src/main/scala/example/WebSocketSimpleClient.scala index e1f538cf42..ef7f4903ff 100644 --- a/zio-http-example/src/main/scala/example/WebSocketSimpleClient.scala +++ b/zio-http-example/src/main/scala/example/WebSocketSimpleClient.scala @@ -9,7 +9,7 @@ object WebSocketSimpleClient extends ZIOAppDefault { val url = "ws://ws.vi-server.org/mirror" - val socketApp: SocketApp[Any] = + val socketApp: WebSocketApp[Any] = Handler // Listen for all websocket channel events diff --git a/zio-http-testkit/src/main/scala/zio/http/TestClient.scala b/zio-http-testkit/src/main/scala/zio/http/TestClient.scala index d6a2599090..4c28c3fe81 100644 --- a/zio-http-testkit/src/main/scala/zio/http/TestClient.scala +++ b/zio-http-testkit/src/main/scala/zio/http/TestClient.scala @@ -14,7 +14,7 @@ import zio.http.{Headers, Method, Scheme, Status, Version} */ final case class TestClient( behavior: Ref[PartialFunction[Request, ZIO[Any, Response, Response]]], - serverSocketBehavior: Ref[SocketApp[Any]], + serverSocketBehavior: Ref[WebSocketApp[Any]], ) extends ZClient.Driver[Any, Throwable] { /** @@ -117,7 +117,7 @@ final case class TestClient( version: Version, url: URL, headers: Headers, - app: SocketApp[Env1], + app: WebSocketApp[Env1], )(implicit trace: Trace): ZIO[Env1 with Scope, Throwable, Response] = { for { env <- ZIO.environment[Env1] @@ -127,13 +127,13 @@ final case class TestClient( promise <- Promise.make[Nothing, Unit] testChannelClient <- TestChannel.make(in, out, promise) testChannelServer <- TestChannel.make(out, in, promise) - _ <- currentSocketBehavior.runZIO(testChannelClient).forkDaemon - _ <- app.provideEnvironment(env).runZIO(testChannelServer).forkDaemon + _ <- currentSocketBehavior.handler.runZIO(testChannelClient).forkDaemon + _ <- app.provideEnvironment(env).handler.runZIO(testChannelServer).forkDaemon } yield Response.status(Status.SwitchingProtocols) } def installSocketApp[Env1]( - app: Handler[Any, Throwable, WebSocketChannel, Unit], + app: WebSocketApp[Any], ): ZIO[Env1, Nothing, Unit] = for { env <- ZIO.environment[Env1] @@ -182,7 +182,7 @@ object TestClient { ZIO.serviceWithZIO[TestClient](_.addHandler(handler)) def installSocketApp( - app: Handler[Any, Throwable, WebSocketChannel, Unit], + app: WebSocketApp[Any], ): ZIO[TestClient, Nothing, Unit] = ZIO.serviceWithZIO[TestClient](_.installSocketApp(app)) @@ -190,7 +190,7 @@ object TestClient { ZLayer.scopedEnvironment { for { behavior <- Ref.make[PartialFunction[Request, ZIO[Any, Response, Response]]](PartialFunction.empty) - socketBehavior <- Ref.make[SocketApp[Any]](Handler.unit) + socketBehavior <- Ref.make[WebSocketApp[Any]](WebSocketApp.unit) driver = TestClient(behavior, socketBehavior) } yield ZEnvironment[TestClient, Client](driver, ZClient.fromDriver(driver)) } diff --git a/zio-http-testkit/src/test/scala/zio/http/SocketContractSpec.scala b/zio-http-testkit/src/test/scala/zio/http/SocketContractSpec.scala index c2653745be..880e76405d 100644 --- a/zio-http-testkit/src/test/scala/zio/http/SocketContractSpec.scala +++ b/zio-http-testkit/src/test/scala/zio/http/SocketContractSpec.scala @@ -14,7 +14,7 @@ object SocketContractSpec extends ZIOSpecDefault { def spec: Spec[Any, Any] = suite("SocketOps")( contract("Successful Multi-message application") { p => - val socketServer: Handler[Any, Throwable, WebSocketChannel, Unit] = + val socketServer: WebSocketApp[Any] = Handler.webSocket { channel => channel.receiveAll { case Read(WebSocketFrame.Text("Hi Server")) => @@ -31,7 +31,7 @@ object SocketContractSpec extends ZIOSpecDefault { socketServer } { _ => - val socketClient: Handler[Any, Throwable, WebSocketChannel, Unit] = + val socketClient: WebSocketApp[Any] = Handler.webSocket { channel => channel.receiveAll { case ChannelEvent.Read(WebSocketFrame.Text("Hi Client")) => @@ -89,8 +89,8 @@ object SocketContractSpec extends ZIOSpecDefault { private def contract( name: String, )( - serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit], - )(clientApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit]) = { + serverApp: Promise[Throwable, Unit] => WebSocketApp[Any], + )(clientApp: Promise[Throwable, Unit] => WebSocketApp[Any]) = { suite(name)( test("Live") { for { @@ -123,7 +123,7 @@ object SocketContractSpec extends ZIOSpecDefault { } private def liveServerSetup( - serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit], + serverApp: Promise[Throwable, Unit] => WebSocketApp[Any], ): ZIO[Server, Nothing, (RuntimeFlags, Promise[Throwable, Unit])] = ZIO.serviceWithZIO[Server](server => for { @@ -133,7 +133,7 @@ object SocketContractSpec extends ZIOSpecDefault { ) private def testServerSetup( - serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit], + serverApp: Promise[Throwable, Unit] => WebSocketApp[Any], ): ZIO[TestClient, Nothing, (RuntimeFlags, Promise[Throwable, Unit])] = for { p <- Promise.make[Throwable, Unit] diff --git a/zio-http-testkit/src/test/scala/zio/http/TestClientSpec.scala b/zio-http-testkit/src/test/scala/zio/http/TestClientSpec.scala index 74d238a323..0b0c6a8795 100644 --- a/zio-http-testkit/src/test/scala/zio/http/TestClientSpec.scala +++ b/zio-http-testkit/src/test/scala/zio/http/TestClientSpec.scala @@ -66,7 +66,7 @@ object TestClientSpec extends ZIOSpecDefault { ), suite("socket ops")( test("happy path") { - val socketClient: Handler[Any, Throwable, WebSocketChannel, Unit] = + val socketClient: WebSocketApp[Any] = Handler.webSocket { channel => channel.receiveAll { case ChannelEvent.Read(WebSocketFrame.Text("Hi Client")) => @@ -77,7 +77,7 @@ object TestClientSpec extends ZIOSpecDefault { } } - val socketServer: Handler[Any, Throwable, WebSocketChannel, Unit] = + val socketServer: WebSocketApp[Any] = Handler.webSocket { channel => channel.receiveAll { case ChannelEvent.Read(WebSocketFrame.Text("Hi Server")) => diff --git a/zio-http/src/main/scala/zio/http/ClientDriver.scala b/zio-http/src/main/scala/zio/http/ClientDriver.scala index 48e2bdfe72..14831f94ff 100644 --- a/zio-http/src/main/scala/zio/http/ClientDriver.scala +++ b/zio-http/src/main/scala/zio/http/ClientDriver.scala @@ -31,7 +31,7 @@ trait ClientDriver { onResponse: Promise[Throwable, Response], onComplete: Promise[Throwable, ChannelState], enableKeepAlive: Boolean, - createSocketApp: () => SocketApp[Any], + createSocketApp: () => WebSocketApp[Any], webSocketConfig: WebSocketConfig, )(implicit trace: Trace): ZIO[Scope, Throwable, ChannelInterface] diff --git a/zio-http/src/main/scala/zio/http/Handler.scala b/zio-http/src/main/scala/zio/http/Handler.scala index 93927d4a99..73fc66fa1d 100644 --- a/zio-http/src/main/scala/zio/http/Handler.scala +++ b/zio-http/src/main/scala/zio/http/Handler.scala @@ -222,34 +222,6 @@ sealed trait Handler[-R, +Err, -In, +Out] { self => )(implicit trace: Trace): Handler[R1, Err1, In1, Out] = that.andThen(self) - /** - * Creates a socket connection on the provided URL. Typically used to connect - * as a client. - */ - def connect( - url: String, - headers: Headers = Headers.empty, - )(implicit - ev: Err <:< Throwable, - ev2: WebSocketChannel <:< In, - trace: Trace, - ): ZIO[R with Client with Scope, Throwable, Response] = - ZIO.fromEither(URL.decode(url)).orDie.flatMap(connect(_, headers)) - - def connect( - url: URL, - headers: Headers, - )(implicit - ev1: Err <:< Throwable, - ev2: WebSocketChannel <:< In, - trace: Trace, - ): ZIO[R with Client with Scope, Throwable, Response] = - ZIO.serviceWithZIO[Client] { client => - val client2 = if (url.isAbsolute) client.url(url) else client.addUrl(url) - - client2.addHeaders(headers).socket(self.asInstanceOf[SocketApp[R]]) - } - /** * Transforms the input of the handler before passing it on to the current * Handler @@ -621,21 +593,6 @@ sealed trait Handler[-R, +Err, -In, +Out] { self => HttpApp(Routes.singleton(handler.contramap[(Path, Request)](_._2))) } - def toHttpAppWS(implicit err: Err <:< Throwable, in: WebSocketChannel <:< In, trace: Trace): HttpApp[R] = - Handler.fromZIO(self.toResponse).toHttpApp - - /** - * Creates a new response from a socket handler. - */ - def toResponse(implicit - ev1: Err <:< Throwable, - ev2: WebSocketChannel <:< In, - trace: Trace, - ): ZIO[R, Nothing, Response] = - ZIO.environment[R].flatMap { env => - Response.fromSocketApp(self.asInstanceOf[SocketApp[R]].provideEnvironment(env)) - } - /** * Takes some defects and converts them into failures. */ @@ -1102,10 +1059,10 @@ object Handler { /** * Constructs a handler from a function that uses a web socket. */ - final def webSocket[Env, Err, Out]( - f: WebSocketChannel => ZIO[Env, Err, Out], - ): Handler[Env, Err, WebSocketChannel, Out] = - Handler.fromFunctionZIO(f) + final def webSocket[Env]( + f: WebSocketChannel => ZIO[Env, Throwable, Any], + ): WebSocketApp[Env] = + WebSocketApp(Handler.fromFunctionZIO(f)) final implicit class RequestHandlerSyntax[-R, +Err](val self: RequestHandler[R, Err]) extends HeaderModifier[RequestHandler[R, Err]] { diff --git a/zio-http/src/main/scala/zio/http/Response.scala b/zio-http/src/main/scala/zio/http/Response.scala index 4a608a31bc..1def87457c 100644 --- a/zio-http/src/main/scala/zio/http/Response.scala +++ b/zio-http/src/main/scala/zio/http/Response.scala @@ -114,7 +114,7 @@ sealed trait Response extends HeaderOps[Response] { self => final def status(status: Status): Response = self.copy(status = status) - private[zio] final def socketApp: Option[SocketApp[Any]] = self match { + private[zio] final def socketApp: Option[WebSocketApp[Any]] = self match { case Response.GetApp(app) => Some(app) case _ => None } @@ -140,7 +140,7 @@ object Response { } object GetApp { - def unapply(response: Response): Option[SocketApp[Any]] = response match { + def unapply(response: Response): Option[WebSocketApp[Any]] = response match { case resp: SocketAppResponse => Some(resp.socketApp0) case _ => None } @@ -186,7 +186,7 @@ object Response { private[zio] class SocketAppResponse( val body: Body, val headers: Headers, - val socketApp0: SocketApp[Any], + val socketApp0: WebSocketApp[Any], val status: Status, ) extends Response { self => @@ -357,18 +357,10 @@ object Response { def fromServerSentEvents(data: ZStream[Any, Nothing, ServerSentEvent])(implicit trace: Trace): Response = new BasicResponse(Body.fromStream(data.map(_.encode)), contentTypeEventStream, Status.Ok) - /** - * Creates a new response for the provided socket - */ - def fromSocket[R]( - http: Handler[R, Throwable, WebSocketChannel, Any], - )(implicit trace: Trace): ZIO[R, Nothing, Response] = - fromSocketApp(http) - /** * Creates a new response for the provided socket app */ - def fromSocketApp[R](app: SocketApp[R])(implicit trace: Trace): ZIO[R, Nothing, Response] = { + def fromSocketApp[R](app: WebSocketApp[R])(implicit trace: Trace): ZIO[R, Nothing, Response] = { ZIO.environment[R].map { env => new SocketAppResponse( Body.empty, diff --git a/zio-http/src/main/scala/zio/http/WebSocketApp.scala b/zio-http/src/main/scala/zio/http/WebSocketApp.scala new file mode 100644 index 0000000000..c914ca7ff0 --- /dev/null +++ b/zio-http/src/main/scala/zio/http/WebSocketApp.scala @@ -0,0 +1,87 @@ +package zio.http + +import zio._ + +final case class WebSocketApp[-R]( + handler: Handler[R, Throwable, WebSocketChannel, Any], + customConfig: Option[WebSocketConfig], +) { self => + + /** + * Creates a socket connection on the provided URL. Typically used to connect + * as a client. + */ + def connect( + url: String, + headers: Headers = Headers.empty, + )(implicit + trace: Trace, + ): ZIO[R with Client with Scope, Throwable, Response] = + ZIO.fromEither(URL.decode(url)).orDie.flatMap(connect(_, headers)) + + def connect( + url: URL, + headers: Headers, + )(implicit + trace: Trace, + ): ZIO[R with Client with Scope, Throwable, Response] = + ZIO.serviceWithZIO[Client] { client => + val client2 = if (url.isAbsolute) client.url(url) else client.addUrl(url) + + client2.addHeaders(headers).socket(self) + } + + def provideEnvironment(r: ZEnvironment[R])(implicit trace: Trace): WebSocketApp[Any] = + WebSocketApp(handler.provideEnvironment(r), customConfig) + + def provideLayer[R0](layer: ZLayer[R0, Throwable, R])(implicit + trace: Trace, + ): WebSocketApp[R0] = + WebSocketApp(handler.provideLayer(layer), customConfig) + + def provideSomeEnvironment[R1](f: ZEnvironment[R1] => ZEnvironment[R])(implicit + trace: Trace, + ): WebSocketApp[R1] = + WebSocketApp(handler.provideSomeEnvironment(f), customConfig) + + def provideSomeLayer[R0, R1: Tag]( + layer: ZLayer[R0, Throwable, R1], + )(implicit ev: R0 with R1 <:< R, trace: Trace): WebSocketApp[R0] = + WebSocketApp(handler.provideSomeLayer(layer), customConfig) + + def tapErrorCauseZIO[R1 <: R]( + f: Cause[Throwable] => ZIO[R1, Throwable, Any], + )(implicit trace: Trace): WebSocketApp[R1] = + WebSocketApp(handler.tapErrorCauseZIO(f), customConfig) + + /** + * Returns a Handler that effectfully peeks at the failure of this SocketApp. + */ + def tapErrorZIO[R1 <: R]( + f: Throwable => ZIO[R1, Throwable, Any], + )(implicit trace: Trace): WebSocketApp[R1] = + self.tapErrorCauseZIO(cause => cause.failureOption.fold[ZIO[R1, Throwable, Any]](ZIO.unit)(f)) + + /** + * Creates a new response from a socket handler. + */ + def toResponse(implicit + trace: Trace, + ): ZIO[R, Nothing, Response] = + ZIO.environment[R].flatMap { env => + Response.fromSocketApp(self.provideEnvironment(env)) + } + + def toHttpAppWS(implicit trace: Trace): HttpApp[R] = + Handler.fromZIO(self.toResponse).toHttpApp + + def withConfig(config: WebSocketConfig): WebSocketApp[R] = + copy(customConfig = Some(config)) +} + +object WebSocketApp { + def apply[R](handler: Handler[R, Throwable, WebSocketChannel, Any]): WebSocketApp[R] = + WebSocketApp(handler, None) + + val unit: WebSocketApp[Any] = WebSocketApp(Handler.unit) +} diff --git a/zio-http/src/main/scala/zio/http/ZClient.scala b/zio-http/src/main/scala/zio/http/ZClient.scala index 0f32687c7c..4d84e85a0e 100644 --- a/zio-http/src/main/scala/zio/http/ZClient.scala +++ b/zio-http/src/main/scala/zio/http/ZClient.scala @@ -205,7 +205,7 @@ final case class ZClient[-Env, -In, +Err, +Out]( def scheme(scheme: Scheme): ZClient[Env, In, Err, Out] = copy(url = url.scheme(scheme)) - def socket[Env1 <: Env](app: SocketApp[Env1])(implicit trace: Trace): ZIO[Env1 & Scope, Err, Out] = + def socket[Env1 <: Env](app: WebSocketApp[Env1])(implicit trace: Trace): ZIO[Env1 & Scope, Err, Out] = driver .socket( Version.Default, @@ -291,7 +291,7 @@ object ZClient { def request(request: Request): ZIO[Client & Scope, Throwable, Response] = ZIO.serviceWithZIO[Client](c => c(request)) - def socket[R](socketApp: SocketApp[R]): ZIO[R with Client & Scope, Throwable, Response] = + def socket[R](socketApp: WebSocketApp[R]): ZIO[R with Client & Scope, Throwable, Response] = ZIO.serviceWithZIO[Client](c => c.socket(socketApp)) trait BodyDecoder[-Env, +Err, +Out] { self => @@ -387,7 +387,7 @@ object ZClient { version: Version, url: URL, headers: Headers, - app: SocketApp[Env1], + app: WebSocketApp[Env1], )(implicit trace: Trace): ZIO[Env1 & Scope, Throwable, Response] = self0 .socket( @@ -420,7 +420,7 @@ object ZClient { version: Version, url: URL, headers: Headers, - app: SocketApp[Env1], + app: WebSocketApp[Env1], )(implicit trace: Trace): ZIO[Env1 & Scope, Err2, Response] = self .socket( @@ -450,7 +450,7 @@ object ZClient { version: Version, url: URL, headers: Headers, - app: SocketApp[Env1], + app: WebSocketApp[Env1], )(implicit trace: Trace): ZIO[Env1 & Scope, Err2, Response] = self .socket( @@ -490,7 +490,7 @@ object ZClient { version: Version, url: URL, headers: Headers, - app: SocketApp[Env2], + app: WebSocketApp[Env2], )(implicit trace: Trace): ZIO[Env2 & Scope, Err1, Response] = self .socket( @@ -506,7 +506,7 @@ object ZClient { version: Version, url: URL, headers: Headers, - app: SocketApp[Env1], + app: WebSocketApp[Env1], )(implicit trace: Trace): ZIO[Env1 & Scope, Err, Response] def widenError[E1](implicit ev: Err <:< E1): Driver[Env, E1] = self.asInstanceOf[Driver[Env, E1]] @@ -633,14 +633,14 @@ object ZClient { val request = Request(version, method, url, headers, body, None) val cfg = sslConfig.fold(config)(config.ssl) - requestAsync(request, cfg, () => Handler.unit, None) + requestAsync(request, cfg, () => WebSocketApp.unit, None) } def socket[Env1]( version: Version, url: URL, headers: Headers, - app: SocketApp[Env1], + app: WebSocketApp[Env1], )(implicit trace: Trace): ZIO[Env1 & Scope, Throwable, Response] = for { env <- ZIO.environment[Env1] @@ -668,7 +668,7 @@ object ZClient { private def requestAsync( request: Request, clientConfig: Config, - createSocketApp: () => SocketApp[Any], + createSocketApp: () => WebSocketApp[Any], outerScope: Option[Scope], )(implicit trace: Trace, diff --git a/zio-http/src/main/scala/zio/http/ZClientAspect.scala b/zio-http/src/main/scala/zio/http/ZClientAspect.scala index 054a38e61b..c6f48a6be5 100644 --- a/zio-http/src/main/scala/zio/http/ZClientAspect.scala +++ b/zio-http/src/main/scala/zio/http/ZClientAspect.scala @@ -155,8 +155,8 @@ object ZClientAspect { .flatMap(_._2) .unsandbox - override def socket[Env1 <: Env](version: Version, url: URL, headers: Headers, app: SocketApp[Env1])(implicit - trace: Trace, + override def socket[Env1 <: Env](version: Version, url: URL, headers: Headers, app: WebSocketApp[Env1])( + implicit trace: Trace, ): ZIO[Env1 with Scope, Err, Response] = client.driver.socket(version, url, headers, app) } @@ -299,8 +299,8 @@ object ZClientAspect { .unsandbox } - override def socket[Env1 <: Env](version: Version, url: URL, headers: Headers, app: SocketApp[Env1])(implicit - trace: Trace, + override def socket[Env1 <: Env](version: Version, url: URL, headers: Headers, app: WebSocketApp[Env1])( + implicit trace: Trace, ): ZIO[Env1 with Scope, Err, Response] = client.driver.socket(version, url, headers, app) } diff --git a/zio-http/src/main/scala/zio/http/netty/client/NettyClientDriver.scala b/zio-http/src/main/scala/zio/http/netty/client/NettyClientDriver.scala index 807baddadf..0ee2a68316 100644 --- a/zio-http/src/main/scala/zio/http/netty/client/NettyClientDriver.scala +++ b/zio-http/src/main/scala/zio/http/netty/client/NettyClientDriver.scala @@ -49,7 +49,7 @@ final case class NettyClientDriver private ( onResponse: Promise[Throwable, Response], onComplete: Promise[Throwable, ChannelState], enableKeepAlive: Boolean, - createSocketApp: () => SocketApp[Any], + createSocketApp: () => WebSocketApp[Any], webSocketConfig: WebSocketConfig, )(implicit trace: Trace): ZIO[Scope, Throwable, ChannelInterface] = { NettyRequestEncoder.encode(req).flatMap { jReq => @@ -68,7 +68,7 @@ final case class NettyClientDriver private ( nettyChannel = NettyChannel.make[JWebSocketFrame](channel) webSocketChannel = WebSocketChannel.make(nettyChannel, queue) app = createSocketApp() - _ <- app.runZIO(webSocketChannel).ignoreLogged.interruptible.forkScoped + _ <- app.handler.runZIO(webSocketChannel).ignoreLogged.interruptible.forkScoped } yield { val pipeline = channel.pipeline() val toRemove: mutable.Set[ChannelHandler] = new mutable.HashSet[ChannelHandler]() @@ -85,7 +85,7 @@ final case class NettyClientDriver private ( val headers = Conversions.headersToNetty(req.headers) val config = NettySocketProtocol - .clientBuilder(webSocketConfig) + .clientBuilder(app.customConfig.getOrElse(webSocketConfig)) .customHeaders(headers) .webSocketUri(req.url.encode) .build() diff --git a/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala b/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala index 3a59ffb705..dc685a855b 100644 --- a/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala +++ b/zio-http/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala @@ -288,20 +288,23 @@ private[zio] final case class ServerInboundHandler( runtime: NettyRuntime, ): Unit = { val app = res.socketApp + jReq match { case jReq: FullHttpRequest => val queue = runtime.runtime(ctx).unsafe.run(Queue.unbounded[WebSocketChannelEvent]).getOrThrowFiberFailure() + val webSocketApp = app.getOrElse(WebSocketApp.unit) runtime.runtime(ctx).unsafe.run { val nettyChannel = NettyChannel.make[JWebSocketFrame](ctx.channel()) val webSocketChannel = WebSocketChannel.make(nettyChannel, queue) - val webSocketApp = app.getOrElse(Handler.unit) - webSocketApp.runZIO(webSocketChannel).ignoreLogged.forkDaemon + webSocketApp.handler.runZIO(webSocketChannel).ignoreLogged.forkDaemon } ctx .channel() .pipeline() .addLast( - new WebSocketServerProtocolHandler(NettySocketProtocol.serverBuilder(config.webSocketConfig).build()), + new WebSocketServerProtocolHandler( + NettySocketProtocol.serverBuilder(webSocketApp.customConfig.getOrElse(config.webSocketConfig)).build(), + ), ) .addLast(Names.WebSocketHandler, new WebSocketAppHandler(runtime, queue, None)) diff --git a/zio-http/src/main/scala/zio/http/package.scala b/zio-http/src/main/scala/zio/http/package.scala index 7e23367b65..5ac8a96abe 100644 --- a/zio-http/src/main/scala/zio/http/package.scala +++ b/zio-http/src/main/scala/zio/http/package.scala @@ -57,8 +57,6 @@ package object http extends UrlInterpolator { type Client = ZClient[Any, Body, Throwable, Response] def Client: ZClient.type = ZClient - type SocketApp[-R] = Handler[R, Throwable, WebSocketChannel, Any] - /** * A channel that allows websocket frames to be written to it. */ diff --git a/zio-http/src/test/scala/zio/http/WebSocketSpec.scala b/zio-http/src/test/scala/zio/http/WebSocketSpec.scala index 73f7ef1c86..8bbafb2564 100644 --- a/zio-http/src/test/scala/zio/http/WebSocketSpec.scala +++ b/zio-http/src/test/scala/zio/http/WebSocketSpec.scala @@ -111,9 +111,10 @@ object WebSocketSpec extends HttpRunnableSpec { } yield assertCompletes } @@ nonFlaky, test("Multiple websocket upgrades") { - val app = Handler.succeed(WebSocketFrame.text("BAR")).toHttpAppWS.deployWS + val app = + Handler.webSocket(channel => channel.send(ChannelEvent.Read(WebSocketFrame.text("BAR")))).toHttpAppWS.deployWS val codes = ZIO - .foreach(1 to 1024)(_ => app.runZIO(Handler.unit).map(_.status)) + .foreach(1 to 1024)(_ => app.runZIO(WebSocketApp.unit).map(_.status)) .map(_.count(_ == Status.SwitchingProtocols)) assertZIO(codes)(equalTo(1024)) diff --git a/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala b/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala index c4640f854b..a1133b934f 100644 --- a/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala +++ b/zio-http/src/test/scala/zio/http/internal/HttpRunnableSpec.scala @@ -87,13 +87,13 @@ abstract class HttpRunnableSpec extends ZIOSpecDefault { self => } yield response def deployWS - : Handler[R with Client with DynamicServer with Scope, Throwable, SocketApp[Client with Scope], Response] = + : Handler[R with Client with DynamicServer with Scope, Throwable, WebSocketApp[Client with Scope], Response] = for { id <- Handler.fromZIO(DynamicServer.deploy[R](app)) rawUrl <- Handler.fromZIO(DynamicServer.wsURL) url <- Handler.fromEither(URL.decode(rawUrl)).orDie client <- Handler.fromZIO(ZIO.service[Client]) - response <- Handler.fromFunctionZIO[SocketApp[Client with Scope]] { app => + response <- Handler.fromFunctionZIO[WebSocketApp[Client with Scope]] { app => ZIO.scoped[Client with Scope]( client .url(url)