From da8d8be32dd955643d78a3820d4a9d92e3044198 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Femen=C3=ADa?= <131800808+pablf@users.noreply.github.com> Date: Fri, 23 Aug 2024 12:17:05 +0200 Subject: [PATCH] Simplify Encoderdecoder (#2998) Co-authored-by: pablf --- .../src/main/scala/zio/http/FormField.scala | 15 + .../zio/http/codec/internal/BodyCodec.scala | 77 +- .../http/codec/internal/EncoderDecoder.scala | 953 ++++++++---------- 3 files changed, 488 insertions(+), 557 deletions(-) diff --git a/zio-http/shared/src/main/scala/zio/http/FormField.scala b/zio-http/shared/src/main/scala/zio/http/FormField.scala index d41cadea4a..2c49de6a5e 100644 --- a/zio-http/shared/src/main/scala/zio/http/FormField.scala +++ b/zio-http/shared/src/main/scala/zio/http/FormField.scala @@ -77,6 +77,21 @@ sealed trait FormField { ZIO.succeed(Chunk.fromArray(value.getBytes(Charsets.Utf8))) } + /** + * Gets the value of this form field as a chunk of bytes. If it is a text + * field, the value gets encoded as an UTF-8 byte stream. + */ + final def asStream(implicit trace: Trace): ZStream[Any, Nothing, Byte] = this match { + case FormField.Text(_, value, _, _) => + ZStream.fromChunk(Chunk.fromArray(value.getBytes(Charsets.Utf8))) + case FormField.Binary(_, value, _, _, _) => + ZStream.fromChunk(value) + case FormField.StreamingBinary(_, _, _, _, stream) => + stream + case FormField.Simple(_, value) => + ZStream.fromChunk(Chunk.fromArray(value.getBytes(Charsets.Utf8))) + } + def name(newName: String): FormField = this match { case FormField.Binary(_, data, contentType, transferEncoding, filename) => FormField.Binary(newName, data, contentType, transferEncoding, filename) diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/BodyCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/BodyCodec.scala index 72225d7691..314c1ef103 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/BodyCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/BodyCodec.scala @@ -24,7 +24,7 @@ import zio.schema._ import zio.http.Header.Accept.MediaTypeWithQFactor import zio.http.codec.{BinaryCodecWithSchema, HttpCodecError, HttpContentCodec} -import zio.http.{Body, MediaType} +import zio.http.{Body, FormField, MediaType} /** * A BodyCodec encapsulates the logic necessary to both encode and decode bodies @@ -39,11 +39,21 @@ private[http] sealed trait BodyCodec[A] { self => */ type Element + /** + * Attempts to decode the `A` from a FormField using the given codec. + */ + def decodeFromField(field: FormField)(implicit trace: Trace): IO[Throwable, A] + /** * Attempts to decode the `A` from a body using the given codec. */ def decodeFromBody(body: Body)(implicit trace: Trace): IO[Throwable, A] + /** + * Encodes the `A` to a FormField in the given codec. + */ + def encodeToField(value: A, mediaTypes: Chunk[MediaTypeWithQFactor], name: String)(implicit trace: Trace): FormField + /** * Encodes the `A` to a body in the given codec. */ @@ -74,8 +84,15 @@ private[http] object BodyCodec { case object Empty extends BodyCodec[Unit] { type Element = Unit + def decodeFromField(field: FormField)(implicit trace: Trace): IO[Throwable, Unit] = ZIO.unit + def decodeFromBody(body: Body)(implicit trace: Trace): IO[Nothing, Unit] = ZIO.unit + def encodeToField(value: Unit, mediaTypes: Chunk[MediaTypeWithQFactor], name: String)(implicit + trace: Trace, + ): FormField = + throw HttpCodecError.CustomError("UnsupportedEncodingType", s"Unit can't be encoded to a FormField") + def encodeToBody(value: Unit, mediaTypes: Chunk[MediaTypeWithQFactor])(implicit trace: Trace): Body = Body.empty def schema: Schema[Unit] = Schema[Unit] @@ -90,6 +107,19 @@ private[http] object BodyCodec { def mediaType(accepted: Chunk[MediaTypeWithQFactor]): Option[MediaType] = Some(codec.chooseFirstOrDefault(accepted)._1) + def decodeFromField(field: FormField)(implicit trace: Trace): IO[Throwable, A] = { + val codec0 = codec + .lookup(field.contentType) + .toRight(HttpCodecError.CustomError("UnsupportedMediaType", s"MediaType: ${field.contentType}")) + codec0 match { + case Left(error) => ZIO.fail(error) + case Right(BinaryCodecWithSchema(_, schema)) if schema == Schema[Unit] => + ZIO.unit.asInstanceOf[IO[Throwable, A]] + case Right(BinaryCodecWithSchema(codec, schema)) => + field.asChunk.flatMap { chunk => ZIO.fromEither(codec.decode(chunk)) }.flatMap(validateZIO(schema)) + } + } + def decodeFromBody(body: Body)(implicit trace: Trace): IO[Throwable, A] = { val codec0 = codecForBody(codec, body) codec0 match { @@ -101,8 +131,27 @@ private[http] object BodyCodec { } } + def encodeToField(value: A, mediaTypes: Chunk[MediaTypeWithQFactor], name: String)(implicit + trace: Trace, + ): FormField = { + val (mediaType, BinaryCodecWithSchema(codec0, _)) = codec.chooseFirstOrDefault(mediaTypes) + if (mediaType.binary) { + FormField.binaryField( + name, + codec0.encode(value), + mediaType, + ) + } else { + FormField.textField( + name, + codec0.encode(value).asString, + mediaType, + ) + } + } + def encodeToBody(value: A, mediaTypes: Chunk[MediaTypeWithQFactor])(implicit trace: Trace): Body = { - val (mediaType, bc @ BinaryCodecWithSchema(_, _)) = codec.chooseFirst(mediaTypes) + val (mediaType, bc @ BinaryCodecWithSchema(_, _)) = codec.chooseFirstOrDefault(mediaTypes) Body.fromChunk(bc.codec.encode(value)).contentType(mediaType) } @@ -115,20 +164,40 @@ private[http] object BodyCodec { def mediaType(accepted: Chunk[MediaTypeWithQFactor]): Option[MediaType] = Some(codec.chooseFirstOrDefault(accepted)._1) + def decodeFromField(field: FormField)(implicit trace: Trace): IO[Throwable, ZStream[Any, Nothing, E]] = + ZIO.fromEither { + codec + .lookup(field.contentType) + .toRight(HttpCodecError.CustomError("UnsupportedMediaType", s"MediaType: ${field.contentType}")) + .map { case BinaryCodecWithSchema(codec, schema) => + (field.asStream >>> codec.streamDecoder >>> validateStream(schema)).orDie + } + } + def decodeFromBody(body: Body)(implicit trace: Trace, - ): IO[Throwable, ZStream[Any, Nothing, E]] = { + ): IO[Throwable, ZStream[Any, Nothing, E]] = ZIO.fromEither { codecForBody(codec, body).map { case bc @ BinaryCodecWithSchema(_, schema) => (body.asStream >>> bc.codec.streamDecoder >>> validateStream(schema)).orDie } } + + def encodeToField(value: ZStream[Any, Nothing, E], mediaTypes: Chunk[MediaTypeWithQFactor], name: String)(implicit + trace: Trace, + ): FormField = { + val (mediaType, BinaryCodecWithSchema(codec0, _)) = codec.chooseFirstOrDefault(mediaTypes) + FormField.streamingBinaryField( + name, + value >>> codec0.streamEncoder, + mediaType, + ) } def encodeToBody(value: ZStream[Any, Nothing, E], mediaTypes: Chunk[MediaTypeWithQFactor])(implicit trace: Trace, ): Body = { - val (mediaType, bc @ BinaryCodecWithSchema(_, _)) = codec.chooseFirst(mediaTypes) + val (mediaType, bc @ BinaryCodecWithSchema(_, _)) = codec.chooseFirstOrDefault(mediaTypes) Body.fromStreamChunked(value >>> bc.codec.streamEncoder).contentType(mediaType) } diff --git a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala index d5cd1e7344..2a8bebdaed 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/internal/EncoderDecoder.scala @@ -152,81 +152,14 @@ private[codec] object EncoderDecoder { private val flattened: AtomizedCodecs = AtomizedCodecs.flatten(httpCodec) - private val formFieldEncoders: Chunk[(String, Any) => FormField] = - flattened.content.map { bodyCodec => (name: String, value: Any) => - { - val (mediaType, codec) = { - bodyCodec match { - case BodyCodec.Single(codec, _) => - codec.choices.headOption - case BodyCodec.Multiple(codec, _) => - codec.choices.headOption - case _ => - None - } - }.getOrElse { - throw HttpCodecError.CustomError( - "CodecNotFound", - s"Cannot encode multipart/form-data field $name: no codec found", - ) - } - - if (mediaType.binary) { - FormField.binaryField( - name, - codec.codec.asInstanceOf[BinaryCodec[Any]].encode(value), - mediaType, - ) - } else { - FormField.textField( - name, - codec.codec.asInstanceOf[BinaryCodec[Any]].encode(value).asString, - mediaType, - ) - } - } - } - - implicit val trace: Trace = Trace.empty - - private val formFieldDecoders: Chunk[FormField => IO[Throwable, Any]] = - flattened.content.map { bodyCodec => (field: FormField) => - { - val mediaType = field.contentType - val codec = { - bodyCodec match { - case BodyCodec.Empty => - None - case BodyCodec.Single(codec, _) => - codec.lookup(mediaType) - case BodyCodec.Multiple(codec, _) => - codec.lookup(mediaType) - } - }.getOrElse { throw HttpCodecError.UnsupportedContentType(mediaType.fullType) } - - field.asChunk.flatMap(chunk => ZIO.fromEither(codec.codec.decode(chunk))) - - } - } - private val formBoundary = Boundary("----zio-http-boundary-D4792A5C-93E0-43B5-9A1F-48E38FDE5714") - private val indexByName = flattened.content.zipWithIndex.map { case (codec, idx) => + implicit val trace: Trace = Trace.empty + private lazy val formBoundary = Boundary("----zio-http-boundary-D4792A5C-93E0-43B5-9A1F-48E38FDE5714") + private lazy val indexByName = flattened.content.zipWithIndex.map { case (codec, idx) => codec.name.getOrElse("field" + idx.toString) -> idx }.toMap - private val nameByIndex = indexByName.map(_.swap) - private val isByteStream = - if (flattened.content.length == 1) { - isByteStreamBody(flattened.content(0)) - } else { - false - } - private val onlyTheLastFieldIsStreaming = - if (flattened.content.size > 1) { - !flattened.content.init.exists(isByteStreamBody) && isByteStreamBody(flattened.content.last) - } else { - false - } + private lazy val nameByIndex = indexByName.map(_.swap) - def decode(url: URL, status: Status, method: Method, headers: Headers, body: Body)(implicit + final def decode(url: URL, status: Status, method: Method, headers: Headers, body: Body)(implicit trace: Trace, ): Task[Value] = ZIO.suspendSucceed { val inputsBuilder = flattened.makeInputsBuilder() @@ -256,170 +189,177 @@ private[codec] object EncoderDecoder { f(URL(path, queryParams = query), status, method, headers0, body) } - private def decodePaths(path: Path, inputs: Array[Any]): Unit = { - assert(flattened.path.length == inputs.length) - - var i = 0 - - while (i < inputs.length) { - val pathCodec = flattened.path(i).erase - - val decoded = pathCodec.decode(path) - - inputs(i) = decoded match { - case Left(error) => - throw HttpCodecError.MalformedPath(path, pathCodec, error) - - case Right(value) => value - } - - i = i + 1 + private def genericDecode[A, Codec]( + a: A, + codecs: Chunk[Codec], + inputs: Array[Any], + decode: (Codec, A) => Any, + ): Unit = { + for (i <- 0 until inputs.length) { + val codec = codecs(i) + inputs(i) = decode(codec, a) } } - private def decodeQuery(queryParams: QueryParams, inputs: Array[Any]): Unit = { - var i = 0 - val queries = flattened.query - while (i < queries.length) { - val query = queries(i).erase - - val isOptional = query.isOptional - - query.queryType match { - case QueryType.Primitive(name, BinaryCodecWithSchema(codec, schema)) => - val count = queryParams.valueCount(name) - val hasParam = queryParams.hasQueryParam(name) - if (!hasParam && isOptional) inputs(i) = None - else if (!hasParam) throw HttpCodecError.MissingQueryParam(name) - else if (count != 1) throw HttpCodecError.InvalidQueryParamCount(name, 1, count) - else { - val decoded = codec.decode( - Chunk.fromArray(queryParams.unsafeQueryParam(name).getBytes(Charsets.Utf8)), - ) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) - case Right(value) => value - } - val validationErrors = schema.validate(decoded)(schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - inputs(i) = + private def decodePaths(path: Path, inputs: Array[Any]): Unit = + genericDecode[Path, PathCodec[_]]( + path, + flattened.path, + inputs, + (codec, path) => { + codec.erase.decode(path) match { + case Left(error) => throw HttpCodecError.MalformedPath(path, codec, error) + case Right(value) => value + } + }, + ) + + private def decodeQuery(queryParams: QueryParams, inputs: Array[Any]): Unit = + genericDecode[QueryParams, HttpCodec.Query[_, _]]( + queryParams, + flattened.query, + inputs, + (codec, queryParams) => { + val query = codec.erase + val isOptional = query.isOptional + query.queryType match { + case QueryType.Primitive(name, BinaryCodecWithSchema(codec, schema)) => + val count = queryParams.valueCount(name) + val hasParam = queryParams.hasQueryParam(name) + if (!hasParam && isOptional) None + else if (!hasParam) throw HttpCodecError.MissingQueryParam(name) + else if (count != 1) throw HttpCodecError.InvalidQueryParamCount(name, 1, count) + else { + val decoded = codec.decode( + Chunk.fromArray(queryParams.unsafeQueryParam(name).getBytes(Charsets.Utf8)), + ) match { + case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) + case Right(value) => value + } + val validationErrors = schema.validate(decoded)(schema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) if (isOptional && decoded == None && emptyStringIsValue(schema.asInstanceOf[Schema.Optional[_]].schema)) Some("") else decoded - } - case c @ QueryType.Collection(_, QueryType.Primitive(name, BinaryCodecWithSchema(codec, _)), optional) => - if (!queryParams.hasQueryParam(name)) { - if (!optional) inputs(i) = c.toCollection(Chunk.empty) - else inputs(i) = None - } else { - val values = queryParams.queryParams(name) - val decoded = c.toCollection { - values.map { value => - codec.decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) - case Right(value) => value + } + case c @ QueryType.Collection(_, QueryType.Primitive(name, BinaryCodecWithSchema(codec, _)), optional) => + if (!queryParams.hasQueryParam(name)) { + if (!optional) c.toCollection(Chunk.empty) + else None + } else { + val values = queryParams.queryParams(name) + val decoded = c.toCollection { + values.map { value => + codec.decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { + case Left(error) => throw HttpCodecError.MalformedQueryParam(name, error) + case Right(value) => value + } } } - } - val erasedSchema = c.colSchema.asInstanceOf[Schema[Any]] - val validationErrors = erasedSchema.validate(decoded)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - inputs(i) = + val erasedSchema = c.colSchema.asInstanceOf[Schema[Any]] + val validationErrors = erasedSchema.validate(decoded)(erasedSchema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) if (optional) Some(decoded) else decoded - } - case query @ QueryType.Record(recordSchema) => - val hasAllParams = query.fieldAndCodecs.forall { case (field, _) => - queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined - } - if (!hasAllParams && recordSchema.isInstanceOf[Schema.Optional[_]]) inputs(i) = None - else if (!hasAllParams && isOptional) { - inputs(i) = recordSchema.defaultValue match { - case Left(err) => - throw new IllegalStateException(s"Cannot compute default value for $recordSchema. Error was: $err") - case Right(value) => value } - } else if (!hasAllParams) throw HttpCodecError.MissingQueryParams { - query.fieldAndCodecs.collect { - case (field, _) - if !(queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined) => - field.name + case query @ QueryType.Record(recordSchema) => + val hasAllParams = query.fieldAndCodecs.forall { case (field, _) => + queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined } - } - else { - val decoded = query.fieldAndCodecs.map { - case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => - if (!queryParams.hasQueryParam(field.name) && field.defaultValue.nonEmpty) field.defaultValue.get - else { - val values = queryParams.queryParams(field.name) - val decoded = values.map { value => - codec.codec.decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) - case Right(value) => value + if (!hasAllParams && recordSchema.isInstanceOf[Schema.Optional[_]]) None + else if (!hasAllParams && isOptional) { + recordSchema.defaultValue match { + case Left(err) => + throw new IllegalStateException(s"Cannot compute default value for $recordSchema. Error was: $err") + case Right(value) => value + } + } else if (!hasAllParams) throw HttpCodecError.MissingQueryParams { + query.fieldAndCodecs.collect { + case (field, _) + if !(queryParams.hasQueryParam(field.name) || field.optional || field.defaultValue.isDefined) => + field.name + } + } + else { + val decoded = query.fieldAndCodecs.map { + case (field, codec) if field.schema.isInstanceOf[Schema.Collection[_, _]] => + if (!queryParams.hasQueryParam(field.name) && field.defaultValue.nonEmpty) field.defaultValue.get + else { + val values = queryParams.queryParams(field.name) + val decoded = values.map { value => + codec.codec.decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { + case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) + case Right(value) => value + } } + val decodedCollection = + field.schema match { + case s @ Schema.Sequence(_, fromChunk, _, _, _) => + val collection = fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) + val erasedSchema = s.asInstanceOf[Schema[Any]] + val validationErrors = erasedSchema.validate(collection)(erasedSchema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) + collection + case s @ Schema.Set(_, _) => + val collection = decoded.toSet[Any] + val erasedSchema = s.asInstanceOf[Schema.Set[Any]] + val validationErrors = erasedSchema.validate(collection)(erasedSchema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) + collection + case _ => throw new IllegalStateException("Only Sequence and Set are supported.") + } + decodedCollection } - val decodedCollection = - field.schema match { - case s @ Schema.Sequence(_, fromChunk, _, _, _) => - val collection = fromChunk.asInstanceOf[Chunk[Any] => Any](decoded) - val erasedSchema = s.asInstanceOf[Schema[Any]] - val validationErrors = erasedSchema.validate(collection)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - collection - case s @ Schema.Set(_, _) => - val collection = decoded.toSet[Any] - val erasedSchema = s.asInstanceOf[Schema.Set[Any]] - val validationErrors = erasedSchema.validate(collection)(erasedSchema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - collection - case _ => throw new IllegalStateException("Only Sequence and Set are supported.") + case (field, codec) => + val value = queryParams.queryParamOrElse(field.name, null) + val decoded = { + if (value == null) field.defaultValue.get + else { + codec.codec.decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { + case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) + case Right(value) => value + } + } + } + val validationErrors = codec.schema.validate(decoded)(codec.schema) + if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) + decoded + } + if (recordSchema.isInstanceOf[Schema.Optional[_]]) { + val schema = recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] + val constructed = schema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw HttpCodecError.MalformedQueryParam( + s"${schema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + schema.validate(value)(schema) match { + case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) + case _ => Some(value) } - decodedCollection } - case (field, codec) => - val value = queryParams.queryParamOrElse(field.name, null) - val decoded = { - if (value == null) field.defaultValue.get - else { - codec.codec.decode(Chunk.fromArray(value.getBytes(Charsets.Utf8))) match { - case Left(error) => throw HttpCodecError.MalformedQueryParam(field.name, error) - case Right(value) => value + } else { + val schema = recordSchema.asInstanceOf[Schema.Record[Any]] + val constructed = schema.construct(decoded)(Unsafe.unsafe) + constructed match { + case Left(value) => + throw HttpCodecError.MalformedQueryParam( + s"${schema.id}", + DecodeError.ReadError(Cause.empty, value), + ) + case Right(value) => + schema.validate(value)(schema) match { + case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) + case _ => value } - } } - val validationErrors = codec.schema.validate(decoded)(codec.schema) - if (validationErrors.nonEmpty) throw HttpCodecError.InvalidEntity.wrap(validationErrors) - decoded - } - if (recordSchema.isInstanceOf[Schema.Optional[_]]) { - val schema = recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] - val constructed = schema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam(s"${schema.id}", DecodeError.ReadError(Cause.empty, value)) - case Right(value) => - schema.validate(value)(schema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => inputs(i) = Some(value) - } - } - } else { - val schema = recordSchema.asInstanceOf[Schema.Record[Any]] - val constructed = schema.construct(decoded)(Unsafe.unsafe) - constructed match { - case Left(value) => - throw HttpCodecError.MalformedQueryParam(s"${schema.id}", DecodeError.ReadError(Cause.empty, value)) - case Right(value) => - schema.validate(value)(schema) match { - case errors if errors.nonEmpty => throw HttpCodecError.InvalidEntity.wrap(errors) - case _ => inputs(i) = value - } } } - } - } - i = i + 1 - } - } + } + }, + ) private def emptyStringIsValue(schema: Schema[_]): Boolean = schema.asInstanceOf[Schema.Primitive[_]].standardType match { @@ -430,380 +370,287 @@ private[codec] object EncoderDecoder { case _ => false } - private def decodeStatus(status: Status, inputs: Array[Any]): Unit = { - var i = 0 - while (i < inputs.length) { - inputs(i) = flattened.status(i) match { - case _: SimpleCodec.Unspecified[_] => status - case SimpleCodec.Specified(expected) => - if (status != expected) - throw HttpCodecError.MalformedStatus(expected, status) - else () - } - - i = i + 1 - } - } - - private def decodeMethod(method: Method, inputs: Array[Any]): Unit = { - var i = 0 - while (i < inputs.length) { - inputs(i) = flattened.method(i) match { - case _: SimpleCodec.Unspecified[_] => method - case SimpleCodec.Specified(expected) => - if (method != expected) throw HttpCodecError.MalformedMethod(expected, method) - else () - } - - i = i + 1 - } - } - - private def decodeHeaders(headers: Headers, inputs: Array[Any]): Unit = { - var i = 0 - while (i < flattened.header.length) { - val header = flattened.header(i).erase - - headers.get(header.name) match { - case Some(value) => - inputs(i) = header.textCodec - .decode(value) - .getOrElse(throw HttpCodecError.MalformedHeader(header.name, header.textCodec)) + private def decodeHeaders(headers: Headers, inputs: Array[Any]): Unit = + genericDecode[Headers, HttpCodec.Header[_]]( + headers, + flattened.header, + inputs, + (codec, headers) => + headers.get(codec.name) match { + case Some(value) => + codec.erase.textCodec + .decode(value) + .getOrElse(throw HttpCodecError.MalformedHeader(codec.name, codec.textCodec)) + + case None => + throw HttpCodecError.MissingHeader(codec.name) + }, + ) - case None => - throw HttpCodecError.MissingHeader(header.name) - } + private def decodeStatus(status: Status, inputs: Array[Any]): Unit = + genericDecode[Status, SimpleCodec[Status, _]]( + status, + flattened.status, + inputs, + (codec, status) => + codec match { + case SimpleCodec.Specified(expected) if expected != status => + throw HttpCodecError.MalformedStatus(expected, status) + case _: SimpleCodec.Unspecified[_] => status + case _ => () + }, + ) - i = i + 1 - } - } + private def decodeMethod(method: Method, inputs: Array[Any]): Unit = + genericDecode[Method, SimpleCodec[Method, _]]( + method, + flattened.method, + inputs, + (codec, method) => + codec match { + case SimpleCodec.Specified(expected) if expected != method => + throw HttpCodecError.MalformedMethod(expected, method) + case _: SimpleCodec.Unspecified[_] => method + case _ => () + }, + ) private def decodeBody(body: Body, inputs: Array[Any])(implicit trace: Trace, ): Task[Unit] = { - if (isByteStream) { - ZIO.attempt(inputs(0) = body.asStream.orDie) - } else if (flattened.content.isEmpty) { - ZIO.unit - } else if (flattened.content.size == 1) { - val bodyCodec = flattened.content(0) - bodyCodec - .decodeFromBody(body) - .mapBoth( - { err => HttpCodecError.MalformedBody(err.getMessage(), Some(err)) }, - result => inputs(0) = result, - ) + val codecs = flattened.content + + if (inputs.length < 2) { + // non multi-part + codecs.headOption.map { codec => + codec + .decodeFromBody(body) + .mapBoth( + { err => HttpCodecError.MalformedBody(err.getMessage(), Some(err)) }, + result => inputs(0) = result, + ) + }.getOrElse(ZIO.unit) } else { - body.asMultipartFormStream.flatMap { form => - if (onlyTheLastFieldIsStreaming) - processStreamingForm(form, inputs) - else - collectAndProcessForm(form, inputs) - }.zipRight { - ZIO.attempt { - var idx = 0 - while (idx < inputs.length) { - if (inputs(idx) == null) - throw HttpCodecError.MalformedBody( - s"Missing multipart/form-data field (${Try(nameByIndex(idx))}", - ) - idx += 1 - } - } - } + // multi-part + decodeForm(body.asMultipartFormStream, inputs) *> check(inputs) } } - private def processStreamingForm(form: StreamingForm, inputs: Array[Any])(implicit - trace: Trace, - ): ZIO[Any, Throwable, Unit] = - Promise.make[Throwable, Unit].flatMap { ready => - form.fields.mapZIO { field => - indexByName.get(field.name) match { - case Some(idx) => - (flattened.content(idx) match { - case BodyCodec.Multiple(codec, _) if codec.defaultMediaType.binary => - field match { - case FormField.Binary(_, data, _, _, _) => - inputs(idx) = ZStream.fromChunk(data) - case FormField.StreamingBinary(_, _, _, _, data) => - inputs(idx) = data - case FormField.Text(_, value, _, _) => - inputs(idx) = ZStream.fromChunk(Chunk.fromArray(value.getBytes(Charsets.Utf8))) - case FormField.Simple(_, value) => - inputs(idx) = ZStream.fromChunk(Chunk.fromArray(value.getBytes(Charsets.Utf8))) - } - ZIO.unit - case _ => - formFieldDecoders(idx)(field).map { result => inputs(idx) = result } - }) - .zipRight( - ready - .succeed(()) - .unless( - inputs.exists(_ == null), - ), // Marking as ready so the handler can start consuming the streaming field before this stream ends - ) - case None => - ready.fail(HttpCodecError.MalformedBody(s"Unexpected multipart/form-data field: ${field.name}")) - } - }.runDrain - .intoPromise(ready) - .forkDaemon - .zipRight( - ready.await, - ) - } - - private def collectAndProcessForm(form: StreamingForm, inputs: Array[Any])(implicit - trace: Trace, - ): ZIO[Any, Throwable, Unit] = - form.collectAll.flatMap { collectedForm => + private def decodeForm(form: Task[StreamingForm], inputs: Array[Any]): ZIO[Any, Throwable, Unit] = + form.flatMap(_.collectAll).flatMap { collectedForm => ZIO.foreachDiscard(collectedForm.formData) { field => - indexByName.get(field.name) match { - case Some(idx) => - flattened.content(idx) match { - case BodyCodec.Multiple(codec, _) if codec.defaultMediaType.binary => - field match { - case FormField.Binary(_, data, _, _, _) => - inputs(idx) = ZStream.fromChunk(data) - case FormField.StreamingBinary(_, _, _, _, data) => - inputs(idx) = data - case FormField.Text(_, value, _, _) => - inputs(idx) = ZStream.fromChunk(Chunk.fromArray(value.getBytes(Charsets.Utf8))) - case FormField.Simple(_, value) => - inputs(idx) = ZStream.fromChunk(Chunk.fromArray(value.getBytes(Charsets.Utf8))) - } - ZIO.unit - case _ => - formFieldDecoders(idx)(field).map { result => inputs(idx) = result } - } - case None => - ZIO.fail(HttpCodecError.MalformedBody(s"Unexpected multipart/form-data field: ${field.name}")) - } + val codecs = flattened.content + val i = indexByName + .get(field.name) + .getOrElse(throw HttpCodecError.MalformedBody(s"Unexpected multipart/form-data field: ${field.name}")) + val codec = codecs(i).erase + for { + decoded <- codec.decodeFromField(field) + _ <- ZIO.attempt { inputs(i) = decoded } + } yield () } } - private def encodePath(inputs: Array[Any]): Path = { - var path: Path = Path.empty - - var i = 0 - while (i < inputs.length) { - val pathCodec = flattened.path(i).erase - val input = inputs(i) - - val encoded = pathCodec.encode(input) match { - case Left(error) => - throw HttpCodecError.MalformedPath(path, pathCodec, error) - case Right(value) => value + private def check(inputs: Array[Any]): ZIO[Any, Throwable, Unit] = + ZIO.attempt { + for (i <- 0 until inputs.length) { + if (inputs(i) == null) + throw HttpCodecError.MalformedBody( + s"Missing multipart/form-data field (${Try(nameByIndex(i))}", + ) } - path = path ++ encoded - - i = i + 1 } - path + private def genericEncode[A, Codec]( + codecs: Chunk[Codec], + inputs: Array[Any], + init: A, + encoding: (Codec, Any, A) => A, + ): A = { + var res = init + for (i <- 0 until inputs.length) { + val codec = codecs(i) + val input = inputs(i) + res = encoding(codec, input, res) + } + res } - private def encodeQuery(inputs: Array[Any]): QueryParams = { - var queryParams = QueryParams.empty + private def simpleEncode[A](codecs: Chunk[SimpleCodec[A, _]], inputs: Array[Any]): Option[A] = + codecs.headOption.map { codec => + codec match { + case _: SimpleCodec.Unspecified[_] => inputs(0).asInstanceOf[A] + case SimpleCodec.Specified(elem) => elem + } + } - var i = 0 - while (i < inputs.length) { - val query = flattened.query(i).erase - val input = inputs(i) + private def encodePath(inputs: Array[Any]): Path = + genericEncode[Path, PathCodec[_]]( + flattened.path, + inputs, + Path.empty, + (codec, a, acc) => { + val encoded = codec.erase.encode(a) match { + case Left(error) => + throw HttpCodecError.MalformedPath(acc, codec, error) + case Right(value) => value + } + acc ++ encoded + }, + ) - query.queryType match { - case QueryType.Primitive(name, codec) => - val schema = codec.schema - if (schema.isInstanceOf[Schema.Primitive[_]]) { - if (schema.asInstanceOf[Schema.Primitive[_]].standardType.isInstanceOf[StandardType.UnitType.type]) { - queryParams = queryParams.addQueryParams(name, Chunk.empty[String]) - } else { + private def encodeQuery(inputs: Array[Any]): QueryParams = + genericEncode[QueryParams, HttpCodec.Query[_, _]]( + flattened.query, + inputs, + QueryParams.empty, + (codec, input, queryParams) => { + val query = codec.erase + + query.queryType match { + case QueryType.Primitive(name, codec) => + val schema = codec.schema + if (schema.isInstanceOf[Schema.Primitive[_]]) { + if (schema.asInstanceOf[Schema.Primitive[_]].standardType.isInstanceOf[StandardType.UnitType.type]) { + queryParams.addQueryParams(name, Chunk.empty[String]) + } else { + val encoded = codec.codec.asInstanceOf[BinaryCodec[Any]].encode(input).asString + queryParams.addQueryParams(name, Chunk(encoded)) + } + } else if (schema.isInstanceOf[Schema.Optional[_]]) { val encoded = codec.codec.asInstanceOf[BinaryCodec[Any]].encode(input).asString - queryParams = queryParams.addQueryParams(name, Chunk(encoded)) + if (encoded.nonEmpty) queryParams.addQueryParams(name, Chunk(encoded)) else queryParams + } else { + throw new IllegalStateException( + "Only primitive schema is supported for query parameters of type Primitive", + ) } - } else if (schema.isInstanceOf[Schema.Optional[_]]) { - val encoded = codec.codec.asInstanceOf[BinaryCodec[Any]].encode(input).asString - if (encoded.nonEmpty) queryParams = queryParams.addQueryParams(name, Chunk(encoded)) - } else { - throw new IllegalStateException( - "Only primitive schema is supported for query parameters of type Primitive", - ) - } - case QueryType.Collection(_, QueryType.Primitive(name, codec), optional) => - var in: Any = input - if (optional) { - in = input.asInstanceOf[Option[Any]].getOrElse(Chunk.empty) - } - val values = input.asInstanceOf[Iterable[Any]] - if (values.nonEmpty) { - queryParams = queryParams.addQueryParams( - name, - Chunk.fromIterable( - values.map { value => - codec.codec.asInstanceOf[BinaryCodec[Any]].encode(value).asString - }, - ), - ) - } - case query @ QueryType.Record(recordSchema) if recordSchema.isInstanceOf[Schema.Optional[_]] => - input match { - case None => - () - case Some(value) => - val innerSchema = recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] - val fieldValues = innerSchema.deconstruct(value)(Unsafe.unsafe) - var j = 0 - while (j < fieldValues.size) { - val (field, codec) = query.fieldAndCodecs(j) - val name = field.name - val value = fieldValues(j) match { - case Some(value) => value - case None => field.defaultValue - } - value match { - case values: Iterable[_] => - queryParams = queryParams.addQueryParams( - name, - Chunk.fromIterable(values.map { v => - codec.codec.asInstanceOf[BinaryCodec[Any]].encode(v).asString - }), - ) - case _ => - val encoded = codec.codec.asInstanceOf[BinaryCodec[Any]].encode(value).asString - queryParams = queryParams.addQueryParam(name, encoded) + case QueryType.Collection(_, QueryType.Primitive(name, codec), optional) => + var in: Any = input + if (optional) { + in = input.asInstanceOf[Option[Any]].getOrElse(Chunk.empty) + } + val values = input.asInstanceOf[Iterable[Any]] + if (values.nonEmpty) { + queryParams.addQueryParams( + name, + Chunk.fromIterable( + values.map { value => + codec.codec.asInstanceOf[BinaryCodec[Any]].encode(value).asString + }, + ), + ) + } else queryParams + case query @ QueryType.Record(recordSchema) if recordSchema.isInstanceOf[Schema.Optional[_]] => + input match { + case None => queryParams + case Some(value) => + val innerSchema = + recordSchema.asInstanceOf[Schema.Optional[_]].schema.asInstanceOf[Schema.Record[Any]] + val fieldValues = innerSchema.deconstruct(value)(Unsafe.unsafe) + var j = 0 + var qp = queryParams + while (j < fieldValues.size) { + val (field, codec) = query.fieldAndCodecs(j) + val name = field.name + val value = fieldValues(j) match { + case Some(value) => value + case None => field.defaultValue + } + value match { + case values: Iterable[_] => + qp = qp.addQueryParams( + name, + Chunk.fromIterable(values.map { v => + codec.codec.asInstanceOf[BinaryCodec[Any]].encode(v).asString + }), + ) + case _ => + val encoded = codec.codec.asInstanceOf[BinaryCodec[Any]].encode(value).asString + qp = qp.addQueryParam(name, encoded) + } + j = j + 1 } - j = j + 1 - } - } - case query @ QueryType.Record(recordSchema) => - val innerSchema = recordSchema.asInstanceOf[Schema.Record[Any]] - val fieldValues = innerSchema.deconstruct(input)(Unsafe.unsafe) - var j = 0 - while (j < fieldValues.size) { - val (field, codec) = query.fieldAndCodecs(j) - val name = field.name - val value = fieldValues(j) match { - case Some(value) => value - case None => field.defaultValue + qp } - value match { - case values if values.isInstanceOf[Iterable[_]] => - queryParams = queryParams.addQueryParams( - name, - Chunk.fromIterable(values.asInstanceOf[Iterable[Any]].map { v => - codec.codec.asInstanceOf[BinaryCodec[Any]].encode(v).asString - }), - ) - case _ => - val encoded = codec.codec.asInstanceOf[BinaryCodec[Any]].encode(value).asString - queryParams = queryParams.addQueryParam(name, encoded) + case query @ QueryType.Record(recordSchema) => + val innerSchema = recordSchema.asInstanceOf[Schema.Record[Any]] + val fieldValues = innerSchema.deconstruct(input)(Unsafe.unsafe) + var j = 0 + var qp = queryParams + while (j < fieldValues.size) { + val (field, codec) = query.fieldAndCodecs(j) + val name = field.name + val value = fieldValues(j) match { + case Some(value) => value + case None => field.defaultValue + } + value match { + case values if values.isInstanceOf[Iterable[_]] => + qp = qp.addQueryParams( + name, + Chunk.fromIterable(values.asInstanceOf[Iterable[Any]].map { v => + codec.codec.asInstanceOf[BinaryCodec[Any]].encode(v).asString + }), + ) + case _ => + val encoded = codec.codec.asInstanceOf[BinaryCodec[Any]].encode(value).asString + qp = qp.addQueryParam(name, encoded) + } + j = j + 1 } - j = j + 1 - } - } - i = i + 1 - } - - queryParams - } - - private def encodeStatus(inputs: Array[Any]): Option[Status] = { - if (flattened.status.length == 0) { - None - } else { - flattened.status(0) match { - case _: SimpleCodec.Unspecified[_] => Some(inputs(0).asInstanceOf[Status]) - case SimpleCodec.Specified(status) => Some(status) - } - } - } - - private def encodeHeaders(inputs: Array[Any]): Headers = { - var headers = Headers.empty - - var i = 0 - while (i < inputs.length) { - val header = flattened.header(i).erase - val input = inputs(i) - - val value = header.textCodec.encode(input) + qp + } + }, + ) - headers = headers ++ Headers(header.name, value) + private def encodeHeaders(inputs: Array[Any]): Headers = + genericEncode[Headers, HttpCodec.Header[_]]( + flattened.header, + inputs, + Headers.empty, + (codec, input, headers) => headers ++ Headers(codec.name, codec.erase.textCodec.encode(input)), + ) - i = i + 1 - } + private def encodeStatus(inputs: Array[Any]): Option[Status] = + simpleEncode(flattened.status, inputs) - headers - } + private def encodeMethod(inputs: Array[Any]): Option[Method] = + simpleEncode(flattened.method, inputs) - private def encodeMethod(inputs: Array[Any]): Option[zio.http.Method] = - if (flattened.method.nonEmpty) { - flattened.method.head match { - case _: SimpleCodec.Unspecified[_] => Some(inputs(0).asInstanceOf[Method]) - case SimpleCodec.Specified(method) => Some(method) - } - } else None private def encodeBody(inputs: Array[Any], outputTypes: Chunk[MediaTypeWithQFactor]): Body = - if (isByteStream) { - Body.fromStreamChunked(inputs(0).asInstanceOf[ZStream[Any, Nothing, Byte]]) - } else { - inputs.length match { - case 0 => - Body.empty - case 1 => - val bodyCodec = flattened.content(0) - bodyCodec.erase.encodeToBody(inputs(0), outputTypes) - case _ => - Body.fromMultipartForm(encodeMultipartFormData(inputs, outputTypes), formBoundary) - } + inputs.length match { + case 0 => + Body.empty + case 1 => + val bodyCodec = flattened.content(0) + bodyCodec.erase.encodeToBody(inputs(0), outputTypes) + case _ => + Body.fromMultipartForm(encodeMultipartFormData(inputs, outputTypes), formBoundary) } private def encodeMultipartFormData(inputs: Array[Any], outputTypes: Chunk[MediaTypeWithQFactor]): Form = { - Form( - flattened.content.zipWithIndex.map { case (bodyCodec, idx) => - val input = inputs(idx) - val name = nameByIndex(idx) - bodyCodec match { - case BodyCodec.Multiple(codec, _) if codec.defaultMediaType.binary => - FormField.streamingBinaryField( - name, - input.asInstanceOf[ZStream[Any, Nothing, Byte]], - bodyCodec.mediaType(outputTypes).getOrElse(MediaType.application.`octet-stream`), - ) - case _ => - formFieldEncoders(idx)(name, input) - } - }: _*, - ) - } - - private def encodeContentType(inputs: Array[Any], outputTypes: Chunk[MediaTypeWithQFactor]): Headers = { - if (isByteStream) { - val mediaType = flattened.content(0).mediaType(outputTypes).getOrElse(MediaType.application.`octet-stream`) - Headers(Header.ContentType(mediaType)) - } else { - if (inputs.length > 1) { - Headers(Header.ContentType(MediaType.multipart.`form-data`)) - } else { - if (flattened.content.length < 1) Headers.empty - else { - val mediaType = flattened - .content(0) - .mediaType(outputTypes) - .getOrElse(throw HttpCodecError.CustomError("InvalidHttpContentCodec", "No codecs found.")) - Headers(Header.ContentType(mediaType)) - } - } + val formFields = flattened.content.zipWithIndex.map { case (bodyCodec, idx) => + val input = inputs(idx) + val name = nameByIndex(idx) + bodyCodec.erase.encodeToField(input, outputTypes, name) } + + Form(formFields: _*) } - private def isByteStreamBody(codec: BodyCodec[_]): Boolean = - codec match { - case BodyCodec.Multiple(codec, _) if codec.defaultMediaType.binary => true - case _ => false + private def encodeContentType(inputs: Array[Any], outputTypes: Chunk[MediaTypeWithQFactor]): Headers = + inputs.length match { + case 0 => + Headers.empty + case 1 => + val mediaType = flattened + .content(0) + .mediaType(outputTypes) + .getOrElse(throw HttpCodecError.CustomError("InvalidHttpContentCodec", "No codecs found.")) + Headers(Header.ContentType(mediaType)) + case _ => + Headers(Header.ContentType(MediaType.multipart.`form-data`)) } } - }