Skip to content

Commit

Permalink
Support for explicit field numbers and skipping unknown fields during…
Browse files Browse the repository at this point in the history
… deserialization (#637)

* Support for explicit field numbers and skipping unknown fields during deserialization

* ScalaFix
  • Loading branch information
vigoo authored Jan 15, 2024
1 parent a87dd6e commit be43667
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 34 deletions.
19 changes: 19 additions & 0 deletions docs/derivations/codecs/protobuf.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@ object ProtobufCodec {
}
```

Optionally the `@fieldNumber(1)` annotation can be used on fields to specify the field number for a case class field. This together with default values
can be used to keep binary compatibility when evolving schemas. Default field numbers are indexed starting from 1.

For example considering the following three versions of a record:

```scala
final case class RecordV1(x: Int, y: Int)
final case class RecordV2(x: Int = 100, y: Int, z: Int)
final case class RecordV3(@fieldNumber(2) y: Int, @fieldNumber(4) extra: String = "unknown", @fieldNumber(3) z: Int)
```

The decoder of V1 can decode a binary encoded by V2, but cannot decode a binary encoded by V3 because it does not have a field number 1 (x).
The decoder of V2 can decode a binary encoded by V3 because it has a default value for field number 1 (x), 100. The decoder of V3 can read V2 but
cannot read V1 (as it does not have field number 3 (z)). As demonstrated, using explicit field numbers also allows reordering the fields without
breaking the format.


```scala

## Example: BinaryCodec

Let's try an example:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package zio.schema.codec

import scala.collection.mutable

import zio.schema.Schema

/**
* A per-encooding/decoding cache for field mappings. No need for thread safety as a single encoding/decoding
* is sequential.
*/
private class FieldMappingCache {
private val mapping: mutable.Map[Schema[_], FieldMapping] = mutable.Map.empty

def get(schema: Schema.Record[_]): FieldMapping =
mapping.getOrElseUpdate(schema, FieldMapping.fromSchema(schema))
}

final case class FieldMapping(indexToFieldNumber: Map[Int, Int], fieldNumberToIndex: Map[Int, Int])

object FieldMapping {

def fromSchema(schema: Schema.Record[_]): FieldMapping = {
val indexToFieldNumber = schema.fields.zipWithIndex.map {
case (field, index) => {
val customFieldNumber = getFieldNumber(field)
index -> customFieldNumber.getOrElse(index + 1)
}
}.toMap
val fieldNumberToIndex = indexToFieldNumber.map(_.swap)
FieldMapping(indexToFieldNumber, fieldNumberToIndex)
}

def getFieldNumber(field: Schema.Field[_, _]): Option[Int] =
field.annotations.collectFirst {
case fieldNumber(n) => n
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import java.util.UUID
import scala.collection.immutable.ListMap
import scala.util.control.NonFatal

import zio.schema.MutableSchemaBasedValueBuilder.CreateValueFromSchemaError
import zio.schema.MutableSchemaBasedValueBuilder.{ CreateValueFromSchemaError, ReadingFieldResult }
import zio.schema._
import zio.schema.annotation.fieldDefaultValue
import zio.schema.codec.DecodeError.{ ExtraFields, MalformedField, MissingField }
import zio.schema.codec.ProtobufCodec.Protobuf.WireType.LengthDelimited
import zio.stream.ZPipeline
Expand Down Expand Up @@ -256,8 +257,10 @@ object ProtobufCodec {
context: EncoderContext,
index: Int,
field: Schema.Field[_, _]
): EncoderContext =
context.copy(fieldNumber = Some(index + 1))
): EncoderContext = {
val fieldNumber = FieldMapping.getFieldNumber(field).getOrElse(index + 1)
context.copy(fieldNumber = Some(fieldNumber))
}

override protected def contextForTuple(context: EncoderContext, index: Int): EncoderContext =
context.copy(fieldNumber = Some(index))
Expand Down Expand Up @@ -492,6 +495,7 @@ object ProtobufCodec {
import Protobuf._

private val state: DecoderState = new DecoderState(chunk, 0)
private val fieldMappingCache = new FieldMappingCache()

def decode[A](schema: Schema[A]): scala.util.Either[DecodeError, A] =
try {
Expand Down Expand Up @@ -595,24 +599,53 @@ object ProtobufCodec {
context: DecoderContext,
record: Schema.Record[_],
index: Int
): Option[(DecoderContext, Int)] =
if (index == record.fields.size) {
None
): ReadingFieldResult[DecoderContext] =
if (state.length(context) <= 0) {
ReadingFieldResult.Finished()
} else {
keyDecoder(context) match {
case (wt, fieldNumber) =>
if (record.fields.isDefinedAt(fieldNumber - 1)) {
Some(wt match {
case LengthDelimited(width) =>
(context.limitedTo(state, width), fieldNumber - 1)
case _ =>
(context, fieldNumber - 1)
})
} else {
throw ExtraFields(
"Unknown",
s"Failed to decode record. Schema does not contain field number $fieldNumber."
)
val fieldMapping = fieldMappingCache.get(record)
fieldMapping.fieldNumberToIndex.get(fieldNumber) match {
case Some(index) => {
if (record.fields.isDefinedAt(index)) {
ReadingFieldResult.ReadField(wt match {
case LengthDelimited(width) =>
context.limitedTo(state, width)
case _ =>
context
}, index)
} else {
throw ExtraFields(
"Unknown",
s"Failed to decode record. Schema does not contain field number $fieldNumber."
)
}
}
case None =>
wt match {
case WireType.VarInt => {
varIntDecoder(context)
ReadingFieldResult.UpdateContext(context)
}
case WireType.Bit64 => {
state.move(8)
ReadingFieldResult.UpdateContext(context)
}
case LengthDelimited(width) => {
state.move(width)
ReadingFieldResult.UpdateContext(context)
}
case WireType.Bit32 => {
state.move(4)
ReadingFieldResult.UpdateContext(context)
}
case _ =>
throw ExtraFields(
"Unknown",
s"Failed to decode record. Schema does not contain field number $fieldNumber and it's length is unknown"
)
}
}
}
}
Expand All @@ -623,9 +656,39 @@ object ProtobufCodec {
values: Chunk[(Int, Any)]
): Any =
Unsafe.unsafe { implicit u =>
record.construct(values.map(_._2)) match {
case Right(result) => result
case Left(message) => throw DecodeError.ReadError(Cause.empty, message)
val array = new Array[Any](record.fields.length)
val mask = Array.fill(record.fields.length)(false)

for ((field, index) <- record.fields.zipWithIndex) {
val defaultValue = field.annotations.collectFirst {
case fieldDefaultValue(defaultValue) => defaultValue
}

defaultValue match {
case Some(defaultValue) =>
array(index) = defaultValue
mask(index) = true
case None =>
}
}

for ((index, value) <- values) {
if (index < array.length) {
array(index) = value
mask(index) = true;
}
}

if (mask.forall(set => set)) {
record.construct(Chunk.fromArray(array)) match {
case Right(result) => result
case Left(message) => throw DecodeError.ReadError(Cause.empty, message)
}
} else {
throw DecodeError.ReadError(
Cause.empty,
s"Failed to decode record. Missing fields: ${mask.zip(record.fields).filter(!_._1).map(_._2).mkString(", ")}"
)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package zio.schema.codec

import scala.annotation.StaticAnnotation

final case class fieldNumber(n: Int) extends StaticAnnotation
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,52 @@ object ProtobufCodecSpec extends ZIOSpecDefault {
assertZIO(encodeAndDecode(MetaSchema.schema, metaSchema))(equalTo(Chunk(metaSchema)))
}
}
),
suite("custom field numbers")(
test("V2 decoder can read encoded V3") {
val recordV3 = RecordV3(y = 200, extra = "hello", z = -10)
for {
encoded <- encode(schemaRecordV3, recordV3)
decoded <- decodeChunkNS(schemaRecordV2, encoded)
} yield assertTrue(decoded == RecordV2(100, 200, -10))
},
test("nested V2 decoder can read encoded V3") {
val v3 = NestedV3(
RecordV3(y = 200, extra = "hello", z = -10),
RecordV3(y = 2, extra = "world", z = 3)
)
for {
encoded <- encode(schemaNestedV3, v3)
decoded <- decodeChunkNS(schemaNestedV2, encoded)
} yield assertTrue(
decoded == NestedV2(
RecordV2(100, 200, -10),
RecordV2(100, 2, 3)
)
)
},
test("V3 decoder cannot read encoded V2") {
val recordV2 = RecordV2(150, 200, -10)
for {
encoded <- encode(schemaRecordV2, recordV2)
decoded <- decodeChunkNS(schemaRecordV3, encoded)
} yield assertTrue(decoded == RecordV3(200, "unknown", -10))
},
test("V1 decoder can read encoded V2") {
val recordV2 = RecordV2(150, 200, -10)
for {
encoded <- encode(schemaRecordV2, recordV2)
decoded <- decodeChunkNS(schemaRecordV1, encoded)
} yield assertTrue(decoded == RecordV1(150, 200))
},
test("V2 decoder cannot read encoded V1") {
// because there is no default value for 'z'
val recordV1 = RecordV1(150, 200)
for {
encoded <- encode(schemaRecordV1, recordV1)
decoded <- decodeChunkNS(schemaRecordV2, encoded).exit
} yield assertTrue(decoded.isFailure)
}
)
)

Expand Down Expand Up @@ -1011,6 +1057,18 @@ object ProtobufCodecSpec extends ZIOSpecDefault {

lazy val schemaChunkOfBytes: Schema[ChunkOfBytes] = DeriveSchema.gen[ChunkOfBytes]

final case class RecordV1(x: Int, y: Int)
final case class RecordV2(x: Int = 100, y: Int, z: Int)
final case class RecordV3(@fieldNumber(2) y: Int, @fieldNumber(4) extra: String = "unknown", @fieldNumber(3) z: Int)
final case class NestedV2(a: RecordV2, b: RecordV2)
final case class NestedV3(a: RecordV3, b: RecordV3)

implicit val schemaRecordV1: Schema[RecordV1] = DeriveSchema.gen[RecordV1]
implicit val schemaRecordV2: Schema[RecordV2] = DeriveSchema.gen[RecordV2]
implicit val schemaRecordV3: Schema[RecordV3] = DeriveSchema.gen[RecordV3]
implicit val schemaNestedV2: Schema[NestedV2] = DeriveSchema.gen[NestedV2]
implicit val schemaNestedV3: Schema[NestedV3] = DeriveSchema.gen[NestedV3]

// "%02X".format doesn't work the same in ScalaJS
def toHex(chunk: Chunk[Byte]): String =
chunk.toArray.map { byte =>
Expand Down Expand Up @@ -1056,12 +1114,25 @@ object ProtobufCodecSpec extends ZIOSpecDefault {
ZStream
.fromChunk(fromHex(hex))
)
.runCollect
.run(ZSink.collectAll)

def decodeChunk[A](schema: Schema[A], bytes: Chunk[Byte]): ZIO[Any, DecodeError, Chunk[A]] =
ProtobufCodec
.protobufCodec(schema)
.streamDecoder
.apply(
ZStream
.fromChunk(bytes)
)
.run(ZSink.collectAll)

//NS == non streaming variant of decode
def decodeNS[A](schema: Schema[A], hex: String): ZIO[Any, DecodeError, A] =
ZIO.succeed(ProtobufCodec.protobufCodec(schema).decode(fromHex(hex))).absolve[DecodeError, A]

def decodeChunkNS[A](schema: Schema[A], bytes: Chunk[Byte]): ZIO[Any, DecodeError, A] =
ZIO.succeed(ProtobufCodec.protobufCodec(schema).decode(bytes)).absolve[DecodeError, A]

def encodeAndDecode[A](schema: Schema[A], input: A): ZIO[Any, DecodeError, Chunk[A]] =
ProtobufCodec
.protobufCodec(schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import scala.util.control.NonFatal

import org.apache.thrift.protocol._

import zio.schema.MutableSchemaBasedValueBuilder.CreateValueFromSchemaError
import zio.schema.MutableSchemaBasedValueBuilder.{ CreateValueFromSchemaError, ReadingFieldResult }
import zio.schema._
import zio.schema.annotation.{ fieldDefaultValue, optionalField, transientField }
import zio.schema.codec.DecodeError.{ EmptyContent, MalformedFieldWithPath, ReadError, ReadErrorWithPath }
Expand Down Expand Up @@ -582,14 +582,14 @@ object ThriftCodec {
context: DecoderContext,
record: Schema.Record[_],
index: Int
): Option[(DecoderContext, Int)] =
): ReadingFieldResult[DecoderContext] =
if (record.fields.nonEmpty) {
val tfield = p.readFieldBegin()
if (tfield.`type` == TType.STOP) None
else Some((context.copy(path = context.path :+ s"fieldId:${tfield.id}"), tfield.id - 1))
if (tfield.`type` == TType.STOP) ReadingFieldResult.Finished()
else ReadingFieldResult.ReadField(context.copy(path = context.path :+ s"fieldId:${tfield.id}"), tfield.id - 1)
} else {
val _ = p.readByte()
None
ReadingFieldResult.Finished()
}

override protected def createRecord(
Expand Down
Loading

0 comments on commit be43667

Please sign in to comment.