diff --git a/zio-http/jvm/src/test/scala/zio/http/ForwardHeaderSpec.scala b/zio-http/jvm/src/test/scala/zio/http/ForwardHeaderSpec.scala new file mode 100644 index 0000000000..59803c7eab --- /dev/null +++ b/zio-http/jvm/src/test/scala/zio/http/ForwardHeaderSpec.scala @@ -0,0 +1,30 @@ +package zio.http + +import zio._ +import zio.test._ + +object ForwardHeaderSpec extends ZIOSpecDefault { + override def spec: Spec[TestEnvironment with Scope, Any] = + suite("ForwardHeaderSpec")( + test("should forward headers") { + val routes = Routes( + Method.GET / "get" -> handler((_: Request) => + for { + client <- ZIO.service[Client] + response <- (client @@ ZClientAspect.forwardHeaders) + .batched(Request.post(url"http://localhost:8080/post", Body.empty)) + } yield response, + ), + Method.POST / "post" -> handler((req: Request) => Response.ok.addHeader(req.header(Header.Accept).get)), + ).sandbox @@ Middleware.forwardHeaders(Header.Accept) + + for { + _ <- Server.install(routes) + response <- Client.batched( + Request.get(url"http://localhost:8080/get").addHeader(Header.Accept(MediaType.application.json)), + ) + } yield assertTrue(response.headers(Header.Accept).contains(Header.Accept(MediaType.application.json))) + }, + ).provide(Client.default, Server.default) @@ TestAspect.withLiveClock + +} diff --git a/zio-http/shared/src/main/scala/zio/http/Header.scala b/zio-http/shared/src/main/scala/zio/http/Header.scala index 76b28571af..4a3d2c101b 100644 --- a/zio-http/shared/src/main/scala/zio/http/Header.scala +++ b/zio-http/shared/src/main/scala/zio/http/Header.scala @@ -51,7 +51,7 @@ sealed trait Header { object Header { sealed trait HeaderType { - type HeaderValue + type HeaderValue <: Header def name: String diff --git a/zio-http/shared/src/main/scala/zio/http/Middleware.scala b/zio-http/shared/src/main/scala/zio/http/Middleware.scala index 2e61b221e8..8023395fdf 100644 --- a/zio-http/shared/src/main/scala/zio/http/Middleware.scala +++ b/zio-http/shared/src/main/scala/zio/http/Middleware.scala @@ -174,6 +174,86 @@ object Middleware extends HandlerAspects { } } + def ensureHeader(header: Header.HeaderType)(make: => header.HeaderValue): Middleware[Any] = + new Middleware[Any] { + def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = + routes.transform[Env1] { h => + handler { (req: Request) => + if (req.headers.contains(header.name)) h(req) + else h(req.addHeader(make)) + } + } + } + + def ensureHeader(headerName: String)(make: => String): Middleware[Any] = + new Middleware[Any] { + def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = + routes.transform[Env1] { h => + handler { (req: Request) => + if (req.headers.contains(headerName)) h(req) + else h(req.addHeader(headerName, make)) + } + } + } + + private[http] case class ForwardedHeaders(headers: Headers) + + def forwardHeaders(header: Header.HeaderType, headers: Header.HeaderType*)(implicit + trace: Trace, + ): Middleware[Any] = { + val allHeaders = header +: headers + new Middleware[Any] { + def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = + routes.transform[Env1] { h => + handler { (req: Request) => + val headerValues = ChunkBuilder.make[Header]() + headerValues.sizeHint(allHeaders.length) + var i = 0 + while (i < allHeaders.length) { + val name = allHeaders(i) + req.headers.get(name).foreach { value => + headerValues += value + } + i += 1 + } + RequestStore.update[ForwardedHeaders] { old => + ForwardedHeaders { + old.map(_.headers).getOrElse(Headers.empty) ++ + Headers.fromIterable(headerValues.result()) + } + } *> h(req) + } + } + } + } + + def forwardHeaders(headerName: String, headerNames: String*)(implicit trace: Trace): Middleware[Any] = { + val allHeaders = headerName +: headerNames + new Middleware[Any] { + def apply[Env1 <: Any, Err](routes: Routes[Env1, Err]): Routes[Env1, Err] = + routes.transform[Env1] { h => + handler { (req: Request) => + val headerValues = ChunkBuilder.make[Header]() + headerValues.sizeHint(allHeaders.length) + var i = 0 + while (i < allHeaders.length) { + val name = allHeaders(i) + req.headers.get(name).foreach { value => + headerValues += Header.Custom(name, value) + } + i += 1 + } + RequestStore.update[ForwardedHeaders] { old => + ForwardedHeaders { + old.map(_.headers).getOrElse(Headers.empty) ++ + Headers.fromIterable(headerValues.result()) + } + } *> h(req) + } + } + } + } + def logAnnotate(key: => String, value: => String)(implicit trace: Trace): Middleware[Any] = logAnnotate(LogAnnotation(key, value)) diff --git a/zio-http/shared/src/main/scala/zio/http/RequestStore.scala b/zio-http/shared/src/main/scala/zio/http/RequestStore.scala new file mode 100644 index 0000000000..b32ffbf4c0 --- /dev/null +++ b/zio-http/shared/src/main/scala/zio/http/RequestStore.scala @@ -0,0 +1,22 @@ +package zio.http + +import zio.{FiberRef, Tag, Unsafe, ZIO} + +object RequestStore { + + private[http] val requestStore: FiberRef[Map[Tag[_], Any]] = + FiberRef.unsafe.make[Map[Tag[_], Any]](Map.empty)(Unsafe.unsafe) + + def get[A: Tag]: ZIO[Any, Nothing, Option[A]] = + requestStore.get.map(_.get(implicitly[Tag[A]]).asInstanceOf[Option[A]]) + + def set[A: Tag](a: A): ZIO[Any, Nothing, Unit] = + requestStore.update(_.updated(implicitly[Tag[A]], a)) + + def update[A: Tag](a: Option[A] => A): ZIO[Any, Nothing, Unit] = + for { + current <- get[A] + _ <- set(a(current)) + } yield () + +} diff --git a/zio-http/shared/src/main/scala/zio/http/ZClientAspect.scala b/zio-http/shared/src/main/scala/zio/http/ZClientAspect.scala index 87721e9a0e..01ab47dc08 100644 --- a/zio-http/shared/src/main/scala/zio/http/ZClientAspect.scala +++ b/zio-http/shared/src/main/scala/zio/http/ZClientAspect.scala @@ -447,4 +447,49 @@ object ZClientAspect { } } } + + final def forwardHeaders: ZClientAspect[Nothing, Any, Nothing, Body, Nothing, Any, Nothing, Response] = + new ZClientAspect[Nothing, Any, Nothing, Body, Nothing, Any, Nothing, Response] { + override def apply[ + ReqEnv, + Env >: Nothing <: Any, + In >: Nothing <: Body, + Err >: Nothing <: Any, + Out >: Nothing <: Response, + ]( + client: ZClient[Env, ReqEnv, In, Err, Out], + ): ZClient[Env, ReqEnv, In, Err, Out] = + client.copy( + driver = new ZClient.Driver[Env, ReqEnv, Err] { + override def request( + version: Version, + method: Method, + url: URL, + headers: Headers, + body: Body, + sslConfig: Option[ClientSSLConfig], + proxy: Option[Proxy], + )(implicit trace: Trace): ZIO[Env & ReqEnv, Err, Response] = + RequestStore.get[Middleware.ForwardedHeaders].flatMap { + case Some(forwardedHeaders) => + client.driver + .request(version, method, url, headers ++ forwardedHeaders.headers, body, sslConfig, proxy) + case None => + client.driver.request(version, method, url, headers, body, sslConfig, proxy) + } + + override def socket[Env1 <: Env](version: Version, url: URL, headers: Headers, app: WebSocketApp[Env1])( + implicit + trace: Trace, + ev: ReqEnv =:= Scope, + ): ZIO[Env1 & ReqEnv, Err, Response] = + RequestStore.get[Middleware.ForwardedHeaders].flatMap { + case Some(forwardedHeaders) => + client.driver.socket(version, url, headers ++ forwardedHeaders.headers, app) + case None => + client.driver.socket(version, url, headers, app) + } + }, + ) + } }