From 3defd4930fa8932f13648f4d41c7af8dddecfb48 Mon Sep 17 00:00:00 2001 From: vvcheremushkin Date: Wed, 12 Feb 2020 11:54:56 +0300 Subject: [PATCH] version 0.12.0 --- requirements.txt | 4 +- setup.py | 52 +++---- star_resty/__init__.py | 2 +- star_resty/apidocs/operation.py | 40 ++++++ star_resty/apidocs/request.py | 53 ++++++++ star_resty/apidocs/response.py | 40 ++++++ star_resty/apidocs/route.py | 31 +++++ star_resty/apidocs/setup.py | 157 ++-------------------- star_resty/method/meta.py | 86 +----------- star_resty/method/method.py | 12 +- star_resty/method/options.py | 11 -- star_resty/method/parser.py | 118 ++++++++++++++++ star_resty/method/render.py | 67 +++++++++ star_resty/method/request_parser.py | 31 ----- star_resty/operation/__init__.py | 2 +- star_resty/operation/schema.py | 22 ++- star_resty/parsers/__init__.py | 1 - star_resty/parsers/query.py | 51 ------- star_resty/{types => payload}/__init__.py | 1 + star_resty/payload/header.py | 32 +++++ star_resty/{types => payload}/json.py | 5 +- star_resty/{types => payload}/parser.py | 10 +- star_resty/{types => payload}/path.py | 3 +- star_resty/payload/query.py | 87 ++++++++++++ star_resty/serializers/json.py | 2 + star_resty/serializers/serializer.py | 2 + star_resty/types/query.py | 32 ----- tests/test_dependencies.py | 2 +- tests/test_method.py | 65 ++++++--- tests/test_query_parser.py | 8 +- 30 files changed, 611 insertions(+), 418 deletions(-) create mode 100644 star_resty/apidocs/operation.py create mode 100644 star_resty/apidocs/request.py create mode 100644 star_resty/apidocs/response.py create mode 100644 star_resty/apidocs/route.py delete mode 100644 star_resty/method/options.py create mode 100644 star_resty/method/parser.py create mode 100644 star_resty/method/render.py delete mode 100644 star_resty/method/request_parser.py delete mode 100644 star_resty/parsers/__init__.py delete mode 100644 star_resty/parsers/query.py rename star_resty/{types => payload}/__init__.py (73%) create mode 100644 star_resty/payload/header.py rename star_resty/{types => payload}/json.py (89%) rename star_resty/{types => payload}/parser.py (80%) rename star_resty/{types => payload}/path.py (92%) create mode 100644 star_resty/payload/query.py delete mode 100644 star_resty/types/query.py diff --git a/requirements.txt b/requirements.txt index d316afa..78b4959 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ ujson typing_extensions marshmallow>=3.0.0rc8,<4 -starlette<=1 -apispec<3 +starlette<1 +apispec<4 # Testing pytest diff --git a/setup.py b/setup.py index 3be5753..24f9a76 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ def get_long_description(): - with open("README.md", encoding="utf8") as f: + with open('README.md', encoding='utf8') as f: return f.read() @@ -15,39 +15,39 @@ def get_packages(package): return [ dirpath for dirpath, dirnames, filenames in os.walk(package) - if os.path.exists(os.path.join(dirpath, "__init__.py")) + if os.path.exists(os.path.join(dirpath, '__init__.py')) ] setup( - name="star_resty", - python_requires=">=3.7", + name='star_resty', + python_requires='>=3.7', install_requires=[ - "ujson", - "typing_extensions", - "marshmallow>=3.0.0rc8,<4", - "starlette<=1", - "apispec<3", + 'ujson', + 'typing_extensions', + 'marshmallow>=3.0.0rc8,<4', + 'starlette<1', + 'apispec<4', ], - version="0.0.11", - url="https://github.com/slv0/start_resty", - license="BSD", - description="The web framework", + version='0.0.12', + url='https://github.com/slv0/start_resty', + license='BSD', + description='The web framework', long_description=get_long_description(), - long_description_content_type="text/markdown", - author="Slava Cheremushkin", - author_email="slv0.chr@gmail.com", - packages=get_packages("star_resty"), - package_data={"star_resty": ["py.typed"]}, - data_files=[("", ["LICENSE"])], + long_description_content_type='text/markdown', + author='Slava Cheremushkin', + author_email='slv0.chr@gmail.com', + packages=get_packages('star_resty'), + package_data={'star_resty': ['py.typed']}, + data_files=[('', ['LICENSE'])], classifiers=[ - "Development Status :: 3 - Alpha", - "Environment :: Web Environment", - "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", - "Operating System :: OS Independent", - "Topic :: Internet :: WWW/HTTP", - "Programming Language :: Python :: 3.7", + 'Development Status :: 3 - Alpha', + 'Environment :: Web Environment', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: BSD License', + 'Operating System :: OS Independent', + 'Topic :: Internet :: WWW/HTTP', + 'Programming Language :: Python :: 3.7', ], zip_safe=False, ) diff --git a/star_resty/__init__.py b/star_resty/__init__.py index 0fac278..995bc6f 100644 --- a/star_resty/__init__.py +++ b/star_resty/__init__.py @@ -1,4 +1,4 @@ from .apidocs import setup_spec from .method import * from .operation import Operation -from .types import * +from .payload import * diff --git a/star_resty/apidocs/operation.py b/star_resty/apidocs/operation.py new file mode 100644 index 0000000..7664a18 --- /dev/null +++ b/star_resty/apidocs/operation.py @@ -0,0 +1,40 @@ +from starlette.routing import Route + +from star_resty.method import Method +from .request import resolve_parameters, resolve_request_body, resolve_request_body_params +from .response import resolve_responses + +__all__ = ('setup_route_operations',) + + +def setup_route_operations(route: Route, endpoint: Method, version: int = 2, + add_head_methods: bool = False): + operation = setup_operation(endpoint, version=version) + operation = {key: val for key, val in operation.items() if val is not None} + return {method.lower(): operation for method in route.methods + if (method != 'HEAD' or add_head_methods)} + + +def setup_operation(endpoint: Method, version=2): + options = endpoint.meta + res = { + 'tags': [options.tag], + 'description': options.description, + 'summary': options.summary, + 'produces': [endpoint.serializer.media_type], + 'parameters': resolve_parameters(endpoint), + 'responses': resolve_responses(endpoint), + } + + if options.security is not None: + res['security'] = options.security + + if version > 2: + res['requestBody'] = resolve_request_body(endpoint) + else: + res['parameters'].extend(resolve_request_body_params(endpoint)) + + if options.meta: + res.update(options.meta) + + return res diff --git a/star_resty/apidocs/request.py b/star_resty/apidocs/request.py new file mode 100644 index 0000000..cef3343 --- /dev/null +++ b/star_resty/apidocs/request.py @@ -0,0 +1,53 @@ +from star_resty.method import Method +from star_resty.method.parser import RequestParser + +__all__ = ('resolve_parameters', 'resolve_request_body', 'resolve_request_body_params') + + +def resolve_parameters(endpoint: Method): + parameters = [] + parser = getattr(endpoint, '__parser__', None) + if parser is None: + return parameters + + for p in parser: + if p.schema is not None and p.location != 'body': + parameters.append({'in': p.location, 'schema': p.schema}) + + return parameters + + +def resolve_request_body(endpoint: Method): + parser = getattr(endpoint, '__parser__', None) + if parser is None: + return None + + content = resolve_request_body_content(parser) + if content: + return {'content': content} + + +def resolve_request_body_content(parser: RequestParser): + content = {} + for p in parser: + if p.schema is not None and p.location == 'body' and p.media_type: + content[p.media_type] = {'schema': p.schema} + + return content + + +def resolve_request_body_params(endpoint: Method): + params = [] + parser = getattr(endpoint, '__parser__', None) + if parser is None: + return params + + for p in parser: + if p.schema is not None and p.location == 'body' and p.media_type: + params.append({ + 'name': 'body', + 'in': 'body', + 'schema': p.schema + }) + + return params diff --git a/star_resty/apidocs/response.py b/star_resty/apidocs/response.py new file mode 100644 index 0000000..fea014e --- /dev/null +++ b/star_resty/apidocs/response.py @@ -0,0 +1,40 @@ +import inspect +from typing import Union, Dict, Type + +from star_resty.method import Method + +__all__ = ('resolve_responses',) + + +def resolve_responses(endpoint: Method): + responses = {} + if endpoint.response_schema: + responses[str(endpoint.status_code)] = { + 'schema': endpoint.response_schema + } + + errors = endpoint.meta.errors or () + for e in errors: + if isinstance(e, dict) and e.get('status_code'): + responses[str(e['status_code'])] = {key: val for key, val in e.items() if key != 'status_code'} + elif isinstance(e, Exception) and getattr(e, 'status_code', None) is not None: + responses[str(getattr(e, 'status_code'))] = create_error_schema_by_exc(e) + elif inspect.isclass(e) and issubclass(e, Exception) and getattr(e, 'status_code', None) is not None: + responses[str(getattr(e, 'status_code'))] = create_error_schema_by_exc(e) + + parser = getattr(endpoint, '__parser__', None) + if parser and '404' not in responses: + responses['400'] = {'description': 'Bad request'} + + return responses + + +def create_error_schema_by_exc(e: Union[Exception, Type[Exception]]) -> Dict: + schema = {'description': (getattr(e, 'detail', None) + or getattr(e, 'description', None) + or str(e))} + error_schema = getattr(e, 'schema', None) + if error_schema is not None: + schema['schema'] = error_schema + + return schema diff --git a/star_resty/apidocs/route.py b/star_resty/apidocs/route.py new file mode 100644 index 0000000..5a3ba45 --- /dev/null +++ b/star_resty/apidocs/route.py @@ -0,0 +1,31 @@ +from typing import Sequence, Union + +from apispec import APISpec +from starlette.routing import Route, Mount + +from .operation import setup_route_operations +from .utils import convert_path + +__all__ = ('setup_routes',) + + +def setup_routes(routes: Sequence[Union[Route, Mount]], + spec: APISpec, version: int = 2, + add_head_methods: bool = False, + path: str = ''): + for route in routes: + if isinstance(route, Mount): + setup_routes(route.routes, spec, version=version, add_head_methods=add_head_methods, + path=f'{path}{route.path}') + continue + elif isinstance(route, Route) and not route.include_in_schema: + continue + + endpoint = getattr(route.endpoint, '__endpoint__', None) + if endpoint is None: + continue + + operations = setup_route_operations(route, endpoint, version=version, + add_head_methods=add_head_methods) + route_path = f'{path}{route.path}' + spec.path(convert_path(route_path), operations=operations) diff --git a/star_resty/apidocs/setup.py b/star_resty/apidocs/setup.py index 2fc3e81..18d90b0 100644 --- a/star_resty/apidocs/setup.py +++ b/star_resty/apidocs/setup.py @@ -1,18 +1,14 @@ -import inspect import logging -from typing import Dict, Optional, Sequence, Type, Union +from typing import Optional, Mapping from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import UJSONResponse -from starlette.routing import Mount, Route -from star_resty.method import Method -from star_resty.method.meta import MethodMetaOptions -from star_resty.method.request_parser import RequestParser -from .utils import convert_path, resolve_schema_name +from .route import setup_routes +from .utils import resolve_schema_name logger = logging.getLogger(__name__) @@ -26,14 +22,18 @@ def setup_spec(app: Starlette, title: str, base_path='/', route: str = '/apidocs.json', add_head_methods: bool = False, + options: Optional[Mapping] = None, **kwargs): + if options is None: + options = {} + spec = APISpec( title=title, version=version, openapi_version=openapi_version, schemes=schemes or ['http', 'https'], plugins=[MarshmallowPlugin(schema_name_resolver=resolve_schema_name)], - **{**kwargs, 'swagger': '2.0', 'basePath': base_path} + **{'swagger': openapi_version, 'basePath': base_path, **options, **kwargs} ) initialized = False @@ -43,9 +43,8 @@ def generate_api_docs(_: Request): nonlocal spec if not initialized: logger.info('initialize open api schema') - setup_paths(app, spec, - version=get_open_api_version(openapi_version), - add_head_methods=add_head_methods) + setup_routes(app.routes, spec, version=get_open_api_version(openapi_version) + , add_head_methods=add_head_methods) initialized = True return UJSONResponse(spec.to_dict()) @@ -57,139 +56,3 @@ def get_open_api_version(version: str) -> int: return int(v) except (ValueError, TypeError): raise ValueError(f'Invalid open api version: {version}') - - -def setup_paths(app: Starlette, spec: APISpec, version: int = 2, - add_head_methods: bool = False): - setup_routes(app.routes, spec, version, add_head_methods) - - -def setup_routes(routes: Sequence[Union[Route, Mount]], - spec: APISpec, version: int = 2, - add_head_methods: bool = False, - path: str = ''): - for route in routes: - if isinstance(route, Mount): - setup_routes(route.routes, spec, version=version, add_head_methods=add_head_methods, - path=f'{path}{route.path}') - continue - elif isinstance(route, Route) and not route.include_in_schema: - continue - - endpoint: Optional[Method] = getattr(route.endpoint, '__endpoint__', None) - if endpoint is None: - continue - - operations = setup_operations(route, endpoint, version=version, - add_head_methods=add_head_methods) - route_path = f'{path}{route.path}' - spec.path(convert_path(route_path), operations=operations) - - -def setup_operations(route: Route, endpoint: Method, version: int = 2, - add_head_methods: bool = False): - operation = setup_operation(endpoint, version=version) - operation = {key: val for key, val in operation.items() if val is not None} - return {method.lower(): operation for method in route.methods - if (method != 'HEAD' or add_head_methods)} - - -def setup_operation(endpoint: Method, version=2): - options = endpoint.meta - meta = endpoint.__meta__ - res = { - 'tags': [options.tag], - 'description': options.description, - 'summary': options.summary, - 'produces': [endpoint.serializer.media_type], - 'parameters': resolve_parameters(meta), - 'responses': resolve_responses(endpoint), - } - - if options.security is not None: - res['security'] = options.security - - if version > 2: - res['requestBody'] = resolve_request_body(meta) - else: - res['parameters'].extend(resolve_request_body_params(meta.parser)) - - return res - - -def resolve_parameters(meta: MethodMetaOptions): - parameters = [] - parser = meta.parser - if parser is None: - return parameters - - for p in parser.iter_parsers(): - if p.schema is not None and p.location != 'body': - parameters.append({'in': p.location, 'schema': p.schema}) - - return parameters - - -def resolve_request_body(meta: MethodMetaOptions): - parser = meta.parser - if parser is None: - return None - - content = resolve_request_body_content(parser) - if content: - return {'content': content} - - -def resolve_request_body_content(parser: RequestParser): - content = {} - for p in parser.iter_parsers(): - if p.schema is not None and p.location == 'body' and p.media_type: - content[p.media_type] = {'schema': p.schema} - - return content - - -def resolve_request_body_params(parser: RequestParser): - params = [] - for p in parser.iter_parsers(): - if p.schema is not None and p.location == 'body' and p.media_type: - params.append({ - 'name': 'body', - 'in': 'body', - 'schema': p.schema - }) - - return params - - -def resolve_responses(endpoint: Method): - responses = {} - if endpoint.response_schema: - responses[str(endpoint.status_code)] = { - 'schema': endpoint.response_schema - } - - errors = endpoint.meta.errors or () - for e in errors: - if isinstance(e, dict) and e.get('status_code'): - responses[str(e['status_code'])] = {key: val for key, val in e.items() if key != 'status_code'} - elif isinstance(e, Exception) and getattr(e, 'status_code', None) is not None: - responses[str(getattr(e, 'status_code'))] = create_error_schema_by_exc(e) - elif inspect.isclass(e) and issubclass(e, Exception) and getattr(e, 'status_code', None) is not None: - responses[str(getattr(e, 'status_code'))] = create_error_schema_by_exc(e) - - if not endpoint.__meta__.parser.is_empty and '404' not in responses: - responses['400'] = {'description': 'Bad request'} - - return responses - - -def create_error_schema_by_exc(e: Union[Exception, Type[Exception]]) -> Dict: - schema = {'description': (getattr(e, 'detail', None) - or getattr(e, 'description', None) - or str(e))} - error_schema = getattr(e, 'schema', None) - if error_schema is not None: - schema['schema'] = error_schema - - return schema diff --git a/star_resty/method/meta.py b/star_resty/method/meta.py index 00917d1..4f98bd6 100644 --- a/star_resty/method/meta.py +++ b/star_resty/method/meta.py @@ -1,15 +1,9 @@ import abc -import inspect -from typing import Any, Callable, Tuple -from marshmallow import Schema -from marshmallow.exceptions import MarshmallowError -from starlette.responses import Response +from .parser import create_parser +from .render import create_render -from star_resty.exceptions import DumpError -from star_resty.types.parser import Parser -from .options import MethodMetaOptions -from .request_parser import RequestParser +__all__ = ('MethodMeta',) class MethodMeta(abc.ABCMeta): @@ -18,75 +12,9 @@ def __new__(mcs, name, bases, attrs, **kwargs): cls = super().__new__(mcs, name, bases, attrs, **kwargs) func = getattr(cls, 'execute', None) - meta = MethodMetaOptions(request_parser=mcs.create_parser(func), - render=mcs.create_render(cls)) - cls.__meta__ = meta - return cls - - @classmethod - def create_render(mcs, method): - renders = [] - response_schema = getattr(method, 'response_schema', None) - if response_schema is None: - response_schema = getattr(method, 'Response', None) - - if response_schema is not None: - if inspect.isclass(response_schema): - response_schema = response_schema() - - renders.append(mcs.dump_content(response_schema)) - - if method.serializer is not None: - renders.append(mcs.render_bytes(method.serializer, method.status_code or 200)) - - return mcs.create_content_render(tuple(renders)) - - @staticmethod - def dump_content(response_schema: Schema): - def dump(content): - try: - return response_schema.dump(content) - except MarshmallowError as e: - raise DumpError(e) from e - except (ValueError, TypeError) as e: - raise DumpError(e) from e - - return dump - - @staticmethod - def create_content_render(renders: Tuple) -> Callable: - def render(content: Any): - for r in renders: - content = r(content) - - return content - - return render - - @staticmethod - def create_parser(func): - req_parser = RequestParser() if func is None: - return req_parser - - data = func.__annotations__ - for key, value in data.items(): - parser = getattr(value, 'parser', None) - if parser is None or not isinstance(parser, Parser): - continue + raise TypeError(f'Invalid method class={name}') - if inspect.iscoroutinefunction(parser.parse): - req_parser.async_parsers.append((key, parser)) - else: - req_parser.parsers.append((key, parser)) - - return req_parser - - @staticmethod - def render_bytes(serializer, status_code): - def render(content): - return Response(serializer.render(content), - media_type=serializer.media_type, - status_code=status_code) - - return render + cls.__parser__ = create_parser(func) + cls.__render__ = create_render(cls) + return cls diff --git a/star_resty/method/method.py b/star_resty/method/method.py index 90b6deb..4a68995 100644 --- a/star_resty/method/method.py +++ b/star_resty/method/method.py @@ -1,6 +1,6 @@ import abc from functools import wraps -from typing import ClassVar, Type, Union, Optional +from typing import ClassVar, Type, Union, Optional, Callable, Any from marshmallow import Schema from starlette.requests import Request @@ -8,14 +8,15 @@ from star_resty.operation import Operation from star_resty.serializers import JsonSerializer, Serializer -from .meta import MethodMeta, MethodMetaOptions +from .meta import MethodMeta __all__ = ('Method', 'endpoint') class Method(abc.ABC, metaclass=MethodMeta): __slots__ = ('request',) - __meta__: ClassVar[MethodMetaOptions] + __parser__: ClassVar[Callable[[Request], Any]] + __render__: ClassVar[Callable[[Any], Response]] meta: ClassVar[Operation] = Operation(tag='default') serializer: ClassVar[Serializer] = JsonSerializer @@ -31,10 +32,9 @@ async def execute(self, *args, **kwargs): pass async def dispatch(self) -> Response: - meta = self.__meta__ - params = await meta.parser.parse(self.request) + params = await self.__parser__.parse(self.request) content = await self.execute(**params) - return meta.render(content) + return self.__render__(content) @classmethod def as_endpoint(cls): diff --git a/star_resty/method/options.py b/star_resty/method/options.py deleted file mode 100644 index 4961b68..0000000 --- a/star_resty/method/options.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Callable - -from .request_parser import RequestParser - - -class MethodMetaOptions: - __slots__ = ('parser', 'render') - - def __init__(self, request_parser: RequestParser, render: Callable): - self.parser = request_parser - self.render = render diff --git a/star_resty/method/parser.py b/star_resty/method/parser.py new file mode 100644 index 0000000..c6045f0 --- /dev/null +++ b/star_resty/method/parser.py @@ -0,0 +1,118 @@ +import inspect +import itertools +from dataclasses import is_dataclass, fields +from functools import partial +from typing import Dict, Tuple, Sequence, Generator, Mapping, Generic, TypeVar, Type, Callable, Optional, Any, Union + +from starlette.requests import Request + +from star_resty.payload.parser import Parser + +__all__ = ('RequestParser', 'create_parser') + +D = TypeVar('D') + + +class RequestParser: + __slots__ = ('_parsers', '_async_parsers') + + def __init__(self, parsers: Sequence[Tuple[str, Union[Parser, 'RequestParser']]] = (), + async_parsers: Sequence[Tuple[str, Union[Parser, 'RequestParser']]] = ()): + self._parsers = parsers + self._async_parsers = async_parsers + + def __nonzero__(self): + return bool(self._parsers or self._async_parsers) + + def __iter__(self) -> Generator[Parser, None, None]: + parsers = itertools.chain(self._parsers, self._async_parsers) + for (_, p) in parsers: + if isinstance(p, RequestParser): + yield from p + else: + yield p + + async def parse(self, request: Request) -> Dict: + params = {} + for (key, p) in self._parsers: + params[key] = p.parse(request) + + for (key, p) in self._async_parsers: + params[key] = await p.parse(request) + + return params + + +class DataClassParser(RequestParser, Generic[D]): + __slots__ = ('_data_cls',) + + def __init__(self, data_cls: Type[D], parsers: Sequence[Tuple[str, Parser]] = (), + async_parsers: Sequence[Tuple[str, Parser]] = ()): + super().__init__(parsers=parsers, async_parsers=async_parsers) + self._data_cls = data_cls + + async def parse(self, request: Request) -> D: + params = await super().parse(request) + return self._data_cls(**params) + + +def create_parser(func) -> RequestParser: + if func is None: + return RequestParser() + + return create_parser_from_data(func.__annotations__) + + +def create_parser_from_data(data: Mapping) -> RequestParser: + parsers = [] + async_parsers = [] + for key, value in data.items(): + if is_dataclass(value): + data = {field.name: field.type for field in fields(value)} + factory = partial(DataClassParser, value) + parser = create_parser_for_dc(data, factory=factory) + else: + parser = getattr(value, 'parser', None) + + if parser is None or not isinstance(parser, (Parser, RequestParser)): + continue + + if inspect.iscoroutinefunction(parser.parse): + async_parsers.append((key, parser)) + else: + parsers.append((key, parser)) + + return RequestParser(parsers=parsers, async_parsers=async_parsers) + + +def create_parser_for_dc(data: Mapping, factory: Callable) -> RequestParser: + parsers = [] + async_parsers = [] + for key, value in data.items(): + if is_dataclass(value): + data = {field.name: field.type for field in fields(value)} + factory = partial(DataClassParser, value) + parser = create_parser_for_dc(data, factory=factory) + else: + parser = getattr(value, 'parser', None) + + if parser is None: + parser = get_parser_from_args(value) + + if parser is None or not isinstance(parser, (Parser, RequestParser)): + continue + + if inspect.iscoroutinefunction(parser.parse): + async_parsers.append((key, parser)) + else: + parsers.append((key, parser)) + + return factory(parsers=parsers, async_parsers=async_parsers) + + +def get_parser_from_args(value: Any) -> Optional[Parser]: + args = getattr(value, '__args__', None) or () + for a in args: + parser = getattr(a, 'parser', None) + if parser is not None and isinstance(parser, Parser): + return parser diff --git a/star_resty/method/render.py b/star_resty/method/render.py new file mode 100644 index 0000000..2861ec9 --- /dev/null +++ b/star_resty/method/render.py @@ -0,0 +1,67 @@ +import inspect +import logging +from typing import Sequence, Callable, Any + +from marshmallow import Schema +from marshmallow.exceptions import MarshmallowError +from starlette.responses import Response + +from star_resty.exceptions import DumpError + +__all__ = ('create_render', 'Render') + +logger = logging.getLogger(__name__) + + +def create_render(method) -> 'Render': + renders = [] + response_schema = getattr(method, 'response_schema', None) + if response_schema is None: + response_schema = getattr(method, 'Response', None) + + if response_schema is not None: + if inspect.isclass(response_schema): + response_schema = response_schema() + renders.append(dump_content(response_schema)) + + serializer = getattr(method, 'serializer', None) + if serializer is not None: + renders.append(render_bytes(serializer, method.status_code or 200)) + + return Render(renders) + + +class Render: + __slots__ = ('_renders',) + + def __init__(self, renders: Sequence): + self._renders = renders + + def __call__(self, content: Any): + for r in self._renders: + content = r(content) + + return content + + +def render_bytes(serializer, status_code): + def render(content): + return Response(serializer.render(content), + media_type=serializer.media_type, + status_code=status_code) + + return render + + +def dump_content(response_schema: Schema) -> Callable: + def dump(content): + try: + return response_schema.dump(content) + except MarshmallowError as e: + logger.error('Dump error: %s', e) + raise DumpError(e) from e + except (ValueError, TypeError) as e: + logger.error('Dump error: %s', e) + raise DumpError(e) from e + + return dump diff --git a/star_resty/method/request_parser.py b/star_resty/method/request_parser.py deleted file mode 100644 index 916726a..0000000 --- a/star_resty/method/request_parser.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Dict, List, Tuple - -from starlette.requests import Request - -from star_resty.types.parser import Parser - - -class RequestParser: - __slots__ = ('parsers', 'async_parsers') - - def __init__(self): - self.parsers: List[Tuple[str, Parser]] = [] - self.async_parsers: List[Tuple[str, Parser]] = [] - - @property - def is_empty(self) -> bool: - return not (self.parsers or self.async_parsers) - - def iter_parsers(self): - yield from (p[1] for p in self.parsers) - yield from (p[1] for p in self.async_parsers) - - async def parse(self, request: Request) -> Dict: - params = {} - for (key, p) in self.parsers: - params[key] = p.parse(request) - - for (key, p) in self.async_parsers: - params[key] = await p.parse(request) - - return params diff --git a/star_resty/operation/__init__.py b/star_resty/operation/__init__.py index 7b50bb1..23e80ae 100644 --- a/star_resty/operation/__init__.py +++ b/star_resty/operation/__init__.py @@ -1 +1 @@ -from .schema import Operation +from .schema import * diff --git a/star_resty/operation/schema.py b/star_resty/operation/schema.py index 5effb52..3525b16 100644 --- a/star_resty/operation/schema.py +++ b/star_resty/operation/schema.py @@ -1,12 +1,26 @@ -from typing import Any, NamedTuple, Optional, Sequence +from dataclasses import dataclass +from typing import Optional, Sequence, Any, Mapping +__all__ = ('Operation',) -class Operation(NamedTuple): + +@dataclass(frozen=True) +class Operation: tag: str = 'default' description: Optional[str] = None summary: Optional[str] = None errors: Sequence[Any] = () security: Optional[Sequence] = None + meta: Optional[Mapping] = None - def update(self, **kwargs): - return self._replace(**kwargs) + @classmethod + def create(cls, + tag: str = 'default', + description: Optional[str] = None, + summary: Optional[str] = None, + errors: Sequence[Any] = (), + security: Optional[Sequence] = None, + **kwargs) -> 'Operation': + return cls(tag=tag, description=description, + summary=summary, errors=errors, + security=security, meta=kwargs) diff --git a/star_resty/parsers/__init__.py b/star_resty/parsers/__init__.py deleted file mode 100644 index da641ce..0000000 --- a/star_resty/parsers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .query import parse_query_params diff --git a/star_resty/parsers/query.py b/star_resty/parsers/query.py deleted file mode 100644 index bb17b77..0000000 --- a/star_resty/parsers/query.py +++ /dev/null @@ -1,51 +0,0 @@ -from functools import lru_cache -from typing import Callable, List, Mapping, Sequence - -from marshmallow import EXCLUDE, Schema, fields -from starlette.requests import Request - -__all__ = ('parse_query_params',) - - -def parse_query_params(request: Request, schema: Schema, unknown: str = EXCLUDE) -> Mapping: - query_schema = get_query_schema(schema) - query_params = request.query_params - getlist = request.query_params.getlist - data = ((key, query_schema[key](getlist(key))) - for key in query_params.keys() if key in query_schema) - data = {key: val for (key, val) in data if val is not None} - - return schema.load(data, many=False, unknown=unknown) - - -@lru_cache(1024) -def get_query_schema(schema: Schema) -> Mapping[str, Callable]: - def iter_fields(): - for key, field in schema.fields.items(): - field: fields.Field - if field.dump_only or isinstance(field, fields.Constant): - continue - - if field.attribute: - name = field.attribute - else: - name = key - - if isinstance(field, fields.List): - func = get_list_value - else: - func = get_value - - yield name, func - if field.data_key: - yield field.data_key, func - - return {key: val for (key, val) in iter_fields()} - - -def get_list_value(values: Sequence) -> List: - return [v for v in values if v] - - -def get_value(values: Sequence): - return next((v for v in values if v), None) diff --git a/star_resty/types/__init__.py b/star_resty/payload/__init__.py similarity index 73% rename from star_resty/types/__init__.py rename to star_resty/payload/__init__.py index ac9a986..cdc6a42 100644 --- a/star_resty/types/__init__.py +++ b/star_resty/payload/__init__.py @@ -1,3 +1,4 @@ +from .header import header, header_schema from .json import json_payload, json_schema from .path import path, path_schema from .query import query, query_schema diff --git a/star_resty/payload/header.py b/star_resty/payload/header.py new file mode 100644 index 0000000..84c9db7 --- /dev/null +++ b/star_resty/payload/header.py @@ -0,0 +1,32 @@ +import types +from typing import Mapping, Type, TypeVar, Union + +from marshmallow import EXCLUDE, Schema +from starlette.requests import Request + +from .parser import Parser, set_parser + +__all__ = ('header', 'header_schema', 'HeaderParser') + +P = TypeVar('P') + + +def header_schema(schema: Union[Schema, Type[Schema]], cls: P, + unknown=EXCLUDE) -> P: + return types.new_class('HeaderInputParams', (cls,), + exec_body=set_parser(HeaderParser.create(schema, unknown=unknown))) + + +def header(schema: Union[Schema, Type[Schema]], unknown=EXCLUDE) -> Type[Mapping]: + return header_schema(schema, Mapping, unknown=unknown) + + +class HeaderParser(Parser): + __slots__ = () + + @property + def location(self): + return 'header' + + def parse(self, request: Request): + return self.schema.load(request.headers, unknown=self.unknown) diff --git a/star_resty/types/json.py b/star_resty/payload/json.py similarity index 89% rename from star_resty/types/json.py rename to star_resty/payload/json.py index 910fe44..a63a3e2 100644 --- a/star_resty/types/json.py +++ b/star_resty/payload/json.py @@ -8,14 +8,14 @@ from star_resty.exceptions import DecodeError from .parser import Parser, set_parser -__all__ = ('json_schema', 'json_payload') +__all__ = ('json_schema', 'json_payload', 'JsonParser') P = TypeVar('P') def json_schema(schema: Union[Schema, Type[Schema]], cls: P, unknown: str = EXCLUDE) -> P: - return types.new_class('QueryInputParams', (cls,), + return types.new_class('JsonInputParams', (cls,), exec_body=set_parser(JsonParser.create(schema, unknown=unknown))) @@ -24,6 +24,7 @@ def json_payload(schema: Union[Schema, Type[Schema]], unknown=EXCLUDE) -> Type[M class JsonParser(Parser): + __slots__ = () @property def location(self): diff --git a/star_resty/types/parser.py b/star_resty/payload/parser.py similarity index 80% rename from star_resty/types/parser.py rename to star_resty/payload/parser.py index 322a6a1..10d08ef 100644 --- a/star_resty/types/parser.py +++ b/star_resty/payload/parser.py @@ -4,6 +4,7 @@ from marshmallow import EXCLUDE, Schema from starlette.requests import Request +from functools import lru_cache __all__ = ('Parser', 'set_parser') @@ -13,15 +14,20 @@ class Parser(abc.ABC): @classmethod def create(cls, schema: Union[Schema, Type[Schema]], unknown: str = EXCLUDE): + return cls(cls._convert_schema(schema), unknown) + + @staticmethod + @lru_cache(maxsize=1024) + def _convert_schema(schema: Union[Schema, Type[Schema]]) -> Schema: if inspect.isclass(schema): if not issubclass(schema, Schema): raise TypeError(f'Invalid schema type: {schema}') - schema = schema() + return schema() elif not isinstance(schema, Schema): raise TypeError(f'Invalid schema type: {type(schema)}') - return cls(schema, unknown) + return schema def __init__(self, schema: Schema, unknown=EXCLUDE): self.schema = schema diff --git a/star_resty/types/path.py b/star_resty/payload/path.py similarity index 92% rename from star_resty/types/path.py rename to star_resty/payload/path.py index 44d540c..ddff7c8 100644 --- a/star_resty/types/path.py +++ b/star_resty/payload/path.py @@ -6,7 +6,7 @@ from .parser import Parser, set_parser -__all__ = ('path', 'path_schema') +__all__ = ('path', 'path_schema', 'PathParser') P = TypeVar('P') @@ -22,6 +22,7 @@ def path(schema: Union[Schema, Type[Schema]], unknown=EXCLUDE) -> Type[Mapping]: class PathParser(Parser): + __slots__ = () @property def location(self): diff --git a/star_resty/payload/query.py b/star_resty/payload/query.py new file mode 100644 index 0000000..6dd1462 --- /dev/null +++ b/star_resty/payload/query.py @@ -0,0 +1,87 @@ +import inspect +import types +from functools import lru_cache +from typing import Mapping, Type, TypeVar, Union, Callable, Sequence, List, Tuple, Iterator + +from marshmallow import EXCLUDE, Schema, fields +from starlette.requests import Request + +from .parser import Parser, set_parser + +__all__ = ('query', 'query_schema', 'QueryParser') + +Q = TypeVar('Q') + + +def query_schema(schema: Union[Schema, Type[Schema]], cls: Q, + unknown=EXCLUDE) -> Q: + return types.new_class('QueryInputParams', (cls,), + exec_body=set_parser(QueryParser.create(schema, unknown=unknown))) + + +def query(schema: Union[Schema, Type[Schema]], unknown=EXCLUDE) -> Type[Mapping]: + return query_schema(schema, Mapping, unknown=unknown) + + +class QueryParser(Parser): + __slots__ = ('fields',) + + @classmethod + def create(cls, schema: Union[Schema, Type[Schema]], unknown: str = EXCLUDE): + schema, query_fields = get_query_fields(schema) + return cls(schema, query_fields, unknown) + + def __init__(self, schema: Schema, query_fields: Mapping, unknown=EXCLUDE): + super().__init__(schema, unknown=unknown) + self.fields = query_fields + + @property + def location(self): + return 'query' + + def parse(self, request: Request): + query_params = request.query_params + getlist = request.query_params.getlist + query_fields = self.fields + data = ((key, query_fields[key](getlist(key))) + for key in query_params.keys() if key in query_fields) + data = {key: val for (key, val) in data if val is not None} + return self.schema.load(data, many=False, unknown=self.unknown) + + +@lru_cache(typed=False, maxsize=1024) +def get_query_fields(schema: Union[Schema, Type[Schema]]) -> Tuple[Schema, Mapping[str, Callable]]: + if inspect.isclass(schema): + schema = schema() + + query_fields = dict(iter_query_fields(schema)) + return schema, query_fields + + +def iter_query_fields(schema: Schema) -> Iterator[Tuple[str, Callable]]: + for key, field in schema.fields.items(): + field: fields.Field + if field.dump_only or isinstance(field, fields.Constant): + continue + + if field.attribute: + name = field.attribute + else: + name = key + + if isinstance(field, fields.List): + func = get_list_value + else: + func = get_value + + yield name, func + if field.data_key: + yield field.data_key, func + + +def get_list_value(values: Sequence) -> List: + return [v for v in values if v] + + +def get_value(values: Sequence): + return next((v for v in values if v), None) diff --git a/star_resty/serializers/json.py b/star_resty/serializers/json.py index ce38d84..34f1fea 100644 --- a/star_resty/serializers/json.py +++ b/star_resty/serializers/json.py @@ -1,5 +1,7 @@ import ujson +__all__ = ('JsonSerializer',) + class JsonSerializer: media_type = 'application/json' diff --git a/star_resty/serializers/serializer.py b/star_resty/serializers/serializer.py index 249f46b..debd9b0 100644 --- a/star_resty/serializers/serializer.py +++ b/star_resty/serializers/serializer.py @@ -1,5 +1,7 @@ from typing_extensions import Protocol +__all__ = ('Serializer',) + class Serializer(Protocol): media_type: str diff --git a/star_resty/types/query.py b/star_resty/types/query.py deleted file mode 100644 index 818ac04..0000000 --- a/star_resty/types/query.py +++ /dev/null @@ -1,32 +0,0 @@ -import types -from typing import Mapping, Type, TypeVar, Union - -from marshmallow import EXCLUDE, Schema -from starlette.requests import Request - -from star_resty.parsers import parse_query_params -from .parser import Parser, set_parser - -__all__ = ('query', 'query_schema') - -Q = TypeVar('Q') - - -def query_schema(schema: Union[Schema, Type[Schema]], cls: Q, - unknown=EXCLUDE) -> Q: - return types.new_class('QueryInputParams', (cls,), - exec_body=set_parser(QueryParser.create(schema, unknown=unknown))) - - -def query(schema: Union[Schema, Type[Schema]], unknown=EXCLUDE) -> Type[Mapping]: - return query_schema(schema, Mapping, unknown=unknown) - - -class QueryParser(Parser): - - @property - def location(self): - return 'query' - - def parse(self, request: Request): - return parse_query_params(request, self.schema, unknown=self.unknown) diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index 29cdbd3..87cfede 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -4,7 +4,7 @@ from marshmallow import EXCLUDE, Schema, fields -from star_resty.types import (json_payload, json_schema, path, path_schema, query, query_schema) +from star_resty.payload import (json_payload, json_schema, path, path_schema, query, query_schema) class QuerySchema(Schema): diff --git a/tests/test_method.py b/tests/test_method.py index a5c3280..8504318 100644 --- a/tests/test_method.py +++ b/tests/test_method.py @@ -1,14 +1,16 @@ import json +from dataclasses import dataclass +from typing import Mapping, Optional import pytest from asynctest import mock -from marshmallow import Schema, fields +from marshmallow import Schema, fields, post_load from starlette.requests import Request from starlette.responses import Response -from star_resty import Method +from star_resty import Method, path_schema, json_payload from star_resty.exceptions import DumpError -from .utils.method import CreateUser +from .utils.method import CreateUser, BodySchema @pytest.mark.asyncio @@ -17,15 +19,7 @@ async def test_create_user(): request.path_params = {'id': 1} request.body.return_value = json.dumps({'name': 'Name', 'email': 'email@mail.com'}).encode('utf8') - endpoint = CreateUser.as_endpoint() - resp = await endpoint(request) - assert resp is not None - assert isinstance(resp, Response) - assert resp.status_code == 201 - assert resp.media_type == 'application/json' - body = resp.body - assert body is not None - user = json.loads(body) + user = await execute(CreateUser, request, status_code=201) assert user == {'name': 'Name', 'email': 'email@mail.com', 'id': 1} @@ -56,13 +50,50 @@ async def execute(self): return {'id': 1} request = mock.MagicMock(spec_set=Request) - endpoint = TestMethod.as_endpoint() + user = await execute(TestMethod, request) + assert user == {'id': 1} + + +@pytest.mark.asyncio +async def test_parse_dataclass(): + class PathParams(Schema): + group_id = fields.Integer(required=True) + + @post_load() + def load(self, data, **_): + return data.get('group_id') + + class TestMethod(Method): + @dataclass() + class Payload: + group_id: path_schema(PathParams, int) + body: Optional[json_payload(BodySchema)] = None + + class Response(Schema): + group_id = fields.Integer() + body = fields.Nested(BodySchema) + + async def execute(self, payload: Payload): + from dataclasses import asdict + return asdict(payload) + + request = mock.MagicMock(spec_set=Request) + request.path_params = {'group_id': 1000} + request.body.return_value = json.dumps({'name': 'Dataclass', 'email': 'email@mail.com'}).encode('utf8') + + user = await execute(TestMethod, request) + assert user == {'body': {'name': 'Dataclass', 'email': 'email@mail.com'}, 'group_id': 1000} + + +async def execute(method, request, status_code: int = 200, + media_type='application/json') -> Mapping: + endpoint = method.as_endpoint() resp = await endpoint(request) assert resp is not None assert isinstance(resp, Response) - assert resp.status_code == 200 - assert resp.media_type == 'application/json' + assert resp.status_code == status_code + assert resp.media_type == media_type body = resp.body assert body is not None - user = json.loads(body) - assert user == {'id': 1} + data = json.loads(body) + return data diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index 8d5378c..52f2455 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -5,7 +5,7 @@ from starlette.datastructures import QueryParams from starlette.requests import Request -from star_resty.parsers.query import parse_query_params +from star_resty.payload.query import QueryParser class QuerySchema(Schema): @@ -18,8 +18,9 @@ class QuerySchema(Schema): def test_parse_query_args_raise_validation_error(): request = MagicMock(spec=Request) request.query_params = QueryParams([('item_id', '1'), ('item_id', '2')]) + parser = QueryParser.create(QuerySchema) with pytest.raises(ValidationError): - parse_query_params(request, QuerySchema()) + parser.parse(request) def test_parse_query_args(): @@ -27,5 +28,6 @@ def test_parse_query_args(): request.query_params = QueryParams([ ('item_id', '1'), ('item_id', '2'), ('limit', '1000'), ('b', '2'), ('n', '100')]) - params = parse_query_params(request, QuerySchema()) + parser = QueryParser.create(QuerySchema) + params = parser.parse(request) assert params == {'limit': 1000, 'item_id': [1, 2], 'a': '2', 'n': 1}