diff --git a/README.md b/README.md index dc5c5c3..b730b62 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ Object-oriented rest framework based on starlette, marshmallow and apispec. * [Starlette] 0.12.0+ * [Marshmallow] 3.0.0rc8+ * [APISpec] 2.0.2+ +* [python-multipart] 0.0.5+ ## Installation @@ -17,35 +18,53 @@ $ pip install star_resty ## Example ```python -from dataclasses import dataclass -from typing import Optional - -from marshmallow import Schema, fields, post_load, ValidationError +from marshmallow import Schema, fields, ValidationError, post_load from starlette.applications import Starlette +from starlette.datastructures import UploadFile from starlette.responses import JSONResponse -from star_resty import Method, Operation, endpoint, json_schema, query, setup_spec +from dataclasses import dataclass +from star_resty import Method, Operation, endpoint, json_schema, json_payload, form_payload, query, setup_spec +from typing import Optional class EchoInput(Schema): a = fields.Int() +# Json Payload (by schema) +class JsonPayloadSchema(Schema): + a = fields.Int(required=True) + s = fields.String() + + +# Json Payload (by dataclass) @dataclass class Payload: a: int s: Optional[str] = None - -class PayloadSchema(Schema): - a = fields.Int(required=True) - s = fields.String() +class JsonPayloadDataclass(Schema): + a=fields.Int(required=True) + s=fields.Str() @post_load def create_payload(self, data, **kwargs): return Payload(**data) +# Form Payload +class FormFile(fields.Field): + def _validate(self, value): + if not isinstance(value, UploadFile): + raise ValidationError('Not a file') + + +class FormPayload(Schema): + id = fields.Int(required=True) + file = FormFile() + + app = Starlette(debug=True) @app.exception_handler(ValidationError) @@ -64,13 +83,32 @@ class Echo(Method): return query_params -@app.route('/post', methods=['POST']) +@app.route('/post/schema', methods=['POST']) +@endpoint +class PostSchema(Method): + meta = Operation(tag='default', description='post json (by schema)') + + async def execute(self, item: json_payload(JsonPayloadSchema)): + return {'a': item.get('a') * 2, 's': item.get('s')} + + +@app.route('/post/dataclass', methods=['POST']) +@endpoint +class PostDataclass(Method): + meta = Operation(tag='default', description='post json (by dataclass)') + + async def execute(self, item: json_schema(JsonPayloadDataclass, Payload)): + return {'a': item.a * 3, 's': item.s} + +@app.route('/form', methods=['POST']) @endpoint -class Post(Method): - meta = Operation(tag='default', description='post') +class PostForm(Method): + meta = Operation(tag='default', description='post form') - async def execute(self, item: json_schema(PayloadSchema, Payload)): - return {'a': item.a * 2, 's': item.s} + async def execute(self, form_data: form_payload(FormPayload)): + file_name = form_data.get('file').filename + id = form_data.get('id') + return {'message': f"file {file_name} with id {id} received"} if __name__ == '__main__': diff --git a/requirements.txt b/requirements.txt index 78b4959..313e3ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ typing_extensions marshmallow>=3.0.0rc8,<4 starlette<1 apispec<4 +python-multipart # Testing pytest diff --git a/setup.py b/setup.py index ccea59f..fc66020 100644 --- a/setup.py +++ b/setup.py @@ -28,8 +28,9 @@ def get_packages(package): 'marshmallow>=3.0.0rc8,<4', 'starlette<1', 'apispec<4', + 'python-multipart' ], - version='0.0.14', + version='0.0.15', url='https://github.com/slv0/start_resty', license='BSD', description='The web framework', diff --git a/star_resty/method/parser.py b/star_resty/method/parser.py index c6045f0..c0c3d81 100644 --- a/star_resty/method/parser.py +++ b/star_resty/method/parser.py @@ -67,12 +67,11 @@ def create_parser_from_data(data: Mapping) -> RequestParser: parsers = [] async_parsers = [] for key, value in data.items(): - if is_dataclass(value): + parser = getattr(value, 'parser', None) + if parser is None and is_dataclass(value): data = {field.name: field.type for field in fields(value)} factory = partial(DataClassParser, value) parser = create_parser_for_dc(data, factory=factory) - else: - parser = getattr(value, 'parser', None) if parser is None or not isinstance(parser, (Parser, RequestParser)): continue diff --git a/star_resty/payload/__init__.py b/star_resty/payload/__init__.py index cdc6a42..bb91186 100644 --- a/star_resty/payload/__init__.py +++ b/star_resty/payload/__init__.py @@ -2,3 +2,4 @@ from .json import json_payload, json_schema from .path import path, path_schema from .query import query, query_schema +from .form import form_payload, form_schema diff --git a/star_resty/payload/form.py b/star_resty/payload/form.py new file mode 100644 index 0000000..07e5ea6 --- /dev/null +++ b/star_resty/payload/form.py @@ -0,0 +1,42 @@ +import types +from typing import Mapping, Type, TypeVar, Union + +from marshmallow import EXCLUDE, Schema +from starlette.requests import Request + +from star_resty.exceptions import DecodeError +from .parser import Parser, set_parser + +__all__ = ('form_schema', 'form_payload', 'FormParser') + +P = TypeVar('P') + + +def form_schema(schema: Union[Schema, Type[Schema]], cls: P, + unknown: str = EXCLUDE) -> P: + return types.new_class('FormDataInputParams', (cls,), + exec_body=set_parser(FormParser.create(schema, unknown=unknown))) + + +def form_payload(schema: Union[Schema, Type[Schema]], unknown=EXCLUDE) -> Type[Mapping]: + return form_schema(schema, Mapping, unknown=unknown) + + +class FormParser(Parser): + __slots__ = () + + @property + def location(self): + return 'body' + + @property + def media_type(self): + return 'multipart/form-data' + + async def parse(self, request: Request): + try: + form_data = await request.form() + form_data = {} if not form_data else form_data + except Exception as e: + raise DecodeError('Invalid form data: %s' % (str(e))) from e + return self.schema.load(form_data, unknown=self.unknown)