From 7610e6085046c4ab3c6711b4d63390e60a062995 Mon Sep 17 00:00:00 2001 From: Jesse Date: Sat, 18 Nov 2023 17:45:30 +0100 Subject: [PATCH] Fix compilation of Recursive GADT schema derivation (#561) * Fix compilation of GADT schema derivation on scala 2 * Fix deriving schema for generic types on scala 3 * Add tests for generically deriving schemas for enums * Update readme * Fix --------- Co-authored-by: Daniel Vigovszky --- .../scala-2/zio/schema/DeriveSchema.scala | 31 ++++-- .../scala-3/zio/schema/DeriveSchema.scala | 103 +++++++++--------- .../scala/zio/schema/DeriveSchemaSpec.scala | 10 ++ 3 files changed, 80 insertions(+), 64 deletions(-) diff --git a/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala b/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala index 78e8f5b0f..7ca21b964 100644 --- a/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala +++ b/zio-schema-derivation/shared/src/main/scala-2/zio/schema/DeriveSchema.scala @@ -14,6 +14,16 @@ object DeriveSchema { val JavaAnnotationTpe = typeOf[java.lang.annotation.Annotation] + lazy val optionType = typeOf[Option[_]] + lazy val listType = typeOf[List[_]] + lazy val setType = typeOf[Set[_]] + lazy val vectorType = typeOf[Vector[_]] + lazy val chunkType = typeOf[Chunk[_]] + lazy val eitherType = typeOf[Either[_, _]] + lazy val tuple2Type = typeOf[(_, _)] + lazy val tuple3Type = typeOf[(_, _, _)] + lazy val tuple4Type = typeOf[(_, _, _, _)] + val tpe = weakTypeOf[T] def concreteType(seenFrom: Type, tpe: Type): Type = @@ -62,7 +72,7 @@ object DeriveSchema { s"Failed to derive schema for $tpe. Can only derive Schema for case class or sealed trait" ) - def directInferSchema(parentType: Type, schemaType: Type, stack: List[Frame[c.type]]): Tree = + def directInferSchema(parentType: Type, schemaType: Type, stack: List[Frame[c.type]]): Tree = { stack .find(_.tpe =:= schemaType) .map { @@ -83,20 +93,20 @@ object DeriveSchema { case Nil => recurse(schemaType, stack) case typeArg1 :: Nil => - if (schemaType <:< c.typeOf[Option[_]]) + if (schemaType <:< optionType) q"_root_.zio.schema.Schema.option(_root_.zio.schema.Schema.defer(${directInferSchema(parentType, concreteType(parentType, typeArg1), stack)}))" - else if (schemaType <:< typeOf[List[_]]) + else if (schemaType <:< listType) q"_root_.zio.schema.Schema.list(_root_.zio.schema.Schema.defer(${directInferSchema(parentType, concreteType(parentType, typeArg1), stack)}))" - else if (schemaType <:< typeOf[Set[_]]) + else if (schemaType <:< setType) q"_root_.zio.schema.Schema.set(_root_.zio.schema.Schema.defer(${directInferSchema(parentType, concreteType(parentType, typeArg1), stack)}))" - else if (schemaType <:< typeOf[Vector[_]]) + else if (schemaType <:< vectorType) q"_root_.zio.schema.Schema.vector(_root_.zio.schema.Schema.defer(${directInferSchema(parentType, concreteType(parentType, typeArg1), stack)}))" - else if (schemaType <:< typeOf[Chunk[_]]) + else if (schemaType <:< chunkType) q"_root_.zio.schema.Schema.chunk(_root_.zio.schema.Schema.defer(${directInferSchema(parentType, concreteType(parentType, typeArg1), stack)}))" else recurse(schemaType, stack) case typeArg1 :: typeArg2 :: Nil => - if (schemaType <:< typeOf[Either[_, _]]) + if (schemaType <:< eitherType) q"""_root_.zio.schema.Schema.either( _root_.zio.schema.Schema.defer(${directInferSchema( parentType, @@ -110,7 +120,7 @@ object DeriveSchema { )}) ) """ - else if (schemaType <:< typeOf[(_, _)]) + else if (schemaType <:< tuple2Type) q"""_root_.zio.schema.Schema.tuple2( _root_.zio.schema.Schema.defer(${directInferSchema( parentType, @@ -127,7 +137,7 @@ object DeriveSchema { else recurse(schemaType, stack) case typeArg1 :: typeArg2 :: typeArg3 :: Nil => - if (schemaType <:< typeOf[(_, _, _)]) + if (schemaType <:< tuple3Type) q"""_root_.zio.schema.Schema.tuple3( _root_.zio.schema.Schema.defer(${directInferSchema( parentType, @@ -150,7 +160,7 @@ object DeriveSchema { else recurse(schemaType, stack) case typeArg1 :: typeArg2 :: typeArg3 :: typeArg4 :: Nil => - if (schemaType <:< typeOf[(_, _, _)]) + if (schemaType <:< tuple4Type) q"""_root_.zio.schema.Schema.tuple4( _root_.zio.schema.Schema.defer(${directInferSchema( parentType, @@ -183,6 +193,7 @@ object DeriveSchema { } } } + } def getFieldName(annotations: List[Tree]): Option[String] = annotations.collectFirst { diff --git a/zio-schema-derivation/shared/src/main/scala-3/zio/schema/DeriveSchema.scala b/zio-schema-derivation/shared/src/main/scala-3/zio/schema/DeriveSchema.scala index 55d7a5cf6..3b02dc3c3 100644 --- a/zio-schema-derivation/shared/src/main/scala-3/zio/schema/DeriveSchema.scala +++ b/zio-schema-derivation/shared/src/main/scala-3/zio/schema/DeriveSchema.scala @@ -52,62 +52,57 @@ private case class DeriveSchema()(using val ctx: Quotes) { val result = stack.find(typeRepr) match { case Some(ref) => '{ Schema.defer(${ref.asExprOf[Schema[T]]}) } - case None => - val summoned = Expr.summon[Schema[T]] + case None => + val summoned = if (!top) Expr.summon[Schema[T]] else None if (!top && summoned.isDefined) { - '{ Schema.defer(${summoned.get}) }.asExprOf[Schema[T]] + '{ + Schema.defer(${ + summoned.get + }) + }.asExprOf[Schema[T]] } else { - typeRepr.asType match { - case '[List[a]] => - val schema = deriveSchema[a](stack) - '{ Schema.list(Schema.defer(${schema})) }.asExprOf[Schema[T]] - case '[scala.util.Either[a, b]] => - val schemaA = deriveSchema[a](stack) - val schemaB = deriveSchema[b](stack) - '{ Schema.either(Schema.defer(${schemaA}), Schema.defer(${schemaB})) }.asExprOf[Schema[T]] - case '[Option[a]] => - val schema = deriveSchema[a](stack) - // throw new Error(s"OPITOS ${schema.show}") - '{ Schema.option(Schema.defer($schema)) }.asExprOf[Schema[T]] - case '[scala.collection.Set[a]] => - val schema = deriveSchema[a](stack) - '{ Schema.set(Schema.defer(${schema})) }.asExprOf[Schema[T]] - case '[Vector[a]] => - val schema = deriveSchema[a](stack) - '{ Schema.vector(Schema.defer(${schema})) }.asExprOf[Schema[T]] - case '[scala.collection.Map[a, b]] => - val schemaA = deriveSchema[a](stack) - val schemaB = deriveSchema[b](stack) - '{ Schema.map(Schema.defer(${schemaA}), Schema.defer(${schemaB})) }.asExprOf[Schema[T]] - case '[zio.Chunk[a]] => - val schema = deriveSchema[a](stack) - '{ Schema.chunk(Schema.defer(${schema})) }.asExprOf[Schema[T]] - case _ => - val summoned = if (!top) Expr.summon[Schema[T]] else None - summoned match { - case Some(schema) => - // println(s"FOR TYPE ${typeRepr.show}") - // println(s"STACK ${stack.find(typeRepr)}") - // println(s"Found schema ${schema.show}") - schema - case _ => - Mirror(typeRepr) match { - case Some(mirror) => - mirror.mirrorType match { - case MirrorType.Sum => - deriveEnum[T](mirror, stack) - case MirrorType.Product => - deriveCaseClass[T](mirror, stack, top) - } - case None => - val sym = typeRepr.typeSymbol - if (sym.isClassDef && sym.flags.is(Flags.Module)) { - deriveCaseObject[T](stack, top) - } - else { - report.errorAndAbort(s"Deriving schema for ${typeRepr.show} is not supported") - } - } + typeRepr.asType match { + case '[List[a]] => + val schema = deriveSchema[a](stack) + '{ Schema.list(Schema.defer(${schema})) }.asExprOf[Schema[T]] + case '[scala.util.Either[a, b]] => + val schemaA = deriveSchema[a](stack) + val schemaB = deriveSchema[b](stack) + '{ Schema.either(Schema.defer(${schemaA}), Schema.defer(${schemaB})) }.asExprOf[Schema[T]] + case '[Option[a]] => + val schema = deriveSchema[a](stack) + // throw new Error(s"OPITOS ${schema.show}") + '{ Schema.option(Schema.defer($schema)) }.asExprOf[Schema[T]] + case '[scala.collection.Set[a]] => + val schema = deriveSchema[a](stack) + '{ Schema.set(Schema.defer(${schema})) }.asExprOf[Schema[T]] + case '[Vector[a]] => + val schema = deriveSchema[a](stack) + '{ Schema.vector(Schema.defer(${schema})) }.asExprOf[Schema[T]] + case '[scala.collection.Map[a, b]] => + val schemaA = deriveSchema[a](stack) + val schemaB = deriveSchema[b](stack) + '{ Schema.map(Schema.defer(${schemaA}), Schema.defer(${schemaB})) }.asExprOf[Schema[T]] + case '[zio.Chunk[a]] => + val schema = deriveSchema[a](stack) + '{ Schema.chunk(Schema.defer(${schema})) }.asExprOf[Schema[T]] + case _ => + Mirror(typeRepr) match { + case Some(mirror) => + mirror.mirrorType match { + case MirrorType.Sum => + deriveEnum[T](mirror, stack) + case MirrorType.Product => + deriveCaseClass[T](mirror, stack, top) + } + case None => + val sym = typeRepr.typeSymbol + if (sym.isClassDef && sym.flags.is(Flags.Module)) { + deriveCaseObject[T](stack, top) + } + else { + report.errorAndAbort(s"Deriving schema for ${typeRepr.show} is not supported") + } } } } diff --git a/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala b/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala index f179f90c8..efc3159b9 100644 --- a/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala +++ b/zio-schema-derivation/shared/src/test/scala/zio/schema/DeriveSchemaSpec.scala @@ -165,6 +165,8 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A] case class Leaf[A](value: A) extends Tree[A] case object Root extends Tree[Nothing] + + implicit def schema[A: Schema]: Schema[Tree[A]] = DeriveSchema.gen[Tree[A]] } sealed trait RBTree[+A, +B] @@ -173,6 +175,8 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS case class Branch[A, B](left: RBTree[A, B], right: RBTree[A, B]) extends RBTree[A, B] case class RLeaf[A](value: A) extends RBTree[A, Nothing] case class BLeaf[B](value: B) extends RBTree[Nothing, B] + + implicit def schema[A: Schema, B: Schema]: Schema[RBTree[A, B]] = DeriveSchema.gen[RBTree[A, B]] } sealed trait AdtWithTypeParameters[+Param1, +Param2] @@ -400,10 +404,16 @@ object DeriveSchemaSpec extends ZIOSpecDefault with VersionSpecificDeriveSchemaS val derived: Schema[Tree[Recursive]] = DeriveSchema.gen[Tree[Recursive]] assert(derived)(anything) }, + test("correctly derives generic recursive Enum") { + assert(Schema[Tree[Recursive]].toString)(not(containsString("null")) && not(equalTo("$Lazy$"))) + }, test("correctly derives recursive Enum with multiple type parameters") { val derived: Schema[RBTree[String, Int]] = DeriveSchema.gen[RBTree[String, Int]] assert(derived)(anything) }, + test("correctly derives generic recursive Enum with multiple type parameters") { + assert(Schema[RBTree[String, Int]].toString)(not(containsString("null")) && not(equalTo("$Lazy$"))) + }, test("correctly derives schema with unused type parameters") { val derived: Schema[AdtWithTypeParameters[Int, Int]] = DeriveSchema.gen[AdtWithTypeParameters[Int, Int]] assert(derived)(anything)