diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponseEncoder.scala b/zio-http/jvm/src/main/scala/zio/http/netty/server/NettyResponseEncoder.scala similarity index 59% rename from zio-http/jvm/src/main/scala/zio/http/netty/NettyResponseEncoder.scala rename to zio-http/jvm/src/main/scala/zio/http/netty/server/NettyResponseEncoder.scala index 4485005b36..cb6470068f 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponseEncoder.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/server/NettyResponseEncoder.scala @@ -14,51 +14,55 @@ * limitations under the License. */ -package zio.http.netty +package zio.http.netty.server import zio._ import zio.stacktracer.TracingImplicits.disableAutoTrace import zio.http._ +import zio.http.netty.CachedDateHeader import zio.http.netty.model.Conversions import io.netty.buffer.Unpooled -import io.netty.channel.ChannelHandlerContext import io.netty.handler.codec.http._ -private[zio] object NettyResponseEncoder { +private object NettyResponseEncoder { private val dateHeaderCache = CachedDateHeader.default - def encode(response: Response)(implicit unsafe: Unsafe): HttpResponse = + def encode(method: Method, response: Response)(implicit unsafe: Unsafe): HttpResponse = response.body match { case body: Body.UnsafeBytes => - fastEncode(response, body.unsafeAsArray) + fastEncode(method, response, body.unsafeAsArray) case body => val status = response.status val jHeaders = Conversions.headersToNetty(response.headers) val jStatus = Conversions.statusToNetty(status) maybeAddDateHeader(jHeaders, status) - body.knownContentLength match { - case Some(contentLength) => - jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, contentLength) - case _ if jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH) => - () - case _ => - jHeaders.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) - } - + val hasContentLength = jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH) + + // See https://github.com/zio/zio-http/issues/3080 + if (method == Method.HEAD && hasContentLength) () + else + body.knownContentLength match { + case Some(contentLength) => + jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, contentLength) + case _ if !hasContentLength => + jHeaders.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) + case _ => + () + } new DefaultHttpResponse(HttpVersion.HTTP_1_1, jStatus, jHeaders) } - def fastEncode(response: Response, bytes: Array[Byte])(implicit unsafe: Unsafe): FullHttpResponse = { + def fastEncode(method: Method, response: Response, bytes: Array[Byte])(implicit unsafe: Unsafe): FullHttpResponse = { if (response.encoded eq null) { - response.encoded = doEncode(response, bytes) + response.encoded = doEncode(method, response, bytes) } response.encoded.asInstanceOf[FullHttpResponse] } - private def doEncode(response: Response, bytes: Array[Byte]): FullHttpResponse = { + private def doEncode(method: Method, response: Response, bytes: Array[Byte]): FullHttpResponse = { val jHeaders = Conversions.headersToNetty(response.headers) val status = response.status maybeAddDateHeader(jHeaders, status) @@ -67,8 +71,13 @@ private[zio] object NettyResponseEncoder { val jContent = Unpooled.wrappedBuffer(bytes) - // The content-length MUST match the length of the content we are sending, so we ignore any user-provided value - jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, bytes.length) + /* + * The content-length MUST match the length of the content we are sending, + * except for HEAD requests where the content-length must equal the length + * of the content we would have sent if this was a GET request. + */ + if (method == Method.HEAD && jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH)) () + else jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, bytes.length) new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, jStatus, jContent, jHeaders, EmptyHttpHeaders.INSTANCE) } diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala index 07a3a84448..d5b95e859a 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala @@ -82,13 +82,14 @@ private[zio] final case class ServerInboundHandler( val throwable = jReq.decoderResult().cause() attemptFastWrite( ctx, + Conversions.methodFromNetty(jReq.method()), Response.fromThrowable(throwable, runtime.getRef(ErrorResponseConfig.configRef)), ) releaseRequest() } else { val req = makeZioRequest(ctx, jReq) val exit = app(req) - if (attemptImmediateWrite(ctx, exit)) { + if (attemptImmediateWrite(ctx, req.method, exit)) { releaseRequest() } else { writeResponse(ctx, runtime, exit, req)(releaseRequest) @@ -140,11 +141,12 @@ private[zio] final case class ServerInboundHandler( private def attemptFastWrite( ctx: ChannelHandlerContext, + method: Method, response: Response, ): Boolean = { def fastEncode(response: Response, bytes: Array[Byte]) = { - val jResponse = NettyResponseEncoder.fastEncode(response, bytes) + val jResponse = NettyResponseEncoder.fastEncode(method, response, bytes) val djResponse = jResponse.retainedDuplicate() ctx.writeAndFlush(djResponse, ctx.voidPromise()) true @@ -173,7 +175,7 @@ private[zio] final case class ServerInboundHandler( upgradeToWebSocket(ctx, request, socketApp, runtime).as(None) case _ => ZIO.attempt { - val jResponse = NettyResponseEncoder.encode(response) + val jResponse = NettyResponseEncoder.encode(request.method, response) if (!jResponse.isInstanceOf[FullHttpResponse]) { @@ -197,11 +199,12 @@ private[zio] final case class ServerInboundHandler( private def attemptImmediateWrite( ctx: ChannelHandlerContext, + method: Method, exit: ZIO[Any, Response, Response], ): Boolean = { exit match { case Exit.Success(response) if response ne null => - attemptFastWrite(ctx, response) + attemptFastWrite(ctx, method, response) case _ => false } } @@ -310,12 +313,12 @@ private[zio] final case class ServerInboundHandler( NettyFutureExecutor.executed(ctx.channel().close()) def writeResponse(response: Response): Task[Unit] = - if (attemptFastWrite(ctx, response)) { + if (attemptFastWrite(ctx, req.method, response)) { Exit.unit } else { attemptFullWrite(ctx, runtime, response, req).foldCauseZIO( cause => { - attemptFastWrite(ctx, withDefaultErrorResponse(cause.squash)) + attemptFastWrite(ctx, req.method, withDefaultErrorResponse(cause.squash)) Exit.unit }, { diff --git a/zio-http/jvm/src/test/scala/zio/http/ServerSpec.scala b/zio-http/jvm/src/test/scala/zio/http/ServerSpec.scala index 036e86fd5e..8e6a41b965 100644 --- a/zio-http/jvm/src/test/scala/zio/http/ServerSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/ServerSpec.scala @@ -432,6 +432,31 @@ object ServerSpec extends HttpRunnableSpec { .contentLength .run() assertZIO(res)(isSome(equalTo(Header.ContentLength(10L)))) + } + + test("provided content-length is used for HEAD requests") { + val res = + Handler.ok + .addHeader(Header.ContentLength(4L)) + .sandbox + .toRoutes + .deploy + .contentLength + .run(method = Method.HEAD) + assertZIO(res)(isSome(equalTo(Header.ContentLength(4L)))) + } + + test("provided content-length is used for HEAD requests with stream body") { + // NOTE: Unlikely use-case, but just in case some 3rd party integration + // uses streams as a generalised way to provide content + val res = + Handler + .fromStream(ZStream.empty, 0L) + .addHeader(Header.ContentLength(4L)) + .sandbox + .toRoutes + .deploy + .contentLength + .run(method = Method.HEAD) + assertZIO(res)(isSome(equalTo(Header.ContentLength(4L)))) } }, ),