Skip to content

Commit

Permalink
Automatically validate request body using Schema.validate (#2360)
Browse files Browse the repository at this point in the history
feat(endpoint): automatically validate body

Uses the `Schema.validate` method from `zio-schema` to validate the body
of a request based on the schema of the input type. If schema validation
fails, `HttpCodecError.InvalidEntity` is passed to the error channel,
indicating that the body is well-formed, but contains invalid values in
one or more fields.  A `wrap` method is included for `InvalidEntity`
which simplifies the process of converting `Schema.validate`'s
output (`Chunk[ValidationError]`) to `InvalidEntity` in a consistent
way.

resolves #2274
  • Loading branch information
williamareynolds authored Aug 12, 2023
1 parent fa76fe1 commit 37fe777
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 16 deletions.
34 changes: 23 additions & 11 deletions zio-http/src/main/scala/zio/http/codec/HttpCodecError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package zio.http.codec

import scala.util.control.NoStackTrace

import zio.Cause
import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Cause, Chunk}

import zio.schema.validation.ValidationError

import zio.http.{Path, Status}

Expand All @@ -28,34 +30,44 @@ sealed trait HttpCodecError extends Exception with NoStackTrace {
def message: String
}
object HttpCodecError {
final case class MissingHeader(headerName: String) extends HttpCodecError {
final case class MissingHeader(headerName: String) extends HttpCodecError {
def message = s"Missing header $headerName"
}
final case class MalformedMethod(expected: zio.http.Method, actual: zio.http.Method) extends HttpCodecError {
final case class MalformedMethod(expected: zio.http.Method, actual: zio.http.Method) extends HttpCodecError {
def message = s"Expected $expected but found $actual"
}
final case class PathTooShort(path: Path, textCodec: TextCodec[_]) extends HttpCodecError {
final case class PathTooShort(path: Path, textCodec: TextCodec[_]) extends HttpCodecError {
def message = s"Expected to find ${textCodec} but found pre-mature end to the path ${path}"
}
final case class MalformedPath(path: Path, pathCodec: PathCodec[_], error: String) extends HttpCodecError {
final case class MalformedPath(path: Path, pathCodec: PathCodec[_], error: String) extends HttpCodecError {
def message = s"Malformed path ${path} failed to decode using $pathCodec: $error"
}
final case class MalformedStatus(expected: Status, actual: Status) extends HttpCodecError {
final case class MalformedStatus(expected: Status, actual: Status) extends HttpCodecError {
def message = s"Expected status code ${expected} but found ${actual}"
}
final case class MalformedHeader(headerName: String, textCodec: TextCodec[_]) extends HttpCodecError {
final case class MalformedHeader(headerName: String, textCodec: TextCodec[_]) extends HttpCodecError {
def message = s"Malformed header $headerName failed to decode using $textCodec"
}
final case class MissingQueryParam(queryParamName: String) extends HttpCodecError {
final case class MissingQueryParam(queryParamName: String) extends HttpCodecError {
def message = s"Missing query parameter $queryParamName"
}
final case class MalformedQueryParam(queryParamName: String, textCodec: TextCodec[_]) extends HttpCodecError {
final case class MalformedQueryParam(queryParamName: String, textCodec: TextCodec[_]) extends HttpCodecError {
def message = s"Malformed query parameter $queryParamName failed to decode using $textCodec"
}
final case class MalformedBody(details: String, cause: Option[Throwable] = None) extends HttpCodecError {
final case class MalformedBody(details: String, cause: Option[Throwable] = None) extends HttpCodecError {
def message = s"Malformed request body failed to decode: $details"
}
final case class CustomError(message: String) extends HttpCodecError
final case class InvalidEntity(details: String, cause: Chunk[ValidationError] = Chunk.empty) extends HttpCodecError {
def message = s"A well-formed entity failed validation: $details"
}
object InvalidEntity {
def wrap(errors: Chunk[ValidationError]): InvalidEntity =
InvalidEntity(
errors.foldLeft("")((acc, err) => acc + err.message + "\n"),
errors,
)
}
final case class CustomError(message: String) extends HttpCodecError

def isHttpCodecError(cause: Cause[Any]): Boolean = {
!cause.isFailure && cause.defects.forall(e => e.isInstanceOf[HttpCodecError])
Expand Down
22 changes: 17 additions & 5 deletions zio-http/src/main/scala/zio/http/codec/internal/BodyCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package zio.http.codec.internal
import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

import zio.stream.ZStream
import zio.stream.{ZPipeline, ZStream}

import zio.schema._
import zio.schema.codec.BinaryCodec

import zio.http.codec.HttpCodecError
import zio.http.{Body, FormField, MediaType}

/**
Expand Down Expand Up @@ -92,13 +93,12 @@ private[internal] object BodyCodec {

final case class Single[A](schema: Schema[A], mediaType: Option[MediaType], name: Option[String])
extends BodyCodec[A] {
def decodeFromBody(body: Body, codec: BinaryCodec[A])(implicit trace: Trace): IO[Throwable, A] = {
def decodeFromBody(body: Body, codec: BinaryCodec[A])(implicit trace: Trace): IO[Throwable, A] =
if (schema == Schema[Unit]) ZIO.unit.asInstanceOf[IO[Throwable, A]]
else
body.asChunk.flatMap { chunk =>
ZIO.fromEither(codec.decode(chunk))
}
}
}.flatMap(validateZIO(schema))

def encodeToBody(value: A, codec: BinaryCodec[A])(implicit trace: Trace): Body =
Body.fromChunk(codec.encode(value))
Expand All @@ -111,11 +111,23 @@ private[internal] object BodyCodec {
def decodeFromBody(body: Body, codec: BinaryCodec[E])(implicit
trace: Trace,
): IO[Throwable, ZStream[Any, Nothing, E]] =
ZIO.succeed((body.asStream >>> codec.streamDecoder).orDie)
ZIO.succeed((body.asStream >>> codec.streamDecoder >>> validateStream(schema)).orDie)

def encodeToBody(value: ZStream[Any, Nothing, E], codec: BinaryCodec[E])(implicit trace: Trace): Body =
Body.fromStream(value >>> codec.streamEncoder)

type Element = E
}

private[internal] def validateZIO[A](schema: Schema[A])(e: A)(implicit trace: Trace): ZIO[Any, HttpCodecError, A] = {
val errors = Schema.validate(e)(schema)
if (errors.isEmpty) ZIO.succeed(e)
else ZIO.fail(HttpCodecError.InvalidEntity.wrap(errors))
}

private[internal] def validateStream[E](schema: Schema[E])(implicit
trace: Trace,
): ZPipeline[Any, HttpCodecError, E, E] =
ZPipeline.mapZIO(validateZIO(schema))

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package zio.http.codec.internal

import zio._
import zio.test._

import zio.stream.{ZSink, ZStream}

import zio.schema._
import zio.schema.annotation.validate
import zio.schema.validation.Validation

import zio.http.codec.HttpCodecError

object BodyCodecSpec extends ZIOSpecDefault {
import BodyCodec._

case class User(
@validate(Validation.greaterThan(0))
id: Int,
@validate(Validation.minLength(2) && Validation.maxLength(64))
name: String,
)
object User {
val schema: Schema[User] = DeriveSchema.gen[User]
}

def spec = suite("BodyCodecSpec")(
suite("validateZIO")(
test("returns a valid entity") {
val valid = User(12, "zio")

for {
actual <- validateZIO(User.schema)(valid)
} yield assertTrue(valid == actual)
} +
test("fails with HttpCodecError for invalid entity") {
val invalid = User(-4, "z")
val validated = BodyCodec.validateZIO(User.schema)(invalid)

assertZIO(validated.exit)(Assertion.failsWithA[HttpCodecError.InvalidEntity])
},
),
suite("validateStream")(
test("returns all valid entities") {
val users = Chunk(
User(1, "Will"),
User(2, "Ammon"),
)
val valids = ZStream.fromChunk(users)

for {
validatedUsers <- valids.via(validateStream(User.schema)).runCollect
} yield assertTrue(validatedUsers == users)
},
test("fails with HttpCodecError for invalid entity") {
val users = Chunk(
User(1, "Will"),
User(-5, "Ammon"),
)
val invalid = ZStream.fromChunk(users)

for {
validatedUsers <- invalid.via(validateStream(User.schema)).runCollect.exit
} yield assert(validatedUsers)(Assertion.failsWithA[HttpCodecError.InvalidEntity])
},
),
)
}
40 changes: 40 additions & 0 deletions zio-http/src/test/scala/zio/http/endpoint/EndpointSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import zio.test._

import zio.stream.ZStream

import zio.schema.annotation.validate
import zio.schema.codec.{DecodeError, JsonCodec}
import zio.schema.validation.Validation
import zio.schema.{DeriveSchema, Schema, StandardType}

import zio.http.Header.ContentType
Expand All @@ -38,6 +40,11 @@ object EndpointSpec extends ZIOHttpSpec {

case class NewPost(value: String)

case class User(
@validate(Validation.greaterThan(0))
id: Int,
)

def spec = suite("EndpointSpec")(
suite("handler")(
test("simple request") {
Expand Down Expand Up @@ -547,6 +554,39 @@ object EndpointSpec extends ZIOHttpSpec {
body2 == "{\"message\":\"something went wrong\"}",
)
},
test("validation occurs automatically on schema") {

implicit val schema: Schema[User] = DeriveSchema.gen[User]

val routes =
Endpoint(POST / "users")
.in[User]
.out[String]
.implement {
Handler.fromFunctionZIO { _ =>
ZIO.succeed("User ID is greater than 0")
}
}
.handleErrorCause { case cause =>
Response.text("Caught: " + cause.defects.headOption.fold("no known cause")(d => d.getMessage))
}

val request1 = Request.post(URL.decode("/users").toOption.get, Body.fromString("""{"id":0}"""))
val request2 = Request.post(URL.decode("/users").toOption.get, Body.fromString("""{"id":1}"""))

for {
response1 <- routes.toHttpApp.runZIO(request1)
body1 <- response1.body.asString.orDie

response2 <- routes.toHttpApp.runZIO(request2)
body2 <- response2.body.asString.orDie
} yield assertTrue(
extractStatus(response1) == Status.BadRequest,
body1 == "",
extractStatus(response2) == Status.Ok,
body2 == "\"User ID is greater than 0\"",
)
},
),
suite("byte stream input/output")(
test("responding with a byte stream") {
Expand Down

0 comments on commit 37fe777

Please sign in to comment.