From a87dd6e3eea0019297de77df59c91f11c3e464e9 Mon Sep 17 00:00:00 2001 From: Daniel Vigovszky Date: Mon, 15 Jan 2024 10:57:47 +0100 Subject: [PATCH] Support sealed abstract class in DeriveSchema (#636) * Support sealed abstract class in DeriveSchema * Fix infinite macro recursion * Merge two branches --------- Co-authored-by: Nabil Abdel-Hafeez <7283535+987Nabil@users.noreply.github.com> --- .../scala-2/zio/schema/DeriveSchema.scala | 26 ++-- .../scala/zio/schema/DeriveSchemaSpec.scala | 111 ++++++++++++++++++ 2 files changed, 126 insertions(+), 11 deletions(-) diff --git a/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala b/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala index 7ca21b964..f5772dd7d 100644 --- a/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala +++ b/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala @@ -35,6 +35,9 @@ object DeriveSchema { def isSealedTrait(tpe: Type): Boolean = tpe.typeSymbol.asClass.isTrait && tpe.typeSymbol.asClass.isSealed + def isSealedAbstractClass(tpe: Type): Boolean = + tpe.typeSymbol.asClass.isClass && tpe.typeSymbol.asClass.isSealed && tpe.typeSymbol.asClass.isAbstract + def isMap(tpe: Type): Boolean = tpe.typeSymbol.fullName == "scala.collection.immutable.Map" def collectTypeAnnotations(tpe: Type): List[Tree] = @@ -63,13 +66,13 @@ object DeriveSchema { else q"_root_.zio.Chunk.apply(..$typeAnnotations)" q"_root_.zio.schema.Schema.CaseClass0($typeId, () => ${tpe.typeSymbol.asClass.module}, $annotations)" } else if (isCaseClass(tpe)) deriveRecord(tpe, stack) - else if (isSealedTrait(tpe)) + else if (isSealedTrait(tpe) || isSealedAbstractClass(tpe)) deriveEnum(tpe, stack) else if (isMap(tpe)) deriveMap(tpe) else c.abort( c.enclosingPosition, - s"Failed to derive schema for $tpe. Can only derive Schema for case class or sealed trait" + s"Failed to derive schema for $tpe. Can only derive Schema for case class or sealed trait or sealed abstract class with case class children." ) def directInferSchema(parentType: Type, schemaType: Type, stack: List[Frame[c.type]]): Tree = { @@ -86,7 +89,7 @@ object DeriveSchema { else { c.inferImplicitValue( c.typecheck(tq"_root_.zio.schema.Schema[$schemaType]", c.TYPEmode).tpe, - withMacrosDisabled = false + withMacrosDisabled = true ) match { case EmptyTree => schemaType.typeArgs match { @@ -530,20 +533,24 @@ object DeriveSchema { sortedKnownSubclasses.flatMap { child => child.typeSignature val childClass = child.asClass - if (childClass.isSealed && childClass.isTrait) knownSubclassesOf(childClass) - else if (childClass.isCaseClass) { + if (childClass.isSealed && childClass.isTrait) + knownSubclassesOf(childClass) + else if (childClass.isCaseClass || (childClass.isClass && childClass.isAbstract)) { val st = concreteType(concreteType(tpe, parent.asType.toType), child.asType.toType) Set(appliedSubtype(st)) } else c.abort(c.enclosingPosition, s"child $child of $parent is not a sealed trait or case class") } } + val subtypes = knownSubclassesOf(tpe.typeSymbol.asClass) + .map(concreteType(tpe, _)) + val isSimpleEnum: Boolean = - !tpe.typeSymbol.asClass.knownDirectSubclasses.map { subtype => - subtype.typeSignature.decls.sorted.collect { + !subtypes.map { subtype => + subtype.typeSymbol.typeSignature.decls.sorted.collect { case p: TermSymbol if p.isCaseAccessor && !p.isMethod => p }.size - }.exists(_ > 0) + }.exists(_ > 0) && subtypes.forall(subtype => subtype.typeSymbol.asClass.isCaseClass) val hasSimpleEnum: Boolean = tpe.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[_root_.zio.schema.annotation.simpleEnum]) @@ -590,9 +597,6 @@ object DeriveSchema { val currentFrame = Frame[c.type](c, selfRefName, tpe) - val subtypes = knownSubclassesOf(tpe.typeSymbol.asClass) - .map(concreteType(tpe, _)) - val typeArgs = subtypes ++ Iterable(tpe) val cases = subtypes.map { (subtype: Type) => diff --git a/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala b/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala index efc3159b9..7927478c9 100644 --- a/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala +++ b/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala @@ -260,6 +260,14 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS sealed trait SimpleEnum2 case class SimpleClass2() extends SimpleEnum2 + sealed abstract class AbstractBaseClass(val x: Int) + final case class ConcreteClass1(override val x: Int, y: Int) extends AbstractBaseClass(x) + final case class ConcreteClass2(override val x: Int, s: String) extends AbstractBaseClass(x) + + sealed abstract class AbstractBaseClass2(val x: Int) + sealed abstract class MiddleClass(override val x: Int, val y: Int) extends AbstractBaseClass2(x) + final case class ConcreteClass3(override val x: Int, override val y: Int, s: String) extends MiddleClass(x, y) + override def spec: Spec[Environment, Any] = suite("DeriveSchemaSpec")( suite("Derivation")( test("correctly derives case class 0") { @@ -484,6 +492,109 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS test("correctly derives simpleEnum without annotation") { val derived = DeriveSchema.gen[SimpleEnum2] assertTrue(derived.annotations == Chunk(simpleEnum(true))) + }, + test("correctly derives schema for abstract sealed class with case class subclasses") { + val derived = DeriveSchema.gen[AbstractBaseClass] + val expected: Schema[AbstractBaseClass] = + Schema.Enum2( + TypeId.parse("zio.schema.DeriveSchemaSpec.AbstractBaseClass"), + Schema.Case( + "ConcreteClass1", + Schema.CaseClass2( + TypeId.parse("zio.schema.DeriveSchemaSpec.ConcreteClass1"), + field01 = Schema.Field[ConcreteClass1, Int]( + "x", + Schema.Primitive(StandardType.IntType), + get0 = _.x, + set0 = (a, b: Int) => a.copy(x = b) + ), + field02 = Schema.Field[ConcreteClass1, Int]( + "y", + Schema.Primitive(StandardType.IntType), + get0 = _.y, + set0 = (a, b: Int) => a.copy(y = b) + ), + ConcreteClass1.apply + ), + (a: AbstractBaseClass) => a.asInstanceOf[ConcreteClass1], + (a: ConcreteClass1) => a.asInstanceOf[AbstractBaseClass], + (a: AbstractBaseClass) => a.isInstanceOf[ConcreteClass1] + ), + Schema.Case( + "ConcreteClass2", + Schema.CaseClass2( + TypeId.parse("zio.schema.DeriveSchemaSpec.ConcreteClass2"), + field01 = Schema.Field[ConcreteClass2, Int]( + "x", + Schema.Primitive(StandardType.IntType), + get0 = _.x, + set0 = (a, b: Int) => a.copy(x = b) + ), + field02 = Schema.Field[ConcreteClass2, String]( + "s", + Schema.Primitive(StandardType.StringType), + get0 = _.s, + set0 = (a, b: String) => a.copy(s = b) + ), + ConcreteClass2.apply + ), + (a: AbstractBaseClass) => a.asInstanceOf[ConcreteClass2], + (a: ConcreteClass2) => a.asInstanceOf[AbstractBaseClass], + (a: AbstractBaseClass) => a.isInstanceOf[ConcreteClass2] + ), + Chunk.empty + ) + assert(derived)(hasSameSchema(expected)) + }, + test( + "correctly derives schema for abstract sealed class with intermediate subclasses, having case class leaf classes" + ) { + val derived = DeriveSchema.gen[AbstractBaseClass2] + val expected: Schema[AbstractBaseClass2] = + Schema.Enum1[MiddleClass, AbstractBaseClass2]( + TypeId.parse("zio.schema.DeriveSchemaSpec.AbstractBaseClass2"), + Schema.Case[AbstractBaseClass2, MiddleClass]( + "MiddleClass", + Schema.Enum1[ConcreteClass3, MiddleClass]( + TypeId.parse("zio.schema.DeriveSchemaSpec.MiddleClass"), + Schema.Case[MiddleClass, ConcreteClass3]( + "ConcreteClass3", + Schema.CaseClass3( + TypeId.parse("zio.schema.DeriveSchemaSpec.ConcreteClass3"), + field01 = Schema.Field[ConcreteClass3, Int]( + "x", + Schema.Primitive(StandardType.IntType), + get0 = _.x, + set0 = (a, b: Int) => a.copy(x = b) + ), + field02 = Schema.Field[ConcreteClass3, Int]( + "y", + Schema.Primitive(StandardType.IntType), + get0 = _.y, + set0 = (a, b: Int) => a.copy(y = b) + ), + field03 = Schema.Field[ConcreteClass3, String]( + "s", + Schema.Primitive(StandardType.StringType), + get0 = _.s, + set0 = (a, b: String) => a.copy(s = b) + ), + ConcreteClass3.apply + ), + (a: MiddleClass) => a.asInstanceOf[ConcreteClass3], + (a: ConcreteClass3) => a.asInstanceOf[MiddleClass], + (a: MiddleClass) => a.isInstanceOf[ConcreteClass3], + Chunk.empty + ), + Chunk.empty + ), + (a: AbstractBaseClass2) => a.asInstanceOf[MiddleClass], + (a: MiddleClass) => a.asInstanceOf[AbstractBaseClass2], + (a: AbstractBaseClass2) => a.isInstanceOf[MiddleClass], + Chunk.empty + ) + ) + assert(derived)(hasSameSchema(expected)) } ), versionSpecificSuite