From be43667e09d676a331cdab926bb8175b83edf70a Mon Sep 17 00:00:00 2001 From: Daniel Vigovszky Date: Mon, 15 Jan 2024 12:12:14 +0100 Subject: [PATCH] Support for explicit field numbers and skipping unknown fields during deserialization (#637) * Support for explicit field numbers and skipping unknown fields during deserialization * ScalaFix --- docs/derivations/codecs/protobuf.md | 19 ++++ .../zio/schema/codec/FieldMappingCache.scala | 37 ++++++ .../zio/schema/codec/ProtobufCodec.scala | 105 ++++++++++++++---- .../scala/zio/schema/codec/fieldNumber.scala | 5 + .../zio/schema/codec/ProtobufCodecSpec.scala | 73 +++++++++++- .../scala/zio/schema/codec/ThriftCodec.scala | 10 +- .../MutableSchemaBasedValueBuilder.scala | 34 ++++-- 7 files changed, 249 insertions(+), 34 deletions(-) create mode 100644 zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/FieldMappingCache.scala create mode 100644 zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/fieldNumber.scala diff --git a/docs/derivations/codecs/protobuf.md b/docs/derivations/codecs/protobuf.md index f734d8eb8..4bd26bb19 100644 --- a/docs/derivations/codecs/protobuf.md +++ b/docs/derivations/codecs/protobuf.md @@ -26,6 +26,25 @@ object ProtobufCodec { } ``` +Optionally the `@fieldNumber(1)` annotation can be used on fields to specify the field number for a case class field. This together with default values +can be used to keep binary compatibility when evolving schemas. Default field numbers are indexed starting from 1. + +For example considering the following three versions of a record: + +```scala +final case class RecordV1(x: Int, y: Int) +final case class RecordV2(x: Int = 100, y: Int, z: Int) +final case class RecordV3(@fieldNumber(2) y: Int, @fieldNumber(4) extra: String = "unknown", @fieldNumber(3) z: Int) +``` + +The decoder of V1 can decode a binary encoded by V2, but cannot decode a binary encoded by V3 because it does not have a field number 1 (x). +The decoder of V2 can decode a binary encoded by V3 because it has a default value for field number 1 (x), 100. The decoder of V3 can read V2 but +cannot read V1 (as it does not have field number 3 (z)). As demonstrated, using explicit field numbers also allows reordering the fields without +breaking the format. + + +```scala + ## Example: BinaryCodec Let's try an example: diff --git a/zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/FieldMappingCache.scala b/zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/FieldMappingCache.scala new file mode 100644 index 000000000..85b456398 --- /dev/null +++ b/zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/FieldMappingCache.scala @@ -0,0 +1,37 @@ +package zio.schema.codec + +import scala.collection.mutable + +import zio.schema.Schema + +/** + * A per-encooding/decoding cache for field mappings. No need for thread safety as a single encoding/decoding + * is sequential. + */ +private class FieldMappingCache { + private val mapping: mutable.Map[Schema[_], FieldMapping] = mutable.Map.empty + + def get(schema: Schema.Record[_]): FieldMapping = + mapping.getOrElseUpdate(schema, FieldMapping.fromSchema(schema)) +} + +final case class FieldMapping(indexToFieldNumber: Map[Int, Int], fieldNumberToIndex: Map[Int, Int]) + +object FieldMapping { + + def fromSchema(schema: Schema.Record[_]): FieldMapping = { + val indexToFieldNumber = schema.fields.zipWithIndex.map { + case (field, index) => { + val customFieldNumber = getFieldNumber(field) + index -> customFieldNumber.getOrElse(index + 1) + } + }.toMap + val fieldNumberToIndex = indexToFieldNumber.map(_.swap) + FieldMapping(indexToFieldNumber, fieldNumberToIndex) + } + + def getFieldNumber(field: Schema.Field[_, _]): Option[Int] = + field.annotations.collectFirst { + case fieldNumber(n) => n + } +} diff --git a/zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/ProtobufCodec.scala b/zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/ProtobufCodec.scala index 01a25f7fd..f65d10924 100644 --- a/zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/ProtobufCodec.scala +++ b/zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/ProtobufCodec.scala @@ -9,8 +9,9 @@ import java.util.UUID import scala.collection.immutable.ListMap import scala.util.control.NonFatal -import zio.schema.MutableSchemaBasedValueBuilder.CreateValueFromSchemaError +import zio.schema.MutableSchemaBasedValueBuilder.{ CreateValueFromSchemaError, ReadingFieldResult } import zio.schema._ +import zio.schema.annotation.fieldDefaultValue import zio.schema.codec.DecodeError.{ ExtraFields, MalformedField, MissingField } import zio.schema.codec.ProtobufCodec.Protobuf.WireType.LengthDelimited import zio.stream.ZPipeline @@ -256,8 +257,10 @@ object ProtobufCodec { context: EncoderContext, index: Int, field: Schema.Field[_, _] - ): EncoderContext = - context.copy(fieldNumber = Some(index + 1)) + ): EncoderContext = { + val fieldNumber = FieldMapping.getFieldNumber(field).getOrElse(index + 1) + context.copy(fieldNumber = Some(fieldNumber)) + } override protected def contextForTuple(context: EncoderContext, index: Int): EncoderContext = context.copy(fieldNumber = Some(index)) @@ -492,6 +495,7 @@ object ProtobufCodec { import Protobuf._ private val state: DecoderState = new DecoderState(chunk, 0) + private val fieldMappingCache = new FieldMappingCache() def decode[A](schema: Schema[A]): scala.util.Either[DecodeError, A] = try { @@ -595,24 +599,53 @@ object ProtobufCodec { context: DecoderContext, record: Schema.Record[_], index: Int - ): Option[(DecoderContext, Int)] = - if (index == record.fields.size) { - None + ): ReadingFieldResult[DecoderContext] = + if (state.length(context) <= 0) { + ReadingFieldResult.Finished() } else { keyDecoder(context) match { case (wt, fieldNumber) => - if (record.fields.isDefinedAt(fieldNumber - 1)) { - Some(wt match { - case LengthDelimited(width) => - (context.limitedTo(state, width), fieldNumber - 1) - case _ => - (context, fieldNumber - 1) - }) - } else { - throw ExtraFields( - "Unknown", - s"Failed to decode record. Schema does not contain field number $fieldNumber." - ) + val fieldMapping = fieldMappingCache.get(record) + fieldMapping.fieldNumberToIndex.get(fieldNumber) match { + case Some(index) => { + if (record.fields.isDefinedAt(index)) { + ReadingFieldResult.ReadField(wt match { + case LengthDelimited(width) => + context.limitedTo(state, width) + case _ => + context + }, index) + } else { + throw ExtraFields( + "Unknown", + s"Failed to decode record. Schema does not contain field number $fieldNumber." + ) + } + } + case None => + wt match { + case WireType.VarInt => { + varIntDecoder(context) + ReadingFieldResult.UpdateContext(context) + } + case WireType.Bit64 => { + state.move(8) + ReadingFieldResult.UpdateContext(context) + } + case LengthDelimited(width) => { + state.move(width) + ReadingFieldResult.UpdateContext(context) + } + case WireType.Bit32 => { + state.move(4) + ReadingFieldResult.UpdateContext(context) + } + case _ => + throw ExtraFields( + "Unknown", + s"Failed to decode record. Schema does not contain field number $fieldNumber and it's length is unknown" + ) + } } } } @@ -623,9 +656,39 @@ object ProtobufCodec { values: Chunk[(Int, Any)] ): Any = Unsafe.unsafe { implicit u => - record.construct(values.map(_._2)) match { - case Right(result) => result - case Left(message) => throw DecodeError.ReadError(Cause.empty, message) + val array = new Array[Any](record.fields.length) + val mask = Array.fill(record.fields.length)(false) + + for ((field, index) <- record.fields.zipWithIndex) { + val defaultValue = field.annotations.collectFirst { + case fieldDefaultValue(defaultValue) => defaultValue + } + + defaultValue match { + case Some(defaultValue) => + array(index) = defaultValue + mask(index) = true + case None => + } + } + + for ((index, value) <- values) { + if (index < array.length) { + array(index) = value + mask(index) = true; + } + } + + if (mask.forall(set => set)) { + record.construct(Chunk.fromArray(array)) match { + case Right(result) => result + case Left(message) => throw DecodeError.ReadError(Cause.empty, message) + } + } else { + throw DecodeError.ReadError( + Cause.empty, + s"Failed to decode record. Missing fields: ${mask.zip(record.fields).filter(!_._1).map(_._2).mkString(", ")}" + ) } } diff --git a/zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/fieldNumber.scala b/zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/fieldNumber.scala new file mode 100644 index 000000000..b4a3458f9 --- /dev/null +++ b/zio-schema-protobuf/shared/src/main/scala/zio/schema/codec/fieldNumber.scala @@ -0,0 +1,5 @@ +package zio.schema.codec + +import scala.annotation.StaticAnnotation + +final case class fieldNumber(n: Int) extends StaticAnnotation diff --git a/zio-schema-protobuf/shared/src/test/scala-2/zio/schema/codec/ProtobufCodecSpec.scala b/zio-schema-protobuf/shared/src/test/scala-2/zio/schema/codec/ProtobufCodecSpec.scala index b3130742a..431a2c3de 100644 --- a/zio-schema-protobuf/shared/src/test/scala-2/zio/schema/codec/ProtobufCodecSpec.scala +++ b/zio-schema-protobuf/shared/src/test/scala-2/zio/schema/codec/ProtobufCodecSpec.scala @@ -793,6 +793,52 @@ object ProtobufCodecSpec extends ZIOSpecDefault { assertZIO(encodeAndDecode(MetaSchema.schema, metaSchema))(equalTo(Chunk(metaSchema))) } } + ), + suite("custom field numbers")( + test("V2 decoder can read encoded V3") { + val recordV3 = RecordV3(y = 200, extra = "hello", z = -10) + for { + encoded <- encode(schemaRecordV3, recordV3) + decoded <- decodeChunkNS(schemaRecordV2, encoded) + } yield assertTrue(decoded == RecordV2(100, 200, -10)) + }, + test("nested V2 decoder can read encoded V3") { + val v3 = NestedV3( + RecordV3(y = 200, extra = "hello", z = -10), + RecordV3(y = 2, extra = "world", z = 3) + ) + for { + encoded <- encode(schemaNestedV3, v3) + decoded <- decodeChunkNS(schemaNestedV2, encoded) + } yield assertTrue( + decoded == NestedV2( + RecordV2(100, 200, -10), + RecordV2(100, 2, 3) + ) + ) + }, + test("V3 decoder cannot read encoded V2") { + val recordV2 = RecordV2(150, 200, -10) + for { + encoded <- encode(schemaRecordV2, recordV2) + decoded <- decodeChunkNS(schemaRecordV3, encoded) + } yield assertTrue(decoded == RecordV3(200, "unknown", -10)) + }, + test("V1 decoder can read encoded V2") { + val recordV2 = RecordV2(150, 200, -10) + for { + encoded <- encode(schemaRecordV2, recordV2) + decoded <- decodeChunkNS(schemaRecordV1, encoded) + } yield assertTrue(decoded == RecordV1(150, 200)) + }, + test("V2 decoder cannot read encoded V1") { + // because there is no default value for 'z' + val recordV1 = RecordV1(150, 200) + for { + encoded <- encode(schemaRecordV1, recordV1) + decoded <- decodeChunkNS(schemaRecordV2, encoded).exit + } yield assertTrue(decoded.isFailure) + } ) ) @@ -1011,6 +1057,18 @@ object ProtobufCodecSpec extends ZIOSpecDefault { lazy val schemaChunkOfBytes: Schema[ChunkOfBytes] = DeriveSchema.gen[ChunkOfBytes] + final case class RecordV1(x: Int, y: Int) + final case class RecordV2(x: Int = 100, y: Int, z: Int) + final case class RecordV3(@fieldNumber(2) y: Int, @fieldNumber(4) extra: String = "unknown", @fieldNumber(3) z: Int) + final case class NestedV2(a: RecordV2, b: RecordV2) + final case class NestedV3(a: RecordV3, b: RecordV3) + + implicit val schemaRecordV1: Schema[RecordV1] = DeriveSchema.gen[RecordV1] + implicit val schemaRecordV2: Schema[RecordV2] = DeriveSchema.gen[RecordV2] + implicit val schemaRecordV3: Schema[RecordV3] = DeriveSchema.gen[RecordV3] + implicit val schemaNestedV2: Schema[NestedV2] = DeriveSchema.gen[NestedV2] + implicit val schemaNestedV3: Schema[NestedV3] = DeriveSchema.gen[NestedV3] + // "%02X".format doesn't work the same in ScalaJS def toHex(chunk: Chunk[Byte]): String = chunk.toArray.map { byte => @@ -1056,12 +1114,25 @@ object ProtobufCodecSpec extends ZIOSpecDefault { ZStream .fromChunk(fromHex(hex)) ) - .runCollect + .run(ZSink.collectAll) + + def decodeChunk[A](schema: Schema[A], bytes: Chunk[Byte]): ZIO[Any, DecodeError, Chunk[A]] = + ProtobufCodec + .protobufCodec(schema) + .streamDecoder + .apply( + ZStream + .fromChunk(bytes) + ) + .run(ZSink.collectAll) //NS == non streaming variant of decode def decodeNS[A](schema: Schema[A], hex: String): ZIO[Any, DecodeError, A] = ZIO.succeed(ProtobufCodec.protobufCodec(schema).decode(fromHex(hex))).absolve[DecodeError, A] + def decodeChunkNS[A](schema: Schema[A], bytes: Chunk[Byte]): ZIO[Any, DecodeError, A] = + ZIO.succeed(ProtobufCodec.protobufCodec(schema).decode(bytes)).absolve[DecodeError, A] + def encodeAndDecode[A](schema: Schema[A], input: A): ZIO[Any, DecodeError, Chunk[A]] = ProtobufCodec .protobufCodec(schema) diff --git a/zio-schema-thrift/src/main/scala/zio/schema/codec/ThriftCodec.scala b/zio-schema-thrift/src/main/scala/zio/schema/codec/ThriftCodec.scala index 411b5c558..f659b20ae 100644 --- a/zio-schema-thrift/src/main/scala/zio/schema/codec/ThriftCodec.scala +++ b/zio-schema-thrift/src/main/scala/zio/schema/codec/ThriftCodec.scala @@ -10,7 +10,7 @@ import scala.util.control.NonFatal import org.apache.thrift.protocol._ -import zio.schema.MutableSchemaBasedValueBuilder.CreateValueFromSchemaError +import zio.schema.MutableSchemaBasedValueBuilder.{ CreateValueFromSchemaError, ReadingFieldResult } import zio.schema._ import zio.schema.annotation.{ fieldDefaultValue, optionalField, transientField } import zio.schema.codec.DecodeError.{ EmptyContent, MalformedFieldWithPath, ReadError, ReadErrorWithPath } @@ -582,14 +582,14 @@ object ThriftCodec { context: DecoderContext, record: Schema.Record[_], index: Int - ): Option[(DecoderContext, Int)] = + ): ReadingFieldResult[DecoderContext] = if (record.fields.nonEmpty) { val tfield = p.readFieldBegin() - if (tfield.`type` == TType.STOP) None - else Some((context.copy(path = context.path :+ s"fieldId:${tfield.id}"), tfield.id - 1)) + if (tfield.`type` == TType.STOP) ReadingFieldResult.Finished() + else ReadingFieldResult.ReadField(context.copy(path = context.path :+ s"fieldId:${tfield.id}"), tfield.id - 1) } else { val _ = p.readByte() - None + ReadingFieldResult.Finished() } override protected def createRecord( diff --git a/zio-schema/shared/src/main/scala/zio/schema/MutableSchemaBasedValueBuilder.scala b/zio-schema/shared/src/main/scala/zio/schema/MutableSchemaBasedValueBuilder.scala index 4a49f6b1a..6c83017c8 100644 --- a/zio-schema/shared/src/main/scala/zio/schema/MutableSchemaBasedValueBuilder.scala +++ b/zio-schema/shared/src/main/scala/zio/schema/MutableSchemaBasedValueBuilder.scala @@ -2,7 +2,7 @@ package zio.schema import scala.util.control.NonFatal -import zio.schema.MutableSchemaBasedValueBuilder.CreateValueFromSchemaError +import zio.schema.MutableSchemaBasedValueBuilder.{ CreateValueFromSchemaError, ReadingFieldResult } import zio.{ Chunk, ChunkBuilder } /** @@ -29,13 +29,15 @@ trait MutableSchemaBasedValueBuilder[Target, Context] { /** The next value to build is a record with the given schema */ protected def startCreatingRecord(context: Context, record: Schema.Record[_]): Context - /** Called for each field of a record. The resulting tuple is either None indicating there are no more fields to read, + /** Called for each field of a record. The result is either Finished indicating there are no more fields to read, * or it contains an updated context belonging to the field and the next field's index in the schema. This allows * the implementation to instantiate fields in a different order than what the schema defines. + * A third option is to just update the context without reading any field - this can be used to skip data, + * for example when reading a newer format of a record that has more fields than the older one. * * The index parameter is a 0-based index, incremented by one for each field read within a record. */ - protected def startReadingField(context: Context, record: Schema.Record[_], index: Int): Option[(Context, Int)] + protected def startReadingField(context: Context, record: Schema.Record[_], index: Int): ReadingFieldResult[Context] /** Creates a record value from the gathered field values */ protected def createRecord(context: Context, record: Schema.Record[_], values: Chunk[(Int, Target)]): Target @@ -183,7 +185,7 @@ trait MutableSchemaBasedValueBuilder[Target, Context] { def readField(index: Int): Unit = { contextStack = contextStack.tail startReadingField(contextStack.head, record, index) match { - case Some((updatedState, idx)) => + case ReadingFieldResult.ReadField(updatedState, idx) => pushContext(updatedState) currentSchema = record.fields(idx).schema push { field => @@ -191,7 +193,10 @@ trait MutableSchemaBasedValueBuilder[Target, Context] { values += elem readField(index + 1) } - case None => + case ReadingFieldResult.UpdateContext(updatedState) => + pushContext(updatedState) + readField(index) + case ReadingFieldResult.Finished() => finishWith(createRecord(contextStack.head, record, values.result())) } } @@ -1013,6 +1018,14 @@ trait MutableSchemaBasedValueBuilder[Target, Context] { object MutableSchemaBasedValueBuilder { case class CreateValueFromSchemaError[Context](context: Context, cause: Throwable) extends RuntimeException + + sealed trait ReadingFieldResult[Context] + + object ReadingFieldResult { + final case class Finished[Context]() extends ReadingFieldResult[Context] + final case class ReadField[Context](context: Context, index: Int) extends ReadingFieldResult[Context] + final case class UpdateContext[Context](context: Context) extends ReadingFieldResult[Context] + } } /** A simpler version of SimpleMutableSchemaBasedValueBuilder without using any Context */ @@ -1025,8 +1038,15 @@ trait SimpleMutableSchemaBasedValueBuilder[Target] extends MutableSchemaBasedVal startCreatingRecord(record) protected def startCreatingRecord(record: Schema.Record[_]): Unit - override protected def startReadingField(context: Unit, record: Schema.Record[_], index: Int): Option[(Unit, Int)] = - startReadingField(record, index).map(((), _)) + override protected def startReadingField( + context: Unit, + record: Schema.Record[_], + index: Int + ): ReadingFieldResult[Unit] = + startReadingField(record, index) match { + case Some(idx) => ReadingFieldResult.ReadField((), idx) + case None => ReadingFieldResult.Finished() + } protected def startReadingField(record: Schema.Record[_], index: Int): Option[Int] override protected def createRecord(context: Unit, record: Schema.Record[_], values: Chunk[(Int, Target)]): Target =