diff --git a/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala b/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala index 76cd528bb7..9d35b0754f 100644 --- a/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala +++ b/zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala @@ -18,8 +18,10 @@ package zio.http.codec import scala.util.control.NoStackTrace -import zio.Cause import zio.stacktracer.TracingImplicits.disableAutoTrace +import zio.{Cause, Chunk} + +import zio.schema.validation.ValidationError import zio.http.{Path, Status} @@ -28,34 +30,44 @@ sealed trait HttpCodecError extends Exception with NoStackTrace { def message: String } object HttpCodecError { - final case class MissingHeader(headerName: String) extends HttpCodecError { + final case class MissingHeader(headerName: String) extends HttpCodecError { def message = s"Missing header $headerName" } - final case class MalformedMethod(expected: zio.http.Method, actual: zio.http.Method) extends HttpCodecError { + final case class MalformedMethod(expected: zio.http.Method, actual: zio.http.Method) extends HttpCodecError { def message = s"Expected $expected but found $actual" } - final case class PathTooShort(path: Path, textCodec: TextCodec[_]) extends HttpCodecError { + final case class PathTooShort(path: Path, textCodec: TextCodec[_]) extends HttpCodecError { def message = s"Expected to find ${textCodec} but found pre-mature end to the path ${path}" } - final case class MalformedPath(path: Path, pathCodec: PathCodec[_], error: String) extends HttpCodecError { + final case class MalformedPath(path: Path, pathCodec: PathCodec[_], error: String) extends HttpCodecError { def message = s"Malformed path ${path} failed to decode using $pathCodec: $error" } - final case class MalformedStatus(expected: Status, actual: Status) extends HttpCodecError { + final case class MalformedStatus(expected: Status, actual: Status) extends HttpCodecError { def message = s"Expected status code ${expected} but found ${actual}" } - final case class MalformedHeader(headerName: String, textCodec: TextCodec[_]) extends HttpCodecError { + final case class MalformedHeader(headerName: String, textCodec: TextCodec[_]) extends HttpCodecError { def message = s"Malformed header $headerName failed to decode using $textCodec" } - final case class MissingQueryParam(queryParamName: String) extends HttpCodecError { + final case class MissingQueryParam(queryParamName: String) extends HttpCodecError { def message = s"Missing query parameter $queryParamName" } - final case class MalformedQueryParam(queryParamName: String, textCodec: TextCodec[_]) extends HttpCodecError { + final case class MalformedQueryParam(queryParamName: String, textCodec: TextCodec[_]) extends HttpCodecError { def message = s"Malformed query parameter $queryParamName failed to decode using $textCodec" } - final case class MalformedBody(details: String, cause: Option[Throwable] = None) extends HttpCodecError { + final case class MalformedBody(details: String, cause: Option[Throwable] = None) extends HttpCodecError { def message = s"Malformed request body failed to decode: $details" } - final case class CustomError(message: String) extends HttpCodecError + final case class InvalidEntity(details: String, cause: Chunk[ValidationError] = Chunk.empty) extends HttpCodecError { + def message = s"A well-formed entity failed validation: $details" + } + object InvalidEntity { + def wrap(errors: Chunk[ValidationError]): InvalidEntity = + InvalidEntity( + errors.foldLeft("")((acc, err) => acc + err.message + "\n"), + errors, + ) + } + final case class CustomError(message: String) extends HttpCodecError def isHttpCodecError(cause: Cause[Any]): Boolean = { !cause.isFailure && cause.defects.forall(e => e.isInstanceOf[HttpCodecError]) diff --git a/zio-http/src/main/scala/zio/http/codec/internal/BodyCodec.scala b/zio-http/src/main/scala/zio/http/codec/internal/BodyCodec.scala index 28b79fc4b9..9dfb7bc0f7 100644 --- a/zio-http/src/main/scala/zio/http/codec/internal/BodyCodec.scala +++ b/zio-http/src/main/scala/zio/http/codec/internal/BodyCodec.scala @@ -19,11 +19,12 @@ package zio.http.codec.internal import zio._ import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.stream.ZStream +import zio.stream.{ZPipeline, ZStream} import zio.schema._ import zio.schema.codec.BinaryCodec +import zio.http.codec.HttpCodecError import zio.http.{Body, FormField, MediaType} /** @@ -92,13 +93,12 @@ private[internal] object BodyCodec { final case class Single[A](schema: Schema[A], mediaType: Option[MediaType], name: Option[String]) extends BodyCodec[A] { - def decodeFromBody(body: Body, codec: BinaryCodec[A])(implicit trace: Trace): IO[Throwable, A] = { + def decodeFromBody(body: Body, codec: BinaryCodec[A])(implicit trace: Trace): IO[Throwable, A] = if (schema == Schema[Unit]) ZIO.unit.asInstanceOf[IO[Throwable, A]] else body.asChunk.flatMap { chunk => ZIO.fromEither(codec.decode(chunk)) - } - } + }.flatMap(validateZIO(schema)) def encodeToBody(value: A, codec: BinaryCodec[A])(implicit trace: Trace): Body = Body.fromChunk(codec.encode(value)) @@ -111,11 +111,23 @@ private[internal] object BodyCodec { def decodeFromBody(body: Body, codec: BinaryCodec[E])(implicit trace: Trace, ): IO[Throwable, ZStream[Any, Nothing, E]] = - ZIO.succeed((body.asStream >>> codec.streamDecoder).orDie) + ZIO.succeed((body.asStream >>> codec.streamDecoder >>> validateStream(schema)).orDie) def encodeToBody(value: ZStream[Any, Nothing, E], codec: BinaryCodec[E])(implicit trace: Trace): Body = Body.fromStream(value >>> codec.streamEncoder) type Element = E } + + private[internal] def validateZIO[A](schema: Schema[A])(e: A)(implicit trace: Trace): ZIO[Any, HttpCodecError, A] = { + val errors = Schema.validate(e)(schema) + if (errors.isEmpty) ZIO.succeed(e) + else ZIO.fail(HttpCodecError.InvalidEntity.wrap(errors)) + } + + private[internal] def validateStream[E](schema: Schema[E])(implicit + trace: Trace, + ): ZPipeline[Any, HttpCodecError, E, E] = + ZPipeline.mapZIO(validateZIO(schema)) + } diff --git a/zio-http/src/test/scala/zio/http/codec/internal/BodyCodecSpec.scala b/zio-http/src/test/scala/zio/http/codec/internal/BodyCodecSpec.scala new file mode 100644 index 0000000000..fa8f095d3d --- /dev/null +++ b/zio-http/src/test/scala/zio/http/codec/internal/BodyCodecSpec.scala @@ -0,0 +1,68 @@ +package zio.http.codec.internal + +import zio._ +import zio.test._ + +import zio.stream.{ZSink, ZStream} + +import zio.schema._ +import zio.schema.annotation.validate +import zio.schema.validation.Validation + +import zio.http.codec.HttpCodecError + +object BodyCodecSpec extends ZIOSpecDefault { + import BodyCodec._ + + case class User( + @validate(Validation.greaterThan(0)) + id: Int, + @validate(Validation.minLength(2) && Validation.maxLength(64)) + name: String, + ) + object User { + val schema: Schema[User] = DeriveSchema.gen[User] + } + + def spec = suite("BodyCodecSpec")( + suite("validateZIO")( + test("returns a valid entity") { + val valid = User(12, "zio") + + for { + actual <- validateZIO(User.schema)(valid) + } yield assertTrue(valid == actual) + } + + test("fails with HttpCodecError for invalid entity") { + val invalid = User(-4, "z") + val validated = BodyCodec.validateZIO(User.schema)(invalid) + + assertZIO(validated.exit)(Assertion.failsWithA[HttpCodecError.InvalidEntity]) + }, + ), + suite("validateStream")( + test("returns all valid entities") { + val users = Chunk( + User(1, "Will"), + User(2, "Ammon"), + ) + val valids = ZStream.fromChunk(users) + + for { + validatedUsers <- valids.via(validateStream(User.schema)).runCollect + } yield assertTrue(validatedUsers == users) + }, + test("fails with HttpCodecError for invalid entity") { + val users = Chunk( + User(1, "Will"), + User(-5, "Ammon"), + ) + val invalid = ZStream.fromChunk(users) + + for { + validatedUsers <- invalid.via(validateStream(User.schema)).runCollect.exit + } yield assert(validatedUsers)(Assertion.failsWithA[HttpCodecError.InvalidEntity]) + }, + ), + ) +} diff --git a/zio-http/src/test/scala/zio/http/endpoint/EndpointSpec.scala b/zio-http/src/test/scala/zio/http/endpoint/EndpointSpec.scala index 224d546a35..b5dfd0e69f 100644 --- a/zio-http/src/test/scala/zio/http/endpoint/EndpointSpec.scala +++ b/zio-http/src/test/scala/zio/http/endpoint/EndpointSpec.scala @@ -23,7 +23,9 @@ import zio.test._ import zio.stream.ZStream +import zio.schema.annotation.validate import zio.schema.codec.{DecodeError, JsonCodec} +import zio.schema.validation.Validation import zio.schema.{DeriveSchema, Schema, StandardType} import zio.http.Header.ContentType @@ -38,6 +40,11 @@ object EndpointSpec extends ZIOHttpSpec { case class NewPost(value: String) + case class User( + @validate(Validation.greaterThan(0)) + id: Int, + ) + def spec = suite("EndpointSpec")( suite("handler")( test("simple request") { @@ -547,6 +554,39 @@ object EndpointSpec extends ZIOHttpSpec { body2 == "{\"message\":\"something went wrong\"}", ) }, + test("validation occurs automatically on schema") { + + implicit val schema: Schema[User] = DeriveSchema.gen[User] + + val routes = + Endpoint(POST / "users") + .in[User] + .out[String] + .implement { + Handler.fromFunctionZIO { _ => + ZIO.succeed("User ID is greater than 0") + } + } + .handleErrorCause { case cause => + Response.text("Caught: " + cause.defects.headOption.fold("no known cause")(d => d.getMessage)) + } + + val request1 = Request.post(URL.decode("/users").toOption.get, Body.fromString("""{"id":0}""")) + val request2 = Request.post(URL.decode("/users").toOption.get, Body.fromString("""{"id":1}""")) + + for { + response1 <- routes.toHttpApp.runZIO(request1) + body1 <- response1.body.asString.orDie + + response2 <- routes.toHttpApp.runZIO(request2) + body2 <- response2.body.asString.orDie + } yield assertTrue( + extractStatus(response1) == Status.BadRequest, + body1 == "", + extractStatus(response2) == Status.Ok, + body2 == "\"User ID is greater than 0\"", + ) + }, ), suite("byte stream input/output")( test("responding with a byte stream") {