Skip to content

Commit

Permalink
SocketApp refactoring, ability to provide custom websocket config (#2338
Browse files Browse the repository at this point in the history
)
  • Loading branch information
vigoo authored Aug 5, 2023
1 parent 52e535c commit d1997c6
Show file tree
Hide file tree
Showing 20 changed files with 145 additions and 107 deletions.
2 changes: 1 addition & 1 deletion docs/examples/advanced/websocket-server.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import zio.http.codec.PathCodec.string

object WebSocketAdvanced extends ZIOAppDefault {

val socketApp: SocketApp[Any] =
val socketApp: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("end")) =>
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/basic/websocket.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import zio.http._
import zio.http.codec.PathCodec.string

object WebSocketEcho extends ZIOAppDefault {
private val socketApp: SocketApp[Any] =
private val socketApp: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("FOO")) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import zio.http.codec.PathCodec.string

object WebSocketAdvanced extends ZIOAppDefault {

val socketApp: SocketApp[Any] =
val socketApp: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("end")) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import zio.http._
import zio.http.codec.PathCodec.string

object WebSocketEcho extends ZIOAppDefault {
private val socketApp: SocketApp[Any] =
private val socketApp: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("FOO")) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object WebSocketReconnectingClient extends ZIOAppDefault {
val url = "ws://ws.vi-server.org/mirror"

// A promise is used to be able to notify application about websocket errors
def makeSocketApp(p: Promise[Nothing, Throwable]): SocketApp[Any] =
def makeSocketApp(p: Promise[Nothing, Throwable]): WebSocketApp[Any] =
Handler

// Listen for all websocket channel events
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ object WebSocketSimpleClient extends ZIOAppDefault {

val url = "ws://ws.vi-server.org/mirror"

val socketApp: SocketApp[Any] =
val socketApp: WebSocketApp[Any] =
Handler

// Listen for all websocket channel events
Expand Down
14 changes: 7 additions & 7 deletions zio-http-testkit/src/main/scala/zio/http/TestClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import zio.http.{Headers, Method, Scheme, Status, Version}
*/
final case class TestClient(
behavior: Ref[PartialFunction[Request, ZIO[Any, Response, Response]]],
serverSocketBehavior: Ref[SocketApp[Any]],
serverSocketBehavior: Ref[WebSocketApp[Any]],
) extends ZClient.Driver[Any, Throwable] {

/**
Expand Down Expand Up @@ -117,7 +117,7 @@ final case class TestClient(
version: Version,
url: URL,
headers: Headers,
app: SocketApp[Env1],
app: WebSocketApp[Env1],
)(implicit trace: Trace): ZIO[Env1 with Scope, Throwable, Response] = {
for {
env <- ZIO.environment[Env1]
Expand All @@ -127,13 +127,13 @@ final case class TestClient(
promise <- Promise.make[Nothing, Unit]
testChannelClient <- TestChannel.make(in, out, promise)
testChannelServer <- TestChannel.make(out, in, promise)
_ <- currentSocketBehavior.runZIO(testChannelClient).forkDaemon
_ <- app.provideEnvironment(env).runZIO(testChannelServer).forkDaemon
_ <- currentSocketBehavior.handler.runZIO(testChannelClient).forkDaemon
_ <- app.provideEnvironment(env).handler.runZIO(testChannelServer).forkDaemon
} yield Response.status(Status.SwitchingProtocols)
}

def installSocketApp[Env1](
app: Handler[Any, Throwable, WebSocketChannel, Unit],
app: WebSocketApp[Any],
): ZIO[Env1, Nothing, Unit] =
for {
env <- ZIO.environment[Env1]
Expand Down Expand Up @@ -182,15 +182,15 @@ object TestClient {
ZIO.serviceWithZIO[TestClient](_.addHandler(handler))

def installSocketApp(
app: Handler[Any, Throwable, WebSocketChannel, Unit],
app: WebSocketApp[Any],
): ZIO[TestClient, Nothing, Unit] =
ZIO.serviceWithZIO[TestClient](_.installSocketApp(app))

val layer: ZLayer[Any, Nothing, TestClient & Client] =
ZLayer.scopedEnvironment {
for {
behavior <- Ref.make[PartialFunction[Request, ZIO[Any, Response, Response]]](PartialFunction.empty)
socketBehavior <- Ref.make[SocketApp[Any]](Handler.unit)
socketBehavior <- Ref.make[WebSocketApp[Any]](WebSocketApp.unit)
driver = TestClient(behavior, socketBehavior)
} yield ZEnvironment[TestClient, Client](driver, ZClient.fromDriver(driver))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ object SocketContractSpec extends ZIOSpecDefault {
def spec: Spec[Any, Any] =
suite("SocketOps")(
contract("Successful Multi-message application") { p =>
val socketServer: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketServer: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("Hi Server")) =>
Expand All @@ -31,7 +31,7 @@ object SocketContractSpec extends ZIOSpecDefault {

socketServer
} { _ =>
val socketClient: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketClient: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case ChannelEvent.Read(WebSocketFrame.Text("Hi Client")) =>
Expand Down Expand Up @@ -89,8 +89,8 @@ object SocketContractSpec extends ZIOSpecDefault {
private def contract(
name: String,
)(
serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit],
)(clientApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit]) = {
serverApp: Promise[Throwable, Unit] => WebSocketApp[Any],
)(clientApp: Promise[Throwable, Unit] => WebSocketApp[Any]) = {
suite(name)(
test("Live") {
for {
Expand Down Expand Up @@ -123,7 +123,7 @@ object SocketContractSpec extends ZIOSpecDefault {
}

private def liveServerSetup(
serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit],
serverApp: Promise[Throwable, Unit] => WebSocketApp[Any],
): ZIO[Server, Nothing, (RuntimeFlags, Promise[Throwable, Unit])] =
ZIO.serviceWithZIO[Server](server =>
for {
Expand All @@ -133,7 +133,7 @@ object SocketContractSpec extends ZIOSpecDefault {
)

private def testServerSetup(
serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit],
serverApp: Promise[Throwable, Unit] => WebSocketApp[Any],
): ZIO[TestClient, Nothing, (RuntimeFlags, Promise[Throwable, Unit])] =
for {
p <- Promise.make[Throwable, Unit]
Expand Down
4 changes: 2 additions & 2 deletions zio-http-testkit/src/test/scala/zio/http/TestClientSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object TestClientSpec extends ZIOSpecDefault {
),
suite("socket ops")(
test("happy path") {
val socketClient: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketClient: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case ChannelEvent.Read(WebSocketFrame.Text("Hi Client")) =>
Expand All @@ -77,7 +77,7 @@ object TestClientSpec extends ZIOSpecDefault {
}
}

val socketServer: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketServer: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case ChannelEvent.Read(WebSocketFrame.Text("Hi Server")) =>
Expand Down
2 changes: 1 addition & 1 deletion zio-http/src/main/scala/zio/http/ClientDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ trait ClientDriver {
onResponse: Promise[Throwable, Response],
onComplete: Promise[Throwable, ChannelState],
enableKeepAlive: Boolean,
createSocketApp: () => SocketApp[Any],
createSocketApp: () => WebSocketApp[Any],
webSocketConfig: WebSocketConfig,
)(implicit trace: Trace): ZIO[Scope, Throwable, ChannelInterface]

Expand Down
51 changes: 4 additions & 47 deletions zio-http/src/main/scala/zio/http/Handler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,34 +222,6 @@ sealed trait Handler[-R, +Err, -In, +Out] { self =>
)(implicit trace: Trace): Handler[R1, Err1, In1, Out] =
that.andThen(self)

/**
* Creates a socket connection on the provided URL. Typically used to connect
* as a client.
*/
def connect(
url: String,
headers: Headers = Headers.empty,
)(implicit
ev: Err <:< Throwable,
ev2: WebSocketChannel <:< In,
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.fromEither(URL.decode(url)).orDie.flatMap(connect(_, headers))

def connect(
url: URL,
headers: Headers,
)(implicit
ev1: Err <:< Throwable,
ev2: WebSocketChannel <:< In,
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.serviceWithZIO[Client] { client =>
val client2 = if (url.isAbsolute) client.url(url) else client.addUrl(url)

client2.addHeaders(headers).socket(self.asInstanceOf[SocketApp[R]])
}

/**
* Transforms the input of the handler before passing it on to the current
* Handler
Expand Down Expand Up @@ -621,21 +593,6 @@ sealed trait Handler[-R, +Err, -In, +Out] { self =>
HttpApp(Routes.singleton(handler.contramap[(Path, Request)](_._2)))
}

def toHttpAppWS(implicit err: Err <:< Throwable, in: WebSocketChannel <:< In, trace: Trace): HttpApp[R] =
Handler.fromZIO(self.toResponse).toHttpApp

/**
* Creates a new response from a socket handler.
*/
def toResponse(implicit
ev1: Err <:< Throwable,
ev2: WebSocketChannel <:< In,
trace: Trace,
): ZIO[R, Nothing, Response] =
ZIO.environment[R].flatMap { env =>
Response.fromSocketApp(self.asInstanceOf[SocketApp[R]].provideEnvironment(env))
}

/**
* Takes some defects and converts them into failures.
*/
Expand Down Expand Up @@ -1102,10 +1059,10 @@ object Handler {
/**
* Constructs a handler from a function that uses a web socket.
*/
final def webSocket[Env, Err, Out](
f: WebSocketChannel => ZIO[Env, Err, Out],
): Handler[Env, Err, WebSocketChannel, Out] =
Handler.fromFunctionZIO(f)
final def webSocket[Env](
f: WebSocketChannel => ZIO[Env, Throwable, Any],
): WebSocketApp[Env] =
WebSocketApp(Handler.fromFunctionZIO(f))

final implicit class RequestHandlerSyntax[-R, +Err](val self: RequestHandler[R, Err])
extends HeaderModifier[RequestHandler[R, Err]] {
Expand Down
16 changes: 4 additions & 12 deletions zio-http/src/main/scala/zio/http/Response.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ sealed trait Response extends HeaderOps[Response] { self =>
final def status(status: Status): Response =
self.copy(status = status)

private[zio] final def socketApp: Option[SocketApp[Any]] = self match {
private[zio] final def socketApp: Option[WebSocketApp[Any]] = self match {
case Response.GetApp(app) => Some(app)
case _ => None
}
Expand All @@ -140,7 +140,7 @@ object Response {
}

object GetApp {
def unapply(response: Response): Option[SocketApp[Any]] = response match {
def unapply(response: Response): Option[WebSocketApp[Any]] = response match {
case resp: SocketAppResponse => Some(resp.socketApp0)
case _ => None
}
Expand Down Expand Up @@ -186,7 +186,7 @@ object Response {
private[zio] class SocketAppResponse(
val body: Body,
val headers: Headers,
val socketApp0: SocketApp[Any],
val socketApp0: WebSocketApp[Any],
val status: Status,
) extends Response { self =>

Expand Down Expand Up @@ -357,18 +357,10 @@ object Response {
def fromServerSentEvents(data: ZStream[Any, Nothing, ServerSentEvent])(implicit trace: Trace): Response =
new BasicResponse(Body.fromStream(data.map(_.encode)), contentTypeEventStream, Status.Ok)

/**
* Creates a new response for the provided socket
*/
def fromSocket[R](
http: Handler[R, Throwable, WebSocketChannel, Any],
)(implicit trace: Trace): ZIO[R, Nothing, Response] =
fromSocketApp(http)

/**
* Creates a new response for the provided socket app
*/
def fromSocketApp[R](app: SocketApp[R])(implicit trace: Trace): ZIO[R, Nothing, Response] = {
def fromSocketApp[R](app: WebSocketApp[R])(implicit trace: Trace): ZIO[R, Nothing, Response] = {
ZIO.environment[R].map { env =>
new SocketAppResponse(
Body.empty,
Expand Down
87 changes: 87 additions & 0 deletions zio-http/src/main/scala/zio/http/WebSocketApp.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package zio.http

import zio._

final case class WebSocketApp[-R](
handler: Handler[R, Throwable, WebSocketChannel, Any],
customConfig: Option[WebSocketConfig],
) { self =>

/**
* Creates a socket connection on the provided URL. Typically used to connect
* as a client.
*/
def connect(
url: String,
headers: Headers = Headers.empty,
)(implicit
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.fromEither(URL.decode(url)).orDie.flatMap(connect(_, headers))

def connect(
url: URL,
headers: Headers,
)(implicit
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.serviceWithZIO[Client] { client =>
val client2 = if (url.isAbsolute) client.url(url) else client.addUrl(url)

client2.addHeaders(headers).socket(self)
}

def provideEnvironment(r: ZEnvironment[R])(implicit trace: Trace): WebSocketApp[Any] =
WebSocketApp(handler.provideEnvironment(r), customConfig)

def provideLayer[R0](layer: ZLayer[R0, Throwable, R])(implicit
trace: Trace,
): WebSocketApp[R0] =
WebSocketApp(handler.provideLayer(layer), customConfig)

def provideSomeEnvironment[R1](f: ZEnvironment[R1] => ZEnvironment[R])(implicit
trace: Trace,
): WebSocketApp[R1] =
WebSocketApp(handler.provideSomeEnvironment(f), customConfig)

def provideSomeLayer[R0, R1: Tag](
layer: ZLayer[R0, Throwable, R1],
)(implicit ev: R0 with R1 <:< R, trace: Trace): WebSocketApp[R0] =
WebSocketApp(handler.provideSomeLayer(layer), customConfig)

def tapErrorCauseZIO[R1 <: R](
f: Cause[Throwable] => ZIO[R1, Throwable, Any],
)(implicit trace: Trace): WebSocketApp[R1] =
WebSocketApp(handler.tapErrorCauseZIO(f), customConfig)

/**
* Returns a Handler that effectfully peeks at the failure of this SocketApp.
*/
def tapErrorZIO[R1 <: R](
f: Throwable => ZIO[R1, Throwable, Any],
)(implicit trace: Trace): WebSocketApp[R1] =
self.tapErrorCauseZIO(cause => cause.failureOption.fold[ZIO[R1, Throwable, Any]](ZIO.unit)(f))

/**
* Creates a new response from a socket handler.
*/
def toResponse(implicit
trace: Trace,
): ZIO[R, Nothing, Response] =
ZIO.environment[R].flatMap { env =>
Response.fromSocketApp(self.provideEnvironment(env))
}

def toHttpAppWS(implicit trace: Trace): HttpApp[R] =
Handler.fromZIO(self.toResponse).toHttpApp

def withConfig(config: WebSocketConfig): WebSocketApp[R] =
copy(customConfig = Some(config))
}

object WebSocketApp {
def apply[R](handler: Handler[R, Throwable, WebSocketChannel, Any]): WebSocketApp[R] =
WebSocketApp(handler, None)

val unit: WebSocketApp[Any] = WebSocketApp(Handler.unit)
}
Loading

0 comments on commit d1997c6

Please sign in to comment.