Skip to content

Commit

Permalink
Move base schema support to plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Jan 30, 2024
1 parent 93f9f97 commit 6b08fce
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 100 deletions.
33 changes: 30 additions & 3 deletions openapify/core/base_plugins.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from typing import Any, ByteString, Dict, Optional, Union

from apispec import APISpec
from mashumaro.jsonschema import OPEN_API_3_1, JSONSchemaBuilder

from openapify.core.models import Body, Cookie, Header, QueryParam
from openapify.plugin import BasePlugin


class BodyBinaryPlugin(BasePlugin):
def schema_helper(
self,
definition: Union[Body, Cookie, Header, QueryParam],
obj: Union[Body, Cookie, Header, QueryParam],
name: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
try:
if isinstance(definition, Body) and issubclass(
definition.value_type, ByteString # type: ignore
if isinstance(obj, Body) and issubclass(
obj.value_type, ByteString # type: ignore
):
return {}
else:
Expand All @@ -31,3 +34,27 @@ def media_type_helper(
return "application/octet-stream"
else:
return "application/json"


class BaseSchemaPlugin(BasePlugin):
spec: APISpec

def init_spec(self, spec: APISpec) -> None:
self.spec = spec

def schema_helper(
self,
obj: Union[Body, Cookie, Header, QueryParam],
name: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
builder = JSONSchemaBuilder(
dialect=OPEN_API_3_1, ref_prefix=f"#/components/schemas"
)
try:
json_schema = builder.build(obj.value_type)
except Exception:
return None
schemas = self.spec.components.schemas
for name, schema in builder.context.definitions.items():
schemas[name] = schema.to_dict()
return json_schema.to_dict()
77 changes: 38 additions & 39 deletions openapify/core/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@

import apispec

from openapify.core.base_plugins import BodyBinaryPlugin, GuessMediaTypePlugin
from openapify.core.base_plugins import (
BaseSchemaPlugin,
BodyBinaryPlugin,
GuessMediaTypePlugin,
)
from openapify.core.const import (
DEFAULT_OPENAPI_VERSION,
DEFAULT_SPEC_TITLE,
DEFAULT_SPEC_VERSION,
RESPONSE_DESCRIPTIONS,
)
from openapify.core.document import OpenAPIDocument
from openapify.core.jsonschema import ComponentType, build_json_schema
from openapify.core.models import (
Body,
Cookie,
Expand All @@ -35,7 +38,7 @@
from openapify.core.openapi import models as openapi
from openapify.plugin import BasePlugin

BASE_PLUGINS = (BodyBinaryPlugin(), GuessMediaTypePlugin())
BASE_PLUGINS = (BodyBinaryPlugin(), GuessMediaTypePlugin(), BaseSchemaPlugin())


METHOD_ORDER = [
Expand Down Expand Up @@ -94,6 +97,8 @@ def __init__(
)
self.spec = spec
self.plugins: Sequence[BasePlugin] = (*plugins, *BASE_PLUGINS)
for plugin in self.plugins:
plugin.init_spec(spec)

def feed_routes(self, routes: Iterable[RouteDef]) -> None:
for route in sorted(
Expand Down Expand Up @@ -198,13 +203,11 @@ def _build_query_params(
for name, param in query_params.items():
if not isinstance(param, QueryParam):
param = QueryParam(param)
parameter_schema = self.__build_definition_schema_with_plugins(
parameter_schema = self.__build_object_schema_with_plugins(
param, name
)
if parameter_schema is None:
parameter_schema = build_json_schema(
param.value_type, self.spec, ComponentType.PARAMETER
)
parameter_schema = {}
if param.default is not None:
parameter_schema["default"] = param.default
result.append(
Expand All @@ -231,15 +234,11 @@ def _build_request_headers(
for name, header in headers.items():
if not isinstance(header, Header):
header = Header(description=header)
parameter_schema = self.__build_definition_schema_with_plugins(
parameter_schema = self.__build_object_schema_with_plugins(
header, name
)
if parameter_schema is None:
parameter_schema = build_json_schema(
instance_type=header.value_type,
spec=self.spec,
component_type=ComponentType.PARAMETER,
)
parameter_schema = {}
result.append(
openapi.Parameter(
name=name,
Expand All @@ -262,15 +261,11 @@ def _build_response_headers(
for name, header in headers.items():
if not isinstance(header, Header):
header = Header(description=header)
header_schema = self.__build_definition_schema_with_plugins(
header_schema = self.__build_object_schema_with_plugins(
header, name
)
if header_schema is None:
header_schema = build_json_schema(
instance_type=header.value_type,
spec=self.spec,
component_type=ComponentType.HEADER,
)
header_schema = {}
result[name] = openapi.Header(
schema=header_schema,
description=header.description,
Expand All @@ -289,15 +284,11 @@ def _build_cookies(
for name, cookie in cookies.items():
if not isinstance(cookie, Cookie):
cookie = Cookie(cookie)
parameter_schema = self.__build_definition_schema_with_plugins(
parameter_schema = self.__build_object_schema_with_plugins(
cookie, name
)
if parameter_schema is None:
parameter_schema = build_json_schema(
instance_type=cookie.value_type,
spec=self.spec,
component_type=ComponentType.PARAMETER,
)
parameter_schema = {}
result.append(
openapi.Parameter(
name=name,
Expand Down Expand Up @@ -339,9 +330,9 @@ def _update_request_body(
example=example,
examples=examples,
)
body_schema = self.__build_definition_schema_with_plugins(body)
body_schema = self.__build_object_schema_with_plugins(body)
if body_schema is None:
body_schema = build_json_schema(value_type, self.spec)
body_schema = {}
if media_type is None:
media_type = self._determine_body_media_type(body, body_schema)
elif media_type is not None:
Expand Down Expand Up @@ -392,9 +383,9 @@ def _update_responses(
example=example,
examples=examples,
)
body_schema = self.__build_definition_schema_with_plugins(body_obj)
body_schema = self.__build_object_schema_with_plugins(body_obj)
if body_schema is None:
body_schema = build_json_schema(body, self.spec)
body_schema = {}
if media_type is None:
media_type = self._determine_body_media_type(
body_obj, body_schema
Expand Down Expand Up @@ -455,19 +446,12 @@ def _build_examples(
result[key] = openapi.Example(value)
return result

def __build_definition_schema_with_plugins(
def __build_object_schema_with_plugins(
self,
definition: Union[Body, Cookie, Header, QueryParam],
obj: Union[Body, Cookie, Header, QueryParam],
name: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
for plugin in self.plugins:
try:
schema = plugin.schema_helper(definition, name)
if schema is not None:
return schema
except NotImplementedError:
continue
return None
return build_object_schema_with_plugins(obj, self.plugins, name)

def _determine_body_media_type(
self, body: Body, schema: Dict[str, Any]
Expand All @@ -482,6 +466,21 @@ def _determine_body_media_type(
return None


def build_object_schema_with_plugins(
obj: Union[Body, Cookie, Header, QueryParam],
plugins: Sequence[BasePlugin],
name: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
for plugin in plugins:
try:
schema = plugin.schema_helper(obj, name)
if schema is not None:
return schema
except NotImplementedError:
continue
return None


@overload
def build_spec(
routes: Iterable[RouteDef], spec: apispec.APISpec
Expand Down
54 changes: 0 additions & 54 deletions openapify/core/jsonschema.py

This file was deleted.

6 changes: 4 additions & 2 deletions openapify/ext/web/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from aiohttp.typedefs import Handler
from aiohttp.web_app import Application
from apispec import APISpec
from mashumaro.jsonschema import OPEN_API_3_1, build_json_schema
from mashumaro.jsonschema.annotations import Pattern
from typing_extensions import Annotated

Expand All @@ -28,7 +29,6 @@
DEFAULT_SPEC_TITLE,
DEFAULT_SPEC_VERSION,
)
from openapify.core.jsonschema import build_json_schema
from openapify.core.models import RouteDef
from openapify.core.openapi.models import (
Parameter,
Expand Down Expand Up @@ -84,7 +84,9 @@ def _sub(match: re.Match) -> str:
name=name,
location=ParameterLocation.PATH,
required=True,
schema=build_json_schema(instance_type),
schema=build_json_schema(
instance_type, dialect=OPEN_API_3_1
).to_dict(),
)
)
return f"{{{name}}}"
Expand Down
4 changes: 2 additions & 2 deletions openapify/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@


class BasePlugin:
def init_spec(self, spec: Optional[APISpec]) -> None:
def init_spec(self, spec: APISpec) -> None:
pass

def schema_helper(
self,
definition: Union[Body, Cookie, Header, QueryParam],
obj: Union[Body, Cookie, Header, QueryParam],
name: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
raise NotImplementedError
Expand Down

0 comments on commit 6b08fce

Please sign in to comment.