Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed enum support in avro codec #578

Merged
merged 13 commits into from
Aug 22, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,31 @@ object AvroCodec {
private def decodeCaseClass1[A, Z](raw: Any, schema: Schema.CaseClass1[A, Z]) =
decodeValue(raw, schema.field.schema).map(schema.defaultConstruct)

private def decodeEnum[Z](raw: Any, cases: Schema.Case[Z, _]*): Either[DecodeError, Any] = {
val generic = raw.asInstanceOf[GenericData.Record]
val enumCaseName = generic.getSchema.getFullName
val enumCaseValue = generic.get("value")
private def decodeEnum[Z](raw: Any, cases: Schema.Case[Z, _]*): Either[DecodeError, Any] =
raw match {
case enums: GenericData.EnumSymbol =>
decodeGenericEnum(enums.toString, None, cases: _*)
case gr: GenericData.Record =>
val enumCaseName = gr.getSchema.getFullName
if (gr.hasField("value")) {
val enumCaseValue = gr.get("value")
decodeGenericEnum[Z](enumCaseName, Some(enumCaseValue), cases: _*)
} else {
decodeGenericEnum[Z](enumCaseName, None, cases: _*)
}
case _ => Left(DecodeError.MalformedFieldWithPath(Chunk.single("Error"), s"Unknown enum: $raw"))
}

private def decodeGenericEnum[Z](
enumCaseName: String,
enumCaseValue: Option[AnyRef],
cases: Schema.Case[Z, _]*
): Either[DecodeError, Any] =
cases
.find(_.id == enumCaseName)
.map(s => decodeValue(enumCaseValue, s.schema))
.map(s => decodeValue(enumCaseValue.getOrElse(s), s.schema))
.toRight(DecodeError.MalformedFieldWithPath(Chunk.single("Error"), s"Unknown enum value: $enumCaseName"))
.flatMap(identity)
}

private def decodeRecord[A](value: A, schema: Schema.Record[_]) = {
val record = value.asInstanceOf[GenericRecord]
Expand Down Expand Up @@ -454,41 +469,41 @@ object AvroCodec {
else decodeValue(value, schema).map(Some(_))

private def encodeValue[A](a: A, schema: Schema[A]): Any = schema match {
case Schema.Enum1(_, c1, _) => encodeEnum(a, c1)
case Schema.Enum2(_, c1, c2, _) => encodeEnum(a, c1, c2)
case Schema.Enum3(_, c1, c2, c3, _) => encodeEnum(a, c1, c2, c3)
case Schema.Enum4(_, c1, c2, c3, c4, _) => encodeEnum(a, c1, c2, c3, c4)
case Schema.Enum5(_, c1, c2, c3, c4, c5, _) => encodeEnum(a, c1, c2, c3, c4, c5)
case Schema.Enum6(_, c1, c2, c3, c4, c5, c6, _) => encodeEnum(a, c1, c2, c3, c4, c5, c6)
case Schema.Enum1(_, c1, _) => encodeEnum(schema, a, c1)
case Schema.Enum2(_, c1, c2, _) => encodeEnum(schema, a, c1, c2)
case Schema.Enum3(_, c1, c2, c3, _) => encodeEnum(schema, a, c1, c2, c3)
case Schema.Enum4(_, c1, c2, c3, c4, _) => encodeEnum(schema, a, c1, c2, c3, c4)
case Schema.Enum5(_, c1, c2, c3, c4, c5, _) => encodeEnum(schema, a, c1, c2, c3, c4, c5)
case Schema.Enum6(_, c1, c2, c3, c4, c5, c6, _) => encodeEnum(schema, a, c1, c2, c3, c4, c5, c6)
case Schema.Enum7(_, c1, c2, c3, c4, c5, c6, c7, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7)
case Schema.Enum8(_, c1, c2, c3, c4, c5, c6, c7, c8, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8)
case Schema.Enum9(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9)
case Schema.Enum10(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10)
case Schema.Enum11(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11)
case Schema.Enum12(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12)
case Schema.Enum13(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13)
case Schema.Enum14(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14)
case Schema.Enum15(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15)
case Schema.Enum16(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16)
case Schema.Enum17(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17)
case Schema.Enum18(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18)
case Schema.Enum19(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19)
case Schema
.Enum20(_, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, _) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20)
encodeEnum(schema, a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20)
case Schema.Enum21(
_,
c1,
Expand All @@ -514,7 +529,31 @@ object AvroCodec {
c21,
_
) =>
encodeEnum(a, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15, c16, c17, c18, c19, c20, c21)
encodeEnum(
schema,
a,
c1,
c2,
c3,
c4,
c5,
c6,
c7,
c8,
c9,
c10,
c11,
c12,
c13,
c14,
c15,
c16,
c17,
c18,
c19,
c20,
c21
)
case Schema.Enum22(
_,
c1,
Expand Down Expand Up @@ -542,6 +581,7 @@ object AvroCodec {
_
) =>
encodeEnum(
schema,
a,
c1,
c2,
Expand Down Expand Up @@ -580,9 +620,10 @@ object AvroCodec {
case Schema.Optional(schema, _) => encodeOption(schema, a)
case Schema.Tuple2(left, right, _) =>
encodeTuple2(left.asInstanceOf[Schema[Any]], right.asInstanceOf[Schema[Any]], a)
case Schema.Either(left, right, _) => encodeEither(left, right, a)
case Schema.Lazy(schema0) => encodeValue(a, schema0())
case Schema.CaseClass0(_, _, _) => encodePrimitive((), StandardType.UnitType)
case Schema.Either(left, right, _) => encodeEither(left, right, a)
case Schema.Lazy(schema0) => encodeValue(a, schema0())
case Schema.CaseClass0(_, _, _) =>
encodeCaseClass(schema, a, Seq.empty: _*) //encodePrimitive((), StandardType.UnitType)
case Schema.CaseClass1(_, f, _, _) => encodeCaseClass(schema, a, f)
case Schema.CaseClass2(_, f0, f1, _, _) => encodeCaseClass(schema, a, f0, f1)
case Schema.CaseClass3(_, f0, f1, f2, _, _) => encodeCaseClass(schema, a, f0, f1, f2)
Expand Down Expand Up @@ -926,11 +967,20 @@ object AvroCodec {
record
}

private def encodeEnum[Z](value: Z, cases: Schema.Case[Z, _]*): Any = {
private def encodeEnum[Z](schemaRaw: Schema[Z], value: Z, cases: Schema.Case[Z, _]*): Any = {
val schema = AvroSchemaCodec
.encodeToApacheAvro(schemaRaw)
.getOrElse(throw new Exception("Avro schema could not be generated for Enum."))
val fieldIndex = cases.indexWhere(c => c.deconstructOption(value).isDefined)
if (fieldIndex >= 0) {
val subtypeCase = cases(fieldIndex)
encodeValue(subtypeCase.deconstruct(value), subtypeCase.schema.asInstanceOf[Schema[Any]])
if (schema.getType == SchemaAvro.Type.ENUM) {
GenericData.get.createEnum(schema.getEnumSymbols.get(fieldIndex), schema)
} else {

encodeValue(subtypeCase.deconstruct(value), subtypeCase.schema.asInstanceOf[Schema[Any]])

}
} else {
throw new Exception("Could not find matching case for enum value.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,8 @@ object AvroSchemaCodec extends AvroSchemaCodec {
}

def hasAvroEnumAnnotation(annotations: Chunk[Any]): Boolean = annotations.exists {
case AvroAnnotations.avroEnum => true
case _ => false
case AvroAnnotations.avroEnum() => true
case _ => false
}

def wrapAvro(schemaAvro: SchemaAvro, name: String, marker: AvroPropMarker): SchemaAvro = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import java.time.{
import java.util.UUID

import zio._
import zio.schema.codec.AvroAnnotations.avroEnum
import zio.schema.{ DeriveSchema, Schema }
import zio.stream.ZStream
import zio.test.TestAspect.failing
import zio.test._

object AvroCodecSpec extends ZIOSpecDefault {
Expand Down Expand Up @@ -106,10 +106,12 @@ object AvroCodecSpec extends ZIOSpecDefault {

case class BooleanValue(value: Boolean) extends OneOf

case object NullValue extends OneOf

implicit val schemaOneOf: Schema[OneOf] = DeriveSchema.gen[OneOf]
}

sealed trait Enums
@avroEnum() sealed trait Enums

object Enums {
case object A extends Enums
Expand Down Expand Up @@ -649,12 +651,18 @@ object AvroCodecSpec extends ZIOSpecDefault {
val result = codec.decode(bytes)
assertTrue(result == Right(OneOf.BooleanValue(true)))
},
test("Decode Enum3 - case object") {
val codec = AvroCodec.schemaBasedBinaryCodec[OneOf]
val bytes = codec.encode(OneOf.NullValue)
val result = codec.decode(bytes)
assertTrue(result == Right(OneOf.NullValue))
},
test("Decode Enum5") {
val codec = AvroCodec.schemaBasedBinaryCodec[Enums]
val bytes = codec.encode(Enums.A)
val result = codec.decode(bytes)
assertTrue(result == Right(Enums.A))
} @@ failing, // TODO: the case object from a sealed trait are not properly encoded and decoded.
},
test("Decode Person") {
val codec = AvroCodec.schemaBasedBinaryCodec[Person]
val bytes = codec.encode(Person("John", 42))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ object AvroSchemaCodecSpec extends ZIOSpecDefault {
},
test("encodes sealed trait objects only as enum when avroEnum annotation is present") {

val schema = DeriveSchema.gen[SpecTestData.CaseObjectsOnlyAdt].annotate(AvroAnnotations.avroEnum)
val schema = DeriveSchema.gen[SpecTestData.CaseObjectsOnlyAdt].annotate(AvroAnnotations.avroEnum())
val result = AvroSchemaCodec.encode(schema)

val expected = """{"type":"enum","name":"MyEnum","symbols":["A","B","MyC"]}"""
Expand Down
Loading