Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SocketApp refactoring, ability to provide custom websocket config #2338

Merged
merged 7 commits into from
Aug 5, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions zio-http-testkit/src/main/scala/zio/http/TestClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ final case class TestClient(
promise <- Promise.make[Nothing, Unit]
testChannelClient <- TestChannel.make(in, out, promise)
testChannelServer <- TestChannel.make(out, in, promise)
_ <- currentSocketBehavior.runZIO(testChannelClient).forkDaemon
_ <- app.provideEnvironment(env).runZIO(testChannelServer).forkDaemon
_ <- currentSocketBehavior.handler.runZIO(testChannelClient).forkDaemon
_ <- app.provideEnvironment(env).handler.runZIO(testChannelServer).forkDaemon
} yield Response.status(Status.SwitchingProtocols)
}

def installSocketApp[Env1](
app: Handler[Any, Throwable, WebSocketChannel, Unit],
app: SocketApp[Any],
): ZIO[Env1, Nothing, Unit] =
for {
env <- ZIO.environment[Env1]
Expand Down Expand Up @@ -182,15 +182,15 @@ object TestClient {
ZIO.serviceWithZIO[TestClient](_.addHandler(handler))

def installSocketApp(
app: Handler[Any, Throwable, WebSocketChannel, Unit],
app: SocketApp[Any],
): ZIO[TestClient, Nothing, Unit] =
ZIO.serviceWithZIO[TestClient](_.installSocketApp(app))

val layer: ZLayer[Any, Nothing, TestClient & Client] =
ZLayer.scopedEnvironment {
for {
behavior <- Ref.make[PartialFunction[Request, ZIO[Any, Response, Response]]](PartialFunction.empty)
socketBehavior <- Ref.make[SocketApp[Any]](Handler.unit)
socketBehavior <- Ref.make[SocketApp[Any]](SocketApp.unit)
driver = TestClient(behavior, socketBehavior)
} yield ZEnvironment[TestClient, Client](driver, ZClient.fromDriver(driver))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ object SocketContractSpec extends ZIOSpecDefault {
def spec: Spec[Any, Any] =
suite("SocketOps")(
contract("Successful Multi-message application") { p =>
val socketServer: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketServer: SocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("Hi Server")) =>
Expand All @@ -31,7 +31,7 @@ object SocketContractSpec extends ZIOSpecDefault {

socketServer
} { _ =>
val socketClient: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketClient: SocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case ChannelEvent.Read(WebSocketFrame.Text("Hi Client")) =>
Expand Down Expand Up @@ -89,8 +89,8 @@ object SocketContractSpec extends ZIOSpecDefault {
private def contract(
name: String,
)(
serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit],
)(clientApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit]) = {
serverApp: Promise[Throwable, Unit] => SocketApp[Any],
)(clientApp: Promise[Throwable, Unit] => SocketApp[Any]) = {
suite(name)(
test("Live") {
for {
Expand Down Expand Up @@ -123,7 +123,7 @@ object SocketContractSpec extends ZIOSpecDefault {
}

private def liveServerSetup(
serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit],
serverApp: Promise[Throwable, Unit] => SocketApp[Any],
): ZIO[Server, Nothing, (RuntimeFlags, Promise[Throwable, Unit])] =
ZIO.serviceWithZIO[Server](server =>
for {
Expand All @@ -133,7 +133,7 @@ object SocketContractSpec extends ZIOSpecDefault {
)

private def testServerSetup(
serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit],
serverApp: Promise[Throwable, Unit] => SocketApp[Any],
): ZIO[TestClient, Nothing, (RuntimeFlags, Promise[Throwable, Unit])] =
for {
p <- Promise.make[Throwable, Unit]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object TestClientSpec extends ZIOSpecDefault {
),
suite("socket ops")(
test("happy path") {
val socketClient: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketClient: SocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case ChannelEvent.Read(WebSocketFrame.Text("Hi Client")) =>
Expand All @@ -77,7 +77,7 @@ object TestClientSpec extends ZIOSpecDefault {
}
}

val socketServer: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketServer: SocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case ChannelEvent.Read(WebSocketFrame.Text("Hi Server")) =>
Expand Down
63 changes: 16 additions & 47 deletions zio-http/src/main/scala/zio/http/Handler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,34 +222,6 @@ sealed trait Handler[-R, +Err, -In, +Out] { self =>
)(implicit trace: Trace): Handler[R1, Err1, In1, Out] =
that.andThen(self)

/**
* Creates a socket connection on the provided URL. Typically used to connect
* as a client.
*/
def connect(
url: String,
headers: Headers = Headers.empty,
)(implicit
ev: Err <:< Throwable,
ev2: WebSocketChannel <:< In,
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.fromEither(URL.decode(url)).orDie.flatMap(connect(_, headers))

def connect(
url: URL,
headers: Headers,
)(implicit
ev1: Err <:< Throwable,
ev2: WebSocketChannel <:< In,
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.serviceWithZIO[Client] { client =>
val client2 = if (url.isAbsolute) client.url(url) else client.addUrl(url)

client2.addHeaders(headers).socket(self.asInstanceOf[SocketApp[R]])
}

/**
* Transforms the input of the handler before passing it on to the current
* Handler
Expand Down Expand Up @@ -621,21 +593,6 @@ sealed trait Handler[-R, +Err, -In, +Out] { self =>
HttpApp(Routes.singleton(handler.contramap[(Path, Request)](_._2)))
}

def toHttpAppWS(implicit err: Err <:< Throwable, in: WebSocketChannel <:< In, trace: Trace): HttpApp[R] =
Handler.fromZIO(self.toResponse).toHttpApp

/**
* Creates a new response from a socket handler.
*/
def toResponse(implicit
ev1: Err <:< Throwable,
ev2: WebSocketChannel <:< In,
trace: Trace,
): ZIO[R, Nothing, Response] =
ZIO.environment[R].flatMap { env =>
Response.fromSocketApp(self.asInstanceOf[SocketApp[R]].provideEnvironment(env))
}

/**
* Takes some defects and converts them into failures.
*/
Expand Down Expand Up @@ -727,6 +684,18 @@ object Handler {
def badRequest(message: => String): Handler[Any, Nothing, Any, Response] =
error(Status.BadRequest, message)

/**
* Constructs a handler from two functions, one that configures web socket and
* another that uses a web socket.
*
* If the config function returns with None, the server configuration is used.
*/
final def customWebSocket[Env](
config: Request => ZIO[Env, Throwable, Option[WebSocketConfig]],
f: WebSocketChannel => ZIO[Env, Throwable, Any],
): SocketApp[Env] =
SocketApp(Handler.fromFunctionZIO(f), Handler.fromFunctionZIO(config))

/**
* Returns a handler that dies with the specified `Throwable`. This method can
* be used for terminating an handler because a defect has been detected in
Expand Down Expand Up @@ -1102,10 +1071,10 @@ object Handler {
/**
* Constructs a handler from a function that uses a web socket.
*/
final def webSocket[Env, Err, Out](
f: WebSocketChannel => ZIO[Env, Err, Out],
): Handler[Env, Err, WebSocketChannel, Out] =
Handler.fromFunctionZIO(f)
final def webSocket[Env](
f: WebSocketChannel => ZIO[Env, Throwable, Any],
): SocketApp[Env] =
SocketApp(Handler.fromFunctionZIO(f))

final implicit class RequestHandlerSyntax[-R, +Err](val self: RequestHandler[R, Err])
extends HeaderModifier[RequestHandler[R, Err]] {
Expand Down
8 changes: 0 additions & 8 deletions zio-http/src/main/scala/zio/http/Response.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,14 +357,6 @@ object Response {
def fromServerSentEvents(data: ZStream[Any, Nothing, ServerSentEvent])(implicit trace: Trace): Response =
new BasicResponse(Body.fromStream(data.map(_.encode)), contentTypeEventStream, Status.Ok)

/**
* Creates a new response for the provided socket
*/
def fromSocket[R](
http: Handler[R, Throwable, WebSocketChannel, Any],
)(implicit trace: Trace): ZIO[R, Nothing, Response] =
fromSocketApp(http)

/**
* Creates a new response for the provided socket app
*/
Expand Down
84 changes: 84 additions & 0 deletions zio-http/src/main/scala/zio/http/SocketApp.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package zio.http

import zio._

final case class SocketApp[-R](
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can rename this WebSocketApp while we're at it.

handler: Handler[R, Throwable, WebSocketChannel, Any],
customConfig: Handler[R, Throwable, Request, Option[WebSocketConfig]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too powerful. I'd prefer Option[WebSocketConfig], or if that cannot be done, something like <subset of Request> => WebSocketConfig.

But it feels like the config should be done at the route pattern level, e.g.:

val route = Method.GET / "connect" -> handler(socketApp.withConfig(...))

You already know the route and pattern there and can set config on a per-route basis. You indeed have access to the full request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The author of the issue wanted to have a way to customize the handshaking based on the request (actual contents of some headers)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<subset of Request> => WebSocketConfig would be great, especially if the result was passed into Handler.webSocket. As it is right now there's no way to negotiate the websocket sub-protocol for a single endpoint (such as choosing between graphql-ws and graphql-transport-ws), which would be really nice to do.

) { self =>

/**
* Creates a socket connection on the provided URL. Typically used to connect
* as a client.
*/
def connect(
url: String,
headers: Headers = Headers.empty,
)(implicit
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.fromEither(URL.decode(url)).orDie.flatMap(connect(_, headers))

def connect(
url: URL,
headers: Headers,
)(implicit
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.serviceWithZIO[Client] { client =>
val client2 = if (url.isAbsolute) client.url(url) else client.addUrl(url)

client2.addHeaders(headers).socket(self)
}

def provideEnvironment(r: ZEnvironment[R])(implicit trace: Trace): SocketApp[Any] =
SocketApp(handler.provideEnvironment(r), customConfig.provideEnvironment(r))

def provideLayer[R0](layer: ZLayer[R0, Throwable, R])(implicit
trace: Trace,
): SocketApp[R0] =
SocketApp(handler.provideLayer(layer), customConfig.provideLayer(layer))

def provideSomeEnvironment[R1](f: ZEnvironment[R1] => ZEnvironment[R])(implicit
trace: Trace,
): SocketApp[R1] =
SocketApp(handler.provideSomeEnvironment(f), customConfig.provideSomeEnvironment(f))

def provideSomeLayer[R0, R1: Tag](
layer: ZLayer[R0, Throwable, R1],
)(implicit ev: R0 with R1 <:< R, trace: Trace): SocketApp[R0] =
SocketApp(handler.provideSomeLayer(layer), customConfig.provideSomeLayer(layer))

def tapErrorCauseZIO[R1 <: R](
f: Cause[Throwable] => ZIO[R1, Throwable, Any],
)(implicit trace: Trace): SocketApp[R1] =
SocketApp(handler.tapErrorCauseZIO(f), customConfig.tapErrorCauseZIO(f))

/**
* Returns a Handler that effectfully peeks at the failure of this SocketApp.
*/
def tapErrorZIO[R1 <: R](
f: Throwable => ZIO[R1, Throwable, Any],
)(implicit trace: Trace): SocketApp[R1] =
self.tapErrorCauseZIO(cause => cause.failureOption.fold[ZIO[R1, Throwable, Any]](ZIO.unit)(f))

/**
* Creates a new response from a socket handler.
*/
def toResponse(implicit
trace: Trace,
): ZIO[R, Nothing, Response] =
ZIO.environment[R].flatMap { env =>
Response.fromSocketApp(self.provideEnvironment(env))
}

def toHttpAppWS(implicit trace: Trace): HttpApp[R] =
Handler.fromZIO(self.toResponse).toHttpApp
}

object SocketApp {
def apply[R](handler: Handler[R, Throwable, WebSocketChannel, Any]): SocketApp[R] =
SocketApp(handler, Handler.succeed(None))

val unit: SocketApp[Any] = SocketApp(Handler.unit)
}
2 changes: 1 addition & 1 deletion zio-http/src/main/scala/zio/http/ZClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ object ZClient {
val request = Request(version, method, url, headers, body, None)
val cfg = sslConfig.fold(config)(config.ssl)

requestAsync(request, cfg, () => Handler.unit, None)
requestAsync(request, cfg, () => SocketApp.unit, None)
}

def socket[Env1](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ final case class NettyClientDriver private (
nettyChannel = NettyChannel.make[JWebSocketFrame](channel)
webSocketChannel = WebSocketChannel.make(nettyChannel, queue)
app = createSocketApp()
_ <- app.runZIO(webSocketChannel).ignoreLogged.interruptible.forkScoped
customConfig <- app.customConfig.runZIO(req)
_ <- app.handler.runZIO(webSocketChannel).ignoreLogged.interruptible.forkScoped
} yield {
val pipeline = channel.pipeline()
val toRemove: mutable.Set[ChannelHandler] = new mutable.HashSet[ChannelHandler]()
Expand All @@ -85,7 +86,7 @@ final case class NettyClientDriver private (

val headers = Conversions.headersToNetty(req.headers)
val config = NettySocketProtocol
.clientBuilder(webSocketConfig)
.clientBuilder(customConfig.getOrElse(webSocketConfig))
.customHeaders(headers)
.webSocketUri(req.url.encode)
.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,20 +288,26 @@ private[zio] final case class ServerInboundHandler(
runtime: NettyRuntime,
): Unit = {
val app = res.socketApp

jReq match {
case jReq: FullHttpRequest =>
val queue = runtime.runtime(ctx).unsafe.run(Queue.unbounded[WebSocketChannelEvent]).getOrThrowFiberFailure()
val webSocketApp = app.getOrElse(SocketApp.unit)
val request = makeZioRequest(ctx, jReq)
val customConfig =
runtime.runtime(ctx).unsafe.run { webSocketApp.customConfig.runZIO(request) }.getOrThrowFiberFailure()
runtime.runtime(ctx).unsafe.run {
val nettyChannel = NettyChannel.make[JWebSocketFrame](ctx.channel())
val webSocketChannel = WebSocketChannel.make(nettyChannel, queue)
val webSocketApp = app.getOrElse(Handler.unit)
webSocketApp.runZIO(webSocketChannel).ignoreLogged.forkDaemon
webSocketApp.handler.runZIO(webSocketChannel).ignoreLogged.forkDaemon
}
ctx
.channel()
.pipeline()
.addLast(
new WebSocketServerProtocolHandler(NettySocketProtocol.serverBuilder(config.webSocketConfig).build()),
new WebSocketServerProtocolHandler(
NettySocketProtocol.serverBuilder(customConfig.getOrElse(config.webSocketConfig)).build(),
),
)
.addLast(Names.WebSocketHandler, new WebSocketAppHandler(runtime, queue, None))

Expand Down
2 changes: 0 additions & 2 deletions zio-http/src/main/scala/zio/http/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ package object http extends UrlInterpolator {
type Client = ZClient[Any, Body, Throwable, Response]
def Client: ZClient.type = ZClient

type SocketApp[-R] = Handler[R, Throwable, WebSocketChannel, Any]

/**
* A channel that allows websocket frames to be written to it.
*/
Expand Down
5 changes: 3 additions & 2 deletions zio-http/src/test/scala/zio/http/WebSocketSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ object WebSocketSpec extends HttpRunnableSpec {
} yield assertCompletes
} @@ nonFlaky,
test("Multiple websocket upgrades") {
val app = Handler.succeed(WebSocketFrame.text("BAR")).toHttpAppWS.deployWS
val app =
Handler.webSocket(channel => channel.send(ChannelEvent.Read(WebSocketFrame.text("BAR")))).toHttpAppWS.deployWS
val codes = ZIO
.foreach(1 to 1024)(_ => app.runZIO(Handler.unit).map(_.status))
.foreach(1 to 1024)(_ => app.runZIO(SocketApp.unit).map(_.status))
.map(_.count(_ == Status.SwitchingProtocols))

assertZIO(codes)(equalTo(1024))
Expand Down
Loading