Skip to content

Commit

Permalink
JsonCodec respects annotations for GenericRecord (#700)
Browse files Browse the repository at this point in the history
  • Loading branch information
987Nabil committed Jun 30, 2024
1 parent 91899ff commit ad289bc
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,10 @@ object DeriveSchema {
q"""(m: scala.collection.immutable.ListMap[String, _]) => try { Right($tpeCompanion.apply(..$casts)) } catch { case e: Throwable => Left(e.getMessage) }"""
}
val toMap = {
val tuples = fieldAccessors.map { fieldName =>
q"(${fieldName.toString},b.$fieldName)"
val tuples = fieldAccessors.zip(fieldAnnotations).map {
case (fieldName, annotations) =>
val newName = getFieldName(annotations).getOrElse(fieldName.toString)
q"(${newName},b.$fieldName)"
}
q"""(b: $tpe) => Right(scala.collection.immutable.ListMap.apply(..$tuples))"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ object JsonCodecJVMSpec extends ZIOSpecDefault {

case class RecordExample2(
f1: Option[String],
f2: Option[String],
f2: String,
f3: Option[String] = None,
f4: Option[String] = None,
f5: Option[String] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ object JsonCodec {
pad(indent_, out)
var first = true
structure.foreach {
case Schema.Field(k, a, _, _, _, _) =>
case field if field.transient || isEmptyOptionalValue(field, value(field.fieldName), cfg) => ()
case f @ Schema.Field(_, a, _, _, _, _) =>
val enc = schemaEncoder(a.asInstanceOf[Schema[Any]], cfg)
if (first)
first = false
Expand All @@ -471,10 +472,10 @@ object JsonCodec {
if (indent.isDefined)
ZJsonEncoder.pad(indent_, out)
}
string.encoder.unsafeEncode(JsonFieldEncoder.string.unsafeEncodeField(k), indent_, out)
string.encoder.unsafeEncode(JsonFieldEncoder.string.unsafeEncodeField(f.fieldName), indent_, out)
if (indent.isEmpty) out.write(':')
else out.write(" : ")
enc.unsafeEncode(value(k), indent_, out)
enc.unsafeEncode(value(f.fieldName), indent_, out)
}
pad(indent, out)
out.write('}')
Expand Down Expand Up @@ -551,7 +552,7 @@ object JsonCodec {
case Schema.Map(ks, vs, _) => mapDecoder(ks, vs)
case Schema.Set(s, _) => ZJsonDecoder.chunk(schemaDecoder(s, -1)).map(entries => entries.toSet)
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk)
case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk, schema.annotations.contains(rejectExtraFields()))
case Schema.Either(left, right, _) => ZJsonDecoder.either(schemaDecoder(left, -1), schemaDecoder(right, -1))
case s @ Schema.Fallback(_, _, _, _) => fallbackDecoder(s)
case l @ Schema.Lazy(_) => schemaDecoder(l.schema, discriminator)
Expand Down Expand Up @@ -806,45 +807,48 @@ object JsonCodec {
private def deAliasCaseName(alias: String, caseNameAliases: Map[String, String]): String =
caseNameAliases.getOrElse(alias, alias)

private def recordDecoder[Z](structure: Seq[Schema.Field[Z, _]]): ZJsonDecoder[ListMap[String, Any]] = {
(trace: List[JsonError], in: RetractReader) =>
{
val builder: ChunkBuilder[(String, Any)] = zio.ChunkBuilder.make[(String, Any)](structure.size)
Lexer.char(trace, in, '{')
if (Lexer.firstField(trace, in)) {
while ({
val field = Lexer.string(trace, in).toString
structure.find(
f => f.name == field || f.annotations.collectFirst { case fieldName(name) => name }.contains(field)
) match {
case Some(Schema.Field(label, schema, _, _, _, _)) =>
val trace_ = JsonError.ObjectAccess(label) :: trace
Lexer.char(trace_, in, ':')
val value = schemaDecoder(schema).unsafeDecode(trace_, in)
builder += ((JsonFieldDecoder.string.unsafeDecodeField(trace_, label), value))
case None =>
Lexer.char(trace, in, ':')
Lexer.skipValue(trace, in)
private def recordDecoder[Z](
structure: Seq[Schema.Field[Z, _]],
rejectAdditionalFields: Boolean
): ZJsonDecoder[ListMap[String, Any]] = { (trace: List[JsonError], in: RetractReader) =>
{
val builder: ChunkBuilder[(String, Any)] = zio.ChunkBuilder.make[(String, Any)](structure.size)
Lexer.char(trace, in, '{')
if (Lexer.firstField(trace, in)) {
while ({
val field = Lexer.string(trace, in).toString
structure.find(f => f.nameAndAliases.contains(field)) match {
case Some(s @ Schema.Field(_, schema, _, _, _, _)) =>
val fieldName = s.fieldName
val trace_ = JsonError.ObjectAccess(fieldName) :: trace
Lexer.char(trace_, in, ':')
val value = schemaDecoder(schema).unsafeDecode(trace_, in)
builder += ((JsonFieldDecoder.string.unsafeDecodeField(trace_, fieldName), value))
case None if rejectAdditionalFields =>
throw UnsafeJson(JsonError.Message(s"unexpected field: $field") :: trace)
case None =>
Lexer.char(trace, in, ':')
Lexer.skipValue(trace, in)

}
(Lexer.nextField(trace, in))
}) {
()
}
Lexer.nextField(trace, in)
}) {
()
}
val tuples = builder.result()
val collectedFields: Set[String] = tuples.map { case (fieldName, _) => fieldName }.toSet
val resultBuilder = ListMap.newBuilder[String, Any]

// add fields with default values if they are not present in the JSON
structure.foreach { field =>
if (!collectedFields.contains(field.name) && field.optional && field.defaultValue.isDefined) {
val value = field.name -> field.defaultValue.get
resultBuilder += value
}
}
val tuples = builder.result()
val collectedFields: Set[String] = tuples.map { case (fieldName, _) => fieldName }.toSet
val resultBuilder = ListMap.newBuilder[String, Any]

// add fields with default values if they are not present in the JSON
structure.foreach { field =>
if (!collectedFields.contains(field.fieldName) && field.optional && field.defaultValue.isDefined) {
val value = field.fieldName -> field.defaultValue.get
resultBuilder += value
}
(resultBuilder ++= tuples).result()
}
(resultBuilder ++= tuples).result()
}
}

private def fallbackDecoder[A, B](schema: Schema.Fallback[A, B]): ZJsonDecoder[Fallback[A, B]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,19 @@ object JsonCodecSpec extends ZIOSpecDefault {
charSequenceToByteChunk("""null""")
)
}
)
),
suite("Generic Record") {
test("Do not encode transient field") {
assertEncodes(
RecordExample.schema.annotate(rejectExtraFields()),
RecordExample(f1 = Some("test"), f3 = Some("transient")),
charSequenceToByteChunk(
"""{"$f1":"test","f2":null,"f4":null,"f5":null,"f6":null,"f7":null,"f8":null,"f9":null,"f10":null,"f11":null,"f12":null,"f13":null,"f14":null,"f15":null,"f16":null,"f17":null,"f18":null,"f19":null,"f20":null,"f21":null,"f22":null,"$f23":null}""".stripMargin
)
)
}

}
)

private val decoderSuite = suite("decoding")(
Expand Down Expand Up @@ -447,7 +459,28 @@ object JsonCodecSpec extends ZIOSpecDefault {
assertDecodes(
RecordExample.schema,
RecordExample(f1 = Some("test"), f2 = None),
charSequenceToByteChunk("""{"f1":"test"}""")
charSequenceToByteChunk("""{"$f1":"test"}""")
)
},
test("aliased field") {
assertDecodes(
RecordExample.schema,
RecordExample(f1 = Some("test"), f2 = Some("alias")),
charSequenceToByteChunk("""{"$f1":"test", "field2":"alias"}""")
)
},
test("reject extra fields") {
assertDecodes(
RecordExample.schema.annotate(rejectExtraFields()),
RecordExample(f1 = Some("test")),
charSequenceToByteChunk("""{"$f1":"test", "extraField":"extra"}""")
).flip.map(err => assertTrue(err.getMessage() == "(unexpected field: extraField)"))
},
test("optional field with schema or annotated default value") {
assertDecodes(
RecordExampleWithOptField.schema,
RecordExampleWithOptField(f1 = Some("test"), f2 = None, f4 = "", f5 = "hello"),
charSequenceToByteChunk("""{"$f1":"test"}""")
)
}
),
Expand Down Expand Up @@ -1232,11 +1265,9 @@ object JsonCodecSpec extends ZIOSpecDefault {
Enumeration3(StringValue3("foo"))
) &> assertEncodesThenDecodes(
Schema[Enumeration3],
Enumeration3(StringValue3Multi("foo", "bar"))
) &> assertEncodesThenDecodes(Schema[Enumeration3], Enumeration3(IntValue3(-1))) &> assertEncodesThenDecodes(
Schema[Enumeration3],
Enumeration3(BooleanValue3(false))
) &> assertEncodesThenDecodes(Schema[Enumeration3], Enumeration3(Nested(StringValue3("foo"))))
Enumeration3(StringValue3Multi("foo", "bar")),
print = true
)
},
test("of case classes with discriminator") {
assertEncodesThenDecodes(Schema[Command], Command.Cash) &>
Expand Down Expand Up @@ -1884,9 +1915,9 @@ object JsonCodecSpec extends ZIOSpecDefault {
}

case class RecordExample(
f1: Option[String], // the only field that does not have a default value
f2: Option[String] = None,
f3: Option[String] = None,
@fieldName("$f1") f1: Option[String], // the only field that does not have a default value
@fieldNameAliases("field2") f2: Option[String] = None,
@transientField f3: Option[String] = None,
f4: Option[String] = None,
f5: Option[String] = None,
f6: Option[String] = None,
Expand All @@ -1909,8 +1940,39 @@ object JsonCodecSpec extends ZIOSpecDefault {
@fieldName("$f23") f23: Option[String] = None
)

case class RecordExampleWithOptField(
@fieldName("$f1") f1: Option[String], // the only field that does not have a default value
@optionalField @fieldNameAliases("field2") f2: Option[String] = None,
@transientField f3: Option[String] = None,
@optionalField f4: String,
@optionalField @fieldDefaultValue("hello") f5: String,
f6: Option[String] = None,
f7: Option[String] = None,
f8: Option[String] = None,
f9: Option[String] = None,
f10: Option[String] = None,
f11: Option[String] = None,
f12: Option[String] = None,
f13: Option[String] = None,
f14: Option[String] = None,
f15: Option[String] = None,
f16: Option[String] = None,
f17: Option[String] = None,
f18: Option[String] = None,
f19: Option[String] = None,
f20: Option[String] = None,
f21: Option[String] = None,
f22: Option[String] = None,
@fieldName("$f23") f23: Option[String] = None
)

object RecordExample {
implicit lazy val schema: Schema[RecordExample] = DeriveSchema.gen[RecordExample]
}

object RecordExampleWithOptField {
implicit lazy val schema: Schema[RecordExampleWithOptField] =
DeriveSchema.gen[RecordExampleWithOptField]
}

}
10 changes: 10 additions & 0 deletions zio-schema/shared/src/main/scala/zio/schema/Schema.scala
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,16 @@ object Schema extends SchemaPlatformSpecific with SchemaEquality {
val transient: Boolean =
annotations.exists(_.isInstanceOf[transientField])

val nameAndAliases: scala.collection.immutable.Set[String] =
annotations.collect {
case aliases: fieldNameAliases => aliases.aliases
case f: fieldName => Seq(f.name)
}.flatten.toSet + name

val fieldName: String = annotations.collectFirst {
case f: fieldName => f.name
}.getOrElse(name)

override def toString: String = s"Field($name,$schema)"
}

Expand Down

0 comments on commit ad289bc

Please sign in to comment.