Skip to content

Commit

Permalink
Support sealed abstract class in DeriveSchema (#636)
Browse files Browse the repository at this point in the history
* Support sealed abstract class in DeriveSchema

* Fix infinite macro recursion

* Merge two branches

---------

Co-authored-by: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com>
  • Loading branch information
vigoo and 987Nabil authored Jan 15, 2024
1 parent 10fc670 commit a87dd6e
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ object DeriveSchema {

def isSealedTrait(tpe: Type): Boolean = tpe.typeSymbol.asClass.isTrait && tpe.typeSymbol.asClass.isSealed

def isSealedAbstractClass(tpe: Type): Boolean =
tpe.typeSymbol.asClass.isClass && tpe.typeSymbol.asClass.isSealed && tpe.typeSymbol.asClass.isAbstract

def isMap(tpe: Type): Boolean = tpe.typeSymbol.fullName == "scala.collection.immutable.Map"

def collectTypeAnnotations(tpe: Type): List[Tree] =
Expand Down Expand Up @@ -63,13 +66,13 @@ object DeriveSchema {
else q"_root_.zio.Chunk.apply(..$typeAnnotations)"
q"_root_.zio.schema.Schema.CaseClass0($typeId, () => ${tpe.typeSymbol.asClass.module}, $annotations)"
} else if (isCaseClass(tpe)) deriveRecord(tpe, stack)
else if (isSealedTrait(tpe))
else if (isSealedTrait(tpe) || isSealedAbstractClass(tpe))
deriveEnum(tpe, stack)
else if (isMap(tpe)) deriveMap(tpe)
else
c.abort(
c.enclosingPosition,
s"Failed to derive schema for $tpe. Can only derive Schema for case class or sealed trait"
s"Failed to derive schema for $tpe. Can only derive Schema for case class or sealed trait or sealed abstract class with case class children."
)

def directInferSchema(parentType: Type, schemaType: Type, stack: List[Frame[c.type]]): Tree = {
Expand All @@ -86,7 +89,7 @@ object DeriveSchema {
else {
c.inferImplicitValue(
c.typecheck(tq"_root_.zio.schema.Schema[$schemaType]", c.TYPEmode).tpe,
withMacrosDisabled = false
withMacrosDisabled = true
) match {
case EmptyTree =>
schemaType.typeArgs match {
Expand Down Expand Up @@ -530,20 +533,24 @@ object DeriveSchema {
sortedKnownSubclasses.flatMap { child =>
child.typeSignature
val childClass = child.asClass
if (childClass.isSealed && childClass.isTrait) knownSubclassesOf(childClass)
else if (childClass.isCaseClass) {
if (childClass.isSealed && childClass.isTrait)
knownSubclassesOf(childClass)
else if (childClass.isCaseClass || (childClass.isClass && childClass.isAbstract)) {
val st = concreteType(concreteType(tpe, parent.asType.toType), child.asType.toType)
Set(appliedSubtype(st))
} else c.abort(c.enclosingPosition, s"child $child of $parent is not a sealed trait or case class")
}
}

val subtypes = knownSubclassesOf(tpe.typeSymbol.asClass)
.map(concreteType(tpe, _))

val isSimpleEnum: Boolean =
!tpe.typeSymbol.asClass.knownDirectSubclasses.map { subtype =>
subtype.typeSignature.decls.sorted.collect {
!subtypes.map { subtype =>
subtype.typeSymbol.typeSignature.decls.sorted.collect {
case p: TermSymbol if p.isCaseAccessor && !p.isMethod => p
}.size
}.exists(_ > 0)
}.exists(_ > 0) && subtypes.forall(subtype => subtype.typeSymbol.asClass.isCaseClass)

val hasSimpleEnum: Boolean =
tpe.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[_root_.zio.schema.annotation.simpleEnum])
Expand Down Expand Up @@ -590,9 +597,6 @@ object DeriveSchema {

val currentFrame = Frame[c.type](c, selfRefName, tpe)

val subtypes = knownSubclassesOf(tpe.typeSymbol.asClass)
.map(concreteType(tpe, _))

val typeArgs = subtypes ++ Iterable(tpe)

val cases = subtypes.map { (subtype: Type) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,14 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS
sealed trait SimpleEnum2
case class SimpleClass2() extends SimpleEnum2

sealed abstract class AbstractBaseClass(val x: Int)
final case class ConcreteClass1(override val x: Int, y: Int) extends AbstractBaseClass(x)
final case class ConcreteClass2(override val x: Int, s: String) extends AbstractBaseClass(x)

sealed abstract class AbstractBaseClass2(val x: Int)
sealed abstract class MiddleClass(override val x: Int, val y: Int) extends AbstractBaseClass2(x)
final case class ConcreteClass3(override val x: Int, override val y: Int, s: String) extends MiddleClass(x, y)

override def spec: Spec[Environment, Any] = suite("DeriveSchemaSpec")(
suite("Derivation")(
test("correctly derives case class 0") {
Expand Down Expand Up @@ -484,6 +492,109 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS
test("correctly derives simpleEnum without annotation") {
val derived = DeriveSchema.gen[SimpleEnum2]
assertTrue(derived.annotations == Chunk(simpleEnum(true)))
},
test("correctly derives schema for abstract sealed class with case class subclasses") {
val derived = DeriveSchema.gen[AbstractBaseClass]
val expected: Schema[AbstractBaseClass] =
Schema.Enum2(
TypeId.parse("zio.schema.DeriveSchemaSpec.AbstractBaseClass"),
Schema.Case(
"ConcreteClass1",
Schema.CaseClass2(
TypeId.parse("zio.schema.DeriveSchemaSpec.ConcreteClass1"),
field01 = Schema.Field[ConcreteClass1, Int](
"x",
Schema.Primitive(StandardType.IntType),
get0 = _.x,
set0 = (a, b: Int) => a.copy(x = b)
),
field02 = Schema.Field[ConcreteClass1, Int](
"y",
Schema.Primitive(StandardType.IntType),
get0 = _.y,
set0 = (a, b: Int) => a.copy(y = b)
),
ConcreteClass1.apply
),
(a: AbstractBaseClass) => a.asInstanceOf[ConcreteClass1],
(a: ConcreteClass1) => a.asInstanceOf[AbstractBaseClass],
(a: AbstractBaseClass) => a.isInstanceOf[ConcreteClass1]
),
Schema.Case(
"ConcreteClass2",
Schema.CaseClass2(
TypeId.parse("zio.schema.DeriveSchemaSpec.ConcreteClass2"),
field01 = Schema.Field[ConcreteClass2, Int](
"x",
Schema.Primitive(StandardType.IntType),
get0 = _.x,
set0 = (a, b: Int) => a.copy(x = b)
),
field02 = Schema.Field[ConcreteClass2, String](
"s",
Schema.Primitive(StandardType.StringType),
get0 = _.s,
set0 = (a, b: String) => a.copy(s = b)
),
ConcreteClass2.apply
),
(a: AbstractBaseClass) => a.asInstanceOf[ConcreteClass2],
(a: ConcreteClass2) => a.asInstanceOf[AbstractBaseClass],
(a: AbstractBaseClass) => a.isInstanceOf[ConcreteClass2]
),
Chunk.empty
)
assert(derived)(hasSameSchema(expected))
},
test(
"correctly derives schema for abstract sealed class with intermediate subclasses, having case class leaf classes"
) {
val derived = DeriveSchema.gen[AbstractBaseClass2]
val expected: Schema[AbstractBaseClass2] =
Schema.Enum1[MiddleClass, AbstractBaseClass2](
TypeId.parse("zio.schema.DeriveSchemaSpec.AbstractBaseClass2"),
Schema.Case[AbstractBaseClass2, MiddleClass](
"MiddleClass",
Schema.Enum1[ConcreteClass3, MiddleClass](
TypeId.parse("zio.schema.DeriveSchemaSpec.MiddleClass"),
Schema.Case[MiddleClass, ConcreteClass3](
"ConcreteClass3",
Schema.CaseClass3(
TypeId.parse("zio.schema.DeriveSchemaSpec.ConcreteClass3"),
field01 = Schema.Field[ConcreteClass3, Int](
"x",
Schema.Primitive(StandardType.IntType),
get0 = _.x,
set0 = (a, b: Int) => a.copy(x = b)
),
field02 = Schema.Field[ConcreteClass3, Int](
"y",
Schema.Primitive(StandardType.IntType),
get0 = _.y,
set0 = (a, b: Int) => a.copy(y = b)
),
field03 = Schema.Field[ConcreteClass3, String](
"s",
Schema.Primitive(StandardType.StringType),
get0 = _.s,
set0 = (a, b: String) => a.copy(s = b)
),
ConcreteClass3.apply
),
(a: MiddleClass) => a.asInstanceOf[ConcreteClass3],
(a: ConcreteClass3) => a.asInstanceOf[MiddleClass],
(a: MiddleClass) => a.isInstanceOf[ConcreteClass3],
Chunk.empty
),
Chunk.empty
),
(a: AbstractBaseClass2) => a.asInstanceOf[MiddleClass],
(a: MiddleClass) => a.asInstanceOf[AbstractBaseClass2],
(a: AbstractBaseClass2) => a.isInstanceOf[MiddleClass],
Chunk.empty
)
)
assert(derived)(hasSameSchema(expected))
}
),
versionSpecificSuite
Expand Down

0 comments on commit a87dd6e

Please sign in to comment.