Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SocketApp refactoring, ability to provide custom websocket config #2338

Merged
merged 7 commits into from
Aug 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/advanced/websocket-server.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import zio.http.codec.PathCodec.string

object WebSocketAdvanced extends ZIOAppDefault {

val socketApp: SocketApp[Any] =
val socketApp: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("end")) =>
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/basic/websocket.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import zio.http._
import zio.http.codec.PathCodec.string

object WebSocketEcho extends ZIOAppDefault {
private val socketApp: SocketApp[Any] =
private val socketApp: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("FOO")) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import zio.http.codec.PathCodec.string

object WebSocketAdvanced extends ZIOAppDefault {

val socketApp: SocketApp[Any] =
val socketApp: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("end")) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import zio.http._
import zio.http.codec.PathCodec.string

object WebSocketEcho extends ZIOAppDefault {
private val socketApp: SocketApp[Any] =
private val socketApp: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("FOO")) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object WebSocketReconnectingClient extends ZIOAppDefault {
val url = "ws://ws.vi-server.org/mirror"

// A promise is used to be able to notify application about websocket errors
def makeSocketApp(p: Promise[Nothing, Throwable]): SocketApp[Any] =
def makeSocketApp(p: Promise[Nothing, Throwable]): WebSocketApp[Any] =
Handler

// Listen for all websocket channel events
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ object WebSocketSimpleClient extends ZIOAppDefault {

val url = "ws://ws.vi-server.org/mirror"

val socketApp: SocketApp[Any] =
val socketApp: WebSocketApp[Any] =
Handler

// Listen for all websocket channel events
Expand Down
14 changes: 7 additions & 7 deletions zio-http-testkit/src/main/scala/zio/http/TestClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import zio.http.{Headers, Method, Scheme, Status, Version}
*/
final case class TestClient(
behavior: Ref[PartialFunction[Request, ZIO[Any, Response, Response]]],
serverSocketBehavior: Ref[SocketApp[Any]],
serverSocketBehavior: Ref[WebSocketApp[Any]],
) extends ZClient.Driver[Any, Throwable] {

/**
Expand Down Expand Up @@ -117,7 +117,7 @@ final case class TestClient(
version: Version,
url: URL,
headers: Headers,
app: SocketApp[Env1],
app: WebSocketApp[Env1],
)(implicit trace: Trace): ZIO[Env1 with Scope, Throwable, Response] = {
for {
env <- ZIO.environment[Env1]
Expand All @@ -127,13 +127,13 @@ final case class TestClient(
promise <- Promise.make[Nothing, Unit]
testChannelClient <- TestChannel.make(in, out, promise)
testChannelServer <- TestChannel.make(out, in, promise)
_ <- currentSocketBehavior.runZIO(testChannelClient).forkDaemon
_ <- app.provideEnvironment(env).runZIO(testChannelServer).forkDaemon
_ <- currentSocketBehavior.handler.runZIO(testChannelClient).forkDaemon
_ <- app.provideEnvironment(env).handler.runZIO(testChannelServer).forkDaemon
} yield Response.status(Status.SwitchingProtocols)
}

def installSocketApp[Env1](
app: Handler[Any, Throwable, WebSocketChannel, Unit],
app: WebSocketApp[Any],
): ZIO[Env1, Nothing, Unit] =
for {
env <- ZIO.environment[Env1]
Expand Down Expand Up @@ -182,15 +182,15 @@ object TestClient {
ZIO.serviceWithZIO[TestClient](_.addHandler(handler))

def installSocketApp(
app: Handler[Any, Throwable, WebSocketChannel, Unit],
app: WebSocketApp[Any],
): ZIO[TestClient, Nothing, Unit] =
ZIO.serviceWithZIO[TestClient](_.installSocketApp(app))

val layer: ZLayer[Any, Nothing, TestClient & Client] =
ZLayer.scopedEnvironment {
for {
behavior <- Ref.make[PartialFunction[Request, ZIO[Any, Response, Response]]](PartialFunction.empty)
socketBehavior <- Ref.make[SocketApp[Any]](Handler.unit)
socketBehavior <- Ref.make[WebSocketApp[Any]](WebSocketApp.unit)
driver = TestClient(behavior, socketBehavior)
} yield ZEnvironment[TestClient, Client](driver, ZClient.fromDriver(driver))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ object SocketContractSpec extends ZIOSpecDefault {
def spec: Spec[Any, Any] =
suite("SocketOps")(
contract("Successful Multi-message application") { p =>
val socketServer: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketServer: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case Read(WebSocketFrame.Text("Hi Server")) =>
Expand All @@ -31,7 +31,7 @@ object SocketContractSpec extends ZIOSpecDefault {

socketServer
} { _ =>
val socketClient: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketClient: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case ChannelEvent.Read(WebSocketFrame.Text("Hi Client")) =>
Expand Down Expand Up @@ -89,8 +89,8 @@ object SocketContractSpec extends ZIOSpecDefault {
private def contract(
name: String,
)(
serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit],
)(clientApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit]) = {
serverApp: Promise[Throwable, Unit] => WebSocketApp[Any],
)(clientApp: Promise[Throwable, Unit] => WebSocketApp[Any]) = {
suite(name)(
test("Live") {
for {
Expand Down Expand Up @@ -123,7 +123,7 @@ object SocketContractSpec extends ZIOSpecDefault {
}

private def liveServerSetup(
serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit],
serverApp: Promise[Throwable, Unit] => WebSocketApp[Any],
): ZIO[Server, Nothing, (RuntimeFlags, Promise[Throwable, Unit])] =
ZIO.serviceWithZIO[Server](server =>
for {
Expand All @@ -133,7 +133,7 @@ object SocketContractSpec extends ZIOSpecDefault {
)

private def testServerSetup(
serverApp: Promise[Throwable, Unit] => Handler[Any, Throwable, WebSocketChannel, Unit],
serverApp: Promise[Throwable, Unit] => WebSocketApp[Any],
): ZIO[TestClient, Nothing, (RuntimeFlags, Promise[Throwable, Unit])] =
for {
p <- Promise.make[Throwable, Unit]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ object TestClientSpec extends ZIOSpecDefault {
),
suite("socket ops")(
test("happy path") {
val socketClient: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketClient: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case ChannelEvent.Read(WebSocketFrame.Text("Hi Client")) =>
Expand All @@ -77,7 +77,7 @@ object TestClientSpec extends ZIOSpecDefault {
}
}

val socketServer: Handler[Any, Throwable, WebSocketChannel, Unit] =
val socketServer: WebSocketApp[Any] =
Handler.webSocket { channel =>
channel.receiveAll {
case ChannelEvent.Read(WebSocketFrame.Text("Hi Server")) =>
Expand Down
2 changes: 1 addition & 1 deletion zio-http/src/main/scala/zio/http/ClientDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ trait ClientDriver {
onResponse: Promise[Throwable, Response],
onComplete: Promise[Throwable, ChannelState],
enableKeepAlive: Boolean,
createSocketApp: () => SocketApp[Any],
createSocketApp: () => WebSocketApp[Any],
webSocketConfig: WebSocketConfig,
)(implicit trace: Trace): ZIO[Scope, Throwable, ChannelInterface]

Expand Down
51 changes: 4 additions & 47 deletions zio-http/src/main/scala/zio/http/Handler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,34 +222,6 @@ sealed trait Handler[-R, +Err, -In, +Out] { self =>
)(implicit trace: Trace): Handler[R1, Err1, In1, Out] =
that.andThen(self)

/**
* Creates a socket connection on the provided URL. Typically used to connect
* as a client.
*/
def connect(
url: String,
headers: Headers = Headers.empty,
)(implicit
ev: Err <:< Throwable,
ev2: WebSocketChannel <:< In,
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.fromEither(URL.decode(url)).orDie.flatMap(connect(_, headers))

def connect(
url: URL,
headers: Headers,
)(implicit
ev1: Err <:< Throwable,
ev2: WebSocketChannel <:< In,
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.serviceWithZIO[Client] { client =>
val client2 = if (url.isAbsolute) client.url(url) else client.addUrl(url)

client2.addHeaders(headers).socket(self.asInstanceOf[SocketApp[R]])
}

/**
* Transforms the input of the handler before passing it on to the current
* Handler
Expand Down Expand Up @@ -621,21 +593,6 @@ sealed trait Handler[-R, +Err, -In, +Out] { self =>
HttpApp(Routes.singleton(handler.contramap[(Path, Request)](_._2)))
}

def toHttpAppWS(implicit err: Err <:< Throwable, in: WebSocketChannel <:< In, trace: Trace): HttpApp[R] =
Handler.fromZIO(self.toResponse).toHttpApp

/**
* Creates a new response from a socket handler.
*/
def toResponse(implicit
ev1: Err <:< Throwable,
ev2: WebSocketChannel <:< In,
trace: Trace,
): ZIO[R, Nothing, Response] =
ZIO.environment[R].flatMap { env =>
Response.fromSocketApp(self.asInstanceOf[SocketApp[R]].provideEnvironment(env))
}

/**
* Takes some defects and converts them into failures.
*/
Expand Down Expand Up @@ -1102,10 +1059,10 @@ object Handler {
/**
* Constructs a handler from a function that uses a web socket.
*/
final def webSocket[Env, Err, Out](
f: WebSocketChannel => ZIO[Env, Err, Out],
): Handler[Env, Err, WebSocketChannel, Out] =
Handler.fromFunctionZIO(f)
final def webSocket[Env](
f: WebSocketChannel => ZIO[Env, Throwable, Any],
): WebSocketApp[Env] =
WebSocketApp(Handler.fromFunctionZIO(f))

final implicit class RequestHandlerSyntax[-R, +Err](val self: RequestHandler[R, Err])
extends HeaderModifier[RequestHandler[R, Err]] {
Expand Down
16 changes: 4 additions & 12 deletions zio-http/src/main/scala/zio/http/Response.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ sealed trait Response extends HeaderOps[Response] { self =>
final def status(status: Status): Response =
self.copy(status = status)

private[zio] final def socketApp: Option[SocketApp[Any]] = self match {
private[zio] final def socketApp: Option[WebSocketApp[Any]] = self match {
case Response.GetApp(app) => Some(app)
case _ => None
}
Expand All @@ -140,7 +140,7 @@ object Response {
}

object GetApp {
def unapply(response: Response): Option[SocketApp[Any]] = response match {
def unapply(response: Response): Option[WebSocketApp[Any]] = response match {
case resp: SocketAppResponse => Some(resp.socketApp0)
case _ => None
}
Expand Down Expand Up @@ -186,7 +186,7 @@ object Response {
private[zio] class SocketAppResponse(
val body: Body,
val headers: Headers,
val socketApp0: SocketApp[Any],
val socketApp0: WebSocketApp[Any],
val status: Status,
) extends Response { self =>

Expand Down Expand Up @@ -357,18 +357,10 @@ object Response {
def fromServerSentEvents(data: ZStream[Any, Nothing, ServerSentEvent])(implicit trace: Trace): Response =
new BasicResponse(Body.fromStream(data.map(_.encode)), contentTypeEventStream, Status.Ok)

/**
* Creates a new response for the provided socket
*/
def fromSocket[R](
http: Handler[R, Throwable, WebSocketChannel, Any],
)(implicit trace: Trace): ZIO[R, Nothing, Response] =
fromSocketApp(http)

/**
* Creates a new response for the provided socket app
*/
def fromSocketApp[R](app: SocketApp[R])(implicit trace: Trace): ZIO[R, Nothing, Response] = {
def fromSocketApp[R](app: WebSocketApp[R])(implicit trace: Trace): ZIO[R, Nothing, Response] = {
ZIO.environment[R].map { env =>
new SocketAppResponse(
Body.empty,
Expand Down
87 changes: 87 additions & 0 deletions zio-http/src/main/scala/zio/http/WebSocketApp.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package zio.http

import zio._

final case class WebSocketApp[-R](
handler: Handler[R, Throwable, WebSocketChannel, Any],
customConfig: Option[WebSocketConfig],
) { self =>

/**
* Creates a socket connection on the provided URL. Typically used to connect
* as a client.
*/
def connect(
url: String,
headers: Headers = Headers.empty,
)(implicit
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.fromEither(URL.decode(url)).orDie.flatMap(connect(_, headers))

def connect(
url: URL,
headers: Headers,
)(implicit
trace: Trace,
): ZIO[R with Client with Scope, Throwable, Response] =
ZIO.serviceWithZIO[Client] { client =>
val client2 = if (url.isAbsolute) client.url(url) else client.addUrl(url)

client2.addHeaders(headers).socket(self)
}

def provideEnvironment(r: ZEnvironment[R])(implicit trace: Trace): WebSocketApp[Any] =
WebSocketApp(handler.provideEnvironment(r), customConfig)

def provideLayer[R0](layer: ZLayer[R0, Throwable, R])(implicit
trace: Trace,
): WebSocketApp[R0] =
WebSocketApp(handler.provideLayer(layer), customConfig)

def provideSomeEnvironment[R1](f: ZEnvironment[R1] => ZEnvironment[R])(implicit
trace: Trace,
): WebSocketApp[R1] =
WebSocketApp(handler.provideSomeEnvironment(f), customConfig)

def provideSomeLayer[R0, R1: Tag](
layer: ZLayer[R0, Throwable, R1],
)(implicit ev: R0 with R1 <:< R, trace: Trace): WebSocketApp[R0] =
WebSocketApp(handler.provideSomeLayer(layer), customConfig)

def tapErrorCauseZIO[R1 <: R](
f: Cause[Throwable] => ZIO[R1, Throwable, Any],
)(implicit trace: Trace): WebSocketApp[R1] =
WebSocketApp(handler.tapErrorCauseZIO(f), customConfig)

/**
* Returns a Handler that effectfully peeks at the failure of this SocketApp.
*/
def tapErrorZIO[R1 <: R](
f: Throwable => ZIO[R1, Throwable, Any],
)(implicit trace: Trace): WebSocketApp[R1] =
self.tapErrorCauseZIO(cause => cause.failureOption.fold[ZIO[R1, Throwable, Any]](ZIO.unit)(f))

/**
* Creates a new response from a socket handler.
*/
def toResponse(implicit
trace: Trace,
): ZIO[R, Nothing, Response] =
ZIO.environment[R].flatMap { env =>
Response.fromSocketApp(self.provideEnvironment(env))
}

def toHttpAppWS(implicit trace: Trace): HttpApp[R] =
Handler.fromZIO(self.toResponse).toHttpApp

def withConfig(config: WebSocketConfig): WebSocketApp[R] =
copy(customConfig = Some(config))
}

object WebSocketApp {
def apply[R](handler: Handler[R, Throwable, WebSocketChannel, Any]): WebSocketApp[R] =
WebSocketApp(handler, None)

val unit: WebSocketApp[Any] = WebSocketApp(Handler.unit)
}
Loading
Loading