From ba05419aae117b5217e2e5285758c3446a06630a Mon Sep 17 00:00:00 2001 From: kyri-petrou <67301607+kyri-petrou@users.noreply.github.com> Date: Tue, 27 Aug 2024 17:34:59 +0300 Subject: [PATCH] Improve error handling for `collect` and `ignore` methods (#3056) * Improve error handling for `collect` and `ignore` methods * Revert unnecessary handling of error in `disableStreaming` --- .../zio/http/netty/NettyBodyWriter.scala | 2 +- .../zio/http/netty/NettyResponseEncoder.scala | 42 ++++++++-------- .../src/test/scala/zio/http/RequestSpec.scala | 50 ++++++++++++++++++- .../test/scala/zio/http/ResponseSpec.scala | 48 ++++++++++++++++++ .../shared/src/main/scala/zio/http/Body.scala | 29 +++++++++-- .../src/main/scala/zio/http/Request.scala | 30 +++++++---- .../src/main/scala/zio/http/Response.scala | 18 +++++-- .../src/main/scala/zio/http/ZClient.scala | 43 ++++++++-------- 8 files changed, 197 insertions(+), 65 deletions(-) diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/NettyBodyWriter.scala b/zio-http/jvm/src/main/scala/zio/http/netty/NettyBodyWriter.scala index 22bc3aba20..8c36f13623 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/NettyBodyWriter.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/NettyBodyWriter.scala @@ -126,7 +126,7 @@ object NettyBodyWriter { case ChunkBody(data, _) => writeArray(data.toArray, isLast = true) None - case EmptyBody => + case EmptyBody | ErrorBody(_) => ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) None } diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponseEncoder.scala b/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponseEncoder.scala index 6e3fdb9efd..4485005b36 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponseEncoder.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponseEncoder.scala @@ -29,29 +29,27 @@ import io.netty.handler.codec.http._ private[zio] object NettyResponseEncoder { private val dateHeaderCache = CachedDateHeader.default - def encode(response: Response)(implicit unsafe: Unsafe): HttpResponse = { - val body = response.body - if (body.isComplete) { - assert(body.isInstanceOf[Body.UnsafeBytes], "expected completed body to implement UnsafeBytes") - fastEncode(response, body.asInstanceOf[Body.UnsafeBytes].unsafeAsArray) - } else { - val status = response.status - val jHeaders = Conversions.headersToNetty(response.headers) - val jStatus = Conversions.statusToNetty(status) - maybeAddDateHeader(jHeaders, status) - - response.body.knownContentLength match { - case Some(contentLength) => - jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, contentLength) - case _ if jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH) => - () - case _ => - jHeaders.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) - } - - new DefaultHttpResponse(HttpVersion.HTTP_1_1, jStatus, jHeaders) + def encode(response: Response)(implicit unsafe: Unsafe): HttpResponse = + response.body match { + case body: Body.UnsafeBytes => + fastEncode(response, body.unsafeAsArray) + case body => + val status = response.status + val jHeaders = Conversions.headersToNetty(response.headers) + val jStatus = Conversions.statusToNetty(status) + maybeAddDateHeader(jHeaders, status) + + body.knownContentLength match { + case Some(contentLength) => + jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, contentLength) + case _ if jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH) => + () + case _ => + jHeaders.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED) + } + + new DefaultHttpResponse(HttpVersion.HTTP_1_1, jStatus, jHeaders) } - } def fastEncode(response: Response, bytes: Array[Byte])(implicit unsafe: Unsafe): FullHttpResponse = { if (response.encoded eq null) { diff --git a/zio-http/jvm/src/test/scala/zio/http/RequestSpec.scala b/zio-http/jvm/src/test/scala/zio/http/RequestSpec.scala index 498c157d14..840ba34f91 100644 --- a/zio-http/jvm/src/test/scala/zio/http/RequestSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/RequestSpec.scala @@ -16,8 +16,10 @@ package zio.http -import zio.Scope import zio.test._ +import zio.{Chunk, Ref, Scope} + +import zio.stream.ZStream object RequestSpec extends ZIOHttpSpec { @@ -96,6 +98,52 @@ object RequestSpec extends ZIOHttpSpec { val actual = Request.get("https://foo.com/bar") assertTrue(actual == expected) }, + suite("ignore")( + test("consumes the stream") { + for { + flag <- Ref.make(false) + stream = ZStream.succeed(1.toByte) ++ ZStream.fromZIO(flag.set(true).as(2.toByte)) + response = Request(body = Body.fromStreamChunked(stream)) + _ <- response.ignoreBody + v <- flag.get + } yield assertTrue(v) + }, + test("ignores failures when consuming the stream") { + for { + flag1 <- Ref.make(false) + flag2 <- Ref.make(false) + stream = ZStream.succeed(1.toByte) ++ + ZStream.fromZIO(flag1.set(true).as(2.toByte)) ++ + ZStream.fail(new Throwable("boom")) ++ + ZStream.fromZIO(flag1.set(true).as(2.toByte)) + response = Request(body = Body.fromStreamChunked(stream)) + _ <- response.ignoreBody + v1 <- flag1.get + v2 <- flag2.get + } yield assertTrue(v1, !v2) + }, + ), + suite("collect")( + test("materializes the stream") { + val stream = ZStream.succeed(1.toByte) ++ ZStream.succeed(2.toByte) + val response = Request(body = Body.fromStreamChunked(stream)) + for { + newResp <- response.collect + body = newResp.body + bytes <- body.asChunk + } yield assertTrue(body.isComplete, body.isInstanceOf[Body.UnsafeBytes], bytes == Chunk[Byte](1, 2)) + }, + test("failures are preserved") { + val err = new Throwable("boom") + val stream = ZStream.succeed(1.toByte) ++ ZStream.fail(err) ++ ZStream.succeed(2.toByte) + val response = Request(body = Body.fromStreamChunked(stream)) + for { + newResp <- response.collect + body = newResp.body + bytes <- body.asChunk.either + } yield assertTrue(body.isComplete, body.isInstanceOf[Body.ErrorBody], bytes == Left(err)) + }, + ), ) } diff --git a/zio-http/jvm/src/test/scala/zio/http/ResponseSpec.scala b/zio-http/jvm/src/test/scala/zio/http/ResponseSpec.scala index 984613c650..af31e67506 100644 --- a/zio-http/jvm/src/test/scala/zio/http/ResponseSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/ResponseSpec.scala @@ -20,6 +20,8 @@ import zio._ import zio.test.Assertion._ import zio.test._ +import zio.stream.ZStream + object ResponseSpec extends ZIOHttpSpec { def extractStatus(response: Response): Status = response.status private val location: URL = URL.decode("www.google.com").toOption.get @@ -101,5 +103,51 @@ object ResponseSpec extends ZIOHttpSpec { assertZIO(http.runZIO(()))(equalTo(ok)) }, ), + suite("ignore")( + test("consumes the stream") { + for { + flag <- Ref.make(false) + stream = ZStream.succeed(1.toByte) ++ ZStream.fromZIO(flag.set(true).as(2.toByte)) + response = Response(body = Body.fromStreamChunked(stream)) + _ <- response.ignoreBody + v <- flag.get + } yield assertTrue(v) + }, + test("ignores failures when consuming the stream") { + for { + flag1 <- Ref.make(false) + flag2 <- Ref.make(false) + stream = ZStream.succeed(1.toByte) ++ + ZStream.fromZIO(flag1.set(true).as(2.toByte)) ++ + ZStream.fail(new Throwable("boom")) ++ + ZStream.fromZIO(flag1.set(true).as(2.toByte)) + response = Response(body = Body.fromStreamChunked(stream)) + _ <- response.ignoreBody + v1 <- flag1.get + v2 <- flag2.get + } yield assertTrue(v1, !v2) + }, + ), + suite("collect")( + test("materializes the stream") { + val stream = ZStream.succeed(1.toByte) ++ ZStream.succeed(2.toByte) + val response = Response(body = Body.fromStreamChunked(stream)) + for { + newResp <- response.collect + body = newResp.body + bytes <- body.asChunk + } yield assertTrue(body.isComplete, body.isInstanceOf[Body.UnsafeBytes], bytes == Chunk[Byte](1, 2)) + }, + test("failures are preserved") { + val err = new Throwable("boom") + val stream = ZStream.succeed(1.toByte) ++ ZStream.fail(err) ++ ZStream.succeed(2.toByte) + val response = Response(body = Body.fromStreamChunked(stream)) + for { + newResp <- response.collect + body = newResp.body + bytes <- body.asChunk.either + } yield assertTrue(body.isComplete, body.isInstanceOf[Body.ErrorBody], bytes == Left(err)) + }, + ), ) } diff --git a/zio-http/shared/src/main/scala/zio/http/Body.scala b/zio-http/shared/src/main/scala/zio/http/Body.scala index 732b69bf48..aed21f7fd6 100644 --- a/zio-http/shared/src/main/scala/zio/http/Body.scala +++ b/zio-http/shared/src/main/scala/zio/http/Body.scala @@ -182,8 +182,8 @@ trait Body { self => /** * Materializes the body of the request into memory */ - def materialize(implicit trace: Trace): Task[Body] = - asArray.map(Body.ArrayBody(_, self.contentType)) + def materialize(implicit trace: Trace): UIO[Body] = + asArray.foldCause(Body.ErrorBody(_), Body.ArrayBody(_, self.contentType)) /** * Returns the media type for this Body @@ -448,14 +448,14 @@ object Body { private[zio] trait UnsafeBytes extends Body { self => private[zio] def unsafeAsArray(implicit unsafe: Unsafe): Array[Byte] - final override def materialize(implicit trace: Trace): Task[Body] = Exit.succeed(self) + final override def materialize(implicit trace: Trace): UIO[Body] = Exit.succeed(self) } /** * Helper to create empty Body */ - private[zio] object EmptyBody extends Body with UnsafeBytes { + private[zio] case object EmptyBody extends Body with UnsafeBytes { override def asArray(implicit trace: Trace): Task[Array[Byte]] = zioEmptyArray @@ -476,6 +476,27 @@ object Body { override def knownContentLength: Option[Long] = Some(0L) } + private[zio] final case class ErrorBody(cause: Cause[Throwable]) extends Body { + + override def asArray(implicit trace: Trace): Task[Array[Byte]] = Exit.failCause(cause) + + override def asChunk(implicit trace: Trace): Task[Chunk[Byte]] = Exit.failCause(cause) + + override def asStream(implicit trace: Trace): ZStream[Any, Throwable, Byte] = ZStream.failCause(cause) + + override def isComplete: Boolean = true + + override def isEmpty: Boolean = true + + override def toString: String = "Body.failed" + + override def contentType(newContentType: Body.ContentType): Body = this + + override def contentType: Option[Body.ContentType] = None + + override def knownContentLength: Option[Long] = Some(0L) + } + private[zio] final case class ChunkBody( data: Chunk[Byte], override val contentType: Option[Body.ContentType] = None, diff --git a/zio-http/shared/src/main/scala/zio/http/Request.scala b/zio-http/shared/src/main/scala/zio/http/Request.scala index baa1af4898..71b78d2dca 100644 --- a/zio-http/shared/src/main/scala/zio/http/Request.scala +++ b/zio-http/shared/src/main/scala/zio/http/Request.scala @@ -68,14 +68,17 @@ final case class Request( def addTrailingSlash: Request = self.copy(url = self.url.addTrailingSlash) /** - * Collects the potentially streaming body of the request into a single chunk. + * Collects the potentially streaming body of the response into a single + * chunk. + * + * Any errors that occur from the collection of the body will be caught and + * propagated to the Body */ - def collect(implicit trace: Trace): ZIO[Any, Throwable, Request] = - if (self.body.isComplete) ZIO.succeed(self) - else - self.body.asChunk.map { bytes => - self.copy(body = Body.fromChunk(bytes)) - } + def collect(implicit trace: Trace): ZIO[Any, Nothing, Request] = + self.body.materialize.map { b => + if (b eq self.body) self + else self.copy(body = b) + } def dropLeadingSlash: Request = updateURL(_.dropLeadingSlash) @@ -84,9 +87,16 @@ final case class Request( */ def dropTrailingSlash: Request = updateURL(_.dropTrailingSlash) - /** Consumes the streaming body fully and then drops it */ - def ignoreBody(implicit trace: Trace): ZIO[Any, Throwable, Request] = - self.collect.map(_.copy(body = Body.empty)) + /** + * Consumes the streaming body fully and then discards it while also ignoring + * any failures + */ + def ignoreBody(implicit trace: Trace): ZIO[Any, Nothing, Request] = { + val out = self.copy(body = Body.empty) + val body0 = self.body + if (body0.isComplete) Exit.succeed(out) + else body0.asStream.runDrain.ignore.as(out) + } def patch(p: Request.Patch): Request = self.copy(headers = self.headers ++ p.addHeaders, url = self.url.addQueryParams(p.addQueryParams)) diff --git a/zio-http/shared/src/main/scala/zio/http/Response.scala b/zio-http/shared/src/main/scala/zio/http/Response.scala index e3b2c8db68..cd05ea37ab 100644 --- a/zio-http/shared/src/main/scala/zio/http/Response.scala +++ b/zio-http/shared/src/main/scala/zio/http/Response.scala @@ -50,16 +50,26 @@ final case class Response( /** * Collects the potentially streaming body of the response into a single * chunk. + * + * Any errors that occur from the collection of the body will be caught and + * propagated to the Body */ - def collect(implicit trace: Trace): ZIO[Any, Throwable, Response] = + def collect(implicit trace: Trace): ZIO[Any, Nothing, Response] = self.body.materialize.map { b => if (b eq self.body) self else self.copy(body = b) } - /** Consumes the streaming body fully and then drops it */ - def ignoreBody(implicit trace: Trace): ZIO[Any, Throwable, Response] = - self.collect.map(_.copy(body = Body.empty)) + /** + * Consumes the streaming body fully and then discards it while also ignoring + * any failures + */ + def ignoreBody(implicit trace: Trace): ZIO[Any, Nothing, Response] = { + val out = self.copy(body = Body.empty) + val body0 = self.body + if (body0.isComplete) Exit.succeed(out) + else body0.asStream.runDrain.ignore.as(out) + } def patch(p: Response.Patch)(implicit trace: Trace): Response = p.apply(self) diff --git a/zio-http/shared/src/main/scala/zio/http/ZClient.scala b/zio-http/shared/src/main/scala/zio/http/ZClient.scala index fcae6cfc7e..a4d965dbfa 100644 --- a/zio-http/shared/src/main/scala/zio/http/ZClient.scala +++ b/zio-http/shared/src/main/scala/zio/http/ZClient.scala @@ -83,7 +83,7 @@ final case class ZClient[-Env, ReqEnv, -In, +Err, +Out]( def batched( request: Request, - )(implicit trace: Trace, ev1: ReqEnv =:= Scope, ev2: Err <:< Throwable, ev3: Body <:< In): ZIO[Env, Throwable, Out] = + )(implicit trace: Trace, ev1: ReqEnv =:= Scope, ev3: Body <:< In): ZIO[Env, Err, Out] = batched.apply(request) /** @@ -93,10 +93,10 @@ final case class ZClient[-Env, ReqEnv, -In, +Err, +Out]( * response is streaming, it will await for it to be fully collected before * resuming. */ - def batched(implicit ev1: ReqEnv =:= Scope, ev2: Err <:< Throwable): ZClient[Env, Any, In, Throwable, Out] = - self.transform[Env, Any, In, Throwable, Out]( - self.bodyEncoder.widenError[Throwable], - self.bodyDecoder.widenError[Throwable], + def batched(implicit ev1: ReqEnv =:= Scope): ZClient[Env, Any, In, Err, Out] = + self.transform[Env, Any, In, Err, Out]( + self.bodyEncoder, + self.bodyDecoder, self.driver.disableStreaming, ) @@ -119,7 +119,7 @@ final case class ZClient[-Env, ReqEnv, -In, +Err, +Out]( refineOrDie { case e if !f(e) => e } @deprecated("use `batched` instead", since = "3.0.0") - def disableStreaming(implicit ev1: ReqEnv =:= Scope, ev2: Err <:< Throwable): ZClient[Env, Any, In, Throwable, Out] = + def disableStreaming(implicit ev1: ReqEnv =:= Scope): ZClient[Env, Any, In, Err, Out] = batched def get(suffix: String)(implicit ev: Body <:< In, trace: Trace): ZIO[Env & ReqEnv, Err, Out] = @@ -226,16 +226,14 @@ final case class ZClient[-Env, ReqEnv, -In, +Err, +Out]( * Executes an HTTP request and transforms the response into a `ZStream` using * the provided function */ - def stream[R, A](request: Request)(f: Response => ZStream[R, Throwable, A])(implicit + def stream[R, E0 >: Err, A](request: Request)(f: Out => ZStream[R, E0, A])(implicit trace: Trace, ev1: Body <:< In, - ev2: Out <:< Response, - ev3: Err <:< Throwable, - ev4: ReqEnv =:= Scope, - ): ZStream[R & Env, Throwable, A] = ZStream.unwrapScoped[R & Env] { + ev2: ReqEnv =:= Scope, + ): ZStream[R & Env, E0, A] = ZStream.unwrapScoped[R & Env] { self .request(request) - .asInstanceOf[ZIO[R & Env & Scope, Throwable, Response]] + .asInstanceOf[ZIO[R & Env & Scope, Err, Out]] .fold(ZStream.fail(_), f) } @@ -326,8 +324,8 @@ object ZClient extends ZClientPlatformSpecific { */ def streamingWith[R, A](request: Request)(f: Response => ZStream[R, Throwable, A])(implicit trace: Trace, - ): ZStream[Client & R, Throwable, A] = - ZStream.serviceWithStream[Client](_.stream[R, A](request)(f)) + ): ZStream[R & Client, Throwable, A] = + ZStream.serviceWithStream[Client](_.stream(request)(f)) def socket[R](socketApp: WebSocketApp[R])(implicit trace: Trace): ZIO[R with Client & Scope, Throwable, Response] = ZIO.serviceWithZIO[Client](c => c.socket(socketApp)) @@ -407,10 +405,10 @@ object ZClient extends ZClientPlatformSpecific { final def apply(request: Request)(implicit trace: Trace): ZIO[Env & ReqEnv, Err, Response] = self.request(request.version, request.method, request.url, request.headers, request.body, None, None) - def disableStreaming(implicit ev1: ReqEnv =:= Scope, ev2: Err <:< Throwable): Driver[Env, Any, Throwable] = - new Driver[Env, Any, Throwable] { + def disableStreaming(implicit ev1: ReqEnv =:= Scope): Driver[Env, Any, Err] = + new Driver[Env, Any, Err] { - private val self0 = self.asInstanceOf[Driver[Env, Scope, Throwable]] + private val self0 = self.asInstanceOf[Driver[Env, Scope, Err]] override def request( version: Version, @@ -420,17 +418,16 @@ object ZClient extends ZClientPlatformSpecific { body: Body, sslConfig: Option[ClientSSLConfig], proxy: Option[Proxy], - )(implicit trace: Trace): ZIO[Env, Throwable, Response] = - ZIO - .scoped[Env] { - self0.request(version, method, url, headers, body, sslConfig, proxy).flatMap(_.collect) - } + )(implicit trace: Trace): ZIO[Env, Err, Response] = + ZIO.scoped[Env] { + self0.request(version, method, url, headers, body, sslConfig, proxy).flatMap(_.collect) + } // This should never be possible to invoke unless the user unsafely casted the Driver environment override def socket[Env1 <: Env](version: Version, url: URL, headers: Headers, app: WebSocketApp[Env1])(implicit trace: Trace, ev: Any =:= Scope, - ): ZIO[Env1 & Any, Throwable, Response] = + ): ZIO[Env1 & Any, Err, Response] = ZIO.die(new UnsupportedOperationException("Streaming is disabled")) }