Skip to content

Commit

Permalink
Improve error handling for collect and ignore methods (#3056)
Browse files Browse the repository at this point in the history
* Improve error handling for `collect` and `ignore` methods

* Revert unnecessary handling of error in `disableStreaming`
  • Loading branch information
kyri-petrou authored Aug 27, 2024
1 parent 8bd449c commit ba05419
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ object NettyBodyWriter {
case ChunkBody(data, _) =>
writeArray(data.toArray, isLast = true)
None
case EmptyBody =>
case EmptyBody | ErrorBody(_) =>
ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT)
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,27 @@ import io.netty.handler.codec.http._
private[zio] object NettyResponseEncoder {
private val dateHeaderCache = CachedDateHeader.default

def encode(response: Response)(implicit unsafe: Unsafe): HttpResponse = {
val body = response.body
if (body.isComplete) {
assert(body.isInstanceOf[Body.UnsafeBytes], "expected completed body to implement UnsafeBytes")
fastEncode(response, body.asInstanceOf[Body.UnsafeBytes].unsafeAsArray)
} else {
val status = response.status
val jHeaders = Conversions.headersToNetty(response.headers)
val jStatus = Conversions.statusToNetty(status)
maybeAddDateHeader(jHeaders, status)

response.body.knownContentLength match {
case Some(contentLength) =>
jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, contentLength)
case _ if jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH) =>
()
case _ =>
jHeaders.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED)
}

new DefaultHttpResponse(HttpVersion.HTTP_1_1, jStatus, jHeaders)
def encode(response: Response)(implicit unsafe: Unsafe): HttpResponse =
response.body match {
case body: Body.UnsafeBytes =>
fastEncode(response, body.unsafeAsArray)
case body =>
val status = response.status
val jHeaders = Conversions.headersToNetty(response.headers)
val jStatus = Conversions.statusToNetty(status)
maybeAddDateHeader(jHeaders, status)

body.knownContentLength match {
case Some(contentLength) =>
jHeaders.set(HttpHeaderNames.CONTENT_LENGTH, contentLength)
case _ if jHeaders.contains(HttpHeaderNames.CONTENT_LENGTH) =>
()
case _ =>
jHeaders.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED)
}

new DefaultHttpResponse(HttpVersion.HTTP_1_1, jStatus, jHeaders)
}
}

