From e9aa674750b12a6654a34da669f6c4c217814783 Mon Sep 17 00:00:00 2001 From: Tamas Tomorkenyi Date: Mon, 21 Oct 2024 13:01:06 +0200 Subject: [PATCH] Add auth to generated OpenAPI (#3183) --- .../endpoint/openapi/OpenAPIGenSpec.scala | 637 +++++++++++++++++- .../http/endpoint/openapi/OpenAPIGen.scala | 88 ++- 2 files changed, 716 insertions(+), 9 deletions(-) diff --git a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala index 73ceede3a0..3a5ada64c9 100644 --- a/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/endpoint/openapi/OpenAPIGenSpec.scala @@ -1,18 +1,19 @@ package zio.http.endpoint.openapi +import zio.http.Header.Authorization.{Basic, Bearer} import zio.json.ast.Json import zio.test._ import zio.{Chunk, Scope, ZIO} - import zio.schema.annotation._ import zio.schema.validation.Validation import zio.schema.{DeriveSchema, Schema} - import zio.http.Method.{GET, POST} import zio.http._ import zio.http.codec.PathCodec.{empty, string} +import zio.http.codec.TextCodec.StringCodec import zio.http.codec._ import zio.http.endpoint._ +import zio.http.endpoint.openapi.OpenAPIGen.{apiKeyHeaderName, apiKeyQueryParamName} object OpenAPIGenSpec extends ZIOSpecDefault { @@ -208,6 +209,44 @@ object OpenAPIGenSpec extends ZIOSpecDefault { .out[SimpleOutputBody] .outError[NotFoundError](Status.NotFound) + private val apiKeyAuthHeaderEndpoint = + Endpoint(GET / "withAuthHeader") + .in[SimpleInputBody] + .header(HttpCodec.Header.apply(apiKeyHeaderName, StringCodec).optional) + .out[SimpleOutputBody] + .outError[NotFoundError](Status.NotFound) + + private val apiKeyAuthQueryEndpoint = + Endpoint(GET / "withAuthQuery") + .in[SimpleInputBody] + .query(HttpCodec.query[String](apiKeyQueryParamName).optional) + .out[SimpleOutputBody] + .outError[NotFoundError](Status.NotFound) + + private val basicAuthEndpoint = + Endpoint(GET / "withAuthHeader") + .in[SimpleInputBody] + .header(HttpCodec.authorization) + .out[SimpleOutputBody] + .outError[NotFoundError](Status.NotFound) + .auth(AuthType.Basic) + + private val bearerAuthEndpoint = + Endpoint(GET / "withAuthHeader") + .in[SimpleInputBody] + .header(HttpCodec.authorization) + .out[SimpleOutputBody] + .outError[NotFoundError](Status.NotFound) + .auth(AuthType.Bearer) + + private val digestAuthEndpoint = + Endpoint(GET / "withAuthHeader") + .in[SimpleInputBody] + .header(HttpCodec.authorization) + .out[SimpleOutputBody] + .outError[NotFoundError](Status.NotFound) + .auth(AuthType.Digest) + private val optionalPayloadEndpoint = Endpoint(GET / "withPayload") .inCodec(HttpCodec.content[Payload].optional) @@ -1010,6 +1049,600 @@ object OpenAPIGenSpec extends ZIOSpecDefault { |}""".stripMargin assertTrue(json == toJsonAst(expectedJson)) }, + test("api key auth in header") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", apiKeyAuthHeaderEndpoint) + val json = toJsonAst(generated) + val expectedJson = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/withAuthHeader" : { + | "get" : { + | "parameters" : [ + | { + | "name" : "x-api-key", + | "in" : "header", + | "schema" : { + | "type" : [ + | "string", + | "null" + | ] + | }, + | "style" : "simple" + | } + | ], + | "requestBody" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleInputBody", + | "description" : "" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleOutputBody" + | } + | } + | } + | }, + | "404" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/NotFoundError" + | } + | } + | } + | } + | }, + | "security" : [ + | { + | "apiKeyAuth" : [] + | } + | ] + | } + | } + | }, + | "components" : { + | "schemas" : { + | "NotFoundError" : { + | "type" : "object", + | "properties" : { + | "message" : { + | "type" : "string" + | } + | }, + | "required" : [ + | "message" + | ] + | }, + | "SimpleInputBody" : { + | "type" : "object", + | "properties" : { + | "name" : { + | "type" : "string" + | }, + | "age" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "SimpleOutputBody" : { + | "type" : "object", + | "properties" : { + | "userName" : { + | "type" : "string" + | }, + | "score" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "userName", + | "score" + | ] + | } + | }, + | "securitySchemes" : { + | "apiKeyAuth" : { + | "type" : "apiKey", + | "name" : "x-api-key", + | "in" : "header" + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("api key auth in query") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", apiKeyAuthQueryEndpoint) + val json = toJsonAst(generated) + val expectedJson = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/withAuthQuery" : { + | "get" : { + | "parameters" : [ + | { + | "name" : "api_key", + | "in" : "query", + | "schema" : { + | "type" : [ + | "string", + | "null" + | ] + | }, + | "allowReserved" : false, + | "style" : "form" + | } + | ], + | "requestBody" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleInputBody" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleOutputBody" + | } + | } + | } + | }, + | "404" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/NotFoundError" + | } + | } + | } + | } + | }, + | "security" : [ + | { + | "apiKeyAuth" : [] + | } + | ] + | } + | } + | }, + | "components" : { + | "schemas" : { + | "NotFoundError" : { + | "type" : "object", + | "properties" : { + | "message" : { + | "type" : "string" + | } + | }, + | "required" : [ + | "message" + | ] + | }, + | "SimpleInputBody" : { + | "type" : "object", + | "properties" : { + | "name" : { + | "type" : "string" + | }, + | "age" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "SimpleOutputBody" : { + | "type" : "object", + | "properties" : { + | "userName" : { + | "type" : "string" + | }, + | "score" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "userName", + | "score" + | ] + | } + | }, + | "securitySchemes" : { + | "apiKeyAuth" : + | { + | "type" : "apiKey", + | "name" : "api_key", + | "in" : "query" + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("basic auth in query") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", basicAuthEndpoint) + val json = toJsonAst(generated) + val expectedJson = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/withAuthHeader" : { + | "get" : { + | "parameters" : [ + | { + | "name": "authorization", + | "in" : "header", + | "required" : true, + | "schema" : { + | "type" : "string" + | }, + | "style" : "simple" + | } + | ], + | "requestBody" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleInputBody" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleOutputBody" + | } + | } + | } + | }, + | "404" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/NotFoundError" + | } + | } + | } + | } + | }, + | "security" : [ + | { + | "basicAuth" : [] + | } + | ] + | } + | } + | }, + | "components" : { + | "schemas" : { + | "NotFoundError" : { + | "type" : "object", + | "properties" : { + | "message" : { + | "type" : "string" + | } + | }, + | "required" : [ + | "message" + | ] + | }, + | "SimpleInputBody" : { + | "type" : "object", + | "properties" : { + | "name" : { + | "type" : "string" + | }, + | "age" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "SimpleOutputBody" : { + | "type" : "object", + | "properties" : { + | "userName" : { + | "type" : "string" + | }, + | "score" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "userName", + | "score" + | ] + | } + | }, + | "securitySchemes" : { + | "basicAuth" : { + | "type" : "http", + | "scheme" : "basic" + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("bearer auth in header") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", bearerAuthEndpoint) + val json = toJsonAst(generated) + val expectedJson = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/withAuthHeader" : { + | "get" : { + | "parameters" : [ + | { + | "name": "authorization", + | "in" : "header", + | "required" : true, + | "schema" : { + | "type" : "string" + | }, + | "style" : "simple" + | } + | ], + | "requestBody" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleInputBody" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleOutputBody" + | } + | } + | } + | }, + | "404" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/NotFoundError" + | } + | } + | } + | } + | }, + | "security" : [ + | { + | "bearerAuth" : [] + | } + | ] + | } + | } + | }, + | "components" : { + | "schemas" : { + | "NotFoundError" : { + | "type" : "object", + | "properties" : { + | "message" : { + | "type" : "string" + | } + | }, + | "required" : [ + | "message" + | ] + | }, + | "SimpleInputBody" : { + | "type" : "object", + | "properties" : { + | "name" : { + | "type" : "string" + | }, + | "age" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "SimpleOutputBody" : { + | "type" : "object", + | "properties" : { + | "userName" : { + | "type" : "string" + | }, + | "score" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "userName", + | "score" + | ] + | } + | }, + | "securitySchemes" : { + | "bearerAuth" : { + | "type" : "http", + | "scheme" : "bearer" + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, + test("digest auth in header") { + val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", digestAuthEndpoint) + val json = toJsonAst(generated) + val expectedJson = """{ + | "openapi" : "3.1.0", + | "info" : { + | "title" : "Simple Endpoint", + | "version" : "1.0" + | }, + | "paths" : { + | "/withAuthHeader" : { + | "get" : { + | "parameters" : [ + | { + | "name": "authorization", + | "in" : "header", + | "required" : true, + | "schema" : { + | "type" : "string" + | }, + | "style" : "simple" + | } + | ], + | "requestBody" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleInputBody" + | } + | } + | }, + | "required" : true + | }, + | "responses" : { + | "200" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/SimpleOutputBody" + | } + | } + | } + | }, + | "404" : { + | "content" : { + | "application/json" : { + | "schema" : { + | "$ref" : "#/components/schemas/NotFoundError" + | } + | } + | } + | } + | }, + | "security" : [ + | { + | "digestAuth" : [] + | } + | ] + | } + | } + | }, + | "components" : { + | "schemas" : { + | "NotFoundError" : { + | "type" : "object", + | "properties" : { + | "message" : { + | "type" : "string" + | } + | }, + | "required" : [ + | "message" + | ] + | }, + | "SimpleInputBody" : { + | "type" : "object", + | "properties" : { + | "name" : { + | "type" : "string" + | }, + | "age" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "name", + | "age" + | ] + | }, + | "SimpleOutputBody" : { + | "type" : "object", + | "properties" : { + | "userName" : { + | "type" : "string" + | }, + | "score" : { + | "type" : "integer", + | "format" : "int32" + | } + | }, + | "required" : [ + | "userName", + | "score" + | ] + | } + | }, + | "securitySchemes" : { + | "digestAuth" : { + | "type" : "http", + | "scheme" : "digest" + | } + | } + | } + |}""".stripMargin + assertTrue(json == toJsonAst(expectedJson)) + }, test("optional payload") { val generated = OpenAPIGen.fromEndpoints("Simple Endpoint", "1.0", optionalPayloadEndpoint) val json = toJsonAst(generated) diff --git a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala index 64ba057807..9d86d9f106 100644 --- a/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala +++ b/zio-http/shared/src/main/scala/zio/http/endpoint/openapi/OpenAPIGen.scala @@ -1,29 +1,37 @@ package zio.http.endpoint.openapi import java.util.UUID - import scala.annotation.tailrec import scala.collection.immutable.ListMap import scala.collection.{immutable, mutable} - import zio._ +import zio.http.Header.Authorization +import zio.http.Header.Authorization.Basic import zio.json.EncoderOps import zio.json.ast.Json - import zio.schema.Schema.{Record, Transform} import zio.schema.codec.JsonCodec import zio.schema.{Schema, TypeId} - import zio.http._ import zio.http.codec.HttpCodec.Metadata import zio.http.codec._ +import zio.http.endpoint.AuthType.Bearer import zio.http.endpoint._ import zio.http.endpoint.openapi.JsonSchema.SchemaStyle -import zio.http.endpoint.openapi.OpenAPI.{Path, PathItem} +import zio.http.endpoint.openapi.OpenAPI.SecurityScheme.{ApiKey, SecurityRequirement} +import zio.http.endpoint.openapi.OpenAPI.{Key, Path, PathItem, ReferenceOr, SecurityScheme, securityRequirementSchema} object OpenAPIGen { private val PathWildcard = "pathWildcard" + val apiKeyHeaderName = "x-api-key" // opinionated names + val apiKeyQueryParamName = "api_key" + val apiKeyAuth = "apiKeyAuth" + val basicAuth = "basicAuth" + val bearerAuth = "bearerAuth" + val digestAuth = "digestAuth" + private val noAuth = "noAuth" // will be removed + private[openapi] def groupMap[A, K, B](chunk: Chunk[A])(key: A => K)(f: A => B): immutable.Map[K, Chunk[B]] = { val m = mutable.Map.empty[K, mutable.Builder[B, Chunk[B]]] for (elem <- chunk) { @@ -591,7 +599,7 @@ object OpenAPIGen { requestBody = requestBody, responses = responses, callbacks = Map.empty, - security = Nil, + security = security(endpoint), servers = Nil, ) } @@ -629,6 +637,10 @@ object OpenAPIGen { def responses: OpenAPI.Responses = responsesForAlternatives(outs) + def security(endpoint: Endpoint[_, _, _, _, _]): List[SecurityRequirement] = { + extractSecurityRequirements(endpoint, queryParams ++ headerParams) + } + def parameters: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]] = queryParams ++ pathParams ++ headerParams @@ -780,6 +792,51 @@ object OpenAPIGen { } } + def collectSecurityScheme(endpoint: Endpoint[_, _, _, _, _], params: Set[ReferenceOr[OpenAPI.Parameter]]) = { + ListMap( + params.collect { + case OpenAPI.ReferenceOr.Or(param) + if param.name.equalsIgnoreCase(apiKeyHeaderName) || param.name.equalsIgnoreCase(apiKeyQueryParamName) => + createApiKeySecurityScheme(param.name) + }.toSeq ++ + Set( + endpoint.authType match { + case AuthType.Basic => createHttpSecurityScheme(basicAuth, "basic") + case AuthType.Bearer => createHttpSecurityScheme(bearerAuth, "bearer") + case AuthType.Digest => createHttpSecurityScheme(digestAuth, "digest") + case _ => createHttpSecurityScheme() + }, + ).filterNot(_._1 == OpenAPI.Key.fromString("noAuth").get).toSeq: _*, + ) + } + + def createApiKeySecurityScheme(name: String) = + OpenAPI.Key.fromString(apiKeyAuth).get -> OpenAPI.ReferenceOr.Or( + OpenAPI.SecurityScheme.ApiKey(description = None, name = name, in = in(name)), + ) + + def in(paramName: String): OpenAPI.SecurityScheme.ApiKey.In = { + if (paramName.equalsIgnoreCase(apiKeyHeaderName)) + OpenAPI.SecurityScheme.ApiKey.In.Header + else if (paramName.equalsIgnoreCase(apiKeyQueryParamName)) + OpenAPI.SecurityScheme.ApiKey.In.Query + else throw new IllegalArgumentException(s"Unknown apiKey param name: $paramName") + } + + def createHttpSecurityScheme( + name: String = noAuth, + scheme: String = "", + description: Option[Doc] = None, + ): (OpenAPI.Key, OpenAPI.ReferenceOr[OpenAPI.SecurityScheme.Http]) = { + OpenAPI.Key.fromString(name).get -> OpenAPI.ReferenceOr.Or( + OpenAPI.SecurityScheme.Http( + scheme = scheme, + bearerFormat = None, + description = description, + ), + ) + } + def components = OpenAPI.Components( schemas = ListMap(componentSchemas.toSeq.sortBy(_._1.name): _*), responses = ListMap.empty, @@ -787,7 +844,7 @@ object OpenAPIGen { examples = ListMap.empty, requestBodies = ListMap.empty, headers = ListMap.empty, - securitySchemes = ListMap.empty, + securitySchemes = collectSecurityScheme(endpoint, headerParams ++ queryParams), links = ListMap.empty, callbacks = ListMap.empty, ) @@ -1012,6 +1069,23 @@ object OpenAPIGen { ) } + private def extractSecurityRequirements( + endpoint: Endpoint[_, _, _, _, _], + parameters: Set[OpenAPI.ReferenceOr[OpenAPI.Parameter]], + ): List[SecurityRequirement] = { + parameters.collect { + case OpenAPI.ReferenceOr.Or(param: OpenAPI.Parameter) + if param.name.equalsIgnoreCase(apiKeyHeaderName) || param.name.equalsIgnoreCase(apiKeyQueryParamName) => + SecurityRequirement(Map(apiKeyAuth -> Nil)) + }.toList ++ + List(endpoint.authType match { + case AuthType.Basic => SecurityRequirement(Map(basicAuth -> Nil)) + case AuthType.Bearer => SecurityRequirement(Map(bearerAuth -> Nil)) + case AuthType.Digest => SecurityRequirement(Map(digestAuth -> Nil)) + case _ => SecurityRequirement(Map.empty) + }).filterNot(_.securitySchemes.isEmpty) + } + private def headersFrom(codec: AtomizedMetaCodecs) = { codec.header.map { case mc @ MetaCodec(codec, _) => codec.name -> OpenAPI.ReferenceOr.Or(