From df4bbc05421b77b52539b87ba92ac1c6b713f570 Mon Sep 17 00:00:00 2001 From: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com> Date: Mon, 29 Jan 2024 22:09:41 +0100 Subject: [PATCH] Ignore transient fields of generic records (#388) (#648) --- .../zio/schema/codec/ProtobufCodecSpec.scala | 37 +++++++- .../MutableSchemaBasedValueProcessor.scala | 95 +++++++++---------- 2 files changed, 81 insertions(+), 51 deletions(-) diff --git a/zio-schema-protobuf/shared/src/test/scala-2/zio/schema/codec/ProtobufCodecSpec.scala b/zio-schema-protobuf/shared/src/test/scala-2/zio/schema/codec/ProtobufCodecSpec.scala index c1c346709..dd4a94db7 100644 --- a/zio-schema-protobuf/shared/src/test/scala-2/zio/schema/codec/ProtobufCodecSpec.scala +++ b/zio-schema-protobuf/shared/src/test/scala-2/zio/schema/codec/ProtobufCodecSpec.scala @@ -9,6 +9,7 @@ import scala.util.Try import zio.Console._ import zio._ import zio.schema.CaseSet._ +import zio.schema.annotation.transientField import zio.schema.meta.MetaSchema import zio.schema.{ CaseSet, DeriveSchema, DynamicValue, DynamicValueGen, Schema, SchemaGen, StandardType, TypeId } import zio.stream.{ ZSink, ZStream } @@ -124,7 +125,12 @@ object ProtobufCodecSpec extends ZIOSpecDefault { test("records with arity greater than 22") { for { ed <- encodeAndDecodeNS(schemaHighArityRecord, HighArity()) - } yield assert(ed)(equalTo(HighArity())) + } yield assertTrue(ed == HighArity()) + }, + test("records with arity greater than 22 and transient field") { + for { + ed <- encodeAndDecodeNS(schemaHighArityRecordTransient, HighArityTransient(f24 = 10)) + } yield assertTrue(ed == HighArityTransient()) }, test("integer") { check(Gen.int) { value => @@ -1037,9 +1043,38 @@ object ProtobufCodecSpec extends ZIOSpecDefault { f23: Int = 23, f24: Int = 24 ) + case class HighArityTransient( + f1: Int = 1, + f2: Int = 2, + f3: Int = 3, + f4: Int = 4, + f5: Int = 5, + f6: Int = 6, + f7: Int = 7, + f8: Int = 8, + f9: Int = 9, + f10: Int = 10, + f11: Int = 11, + f12: Int = 12, + f13: Int = 13, + f14: Int = 14, + f15: Int = 15, + f16: Int = 16, + f17: Int = 17, + f18: Int = 18, + f19: Int = 19, + f20: Int = 20, + f21: Int = 21, + f22: Int = 22, + f23: Int = 23, + @transientField + f24: Int = 24 + ) lazy val schemaHighArityRecord: Schema[HighArity] = DeriveSchema.gen[HighArity] + lazy val schemaHighArityRecordTransient: Schema[HighArityTransient] = DeriveSchema.gen[HighArityTransient] + lazy val schemaOneOf: Schema[OneOf] = DeriveSchema.gen[OneOf] case class MyRecord(age: Int) diff --git a/zio-schema/shared/src/main/scala/zio/schema/MutableSchemaBasedValueProcessor.scala b/zio-schema/shared/src/main/scala/zio/schema/MutableSchemaBasedValueProcessor.scala index 044e02ad3..432c1fecf 100644 --- a/zio-schema/shared/src/main/scala/zio/schema/MutableSchemaBasedValueProcessor.scala +++ b/zio-schema/shared/src/main/scala/zio/schema/MutableSchemaBasedValueProcessor.scala @@ -143,35 +143,30 @@ trait MutableSchemaBasedValueProcessor[Target, Context] { } def fields(s: Schema.Record[_], record: Any, fs: Schema.Field[_, _]*): Unit = { - val nonTransientFields = fs.filter { - case Schema.Field(_, _, annotations, _, _, _) - if annotations.collectFirst { case a: transientField => a }.isDefined => - false - case _ => true - } - val values = ChunkBuilder.make[Target](nonTransientFields.size) - - def processNext(index: Int, remaining: List[Schema.Field[_, _]]): Unit = - remaining match { - case next :: _ => - currentSchema = next.schema - currentValue = next.asInstanceOf[Schema.Field[Any, Any]].get(record) - pushContext(contextForRecordField(contextStack.head, index, next)) - push(processField(index, remaining, _)) - case Nil => - finishWith( - processRecord( - contextStack.head, - s, - nonTransientFields.map(_.name).zip(values.result()).foldLeft(ListMap.empty[String, Target]) { - case (lm, pair) => - lm.updated(pair._1, pair._2) - } - ) + val nonTransientFields = fs.filterNot(_.annotations.exists(_.isInstanceOf[transientField])) + val values = ChunkBuilder.make[Target](nonTransientFields.size) + + def processNext(index: Int, remaining: Seq[Schema.Field[_, _]]): Unit = + if (remaining.isEmpty) { + finishWith( + processRecord( + contextStack.head, + s, + nonTransientFields.map(_.name).zip(values.result()).foldLeft(ListMap.empty[String, Target]) { + case (lm, pair) => + lm.updated(pair._1, pair._2) + } ) + ) + } else { + val next = remaining.head + currentSchema = next.schema + currentValue = next.asInstanceOf[Schema.Field[Any, Any]].get(record) + pushContext(contextForRecordField(contextStack.head, index, next)) + push(processField(index, remaining, _)) } - def processField(index: Int, currentStructure: List[Schema.Field[_, _]], fieldResult: Target): Unit = { + def processField(index: Int, currentStructure: Seq[Schema.Field[_, _]], fieldResult: Target): Unit = { contextStack = contextStack.tail values += fieldResult val remaining = currentStructure.tail @@ -179,7 +174,7 @@ trait MutableSchemaBasedValueProcessor[Target, Context] { } startProcessingRecord(contextStack.head, s) - processNext(0, nonTransientFields.toList) + processNext(0, nonTransientFields) } def enumCases(s: Schema.Enum[_], cs: Schema.Case[_, _]*): Unit = { @@ -223,33 +218,33 @@ trait MutableSchemaBasedValueProcessor[Target, Context] { finishWith(processPrimitive(currentContext, currentValue, p.asInstanceOf[StandardType[Any]])) case s @ Schema.GenericRecord(_, structure, _) => - val map = currentValue.asInstanceOf[ListMap[String, _]] - val structureChunk = structure.toChunk - val values = ChunkBuilder.make[Target](structureChunk.size) - - def processNext(index: Int, remaining: List[Schema.Field[ListMap[String, _], _]]): Unit = - remaining match { - case next :: _ => - currentSchema = next.schema - currentValue = map(next.name) - pushContext(contextForRecordField(currentContext, index, next)) - push(processField(index, remaining, _)) - case Nil => - finishWith( - processRecord( - currentContext, - s, - structureChunk.map(_.name).zip(values.result()).foldLeft(ListMap.empty[String, Target]) { - case (lm, pair) => - lm.updated(pair._1, pair._2) - } - ) + val map = currentValue.asInstanceOf[ListMap[String, _]] + val nonTransientFields = structure.toChunk.filterNot(_.annotations.exists(_.isInstanceOf[transientField])) + val values = ChunkBuilder.make[Target](nonTransientFields.size) + + def processNext(index: Int, remaining: Seq[Schema.Field[ListMap[String, _], _]]): Unit = + if (remaining.isEmpty) { + finishWith( + processRecord( + currentContext, + s, + nonTransientFields.map(_.name).zip(values.result()).foldLeft(ListMap.empty[String, Target]) { + case (lm, pair) => + lm.updated(pair._1, pair._2) + } ) + ) + } else { + val next = remaining.head + currentSchema = next.schema + currentValue = map(next.name) + pushContext(contextForRecordField(currentContext, index, next)) + push(processField(index, remaining, _)) } def processField( index: Int, - currentStructure: List[Schema.Field[ListMap[String, _], _]], + currentStructure: Seq[Schema.Field[ListMap[String, _], _]], fieldResult: Target ): Unit = { contextStack = contextStack.tail @@ -259,7 +254,7 @@ trait MutableSchemaBasedValueProcessor[Target, Context] { } startProcessingRecord(currentContext, s) - processNext(0, structureChunk.toList) + processNext(0, nonTransientFields) case s @ Schema.Enum1(_, case1, _) => enumCases(s, case1)