Skip to content

Commit

Permalink
Merge pull request #352 from fd4s/union-regression
Browse files Browse the repository at this point in the history
Fix #351
  • Loading branch information
bplommer authored May 18, 2021
2 parents 172b1c5 + 8c335c9 commit 94b979a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 52 deletions.
102 changes: 55 additions & 47 deletions modules/core/src/main/scala/vulcan/Codec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1092,60 +1092,68 @@ object Codec extends CodecCompanionCompat {
},
(value, schema) => {
val schemaTypes =
schema.getType() match {
schema.getType match {
case UNION => schema.getTypes.asScala
case _ => Seq(schema)
}

def decodeNamedContainerType(container: GenericContainer) = {
val altName =
container.getSchema.getName

val altWriterSchema =
schemaTypes
.find(_.getName == altName)
.toRight(AvroError.decodeMissingUnionSchema(altName))

def altMatching =
alts
.find(_.codec.schema.exists { schema =>
schema.getType match {
case RECORD | FIXED | ENUM =>
schema.getName == altName || schema.getAliases.asScala
.exists(alias => alias == altName || alias.endsWith(s".$altName"))
case _ => false
}
})
.toRight(AvroError.decodeMissingUnionAlternative(altName))

altWriterSchema.flatMap { altSchema =>
altMatching.flatMap { alt =>
alt.codec
.decode(container, altSchema)
.map(alt.prism.reverseGet)
}
}
}

def decodeUnnamedType(other: Any) =
alts
.collectFirstSome { alt =>
alt.codec.schema
.traverse { altSchema =>
val altName = altSchema.getName
schemaTypes
.find(_.getName == altName)
.flatMap { schema =>
alt.codec
.decode(other, schema)
.map(alt.prism.reverseGet)
.toOption
}
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other))
}

value match {
case container: GenericContainer =>
val altName =
container.getSchema.getName

val altWriterSchema =
schemaTypes
.find(_.getName == altName)
.toRight(AvroError.decodeMissingUnionSchema(altName))

def altMatching =
alts
.find(_.codec.schema.exists { schema =>
schema.getType match {
case RECORD | FIXED | ENUM =>
schema.getName == altName || schema.getAliases.asScala
.exists(alias => alias == altName || alias.endsWith(s".$altName"))
case _ => false
}
})
.toRight(AvroError.decodeMissingUnionAlternative(altName))

altWriterSchema.flatMap { altSchema =>
altMatching.flatMap { alt =>
alt.codec
.decode(container, altSchema)
.map(alt.prism.reverseGet)
}
container.getSchema.getType match {
case RECORD | FIXED | ENUM => decodeNamedContainerType(container)
case _ => decodeUnnamedType(container)
}

case other =>
alts
.collectFirstSome { alt =>
alt.codec.schema
.traverse { altSchema =>
val altName = altSchema.getName
schemaTypes
.find(_.getName == altName)
.flatMap { schema =>
alt.codec
.decode(other, schema)
.map(alt.prism.reverseGet)
.toOption
}
}
}
.getOrElse {
Left(AvroError.decodeExhaustedAlternatives(other))
}
case other => decodeUnnamedType(other)
}
}
)
Expand Down
2 changes: 1 addition & 1 deletion modules/core/src/test/scala/vulcan/CodecSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2777,7 +2777,7 @@ final class CodecSpec extends BaseSpec with CodecSpecHelpers {
describe("schema") {
it("should encode as union") {
assertSchemaIs[SealedTraitCaseClass] {
"""[{"type":"record","name":"FirstInSealedTraitCaseClass","namespace":"com.example","fields":[{"name":"value","type":"int"}]},{"type":"record","name":"SecondInSealedTraitCaseClass","namespace":"com.example","fields":[{"name":"value","type":"string"}]},"int"]"""
"""[{"type":"record","name":"FirstInSealedTraitCaseClass","namespace":"com.example","fields":[{"name":"value","type":"int"}]},{"type":"record","name":"SecondInSealedTraitCaseClass","namespace":"com.example","fields":[{"name":"value","type":"string"}]},{"type":"array","items":"int"}]"""
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ object SealedTraitCaseClass {
Gen.oneOf[SealedTraitCaseClass](
arbitrary[Int].map(FirstInSealedTraitCaseClass(_)),
arbitrary[String].map(SecondInSealedTraitCaseClass(_)),
arbitrary[Int].map(ThirdInSealedTraitCaseClass(_))
arbitrary[List[Int]].map(ThirdInSealedTraitCaseClass(_))
)
}
}
Expand Down Expand Up @@ -61,12 +61,12 @@ object SecondInSealedTraitCaseClass {
Arbitrary(arbitrary[String].map(apply))
}

final case class ThirdInSealedTraitCaseClass(value: Int) extends SealedTraitCaseClass
final case class ThirdInSealedTraitCaseClass(value: List[Int]) extends SealedTraitCaseClass

object ThirdInSealedTraitCaseClass {
implicit val codec: Codec[ThirdInSealedTraitCaseClass] =
Codec[Int].imap(apply)(_.value)
Codec[List[Int]].imap(apply)(_.value)

implicit val arb: Arbitrary[ThirdInSealedTraitCaseClass] =
Arbitrary(arbitrary[Int].map(apply))
Arbitrary(arbitrary[List[Int]].map(apply))
}

0 comments on commit 94b979a

Please sign in to comment.