diff --git a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala index b7d4ea0b02..9dffa2a76b 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala @@ -150,7 +150,7 @@ final case class EndpointGen(config: Config) { // we don't need a separate code file, as it will be included in the sealed trait companion. // this var stores the bookkeeping of such case classes, and is later used to omit the redundant code files. var replacedCasesToOmitAsTopComponents = Set.empty[String] - val allComponents = componentNameToCodeFile.view.map { case (name, codeFile) => + val allComponents: List[Code.File] = componentNameToCodeFile.view.map { case (name, codeFile) => traitToSubtypes .get(name) .fold(codeFile) { subtypes => @@ -166,13 +166,110 @@ final case class EndpointGen(config: Config) { } }.toList - allComponents.filterNot { cf => + val noDuplicateFiles = allComponents.filterNot { cf => cf.enums.isEmpty && cf.objects.isEmpty && cf.caseClasses.nonEmpty && cf.caseClasses.forall(cc => replacedCasesToOmitAsTopComponents(cc.name), ) } + + // After filtering out duplicate files, we can make sure that in all the files left, all fields has proper types. + // The `mapType` function is used to alter any relevant part of each code file + noDuplicateFiles.map { cf => + cf.copy( + objects = cf.objects.map(mapType(_.name, subtypeToTraits)), + caseClasses = cf.caseClasses.map(mapType(_.name, subtypeToTraits)), + enums = cf.enums.map(mapType(_.name, subtypeToTraits)), + ) + } } + /** + * The types may not be valid in case we reference a concrete subtype of a + * sealed trait, as the subtype is defined as an inner class encapsulated + * inside the trait's companion. Therefore, we can alter the type to include + * the enclosing trait/object's name. The following function will be used to + * alter the type of all fields needed. The `mapCaseClasses` helper takes a + * function that alters a case class, and lifts it such that we can apply it + * to any structure, and it'll take care to recurse when needed. + * + * @param getEncapsulatingName + * used to get the name of the code structure we operate on + * (Object/CaseClass/Enum) + * @param subtypeToTraits + * mappings of subtypes to their mixins - if theres only one, we assume + * subtype is nested. + * @param codeStructureToAlter + * the structure to modify + * @return + * the modified structure + */ + def mapType[T <: Code.ScalaType](getEncapsulatingName: T => String, subtypeToTraits: Map[String, Set[String]])( + codeStructureToAlter: T, + ): T = + mapCaseClasses { cc => + cc.copy(fields = cc.fields.foldRight(List.empty[Code.Field]) { case (f @ Code.Field(_, scalaType), tail) => + f.copy(fieldType = mapTypeRef(scalaType) { case originalType @ Code.TypeRef(tName) => + // We use the subtypeToTraits map to check if the type is a concrete subtype of a sealed trait. + // As of the time of writing this code, there should be only a single trait. + // In case future code generalizes to allow multiple mixins, this code should be updated. + subtypeToTraits.get(tName).fold(originalType) { set => + // If the type parameter has exactly 1 super type trait, + // and that trait's name is different from our enclosing object's name, + // then we should alter the type to include the object's name. + if (set.size != 1 || set.head == getEncapsulatingName(codeStructureToAlter)) originalType + else Code.TypeRef(set.head + "." + tName) + } + }) :: tail + }) + }(codeStructureToAlter) + + /** + * Given the type parameter of a field, we may want to alter it, e.g. by + * prepending the enclosing trait/object's name. This function will + * recursively alter the type of a field. Recursion is needed for types that + * contain a type parameter. e.g. transforming: {{{Chunk[Option[Zebra]]}}} to + * {{{Chunk[Option[Animal.Zebra]]}}} + * + * @param sType + * the original type we want to alter + * @param f + * a function that may alter the type, None means no altering is needed. + * @return + * The altered type, or gives back the input if no modification was needed. + */ + def mapTypeRef(sType: Code.ScalaType)(f: Code.TypeRef => Code.TypeRef): Code.ScalaType = + sType match { + case tref: Code.TypeRef => f(tref) + case Collection.Seq(inner) => Collection.Seq(mapTypeRef(inner)(f)) + case Collection.Set(inner) => Collection.Set(mapTypeRef(inner)(f)) + case Collection.Map(inner) => Collection.Map(mapTypeRef(inner)(f)) + case Collection.Opt(inner) => Collection.Opt(mapTypeRef(inner)(f)) + case _ => sType + } + + /** + * Given a function to alter a case class, this function will apply it to any + * structure recursively. + * + * @param f + * function to transform a [[zio.http.gen.scala.Code.CaseClass]] + * @param code + * the structure to apply transformation of case classes on + * @return + * the transformed structure + */ + def mapCaseClasses[T <: Code.ScalaType](f: Code.CaseClass => Code.CaseClass)(code: T): T = + (code match { + case obj: Code.Object => + obj.copy( + caseClasses = obj.caseClasses.map(mapCaseClasses(f)), + objects = obj.objects.map(mapCaseClasses(f)), + ) + case cc: Code.CaseClass => f(cc) + case sum: Code.Enum => sum.copy(cases = sum.cases.map(mapCaseClasses(f))) + case _ => code + }).asInstanceOf[T] + def fromOpenAPI(openAPI: OpenAPI): Code.Files = Code.Files { val componentsCode = extractComponents(openAPI) diff --git a/zio-http-gen/src/test/resources/ComponentAnimalWithFieldsReferencingSubs.scala b/zio-http-gen/src/test/resources/ComponentAnimalWithFieldsReferencingSubs.scala new file mode 100644 index 0000000000..0a37bc61a7 --- /dev/null +++ b/zio-http-gen/src/test/resources/ComponentAnimalWithFieldsReferencingSubs.scala @@ -0,0 +1,36 @@ +package test.component + +import zio.schema._ +import zio.schema.annotation._ +import zio.Chunk + +@noDiscriminator +sealed trait Animal { + def age: Int + def weight: Float +} +object Animal { + + implicit val codec: Schema[Animal] = DeriveSchema.gen[Animal] + case class Alligator( + age: Int, + weight: Float, + num_teeth: Int, + ) extends Animal + object Alligator { + + implicit val codec: Schema[Alligator] = DeriveSchema.gen[Alligator] + + } + case class Zebra( + age: Int, + weight: Float, + num_stripes: Int, + dazzle: Chunk[Zebra], + ) extends Animal + object Zebra { + + implicit val codec: Schema[Zebra] = DeriveSchema.gen[Zebra] + + } +} diff --git a/zio-http-gen/src/test/resources/ComponentLion.scala b/zio-http-gen/src/test/resources/ComponentLion.scala new file mode 100644 index 0000000000..b28b727c5e --- /dev/null +++ b/zio-http-gen/src/test/resources/ComponentLion.scala @@ -0,0 +1,13 @@ +package test.component + +import zio.schema._ + +case class Lion( + eats: Animal.Zebra, + enemy: Option[Animal.Alligator], +) +object Lion { + + implicit val codec: Schema[Lion] = DeriveSchema.gen[Lion] + +} \ No newline at end of file diff --git a/zio-http-gen/src/test/resources/inline_schema_sumtype_with_subtype_referenced_directly.yaml b/zio-http-gen/src/test/resources/inline_schema_sumtype_with_subtype_referenced_directly.yaml new file mode 100644 index 0000000000..96c1123d9d --- /dev/null +++ b/zio-http-gen/src/test/resources/inline_schema_sumtype_with_subtype_referenced_directly.yaml @@ -0,0 +1,86 @@ +info: + title: Animals Service + version: 0.0.1 +servers: + - url: http://127.0.0.1:5000/ +tags: + - name: Animals_API +paths: + /api/v1/zoo/{animal}: + get: + operationId: get_animal + parameters: + - in: path + name: animal + schema: + type: string + required: true + tags: + - Animals_API + description: Get animals by species name + responses: + "200": + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Animal' + description: OK +openapi: 3.0.3 +components: + schemas: + Animal: + oneOf: + - $ref: '#/components/schemas/Alligator' + - $ref: '#/components/schemas/Zebra' + AnimalSharedFields: + type: object + required: + - age + - weight + properties: + age: + type: integer + format: int32 + minimum: 0 + weight: + type: number + format: float + minimum: 0 + Alligator: + allOf: + - $ref: '#/components/schemas/AnimalSharedFields' + - type: object + required: + - num_teeth + properties: + num_teeth: + type: integer + format: int32 + minimum: 0 + Zebra: + allOf: + - $ref: '#/components/schemas/AnimalSharedFields' + - type: object + required: + - num_stripes + - dazzle + properties: + num_stripes: + type: integer + format: int32 + minimum: 0 + dazzle: + type: array + items: + $ref: '#/components/schemas/Zebra' + Lion: + type: object + required: + - eats + properties: + eats: + $ref: '#/components/schemas/Zebra' + enemy: + $ref: '#/components/schemas/Alligator' \ No newline at end of file diff --git a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala index 84c0518b1f..e139209b77 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala @@ -491,6 +491,57 @@ object CodeGenSpec extends ZIOSpecDefault { } } } @@ TestAspect.exceptScala3, // for some reason, the temp dir is empty in Scala 3 + test("OpenAPI spec with inline schema response body of sum-type whose concrete subtype is referenced directly") { + + import zio.json.yaml.DecoderYamlOps + implicit val decoder: JsonDecoder[OpenAPI] = JsonCodec.jsonDecoder(OpenAPI.schema) + + val openAPIString = + Files + .readAllLines( + Paths.get( + getClass + .getResource("/inline_schema_sumtype_with_subtype_referenced_directly.yaml") + .toURI, + ), + ) + .asScala + .mkString("\n") + + openAPIString.fromYaml match { + case Left(error) => TestResult(TestArrow.make(_ => TestTrace.fail(ErrorMessage.text(error)))) + case Right(oapi) => + val t = Try(EndpointGen.fromOpenAPI(oapi, Config(commonFieldsOnSuperType = true))) + assert(t)(isSuccess) && { + val tempDir = Files.createTempDirectory("codegen") + val testDir = tempDir.resolve("test") + + CodeGen.writeFiles(t.get, testDir, "test", Some(scalaFmtPath)) + + allFilesShouldBe( + testDir.toFile, + List( + "api/v1/zoo/Animal.scala", + "component/Animal.scala", + "component/AnimalSharedFields.scala", + "component/Lion.scala", + ), + ) && fileShouldBe( + testDir, + "api/v1/zoo/Animal.scala", + "/EndpointForZooNoError.scala", + ) && fileShouldBe( + testDir, + "component/Animal.scala", + "/ComponentAnimalWithFieldsReferencingSubs.scala", + ) && fileShouldBe( + testDir, + "component/Lion.scala", + "/ComponentLion.scala", + ) + } + } + } @@ TestAspect.exceptScala3, // for some reason, the temp dir is empty in Scala 3 test("Endpoint with array field in input") { val endpoint = Endpoint(Method.POST / "api" / "v1" / "users").in[UserNameArray].out[User] val openAPI = OpenAPIGen.fromEndpoints("", "", endpoint)