diff --git a/modules/core/src/main/scala/vulcan/Codec.scala b/modules/core/src/main/scala/vulcan/Codec.scala index 1cabf3e1..73a90836 100644 --- a/modules/core/src/main/scala/vulcan/Codec.scala +++ b/modules/core/src/main/scala/vulcan/Codec.scala @@ -1092,60 +1092,68 @@ object Codec extends CodecCompanionCompat { }, (value, schema) => { val schemaTypes = - schema.getType() match { + schema.getType match { case UNION => schema.getTypes.asScala case _ => Seq(schema) } + def decodeNamedContainerType(container: GenericContainer) = { + val altName = + container.getSchema.getName + + val altWriterSchema = + schemaTypes + .find(_.getName == altName) + .toRight(AvroError.decodeMissingUnionSchema(altName)) + + def altMatching = + alts + .find(_.codec.schema.exists { schema => + schema.getType match { + case RECORD | FIXED | ENUM => + schema.getName == altName || schema.getAliases.asScala + .exists(alias => alias == altName || alias.endsWith(s".$altName")) + case _ => false + } + }) + .toRight(AvroError.decodeMissingUnionAlternative(altName)) + + altWriterSchema.flatMap { altSchema => + altMatching.flatMap { alt => + alt.codec + .decode(container, altSchema) + .map(alt.prism.reverseGet) + } + } + } + + def decodeUnnamedType(other: Any) = + alts + .collectFirstSome { alt => + alt.codec.schema + .traverse { altSchema => + val altName = altSchema.getName + schemaTypes + .find(_.getName == altName) + .flatMap { schema => + alt.codec + .decode(other, schema) + .map(alt.prism.reverseGet) + .toOption + } + } + } + .getOrElse { + Left(AvroError.decodeExhaustedAlternatives(other)) + } + value match { case container: GenericContainer => - val altName = - container.getSchema.getName - - val altWriterSchema = - schemaTypes - .find(_.getName == altName) - .toRight(AvroError.decodeMissingUnionSchema(altName)) - - def altMatching = - alts - .find(_.codec.schema.exists { schema => - schema.getType match { - case RECORD | FIXED | ENUM => - schema.getName == altName || schema.getAliases.asScala - .exists(alias => alias == altName || alias.endsWith(s".$altName")) - case _ => false - } - }) - .toRight(AvroError.decodeMissingUnionAlternative(altName)) - - altWriterSchema.flatMap { altSchema => - altMatching.flatMap { alt => - alt.codec - .decode(container, altSchema) - .map(alt.prism.reverseGet) - } + container.getSchema.getType match { + case RECORD | FIXED | ENUM => decodeNamedContainerType(container) + case _ => decodeUnnamedType(container) } - - case other => - alts - .collectFirstSome { alt => - alt.codec.schema - .traverse { altSchema => - val altName = altSchema.getName - schemaTypes - .find(_.getName == altName) - .flatMap { schema => - alt.codec - .decode(other, schema) - .map(alt.prism.reverseGet) - .toOption - } - } - } - .getOrElse { - Left(AvroError.decodeExhaustedAlternatives(other)) - } + case other => decodeUnnamedType(other) } } ) diff --git a/modules/core/src/test/scala/vulcan/CodecSpec.scala b/modules/core/src/test/scala/vulcan/CodecSpec.scala index bdd80654..52fade22 100644 --- a/modules/core/src/test/scala/vulcan/CodecSpec.scala +++ b/modules/core/src/test/scala/vulcan/CodecSpec.scala @@ -2777,7 +2777,7 @@ final class CodecSpec extends BaseSpec with CodecSpecHelpers { describe("schema") { it("should encode as union") { assertSchemaIs[SealedTraitCaseClass] { - """[{"type":"record","name":"FirstInSealedTraitCaseClass","namespace":"com.example","fields":[{"name":"value","type":"int"}]},{"type":"record","name":"SecondInSealedTraitCaseClass","namespace":"com.example","fields":[{"name":"value","type":"string"}]},"int"]""" + """[{"type":"record","name":"FirstInSealedTraitCaseClass","namespace":"com.example","fields":[{"name":"value","type":"int"}]},{"type":"record","name":"SecondInSealedTraitCaseClass","namespace":"com.example","fields":[{"name":"value","type":"string"}]},{"type":"array","items":"int"}]""" } } diff --git a/modules/core/src/test/scala/vulcan/examples/SealedTraitCaseClass.scala b/modules/core/src/test/scala/vulcan/examples/SealedTraitCaseClass.scala index 04bf99eb..cd029565 100644 --- a/modules/core/src/test/scala/vulcan/examples/SealedTraitCaseClass.scala +++ b/modules/core/src/test/scala/vulcan/examples/SealedTraitCaseClass.scala @@ -26,7 +26,7 @@ object SealedTraitCaseClass { Gen.oneOf[SealedTraitCaseClass]( arbitrary[Int].map(FirstInSealedTraitCaseClass(_)), arbitrary[String].map(SecondInSealedTraitCaseClass(_)), - arbitrary[Int].map(ThirdInSealedTraitCaseClass(_)) + arbitrary[List[Int]].map(ThirdInSealedTraitCaseClass(_)) ) } } @@ -61,12 +61,12 @@ object SecondInSealedTraitCaseClass { Arbitrary(arbitrary[String].map(apply)) } -final case class ThirdInSealedTraitCaseClass(value: Int) extends SealedTraitCaseClass +final case class ThirdInSealedTraitCaseClass(value: List[Int]) extends SealedTraitCaseClass object ThirdInSealedTraitCaseClass { implicit val codec: Codec[ThirdInSealedTraitCaseClass] = - Codec[Int].imap(apply)(_.value) + Codec[List[Int]].imap(apply)(_.value) implicit val arb: Arbitrary[ThirdInSealedTraitCaseClass] = - Arbitrary(arbitrary[Int].map(apply)) + Arbitrary(arbitrary[List[Int]].map(apply)) }