Skip to content

Commit

Permalink
add schema fallback (#645)
Browse files Browse the repository at this point in the history
  • Loading branch information
pablf authored Jan 20, 2024
1 parent 745abda commit 6a0e914
Show file tree
Hide file tree
Showing 27 changed files with 1,245 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ object DynamicValueGen {
case Schema.Tuple2(left, right, _) => anyDynamicTupleValue(left, right)
case Schema.Either(left, right, _) =>
Gen.oneOf(anyDynamicLeftValueOfSchema(left), anyDynamicRightValueOfSchema(right))
case Schema.Fallback(left, right, _, _) =>
Gen.oneOf(anyDynamicLeftValueOfSchema(left), anyDynamicRightValueOfSchema(right), anyDynamicBothValueOfSchema(left, right))
case Schema.Transform(schema, _, _, _, _) => anyDynamicValueOfSchema(schema)
case Schema.Fail(message, _) => Gen.const(DynamicValue.Error(message))
case l @ Schema.Lazy(_) => anyDynamicValueOfSchema(l.schema)
Expand All @@ -92,6 +94,11 @@ object DynamicValueGen {
def anyDynamicRightValueOfSchema[A](schema: Schema[A]): Gen[Sized, DynamicValue.RightValue] =
anyDynamicValueOfSchema(schema).map(DynamicValue.RightValue(_))

def anyDynamicBothValueOfSchema[A, B](left: Schema[A], right: Schema[A]): Gen[Sized, DynamicValue.BothValue] =
anyDynamicValueOfSchema(left).zip(anyDynamicValueOfSchema(right)).map {
case (l, r) => DynamicValue.BothValue(l, r)
}

def anyDynamicSomeValueOfSchema[A](schema: Schema[A]): Gen[Sized, DynamicValue.SomeValue] =
anyDynamicValueOfSchema(schema).map(DynamicValue.SomeValue(_))

Expand Down
27 changes: 27 additions & 0 deletions tests/shared/src/test/scala-2/zio/schema/SchemaGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,33 @@ object SchemaGen {
value <- gen
} yield (schema, value)

def anyFallback(fullDecode: Boolean): Gen[Sized, Schema.Fallback[_, _]] =
for {
left <- anyPrimitive
right <- anyPrimitive
} yield Schema.Fallback(left, right, fullDecode)

type FallbackAndGen[A, B] = (Schema.Fallback[A, B], Gen[Sized, Fallback[A, B]])

def anyFallbackAndGen(fullDecode: Boolean): Gen[Sized, FallbackAndGen[_, _]] =
for {
(leftSchema, leftGen) <- anyPrimitiveAndGen
(rightSchema, rightGen) <- anyPrimitiveAndGen
} yield (
Schema.Fallback(leftSchema, rightSchema, fullDecode),
Gen.oneOf(leftGen.map(Fallback.Left(_)), rightGen.map(Fallback.Right(_)), leftGen.zip(rightGen).map {
case (l, r) => Fallback.Both(l, r)
})
)

type FallbackAndValue[A, B] = (Schema.Fallback[A, B], Fallback[A, B])

def anyFallbackAndValue(fullDecode: Boolean): Gen[Sized, FallbackAndValue[_, _]] =
for {
(schema, gen) <- anyFallbackAndGen(fullDecode)
value <- gen
} yield (schema, value)

lazy val anyTuple: Gen[Sized, Schema.Tuple2[_, _]] =
anySchema.zipWith(anySchema) { (a, b) =>
Schema.Tuple2(a, b)
Expand Down
57 changes: 54 additions & 3 deletions zio-schema-avro/src/main/scala/zio/schema/codec/AvroCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import org.apache.avro.io.{ DecoderFactory, EncoderFactory }
import org.apache.avro.util.Utf8
import org.apache.avro.{ Conversions, LogicalTypes, Schema => SchemaAvro }

import zio.schema.{ FieldSet, Schema, StandardType, TypeId }
import zio.schema.{ Fallback, FieldSet, Schema, StandardType, TypeId }
import zio.stream.ZPipeline
import zio.{ Chunk, Unsafe, ZIO }

Expand Down Expand Up @@ -213,6 +213,7 @@ object AvroCodec {
case Schema.Fail(message, _) => Left(DecodeError.MalformedFieldWithPath(Chunk.empty, message))
case Schema.Tuple2(left, right, _) => decodeTuple2(raw, left, right).map(_.asInstanceOf[A])
case Schema.Either(left, right, _) => decodeEitherValue(raw, left, right)
case s @ Schema.Fallback(_, _, _, _) => decodeFallbackValue(raw, s)
case lzy @ Schema.Lazy(_) => decodeValue(raw, lzy.schema)
case unknown => Left(DecodeError.MalformedFieldWithPath(Chunk.empty, s"Unknown schema: $unknown"))
}
Expand Down Expand Up @@ -468,13 +469,42 @@ object AvroCodec {
val result2 = decodeValue(record.get("_2"), schemaRight)
result1.flatMap(a => result2.map(b => (a, b)))
}

private def decodeEitherValue[A, B](value: Any, schemaLeft: Schema[A], schemaRight: Schema[B]) = {
val record = value.asInstanceOf[GenericRecord]
val result = decodeValue(record.get("value"), schemaLeft)
if (result.isRight) result.map(Left(_))
else decodeValue(record.get("value"), schemaRight).map(Right(_))
}

private def decodeFallbackValue[A, B](value: Any, schema: Schema.Fallback[A, B]) = {
var error: Option[DecodeError] = None

val record = value.asInstanceOf[GenericRecord]
val left: Option[A] = decodeValue(record.get("_1"), Schema.Optional(schema.left)) match {
case Right(value) => value
case Left(err) => {
error = Some(err)
None
}
}

val right = left match {
case Some(_) =>
if (schema.fullDecode) decodeValue(record.get("_2"), Schema.Optional(schema.right)).getOrElse(None)
else None
case _ =>
decodeValue(record.get("_2"), Schema.Optional(schema.right)).getOrElse(None)
}

(left, right) match {
case (Some(a), Some(b)) => Right(Fallback.Both(a, b))
case (_, Some(b)) => Right(Fallback.Right(b))
case (Some(a), _) => Right(Fallback.Left(a))
case _ => Left(error.get)
}
}

private def decodeOptionalValue[A](value: Any, schema: Schema[A]) =
if (value == null) Right(None)
else decodeValue(value, schema).map(Some(_))
Expand Down Expand Up @@ -631,8 +661,9 @@ object AvroCodec {
case Schema.Optional(schema, _) => encodeOption(schema, a)
case Schema.Tuple2(left, right, _) =>
encodeTuple2(left.asInstanceOf[Schema[Any]], right.asInstanceOf[Schema[Any]], a)
case Schema.Either(left, right, _) => encodeEither(left, right, a)
case Schema.Lazy(schema0) => encodeValue(a, schema0())
case Schema.Either(left, right, _) => encodeEither(left, right, a)
case s @ Schema.Fallback(_, _, _, _) => encodeFallback(s, a)
case Schema.Lazy(schema0) => encodeValue(a, schema0())
case Schema.CaseClass0(_, _, _) =>
encodeCaseClass(schema, a, Seq.empty: _*) //encodePrimitive((), StandardType.UnitType)
case Schema.CaseClass1(_, f, _, _) => encodeCaseClass(schema, a, f)
Expand Down Expand Up @@ -942,6 +973,26 @@ object AvroCodec {
result.build()
}

private def encodeFallback[A, B](s: Schema.Fallback[A, B], f: zio.schema.Fallback[A, B]): Any = {
val schema = AvroSchemaCodec
.encodeToApacheAvro(s)
.getOrElse(throw new Exception("Avro schema could not be generated for Fallback."))

val value: (Option[A], Option[B]) = f match {
case zio.schema.Fallback.Left(a) => (Some(a), None)
case zio.schema.Fallback.Right(b) => (None, Some(b))
case zio.schema.Fallback.Both(a, b) => (Some(a), Some(b))
}

val left = encodeOption[A](s.left, value._1)
val right = encodeOption[B](s.right, value._2)

val record = new GenericData.Record(schema)
record.put("_1", left)
record.put("_2", right)
record
}

private def encodeTuple2[A](schema1: Schema[Any], schema2: Schema[Any], a: A) = {
val schema = AvroSchemaCodec
.encodeToApacheAvro(Schema.Tuple2(schema1, schema2, Chunk.empty))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ object AvroSchemaCodec extends AvroSchemaCodec {
name <- getName(e)
} yield wrapAvro(union, name, EitherWrapper)

case Schema.Fallback(left, right, _, _) =>
toAvroSchema(Schema.Tuple2(Schema.Optional(left), Schema.Optional(right)))

case Lazy(schema0) => toAvroSchema(schema0())
case Dynamic(_) => toAvroSchema(Schema[MetaSchema])
}
Expand Down Expand Up @@ -784,6 +787,7 @@ object AvroSchemaCodec extends AvroSchemaCodec {
case _ => Left("ZIO schema wrapped either must have exactly two cases")
}
case e: Schema.Either[_, _] => Right(e)
case f: Schema.Fallback[_, _] => Right(f)
case c: CaseClass0[_] => Right(c)
case c: CaseClass1[_, _] => Right(c)
case c: CaseClass2[_, _, _] => Right(c)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.avro.generic.GenericData

import zio._
import zio.schema.codec.AvroAnnotations.avroEnum
import zio.schema.{ DeriveSchema, Schema }
import zio.schema.{ DeriveSchema, Fallback, Schema }
import zio.stream.ZStream
import zio.test._

Expand Down Expand Up @@ -131,13 +131,15 @@ object AvroCodecSpec extends ZIOSpecDefault {
collectionsEncoderSpec,
optionEncoderSpec,
eitherEncoderSpec,
fallbackEncoderSpec,
tupleEncoderSpec,
genericRecordEncoderSpec,
caseClassEncoderSpec,
enumEncoderSpec,
primitiveDecoderSpec,
optionDecoderSpec,
eitherDecoderSpec,
fallbackDecoderSpec,
tupleDecoderSpec,
sequenceDecoderSpec,
genericRecordDecoderSpec,
Expand Down Expand Up @@ -308,6 +310,24 @@ object AvroCodecSpec extends ZIOSpecDefault {
}
)

private val fallbackEncoderSpec = suite("Avro Codec - Encoder Fallback spec")(
test("Encode Fallback.Right") {
val codec = AvroCodec.schemaBasedBinaryCodec[Fallback[Int, String]]
val bytes = codec.encode(Fallback.Right("John"))
assertTrue(bytes.length == 7)
},
test("Encode Fallback.Left") {
val codec = AvroCodec.schemaBasedBinaryCodec[Fallback[Int, String]]
val bytes = codec.encode(Fallback.Left(42))
assertTrue(bytes.length == 3)
},
test("Encode Fallback.Both") {
val codec = AvroCodec.schemaBasedBinaryCodec[Fallback[Int, String]]
val bytes = codec.encode(Fallback.Both(42, "John"))
assertTrue(bytes.length == 8)
}
)

private val tupleEncoderSpec = suite("Avro Codec - Encode Tuples spec")(
test("Encode Tuple2[Int, String]") {
val codec = AvroCodec.schemaBasedBinaryCodec[(Int, String)]
Expand Down Expand Up @@ -590,6 +610,33 @@ object AvroCodecSpec extends ZIOSpecDefault {
}
)

private val fallbackDecoderSpec = suite("Avro Codec - Fallback Decoder spec")(
test("Decode Fallback") {
val codec = AvroCodec.schemaBasedBinaryCodec[Fallback[String, Int]]
val bytes = codec.encode(Fallback.Right(42))
val result = codec.decode(bytes)
assertTrue(result == Right(Fallback.Right(42)))
},
test("Decode Fallback[List[String], Int]") {
val codec = AvroCodec.schemaBasedBinaryCodec[Fallback[List[String], Int]]
val bytes = codec.encode(Fallback.Left(List("John", "Adam", "Daniel")))
val result = codec.decode(bytes)
assertTrue(result == Right(Fallback.Left(List("John", "Adam", "Daniel"))))
},
test("Decode Fallback.Both full decode") {
val codec = AvroCodec.schemaBasedBinaryCodec(Schema.Fallback[String, Int](Schema[String], Schema[Int], true))
val bytes = codec.encode(Fallback.Both("hello", 42))
val result = codec.decode(bytes)
assertTrue(result == Right(Fallback.Both("hello", 42)))
},
test("Decode Fallback.Both non full decode") {
val codec = AvroCodec.schemaBasedBinaryCodec[Fallback[String, Int]]
val bytes = codec.encode(Fallback.Both("hello", 42))
val result = codec.decode(bytes)
assertTrue(result == Right(Fallback.Left("hello")))
}
)

private val tupleDecoderSpec = suite("Avro Codec - Tuple Decoder Spec")(
test("Decode Tuple2") {
val codec = AvroCodec.schemaBasedBinaryCodec[(Int, String)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import zio.schema.annotation.{
transientCase,
transientField
}
import zio.schema.{ DynamicValue, Schema, StandardType, TypeId }
import zio.schema.{ DynamicValue, Fallback, Schema, StandardType, TypeId }
import zio.{ Chunk, ChunkBuilder, Unsafe }

object BsonSchemaCodec {
Expand Down Expand Up @@ -312,6 +312,78 @@ object BsonSchemaCodec {
}
}

protected[codec] def fallbackEncoder[A: BsonEncoder, B: BsonEncoder]: BsonEncoder[Fallback[A, B]] =
new BsonEncoder[Fallback[A, B]] {
override def encode(writer: BsonWriter, value: Fallback[A, B], ctx: BsonEncoder.EncoderContext): Unit = {
val nextCtx = BsonEncoder.EncoderContext.default

if (!ctx.inlineNextObject) writer.writeStartDocument()

value match {
case Fallback.Left(value) =>
BsonEncoder[A].encode(writer, value, nextCtx)
case Fallback.Right(value) =>
BsonEncoder[B].encode(writer, value, nextCtx)
case Fallback.Both(left, right) =>
writer.writeStartArray()
BsonEncoder[A].encode(writer, left, nextCtx)
BsonEncoder[B].encode(writer, right, nextCtx)
writer.writeEndArray()
}

if (!ctx.inlineNextObject) writer.writeEndDocument()
}

override def toBsonValue(value: Fallback[A, B]): BsonValue = value match {
case Fallback.Left(value) => array(value.toBsonValue)
case Fallback.Right(value) => array(value.toBsonValue)
case Fallback.Both(left, right) => array(left.toBsonValue, right.toBsonValue)
}
}

protected[codec] def fallbackDecoder[A: BsonDecoder, B: BsonDecoder]: BsonDecoder[Fallback[A, B]] =
new BsonDecoder[Fallback[A, B]] {

override def decodeUnsafe(
reader: BsonReader,
trace: List[BsonTrace],
ctx: BsonDecoder.BsonDecoderContext
): Fallback[A, B] = unsafeCall(trace) {
val nextCtx = BsonDecoder.BsonDecoderContext.default

try {
Fallback.Left(BsonDecoder[A].decodeUnsafe(reader, trace, nextCtx))
} catch {
case _: BsonDecoder.Error =>
try {
Fallback.Right(BsonDecoder[B].decodeUnsafe(reader, trace, nextCtx))
} catch {
case _: BsonDecoder.Error => throw BsonDecoder.Error(trace, "Both `left` and `right` cases missing.")
}
}
}

override def fromBsonValueUnsafe(
value: BsonValue,
trace: List[BsonTrace],
ctx: BsonDecoder.BsonDecoderContext
): Fallback[A, B] =
assumeType(trace)(BsonType.DOCUMENT, value) { value =>
val nextCtx = BsonDecoder.BsonDecoderContext.default

try {
Fallback.Left(BsonDecoder[A].fromBsonValueUnsafe(value, trace, nextCtx))
} catch {
case _: BsonDecoder.Error =>
try {
Fallback.Right(BsonDecoder[B].fromBsonValueUnsafe(value, trace, nextCtx))
} catch {
case _: BsonDecoder.Error => throw BsonDecoder.Error(trace, "Both `left` and `right` cases missing.")
}
}
}
}

protected[codec] def failDecoder[A](message: String): BsonDecoder[A] =
new BsonDecoder[A] {
override def decodeUnsafe(reader: BsonReader, trace: List[BsonTrace], ctx: BsonDecoder.BsonDecoderContext): A =
Expand Down Expand Up @@ -394,6 +466,7 @@ object BsonSchemaCodec {
case Schema.Fail(_, _) => unitEncoder.contramap(_ => ())
case Schema.GenericRecord(_, structure, _) => genericRecordEncoder(structure.toChunk)
case Schema.Either(left, right, _) => eitherEncoder(schemaEncoder(left), schemaEncoder(right))
case Schema.Fallback(left, right, _, _) => fallbackEncoder(schemaEncoder(left), schemaEncoder(right))
case l @ Schema.Lazy(_) => schemaEncoder(l.schema)
case r: Schema.Record[A] => caseClassEncoder(r)
case e: Schema.Enum[A] => enumEncoder(e, e.cases)
Expand Down Expand Up @@ -464,6 +537,8 @@ object BsonSchemaCodec {
throw new Exception(s"DynamicValue.LeftValue is not supported in directDynamicMapping mode")
case DynamicValue.RightValue(_) =>
throw new Exception(s"DynamicValue.RightValue is not supported in directDynamicMapping mode")
case DynamicValue.BothValue(_, _) =>
throw new Exception(s"DynamicValue.BothValue is not supported in directDynamicMapping mode")
case DynamicValue.DynamicAst(_) =>
throw new Exception(s"DynamicValue.DynamicAst is not supported in directDynamicMapping mode")
case DynamicValue.Error(message) =>
Expand Down Expand Up @@ -496,6 +571,8 @@ object BsonSchemaCodec {
throw new Exception(s"DynamicValue.LeftValue is not supported in directDynamicMapping mode")
case DynamicValue.RightValue(_) =>
throw new Exception(s"DynamicValue.RightValue is not supported in directDynamicMapping mode")
case DynamicValue.BothValue(_, _) =>
throw new Exception(s"DynamicValue.BothValue is not supported in directDynamicMapping mode")
case DynamicValue.DynamicAst(_) =>
throw new Exception(s"DynamicValue.DynamicAst is not supported in directDynamicMapping mode")
case DynamicValue.Error(message) =>
Expand Down Expand Up @@ -696,6 +773,7 @@ object BsonSchemaCodec {
case Schema.Fail(message, _) => failDecoder(message)
case Schema.GenericRecord(_, structure, _) => recordDecoder(structure.toChunk)
case Schema.Either(left, right, _) => eitherDecoder(schemaDecoder(left), schemaDecoder(right))
case Schema.Fallback(left, right, _, _) => fallbackDecoder(schemaDecoder(left), schemaDecoder(right))
case l @ Schema.Lazy(_) => schemaDecoder(l.schema)
case s: Schema.Record[A] => caseClassDecoder(s)
case e: Schema.Enum[A] => enumDecoder(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private[schema] object CachedDeriver {
final case class WithIdentityObject[A](inner: CacheKey[_], id: Any) extends CacheKey[A]
final case class Optional[A](key: CacheKey[A]) extends CacheKey[A]
final case class Either[A, B](leftKey: CacheKey[A], rightKey: CacheKey[B]) extends CacheKey[Either[A, B]]
final case class Fallback[A, B](leftKey: CacheKey[A], rightKey: CacheKey[B]) extends CacheKey[Fallback[A, B]]
final case class Tuple2[A, B](leftKey: CacheKey[A], rightKey: CacheKey[B]) extends CacheKey[(A, B)]
final case class Set[A](element: CacheKey[A]) extends CacheKey[Set[A]]
final case class Map[K, V](key: CacheKey[K], valuew: CacheKey[V]) extends CacheKey[Map[K, V]]
Expand All @@ -144,6 +145,8 @@ private[schema] object CachedDeriver {
Tuple2(fromSchema(tuple.left), fromSchema(tuple.right)).asInstanceOf[CacheKey[A]]
case either: Schema.Either[_, _] =>
Either(fromSchema(either.leftSchema), fromSchema(either.rightSchema)).asInstanceOf[CacheKey[A]]
case fallback: Schema.Fallback[_, _] =>
Fallback(fromSchema(fallback.left), fromSchema(fallback.right)).asInstanceOf[CacheKey[A]]
case Schema.Lazy(schema0) => fromSchema(schema0())
case Schema.Dynamic(_) => Misc(schema)
case Schema.Fail(_, _) => Misc(schema)
Expand Down
Loading

0 comments on commit 6a0e914

Please sign in to comment.