def fastEncode(response: Response, bytes: Array[Byte])(implicit unsafe: Unsafe): FullHttpResponse = {
if (response.encoded eq null) {
Expand Down
50 changes: 49 additions & 1 deletion zio-http/jvm/src/test/scala/zio/http/RequestSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

package zio.http

import zio.Scope
import zio.test._
import zio.{Chunk, Ref, Scope}

import zio.stream.ZStream

object RequestSpec extends ZIOHttpSpec {

Expand Down Expand Up @@ -96,6 +98,52 @@ object RequestSpec extends ZIOHttpSpec {
val actual = Request.get("https://foo.com/bar")
assertTrue(actual == expected)
},
suite("ignore")(
test("consumes the stream") {
for {
flag <- Ref.make(false)
stream = ZStream.succeed(1.toByte) ++ ZStream.fromZIO(flag.set(true).as(2.toByte))
response = Request(body = Body.fromStreamChunked(stream))
_ <- response.ignoreBody
v <- flag.get
} yield assertTrue(v)
},
test("ignores failures when consuming the stream") {
for {
flag1 <- Ref.make(false)
flag2 <- Ref.make(false)
stream = ZStream.succeed(1.toByte) ++
ZStream.fromZIO(flag1.set(true).as(2.toByte)) ++
ZStream.fail(new Throwable("boom")) ++
ZStream.fromZIO(flag1.set(true).as(2.toByte))
response = Request(body = Body.fromStreamChunked(stream))
_ <- response.ignoreBody
v1 <- flag1.get
v2 <- flag2.get
} yield assertTrue(v1, !v2)
},
),
suite("collect")(
test("materializes the stream") {
val stream = ZStream.succeed(1.toByte) ++ ZStream.succeed(2.toByte)
val response = Request(body = Body.fromStreamChunked(stream))
for {
newResp <- response.collect
body = newResp.body
bytes <- body.asChunk
} yield assertTrue(body.isComplete, body.isInstanceOf[Body.UnsafeBytes], bytes == Chunk[Byte](1, 2))
},
test("failures are preserved") {
val err = new Throwable("boom")
val stream = ZStream.succeed(1.toByte) ++ ZStream.fail(err) ++ ZStream.succeed(2.toByte)
val response = Request(body = Body.fromStreamChunked(stream))
for {
newResp <- response.collect
body = newResp.body
bytes <- body.asChunk.either
} yield assertTrue(body.isComplete, body.isInstanceOf[Body.ErrorBody], bytes == Left(err))
},
),
)

}
48 changes: 48 additions & 0 deletions zio-http/jvm/src/test/scala/zio/http/ResponseSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import zio._
import zio.test.Assertion._
import zio.test._

import zio.stream.ZStream

object ResponseSpec extends ZIOHttpSpec {
def extractStatus(response: Response): Status = response.status
private val location: URL = URL.decode("www.google.com").toOption.get
Expand Down Expand Up @@ -101,5 +103,51 @@ object ResponseSpec extends ZIOHttpSpec {
assertZIO(http.runZIO(()))(equalTo(ok))
},
),
suite("ignore")(
test("consumes the stream") {
for {
flag <- Ref.make(false)
stream = ZStream.succeed(1.toByte) ++ ZStream.fromZIO(flag.set(true).as(2.toByte))
response = Response(body = Body.fromStreamChunked(stream))
_ <- response.ignoreBody
v <- flag.get
} yield assertTrue(v)
},
test("ignores failures when consuming the stream") {
for {
flag1 <- Ref.make(false)
flag2 <- Ref.make(false)
stream = ZStream.succeed(1.toByte) ++
ZStream.fromZIO(flag1.set(true).as(2.toByte)) ++
ZStream.fail(new Throwable("boom")) ++
ZStream.fromZIO(flag1.set(true).as(2.toByte))
response = Response(body = Body.fromStreamChunked(stream))
_ <- response.ignoreBody
v1 <- flag1.get
v2 <- flag2.get
} yield assertTrue(v1, !v2)
},
),
suite("collect")(
test("materializes the stream") {
val stream = ZStream.succeed(1.toByte) ++ ZStream.succeed(2.toByte)
val response = Response(body = Body.fromStreamChunked(stream))
for {
newResp <- response.collect
body = newResp.body
bytes <- body.asChunk
} yield assertTrue(body.isComplete, body.isInstanceOf[Body.UnsafeBytes], bytes == Chunk[Byte](1, 2))
},
test("failures are preserved") {
val err = new Throwable("boom")
val stream = ZStream.succeed(1.toByte) ++ ZStream.fail(err) ++ ZStream.succeed(2.toByte)
val response = Response(body = Body.fromStreamChunked(stream))
for {
newResp <- response.collect
body = newResp.body
bytes <- body.asChunk.either
} yield assertTrue(body.isComplete, body.isInstanceOf[Body.ErrorBody], bytes == Left(err))
},
),
)
}
29 changes: 25 additions & 4 deletions zio-http/shared/src/main/scala/zio/http/Body.scala
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ trait Body { self =>
/**
* Materializes the body of the request into memory
*/
def materialize(implicit trace: Trace): Task[Body] =
asArray.map(Body.ArrayBody(_, self.contentType))
def materialize(implicit trace: Trace): UIO[Body] =
asArray.foldCause(Body.ErrorBody(_), Body.ArrayBody(_, self.contentType))

/**
* Returns the media type for this Body
Expand Down Expand Up @@ -448,14 +448,14 @@ object Body {
private[zio] trait UnsafeBytes extends Body { self =>
private[zio] def unsafeAsArray(implicit unsafe: Unsafe): Array[Byte]

final override def materialize(implicit trace: Trace): Task[Body] = Exit.succeed(self)
final override def materialize(implicit trace: Trace): UIO[Body] = Exit.succeed(self)
}

/**
* Helper to create empty Body
*/

private[zio] object EmptyBody extends Body with UnsafeBytes {
private[zio] case object EmptyBody extends Body with UnsafeBytes {

override def asArray(implicit trace: Trace): Task[Array[Byte]] = zioEmptyArray

Expand All @@ -476,6 +476,27 @@ object Body {
override def knownContentLength: Option[Long] = Some(0L)
}

private[zio] final case class ErrorBody(cause: Cause[Throwable]) extends Body {

override def asArray(implicit trace: Trace): Task[Array[Byte]] = Exit.failCause(cause)

override def asChunk(implicit trace: Trace): Task[Chunk[Byte]] = Exit.failCause(cause)

override def asStream(implicit trace: Trace): ZStream[Any, Throwable, Byte] = ZStream.failCause(cause)

override def isComplete: Boolean = true

override def isEmpty: Boolean = true

override def toString: String = "Body.failed"

override def contentType(newContentType: Body.ContentType): Body = this

override def contentType: Option[Body.ContentType] = None

override def knownContentLength: Option[Long] = Some(0L)
}

private[zio] final case class ChunkBody(
data: Chunk[Byte],
override val contentType: Option[Body.ContentType] = None,
Expand Down
30 changes: 20 additions & 10 deletions zio-http/shared/src/main/scala/zio/http/Request.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ final case class Request(
def addTrailingSlash: Request = self.copy(url = self.url.addTrailingSlash)

/**
* Collects the potentially streaming body of the request into a single chunk.
* Collects the potentially streaming body of the response into a single
* chunk.
*
* Any errors that occur from the collection of the body will be caught and
* propagated to the Body
*/
def collect(implicit trace: Trace): ZIO[Any, Throwable, Request] =
if (self.body.isComplete) ZIO.succeed(self)
else
self.body.asChunk.map { bytes =>
self.copy(body = Body.fromChunk(bytes))
}
def collect(implicit trace: Trace): ZIO[Any, Nothing, Request] =
self.body.materialize.map { b =>
if (b eq self.body) self
else self.copy(body = b)
}

def dropLeadingSlash: Request = updateURL(_.dropLeadingSlash)

Expand All @@ -84,9 +87,16 @@ final case class Request(
*/
def dropTrailingSlash: Request = updateURL(_.dropTrailingSlash)

/** Consumes the streaming body fully and then drops it */
def ignoreBody(implicit trace: Trace): ZIO[Any, Throwable, Request] =
self.collect.map(_.copy(body = Body.empty))
/**
* Consumes the streaming body fully and then discards it while also ignoring
* any failures
*/
def ignoreBody(implicit trace: Trace): ZIO[Any, Nothing, Request] = {
val out = self.copy(body = Body.empty)
val body0 = self.body
if (body0.isComplete) Exit.succeed(out)
else body0.asStream.runDrain.ignore.as(out)
}

def patch(p: Request.Patch): Request =
self.copy(headers = self.headers ++ p.addHeaders, url = self.url.addQueryParams(p.addQueryParams))
Expand Down
18 changes: 14 additions & 4 deletions zio-http/shared/src/main/scala/zio/http/Response.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,26 @@ final case class Response(
/**
* Collects the potentially streaming body of the response into a single
* chunk.
*
* Any errors that occur from the collection of the body will be caught and
* propagated to the Body
*/
def collect(implicit trace: Trace): ZIO[Any, Throwable, Response] =
def collect(implicit trace: Trace): ZIO[Any, Nothing, Response] =
self.body.materialize.map { b =>
if (b eq self.body) self
else self.copy(body = b)
}

/** Consumes the streaming body fully and then drops it */
def ignoreBody(implicit trace: Trace): ZIO[Any, Throwable, Response] =
self.collect.map(_.copy(body = Body.empty))
/**
* Consumes the streaming body fully and then discards it while also ignoring
* any failures
*/
def ignoreBody(implicit trace: Trace): ZIO[Any, Nothing, Response] = {
val out = self.copy(body = Body.empty)
val body0 = self.body
if (body0.isComplete) Exit.succeed(out)
else body0.asStream.runDrain.ignore.as(out)
}

def patch(p: Response.Patch)(implicit trace: Trace): Response = p.apply(self)

Expand Down
Loading

0 comments on commit ba05419

Please sign in to comment.