Skip to content

Commit

Permalink
[gen] simplified code for better maintainability
Browse files Browse the repository at this point in the history
  • Loading branch information
hochgi committed May 28, 2024
1 parent ca18cba commit 8cdbe46
Showing 1 changed file with 55 additions and 47 deletions.
102 changes: 55 additions & 47 deletions zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ object EndpointGen {
case None => m - k
}
}

implicit class OptionCompanionCompatOps(o: Option.type) {
// scala 2.12 collection does not support Option.when natively, so we're adding this as an extension.
def when[A](p: => Boolean)(value: => A): Option[A] =
if (p) Some(value) else None
}
}

final case class EndpointGen(config: Config) {
Expand Down Expand Up @@ -179,42 +173,56 @@ final case class EndpointGen(config: Config) {
}

// After filtering out duplicate files, we can make sure that in all the files left, all fields has proper types.
// The types may not be valid in case we reference a concrete subtype of a sealed trait,
// as the subtype is defined as an inner class encapsulated inside the trait's companion.
// Therefore, we can alter the type to include the enclosing trait/object's name.
//
// The following function will be used to alter the type of all fields needed.
// The `mapCaseClasses` helper takes a function that alters a case class,
// and lifts it such that we can apply it to any structure, and it'll take care to recurse when needed.
val mapType: String => Code.ScalaType => Code.ScalaType = encapsulatingName =>
mapCaseClasses { cc =>
cc.copy(fields = cc.fields.foldRight(List.empty[Code.Field]) { case (o @ Code.Field(_, scalaType), tail) =>
mapTypeRef(scalaType) { case Code.TypeRef(tName) =>
// We use the subtypeToTraits map to check if the type is a concrete subtype of a sealed trait.
// As of the time of writing this code, there should be only a single trait.
// In case future code generalizes to allow multiple mixins, this code should be updated.
subtypeToTraits.get(tName).flatMap { set =>
// If the type parameter has exactly 1 super type trait,
// and that trait's name is different from our enclosing object's name,
// then we should alter the type to include the object's name.
Option.when(set.size == 1 && set.head != encapsulatingName) {
Code.TypeRef(set.head + "." + tName)
}
}
}.fold(o)(o.copy) :: tail
})
}

// The `mapType` function is used to alter any relevant part of each code file
noDuplicateFiles.map { cf =>
cf.copy(
objects = cf.objects.map(o => mapType(o.name)(o)).asInstanceOf[List[Code.Object]],
caseClasses = cf.caseClasses.map(c => mapType(c.name)(c)).asInstanceOf[List[Code.CaseClass]],
enums = cf.enums.map(e => mapType(e.name)(e)).asInstanceOf[List[Code.Enum]],
objects = cf.objects.map(mapType(_.name, subtypeToTraits)),
caseClasses = cf.caseClasses.map(mapType(_.name, subtypeToTraits)),
enums = cf.enums.map(mapType(_.name, subtypeToTraits)),
)
}
}

/**
* The types may not be valid in case we reference a concrete subtype of a
* sealed trait, as the subtype is defined as an inner class encapsulated
* inside the trait's companion. Therefore, we can alter the type to include
* the enclosing trait/object's name. The following function will be used to
* alter the type of all fields needed. The `mapCaseClasses` helper takes a
* function that alters a case class, and lifts it such that we can apply it
* to any structure, and it'll take care to recurse when needed.
*
* @param getEncapsulatingName
* used to get the name of the code structure we operate on
* (Object/CaseClass/Enum)
* @param subtypeToTraits
* mappings of subtypes to their mixins - if theres only one, we assume
* subtype is nested.
* @param codeStructureToAlter
* the structure to modify
* @return
* the modified structure
*/
def mapType[T <: Code.ScalaType](getEncapsulatingName: T => String, subtypeToTraits: Map[String, Set[String]])(
codeStructureToAlter: T,
): T =
mapCaseClasses { cc =>
cc.copy(fields = cc.fields.foldRight(List.empty[Code.Field]) { case (f @ Code.Field(_, scalaType), tail) =>
f.copy(fieldType = mapTypeRef(scalaType) { case originalType @ Code.TypeRef(tName) =>
// We use the subtypeToTraits map to check if the type is a concrete subtype of a sealed trait.
// As of the time of writing this code, there should be only a single trait.
// In case future code generalizes to allow multiple mixins, this code should be updated.
subtypeToTraits.get(tName).fold(originalType) { set =>
// If the type parameter has exactly 1 super type trait,
// and that trait's name is different from our enclosing object's name,
// then we should alter the type to include the object's name.
if (set.size != 1 || set.head == getEncapsulatingName(codeStructureToAlter)) originalType
else Code.TypeRef(set.head + "." + tName)
}
}) :: tail
})
}(codeStructureToAlter)

/**
* Given the type parameter of a field, we may want to alter it, e.g. by
* prepending the enclosing trait/object's name. This function will
Expand All @@ -229,14 +237,14 @@ final case class EndpointGen(config: Config) {
* @return
* [[Option]] of the altered type, or None if no modification was needed.
*/
def mapTypeRef(sType: Code.ScalaType)(f: Code.TypeRef => Option[Code.TypeRef]): Option[Code.ScalaType] =
def mapTypeRef(sType: Code.ScalaType)(f: Code.TypeRef => Code.TypeRef): Code.ScalaType =
sType match {
case tref: Code.TypeRef => f(tref)
case Collection.Seq(inner) => Some(Collection.Seq(mapTypeRef(inner)(f).getOrElse(inner)))
case Collection.Set(inner) => Some(Collection.Set(mapTypeRef(inner)(f).getOrElse(inner)))
case Collection.Map(inner) => Some(Collection.Map(mapTypeRef(inner)(f).getOrElse(inner)))
case Collection.Opt(inner) => Some(Collection.Opt(mapTypeRef(inner)(f).getOrElse(inner)))
case _ => None
case Collection.Seq(inner) => Collection.Seq(mapTypeRef(inner)(f))
case Collection.Set(inner) => Collection.Set(mapTypeRef(inner)(f))
case Collection.Map(inner) => Collection.Map(mapTypeRef(inner)(f))
case Collection.Opt(inner) => Collection.Opt(mapTypeRef(inner)(f))
case _ => sType
}

/**
Expand All @@ -250,17 +258,17 @@ final case class EndpointGen(config: Config) {
* @return
* the transformed structure
*/
def mapCaseClasses(f: Code.CaseClass => Code.CaseClass)(code: Code.ScalaType): Code.ScalaType =
code match {
def mapCaseClasses[T <: Code.ScalaType](f: Code.CaseClass => Code.CaseClass)(code: T): T =
(code match {
case obj: Code.Object =>
obj.copy(
caseClasses = obj.caseClasses.map(mapCaseClasses(f)).asInstanceOf[List[Code.CaseClass]],
objects = obj.objects.map(mapCaseClasses(f)).asInstanceOf[List[Code.Object]],
caseClasses = obj.caseClasses.map(mapCaseClasses(f)),
objects = obj.objects.map(mapCaseClasses(f)),
)
case cc: Code.CaseClass => f(cc)
case sum: Code.Enum => sum.copy(cases = sum.cases.map(mapCaseClasses(f)).asInstanceOf[List[Code.CaseClass]])
case sum: Code.Enum => sum.copy(cases = sum.cases.map(mapCaseClasses(f)))
case _ => code
}
}).asInstanceOf[T]

def fromOpenAPI(openAPI: OpenAPI): Code.Files =
Code.Files {
Expand Down

0 comments on commit 8cdbe46

Please sign in to comment.