From fd4642f152ecd929aca5cc4abee5abe531b74ca2 Mon Sep 17 00:00:00 2001 From: Insaf Nureev Date: Sun, 8 Sep 2024 19:00:40 +0300 Subject: [PATCH] Add different pagination strategies --- pydantic_filters/__init__.py | 4 +- pydantic_filters/_types.py | 8 +-- pydantic_filters/drivers/sqlalchemy/_main.py | 17 +++---- pydantic_filters/pagination/__init__.py | 6 ++- pydantic_filters/pagination/_base.py | 36 +++++++++++++ pydantic_filters/pagination/base.py | 10 ---- pydantic_filters/plugins/_utils.py | 12 ++--- pydantic_filters/plugins/fastapi.py | 53 ++++++++++++-------- tests/drivers/sqlalchemy/test_main.py | 23 +++++---- tests/pagination/test_base.py | 31 ++++++++++++ tests/plugins/test_utils.py | 8 +-- 11 files changed, 143 insertions(+), 65 deletions(-) create mode 100644 pydantic_filters/pagination/_base.py delete mode 100644 pydantic_filters/pagination/base.py create mode 100644 tests/pagination/test_base.py diff --git a/pydantic_filters/__init__.py b/pydantic_filters/__init__.py index 2bc1f16..78c7156 100644 --- a/pydantic_filters/__init__.py +++ b/pydantic_filters/__init__.py @@ -8,7 +8,9 @@ get_suffixes_map, ) from .pagination import ( - BasePagination, + OffsetPagination, + PagePagination, + PaginationInterface, ) from .sort import ( BaseSort, diff --git a/pydantic_filters/_types.py b/pydantic_filters/_types.py index 496a2ba..5f0bc14 100644 --- a/pydantic_filters/_types.py +++ b/pydantic_filters/_types.py @@ -9,14 +9,14 @@ from typing_extensions import TypeAlias, Union -if sys.version_info >= (3, 10): +if sys.version_info >= (3, 10): # pragma: no cover from types import NoneType -else: +else: # pragma: no cover NoneType = type(None) -if sys.version_info >= (3, 9): +if sys.version_info >= (3, 9): # pragma: no cover from types import GenericAlias -else: +else: # pragma: no cover GenericAlias = type(List[int]) diff --git a/pydantic_filters/drivers/sqlalchemy/_main.py b/pydantic_filters/drivers/sqlalchemy/_main.py index 652e146..38c6434 100644 --- a/pydantic_filters/drivers/sqlalchemy/_main.py +++ b/pydantic_filters/drivers/sqlalchemy/_main.py @@ -5,8 +5,8 @@ from pydantic_filters import ( BaseFilter, - BasePagination, BaseSort, + PaginationInterface, SortByOrder, ) @@ -14,7 +14,7 @@ from ._mapping import filter_to_column_clauses, filter_to_column_options _Filter = TypeVar("_Filter", bound=BaseFilter) -_Pagination = TypeVar("_Pagination", bound=BasePagination) +_Pagination = TypeVar("_Pagination", bound=PaginationInterface) _Sort = TypeVar("_Sort", bound=BaseSort) _Model = TypeVar("_Model", bound=so.DeclarativeBase) _T = TypeVar("_T") @@ -41,14 +41,11 @@ def append_pagination_to_statement( pagination: _Pagination, ) -> sa.Select[_T]: - if pagination.limit is not None: - statement = ( - statement - .limit(pagination.limit) - .offset(pagination.offset) - ) - - return statement + return ( + statement + .limit(pagination.get_limit()) + .offset(pagination.get_offset()) + ) def append_sort_to_statement( diff --git a/pydantic_filters/pagination/__init__.py b/pydantic_filters/pagination/__init__.py index 267f3b4..d36ea24 100644 --- a/pydantic_filters/pagination/__init__.py +++ b/pydantic_filters/pagination/__init__.py @@ -1 +1,5 @@ -from .base import BasePagination +from ._base import ( + OffsetPagination, + PagePagination, + PaginationInterface, +) diff --git a/pydantic_filters/pagination/_base.py b/pydantic_filters/pagination/_base.py new file mode 100644 index 0000000..8c2b063 --- /dev/null +++ b/pydantic_filters/pagination/_base.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod + +from pydantic import BaseModel, Field + + +class PaginationInterface(BaseModel, ABC): + + @abstractmethod + def get_limit(self) -> int: + ... + + @abstractmethod + def get_offset(self) -> int: + ... + + +class OffsetPagination(PaginationInterface): + limit: int = Field(100, ge=1) + offset: int = Field(0, ge=0) + + def get_limit(self) -> int: + return self.limit + + def get_offset(self) -> int: + return self.offset + + +class PagePagination(PaginationInterface): + page: int = Field(1, ge=1) + per_page: int = Field(100, ge=1) + + def get_limit(self) -> int: + return self.per_page + + def get_offset(self) -> int: + return (self.page - 1) * self.per_page diff --git a/pydantic_filters/pagination/base.py b/pydantic_filters/pagination/base.py deleted file mode 100644 index 931cf23..0000000 --- a/pydantic_filters/pagination/base.py +++ /dev/null @@ -1,10 +0,0 @@ -from pydantic import BaseModel, Field - -__all__ = ( - "BasePagination", -) - - -class BasePagination(BaseModel): - limit: int = Field(None, ge=0) - offset: int = Field(0, ge=0) diff --git a/pydantic_filters/plugins/_utils.py b/pydantic_filters/plugins/_utils.py index e10f935..49f08d0 100644 --- a/pydantic_filters/plugins/_utils.py +++ b/pydantic_filters/plugins/_utils.py @@ -1,5 +1,5 @@ from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Dict, Type, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar from pydantic_filters.filter._base import BaseFilter @@ -31,8 +31,7 @@ def squash_filter( filter_: Type[_Filter], prefix: str, delimiter: str, - converter: Callable[["FieldInfo"], _T], -) -> Dict[str, _T]: +) -> Dict[str, "FieldInfo"]: """ **Example:** @@ -44,15 +43,15 @@ def squash_filter( >>> class MyFilter(BaseFilter): ... a: int ... b: NestedFilter - >>> squash_filter(MyFilter, "", "__", lambda f: None) - {"a": None, "b__c": None, "b__d__e": None} + >>> squash_filter(MyFilter, "", "__") + {"a": ..., "b__c": ..., "b__d__e": ...} """ squashed = {} for key in chain(filter_.filter_fields, filter_.search_fields): field_info = filter_.model_fields[key] - squashed[add_prefix(key, prefix, delimiter)] = converter(field_info) + squashed[add_prefix(key, prefix, delimiter)] = field_info for key, nested_filter in filter_.nested_filters.items(): squashed.update( @@ -60,7 +59,6 @@ def squash_filter( filter_=nested_filter, prefix=add_prefix(key, prefix, delimiter), delimiter=delimiter, - converter=converter, ), ) diff --git a/pydantic_filters/plugins/fastapi.py b/pydantic_filters/plugins/fastapi.py index 935d044..53778ac 100644 --- a/pydantic_filters/plugins/fastapi.py +++ b/pydantic_filters/plugins/fastapi.py @@ -1,18 +1,20 @@ from copy import deepcopy from inspect import Parameter, signature -from typing import Any, List, Tuple, Type, TypeVar +from typing import Any, List, Type, TypeVar from fastapi import Depends, Query from fastapi import params as fastapi_params +from pydantic import BaseModel from pydantic.fields import FieldInfo -from pydantic_filters import BaseFilter, BasePagination, BaseSort +from pydantic_filters import BaseFilter, BaseSort, PaginationInterface from ._utils import inflate_filter, squash_filter _Filter = TypeVar("_Filter", bound=BaseFilter) -_Pagination = TypeVar("_Pagination", bound=BasePagination) +_Pagination = TypeVar("_Pagination", bound=PaginationInterface) _Sort = TypeVar("_Sort", bound=BaseSort) +_PydanticModel = TypeVar("_PydanticModel", bound=BaseModel) def _field_info_to_query( @@ -42,24 +44,20 @@ def _get_custom_params( delimiter: str, ) -> List[Parameter]: - def _converter(f: FieldInfo) -> Tuple[FieldInfo, fastapi_params.Query]: - return f, _field_info_to_query(f) - squashed = squash_filter( filter_=filter_, prefix=prefix, delimiter=delimiter, - converter=_converter, ) return [ Parameter( name=key, kind=Parameter.KEYWORD_ONLY, - default=fastapi_query, + default=_field_info_to_query(field_info), annotation=field_info.annotation, ) - for key, (field_info, fastapi_query) in squashed.items() + for key, field_info in squashed.items() ] @@ -67,7 +65,7 @@ def FilterDepends( # noqa: N802 filter_: Type[_Filter], prefix: str = "", delimiter: str = "__", -) -> _Filter: +) -> _Filter: # pragma: no cover def _depends(**kwargs: Any) -> _Filter: # noqa: ANN401 """Signature of this function is replaced with Query parameters, @@ -81,7 +79,6 @@ def _depends(**kwargs: Any) -> _Filter: # noqa: ANN401 data=kwargs, ) - # Переопределяем то, что функция принимает на вход _depends.__signature__ = signature(_depends).replace( parameters=_get_custom_params(filter_, prefix, delimiter), ) @@ -89,20 +86,36 @@ def _depends(**kwargs: Any) -> _Filter: # noqa: ANN401 return Depends(_depends) -def PaginationDepends(pagination: Type[_Pagination]) -> _Pagination: - limit_field = pagination.model_fields["limit"] - offset_field = pagination.model_fields["offset"] +def _PydanticModelAsDepends(pydantic_model: Type[_PydanticModel]) -> _PydanticModel: # pragma: no cover + def _depends(**kwargs: Any) -> _Filter: # noqa: ANN401 + return pydantic_model.model_construct(**kwargs) + + custom_params = [] + for key, field_info in pydantic_model.model_fields.items(): + if issubclass(field_info.annotation, BaseModel): + continue + + custom_params.append( + Parameter( + name=key, + kind=Parameter.KEYWORD_ONLY, + default=_field_info_to_query(field_info), + annotation=field_info.annotation, + ), + ) - def _depends( - limit: limit_field.annotation = _field_info_to_query(limit_field), - offset: offset_field.annotation = _field_info_to_query(offset_field), - ) -> _Filter: - return pagination.model_construct(limit=limit, offset=offset) + _depends.__signature__ = signature(_depends).replace( + parameters=custom_params, + ) return Depends(_depends) -def SortDepends(sort: Type[_Sort]) -> _Sort: +def PaginationDepends(pagination: Type[_Pagination]) -> _Pagination: # pragma: no cover + return _PydanticModelAsDepends(pagination) + + +def SortDepends(sort: Type[_Sort]) -> _Sort: # pragma: no cover sort_by_field = sort.model_fields["sort_by"] sort_by_order_field = sort.model_fields["sort_by_order"] diff --git a/tests/drivers/sqlalchemy/test_main.py b/tests/drivers/sqlalchemy/test_main.py index e6a18ae..d35cc4e 100644 --- a/tests/drivers/sqlalchemy/test_main.py +++ b/tests/drivers/sqlalchemy/test_main.py @@ -5,7 +5,14 @@ import sqlalchemy.orm as so from sqlalchemy.dialects.sqlite import dialect as sa_sqlite_dialect -from pydantic_filters import BaseFilter, BasePagination, BaseSort, SortByOrder +from pydantic_filters import ( + BaseFilter, + OffsetPagination, + BaseSort, + SortByOrder, + PaginationInterface, + PagePagination, +) from pydantic_filters.drivers.sqlalchemy._main import ( append_filter_to_statement, append_pagination_to_statement, @@ -67,15 +74,13 @@ def test_append_filter_to_statement() -> None: @pytest.mark.parametrize( "pagination, expected_stmt", [ - (BasePagination(), - "SELECT a.id, a.b_id FROM a"), - (BasePagination(limit=10), - "SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 0"), - (BasePagination(limit=10, offset=20), - "SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 20"), + (OffsetPagination(limit=10), "SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 0"), + (OffsetPagination(limit=10, offset=20), "SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 20"), + (PagePagination(page=1, per_page=10), "SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 0"), + (PagePagination(page=10, per_page=10), "SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 90"), ] ) -def test_append_pagination_to_statement(pagination: BasePagination, expected_stmt: str) -> None: +def test_append_pagination_to_statement(pagination: PaginationInterface, expected_stmt: str) -> None: stmt = append_pagination_to_statement( statement=sa.select(AModel), pagination=pagination, @@ -109,7 +114,7 @@ def test_append_to_statement() -> None: statement=sa.select(AModel), model=AModel, filter_=AFilter(id=1, b=BFilter(id=2)), - pagination=BasePagination(limit=10, offset=20), + pagination=OffsetPagination(limit=10, offset=20), sort=BaseSort(sort_by="id", sort_by_order=SortByOrder.desc) ) expected_stmt = ( diff --git a/tests/pagination/test_base.py b/tests/pagination/test_base.py new file mode 100644 index 0000000..0bae1dd --- /dev/null +++ b/tests/pagination/test_base.py @@ -0,0 +1,31 @@ +import pytest + +from pydantic_filters.pagination import PaginationInterface, PagePagination, OffsetPagination + + +class TestOffsetPagination: + + @pytest.mark.parametrize( + "obj, limit, offset", + [ + (OffsetPagination(limit=1, offset=1), 1, 1), + (OffsetPagination(limit=10, offset=0), 10, 0), + ] + ) + def test_get_limit_get_offset(self, obj: PaginationInterface, limit: int, offset: int) -> None: + assert obj.get_limit() == limit + assert obj.get_offset() == offset + + +class TestPagePagination: + + @pytest.mark.parametrize( + "obj, limit, offset", + [ + (PagePagination(page=1, per_page=10), 10, 0), + (PagePagination(page=12, per_page=34), 34, 374), + ] + ) + def test_get_limit_get_offset(self, obj: PaginationInterface, limit: int, offset: int) -> None: + assert obj.get_limit() == limit + assert obj.get_offset() == offset diff --git a/tests/plugins/test_utils.py b/tests/plugins/test_utils.py index 452e93f..0c0bba4 100644 --- a/tests/plugins/test_utils.py +++ b/tests/plugins/test_utils.py @@ -1,6 +1,8 @@ from typing import Dict, Any +from unittest import mock import pytest +from pydantic.fields import FieldInfo from pydantic_filters import BaseFilter from pydantic_filters.plugins._utils import ( @@ -48,7 +50,8 @@ class FilterTest(BaseFilter): a: int b: NestedFilter - + +@mock.patch.object(FieldInfo, "__eq__", new=lambda *_: True) @pytest.mark.parametrize( "prefix, delimiter, res", [ @@ -61,8 +64,7 @@ def test_squash_filter(prefix: str, delimiter: str, res: Dict[str, Any]): assert squash_filter( filter_=FilterTest, prefix=prefix, - delimiter=delimiter, - converter=lambda f: None, + delimiter=delimiter, ) == res