diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala index 967bd1630d..b7f2f46640 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala @@ -160,6 +160,11 @@ object OpenAPIGenSpec extends ZIOSpecDefault { implicit def schema[T: Schema]: Schema[WithGenericPayload[T]] = DeriveSchema.gen } + final case class WithOptionalAdtPayload(optionalAdtField: Option[SealedTraitCustomDiscriminator]) + object WithOptionalAdtPayload { + implicit val schema: Schema[WithOptionalAdtPayload] = DeriveSchema.gen + } + private val simpleEndpoint = Endpoint( (GET / "static" / int("id") / uuid("uuid") ?? Doc.p("user id") / string("name")) ?? Doc.p("get path"), @@ -2231,9 +2236,9 @@ object OpenAPIGenSpec extends ZIOSpecDefault { | "discriminator" : { | "propertyName" : "type", | "mapping" : { - | "One" : "#/components/schemas/One}", - | "Two" : "#/components/schemas/Two}", - | "three" : "#/components/schemas/Three}" + | "One" : "#/components/schemas/One", + | "Two" : "#/components/schemas/Two", + | "three" : "#/components/schemas/Three" | } | } | }, @@ -2810,7 +2815,10 @@ object OpenAPIGenSpec extends ZIOSpecDefault { | ] | }, | "nestedOption" : { - | "$ref" : "#/components/schemas/Recursive" + | "anyOf" : [ + | { "type" : "null" }, + | { "$ref" : "#/components/schemas/Recursive" } + | ] | }, | "nestedMap" : { | "type" : "object", @@ -2831,7 +2839,6 @@ object OpenAPIGenSpec extends ZIOSpecDefault { | } | }, | "required" : [ - | "nestedOption", | "nestedList", | "nestedMap", | "nestedSet", @@ -2989,6 +2996,111 @@ object OpenAPIGenSpec extends ZIOSpecDefault { |""".stripMargin assertTrue(json == toJsonAst(expectedJson)) }, + test("Optional ADT payload") { + val endpoint = Endpoint(GET / "static").in[WithOptionalAdtPayload] + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", endpoint) + val json = toJsonAst(generated) + val expectedJson = + """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/static" : { + | "get" : { + | "requestBody" : + | { + | "content" : { + | "application/json" : { + | "schema" : + | { + | "$ref" : "#/components/schemas/WithOptionalAdtPayload" + | } + | } + | }, + | "required" : true + | } + | } + | } + | }, + | "components" : { + | "schemas" : { + | "One" : + | { + | "type" : + | "object", + | "properties" : {} + | }, + | "SealedTraitCustomDiscriminator" : + | { + | "oneOf" : [ + | { + | "$ref" : "#/components/schemas/One" + | }, + | { + | "$ref" : "#/components/schemas/Two" + | }, + | { + | "$ref" : "#/components/schemas/Three" + | } + | ], + | "discriminator" : { + | "propertyName" : "type", + | "mapping" : { + | "One" : "#/components/schemas/One", + | "Two" : "#/components/schemas/Two", + | "three" : "#/components/schemas/Three" + | } + | } + | }, + | "WithOptionalAdtPayload" : + | { + | "type" : + | "object", + | "properties" : { + | "optionalAdtField" : { + | "anyOf": [ + | { "type": "null" }, + | { "$ref": "#/components/schemas/SealedTraitCustomDiscriminator" } + | ] + | } + | } + | }, + | "Two" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "required" : [ + | "name" + | ] + | }, + | "Three" : + | { + | "type" : + | "object", + | "properties" : { + | "name" : { + | "type" : + | "string" + | } + | }, + | "required" : [ + | "name" + | ] + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, ) } diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala index ed8e1bc284..40203a8419 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala @@ -12,7 +12,8 @@ import zio.schema.codec._ import zio.schema.codec.json._ import zio.schema.validation._ -import zio.http.codec.{PathCodec, SegmentCodec, TextCodec} +import zio.http.codec.{SegmentCodec, TextCodec} +import zio.http.endpoint.openapi.JsonSchema.MetaData @nowarn("msg=possible missing interpolator") private[openapi] case class SerializableJsonSchema( @@ -47,23 +48,35 @@ private[openapi] case class SerializableJsonSchema( uniqueItems: Option[Boolean] = None, minItems: Option[Int] = None, ) { - def asNullableType(nullable: Boolean): SerializableJsonSchema = + def asNullableType(nullable: Boolean): SerializableJsonSchema = { + import SerializableJsonSchema.typeNull + if (nullable && schemaType.isDefined) copy(schemaType = Some(schemaType.get.add("null"))) else if (nullable && oneOf.isDefined) - copy(oneOf = Some(oneOf.get :+ SerializableJsonSchema(schemaType = Some(TypeOrTypes.Type("null"))))) + copy(oneOf = Some(oneOf.get :+ typeNull)) else if (nullable && allOf.isDefined) - SerializableJsonSchema(allOf = - Some(Chunk(this, SerializableJsonSchema(schemaType = Some(TypeOrTypes.Type("null"))))), - ) + SerializableJsonSchema(allOf = Some(Chunk(this, typeNull))) else if (nullable && anyOf.isDefined) - copy(anyOf = Some(anyOf.get :+ SerializableJsonSchema(schemaType = Some(TypeOrTypes.Type("null"))))) + copy(anyOf = Some(anyOf.get :+ typeNull)) + else if (nullable && ref.isDefined) + SerializableJsonSchema(anyOf = Some(Chunk(typeNull, this))) else this - + } } private[openapi] object SerializableJsonSchema { + + /** + * Used to generate a OpenAPI schema part looking like this: + * {{{ + * { "type": "null"} + * }}} + */ + private[SerializableJsonSchema] val typeNull: SerializableJsonSchema = + SerializableJsonSchema(schemaType = Some(TypeOrTypes.Type("null"))) + implicit val doubleOrLongSchema: Schema[Either[Double, Long]] = Schema.fallback(Schema[Double], Schema[Long]).transform(_.toEither, Fallback.fromEither) @@ -149,6 +162,12 @@ sealed trait JsonSchema extends Product with Serializable { self => case _ => Chunk.empty } + final def isNullable: Boolean = + annotations.exists { + case MetaData.Nullable(nullable) => nullable + case _ => false + } + def withoutAnnotations: JsonSchema = self match { case JsonSchema.AnnotatedSchema(schema, _) => schema.withoutAnnotations case _ => self @@ -861,22 +880,18 @@ object JsonSchema { val markedForRemoval = (for { obj <- objects otherObj <- objects - notNullableSchemas = obj.withoutAnnotations.asInstanceOf[JsonSchema.Object].properties.collect { - case (name, schema) - if !schema.annotations.exists { case MetaData.Nullable(nullable) => nullable; case _ => false } => - name -> schema - } + notNullableSchemas = + obj.withoutAnnotations + .asInstanceOf[JsonSchema.Object] + .properties + .filterNot { case (_, schema) => schema.isNullable } if notNullableSchemas == otherObj.withoutAnnotations.asInstanceOf[JsonSchema.Object].properties } yield otherObj).distinct val minified = objects.filterNot(markedForRemoval.contains).map { obj => val annotations = obj.annotations val asObject = obj.withoutAnnotations.asInstanceOf[JsonSchema.Object] - val notNullableSchemas = asObject.properties.collect { - case (name, schema) - if !schema.annotations.exists { case MetaData.Nullable(nullable) => nullable; case _ => false } => - name -> schema - } + val notNullableSchemas = asObject.properties.filterNot { case (_, schema) => schema.isNullable } asObject.required(asObject.required.filter(notNullableSchemas.contains)).annotate(annotations) } val newAnyOf = minified ++ others @@ -1266,11 +1281,20 @@ object JsonSchema { case Left(false) => Some(BoolOrSchema.BooleanWrapper(false)) case Right(schema) => Some(BoolOrSchema.SchemaWrapper(schema.toSerializableSchema)) } + + val nullableFields = properties.collect { case (name, schema) if schema.isNullable => name }.toSet + SerializableJsonSchema( schemaType = Some(TypeOrTypes.Type("object")), properties = Some(properties.map { case (name, schema) => name -> schema.toSerializableSchema }), additionalProperties = additionalProperties, - required = if (required.isEmpty) None else Some(required), + required = + if (required.isEmpty) None + else if (nullableFields.isEmpty) Some(required) + else { + val newRequired = required.filterNot(nullableFields.contains) + if (newRequired.isEmpty) None else Some(newRequired) + }, ) } } diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala index 2e4d0e500f..dee8fe5476 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -988,7 +988,7 @@ object OpenAPIGen { } private def schemaReferencePath(nominal: TypeId.Nominal, referenceType: SchemaStyle): String = { referenceType match { - case SchemaStyle.Compact => s"#/components/schemas/${nominal.typeName}}" + case SchemaStyle.Compact => s"#/components/schemas/${nominal.typeName}" case _ => s"#/components/schemas/${nominal.fullyQualified.replace(".", "_")}}" } }