Skip to content

Commit

Permalink
Use user-defined content-length for HEAD responses (#3084)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyri-petrou authored Sep 5, 2024
1 parent 825ef5d commit 6c68278
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,51 +14,55 @@
* limitations under the License.
*/

package zio.http.netty
package zio.http.netty.server

import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.http._
import zio.http.netty.CachedDateHeader
import zio.http.netty.model.Conversions

import io.netty.buffer.Unpooled
import io.netty.channel.ChannelHandlerContext
import io.netty.handler.codec.http._

private[zio] object NettyResponseEncoder {
private object NettyResponseEncoder {
private val dateHeaderCache = CachedDateHeader.default

def encode(response: Response)(implicit unsafe: Unsafe): HttpResponse =
def encode(method: Method, response: Response)(implicit unsafe: Unsafe): HttpResponse =
response.body match {
case body: Body.UnsafeBytes =>
fastEncode(response, body.unsafeAsArray)
fastEncode(method, response, body.unsafeAsArray)
case body =>
val status = response.status
val jHeaders = Conversions.headersToNetty(response.headers)
val jStatus = Conversions.statusToNetty(status)
maybeAddDateHeader(jHeaders, status)

body.knownContentLength match {
case Some(contentLength) =>
jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, contentLength)
case _ if jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH) =>
()
case _ =>
jHeaders.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED)
}

val hasContentLength = jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH)

// See https://github.com/zio/zio-http/issues/3080
if (method == Method.HEAD && hasContentLength) ()
else
body.knownContentLength match {
case Some(contentLength) =>
jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, contentLength)
case _ if !hasContentLength =>
jHeaders.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED)
case _ =>
()
}
new DefaultHttpResponse(HttpVersion.HTTP_1_1, jStatus, jHeaders)
}

def fastEncode(response: Response, bytes: Array[Byte])(implicit unsafe: Unsafe): FullHttpResponse = {
def fastEncode(method: Method, response: Response, bytes: Array[Byte])(implicit unsafe: Unsafe): FullHttpResponse = {
if (response.encoded eq null) {
response.encoded = doEncode(response, bytes)
response.encoded = doEncode(method, response, bytes)
}
response.encoded.asInstanceOf[FullHttpResponse]
}

private def doEncode(response: Response, bytes: Array[Byte]): FullHttpResponse = {
private def doEncode(method: Method, response: Response, bytes: Array[Byte]): FullHttpResponse = {
val jHeaders = Conversions.headersToNetty(response.headers)
val status = response.status
maybeAddDateHeader(jHeaders, status)
Expand All @@ -67,8 +71,13 @@ private[zio] object NettyResponseEncoder {

val jContent = Unpooled.wrappedBuffer(bytes)

// The content-length MUST match the length of the content we are sending, so we ignore any user-provided value
jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, bytes.length)
/*
* The content-length MUST match the length of the content we are sending,
* except for HEAD requests where the content-length must equal the length
* of the content we would have sent if this was a GET request.
*/
if (method == Method.HEAD && jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH)) ()
else jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, bytes.length)

new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, jStatus, jContent, jHeaders, EmptyHttpHeaders.INSTANCE)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ private[zio] final case class ServerInboundHandler(
val throwable = jReq.decoderResult().cause()
attemptFastWrite(
ctx,
Conversions.methodFromNetty(jReq.method()),
Response.fromThrowable(throwable, runtime.getRef(ErrorResponseConfig.configRef)),
)
releaseRequest()
} else {
val req = makeZioRequest(ctx, jReq)
val exit = app(req)
if (attemptImmediateWrite(ctx, exit)) {
if (attemptImmediateWrite(ctx, req.method, exit)) {
releaseRequest()
} else {
writeResponse(ctx, runtime, exit, req)(releaseRequest)
Expand Down Expand Up @@ -140,11 +141,12 @@ private[zio] final case class ServerInboundHandler(

private def attemptFastWrite(
ctx: ChannelHandlerContext,
method: Method,
response: Response,
): Boolean = {

def fastEncode(response: Response, bytes: Array[Byte]) = {
val jResponse = NettyResponseEncoder.fastEncode(response, bytes)
val jResponse = NettyResponseEncoder.fastEncode(method, response, bytes)
val djResponse = jResponse.retainedDuplicate()
ctx.writeAndFlush(djResponse, ctx.voidPromise())
true
Expand Down Expand Up @@ -173,7 +175,7 @@ private[zio] final case class ServerInboundHandler(
upgradeToWebSocket(ctx, request, socketApp, runtime).as(None)
case _ =>
ZIO.attempt {
val jResponse = NettyResponseEncoder.encode(response)
val jResponse = NettyResponseEncoder.encode(request.method, response)

if (!jResponse.isInstanceOf[FullHttpResponse]) {

Expand All @@ -197,11 +199,12 @@ private[zio] final case class ServerInboundHandler(

private def attemptImmediateWrite(
ctx: ChannelHandlerContext,
method: Method,
exit: ZIO[Any, Response, Response],
): Boolean = {
exit match {
case Exit.Success(response) if response ne null =>
attemptFastWrite(ctx, response)
attemptFastWrite(ctx, method, response)
case _ => false
}
}
Expand Down Expand Up @@ -310,12 +313,12 @@ private[zio] final case class ServerInboundHandler(
NettyFutureExecutor.executed(ctx.channel().close())

def writeResponse(response: Response): Task[Unit] =
if (attemptFastWrite(ctx, response)) {
if (attemptFastWrite(ctx, req.method, response)) {
Exit.unit
} else {
attemptFullWrite(ctx, runtime, response, req).foldCauseZIO(
cause => {
attemptFastWrite(ctx, withDefaultErrorResponse(cause.squash))
attemptFastWrite(ctx, req.method, withDefaultErrorResponse(cause.squash))
Exit.unit
},
{
Expand Down
25 changes: 25 additions & 0 deletions zio-http/jvm/src/test/scala/zio/http/ServerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,31 @@ object ServerSpec extends HttpRunnableSpec {
.contentLength
.run()
assertZIO(res)(isSome(equalTo(Header.ContentLength(10L))))
} +
test("provided content-length is used for HEAD requests") {
val res =
Handler.ok
.addHeader(Header.ContentLength(4L))
.sandbox
.toRoutes
.deploy
.contentLength
.run(method = Method.HEAD)
assertZIO(res)(isSome(equalTo(Header.ContentLength(4L))))
} +
test("provided content-length is used for HEAD requests with stream body") {
// NOTE: Unlikely use-case, but just in case some 3rd party integration
// uses streams as a generalised way to provide content
val res =
Handler
.fromStream(ZStream.empty, 0L)
.addHeader(Header.ContentLength(4L))
.sandbox
.toRoutes
.deploy
.contentLength
.run(method = Method.HEAD)
assertZIO(res)(isSome(equalTo(Header.ContentLength(4L))))
}
},
),
Expand Down

0 comments on commit 6c68278

Please sign in to comment.