Skip to content

Commit

Permalink
Add middlewares to forward header and ensure the existence of a header (
Browse files Browse the repository at this point in the history
#2808)

* Add middlewares to forward header and ensure the existence of a header

* Fix Scala 3

* Migrate main changes
  • Loading branch information
987Nabil authored Sep 5, 2024
1 parent ecce294 commit 4087df4
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 1 deletion.
30 changes: 30 additions & 0 deletions zio-http/jvm/src/test/scala/zio/http/ForwardHeaderSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package zio.http

import zio._
import zio.test._

object ForwardHeaderSpec extends ZIOSpecDefault {
override def spec: Spec[TestEnvironment with Scope, Any] =
suite("ForwardHeaderSpec")(
test("should forward headers") {
val routes = Routes(
Method.GET / "get" -> handler((_: Request) =>
for {
client <- ZIO.service[Client]
response <- (client @@ ZClientAspect.forwardHeaders)
.batched(Request.post(url"http://localhost:8080/post", Body.empty))
} yield response,
),
Method.POST / "post" -> handler((req: Request) => Response.ok.addHeader(req.header(Header.Accept).get)),
).sandbox @@ Middleware.forwardHeaders(Header.Accept)

for {
_ <- Server.install(routes)
response <- Client.batched(
Request.get(url"http://localhost:8080/get").addHeader(Header.Accept(MediaType.application.json)),
)
} yield assertTrue(response.headers(Header.Accept).contains(Header.Accept(MediaType.application.json)))
},
).provide(Client.default, Server.default) @@ TestAspect.withLiveClock

}
2 changes: 1 addition & 1 deletion zio-http/shared/src/main/scala/zio/http/Header.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ sealed trait Header {
object Header {

sealed trait HeaderType {
type HeaderValue
type HeaderValue <: Header

def name: String

Expand Down
80 changes: 80 additions & 0 deletions zio-http/shared/src/main/scala/zio/http/Middleware.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,86 @@ object Middleware extends HandlerAspects {
}
}

def ensureHeader(header: Header.HeaderType)(make: => header.HeaderValue): Middleware[Any] =
new Middleware[Any] {
def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
routes.transform[Env1] { h =>
handler { (req: Request) =>
if (req.headers.contains(header.name)) h(req)
else h(req.addHeader(make))
}
}
}

def ensureHeader(headerName: String)(make: => String): Middleware[Any] =
new Middleware[Any] {
def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
routes.transform[Env1] { h =>
handler { (req: Request) =>
if (req.headers.contains(headerName)) h(req)
else h(req.addHeader(headerName, make))
}
}
}

private[http] case class ForwardedHeaders(headers: Headers)

def forwardHeaders(header: Header.HeaderType, headers: Header.HeaderType*)(implicit
trace: Trace,
): Middleware[Any] = {
val allHeaders = header +: headers
new Middleware[Any] {
def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
routes.transform[Env1] { h =>
handler { (req: Request) =>
val headerValues = ChunkBuilder.make[Header]()
headerValues.sizeHint(allHeaders.length)
var i = 0
while (i < allHeaders.length) {
val name = allHeaders(i)
req.headers.get(name).foreach { value =>
headerValues += value
}
i += 1
}
RequestStore.update[ForwardedHeaders] { old =>
ForwardedHeaders {
old.map(_.headers).getOrElse(Headers.empty) ++
Headers.fromIterable(headerValues.result())
}
} *> h(req)
}
}
}
}

def forwardHeaders(headerName: String, headerNames: String*)(implicit trace: Trace): Middleware[Any] = {
val allHeaders = headerName +: headerNames
new Middleware[Any] {
def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] =
routes.transform[Env1] { h =>
handler { (req: Request) =>
val headerValues = ChunkBuilder.make[Header]()
headerValues.sizeHint(allHeaders.length)
var i = 0
while (i < allHeaders.length) {
val name = allHeaders(i)
req.headers.get(name).foreach { value =>
headerValues += Header.Custom(name, value)
}
i += 1
}
RequestStore.update[ForwardedHeaders] { old =>
ForwardedHeaders {
old.map(_.headers).getOrElse(Headers.empty) ++
Headers.fromIterable(headerValues.result())
}
} *> h(req)
}
}
}
}

def logAnnotate(key: => String, value: => String)(implicit trace: Trace): Middleware[Any] =
logAnnotate(LogAnnotation(key, value))

Expand Down
22 changes: 22 additions & 0 deletions zio-http/shared/src/main/scala/zio/http/RequestStore.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package zio.http

import zio.{FiberRef, Tag, Unsafe, ZIO}

object RequestStore {

private[http] val requestStore: FiberRef[Map[Tag[_], Any]] =
FiberRef.unsafe.make[Map[Tag[_], Any]](Map.empty)(Unsafe.unsafe)

def get[A: Tag]: ZIO[Any, Nothing, Option[A]] =
requestStore.get.map(_.get(implicitly[Tag[A]]).asInstanceOf[Option[A]])

def set[A: Tag](a: A): ZIO[Any, Nothing, Unit] =
requestStore.update(_.updated(implicitly[Tag[A]], a))

def update[A: Tag](a: Option[A] => A): ZIO[Any, Nothing, Unit] =
for {
current <- get[A]
_ <- set(a(current))
} yield ()

}
45 changes: 45 additions & 0 deletions zio-http/shared/src/main/scala/zio/http/ZClientAspect.scala
Original file line number Diff line number Diff line change
Expand Up @@ -447,4 +447,49 @@ object ZClientAspect {
}
}
}

final def forwardHeaders: ZClientAspect[Nothing, Any, Nothing, Body, Nothing, Any, Nothing, Response] =
new ZClientAspect[Nothing, Any, Nothing, Body, Nothing, Any, Nothing, Response] {
override def apply[
ReqEnv,
Env >: Nothing <: Any,
In >: Nothing <: Body,
Err >: Nothing <: Any,
Out >: Nothing <: Response,
](
client: ZClient[Env, ReqEnv, In, Err, Out],
): ZClient[Env, ReqEnv, In, Err, Out] =
client.copy(
driver = new ZClient.Driver[Env, ReqEnv, Err] {
override def request(
version: Version,
method: Method,
url: URL,
headers: Headers,
body: Body,
sslConfig: Option[ClientSSLConfig],
proxy: Option[Proxy],
)(implicit trace: Trace): ZIO[Env & ReqEnv, Err, Response] =
RequestStore.get[Middleware.ForwardedHeaders].flatMap {
case Some(forwardedHeaders) =>
client.driver
.request(version, method, url, headers ++ forwardedHeaders.headers, body, sslConfig, proxy)
case None =>
client.driver.request(version, method, url, headers, body, sslConfig, proxy)
}

override def socket[Env1 <: Env](version: Version, url: URL, headers: Headers, app: WebSocketApp[Env1])(
implicit
trace: Trace,
ev: ReqEnv =:= Scope,
): ZIO[Env1 & ReqEnv, Err, Response] =
RequestStore.get[Middleware.ForwardedHeaders].flatMap {
case Some(forwardedHeaders) =>
client.driver.socket(version, url, headers ++ forwardedHeaders.headers, app)
case None =>
client.driver.socket(version, url, headers, app)
}
},
)
}
}

0 comments on commit 4087df4

Please sign in to comment.