Skip to content

Commit

Permalink
add_pagination
Browse files Browse the repository at this point in the history
Add different pagination strategies
  • Loading branch information
so-saf authored Sep 8, 2024
2 parents 8d997e7 + fd4642f commit 09a60be
Show file tree
Hide file tree
Showing 11 changed files with 143 additions and 65 deletions.
4 changes: 3 additions & 1 deletion pydantic_filters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
get_suffixes_map,
)
from .pagination import (
BasePagination,
OffsetPagination,
PagePagination,
PaginationInterface,
)
from .sort import (
BaseSort,
Expand Down
8 changes: 4 additions & 4 deletions pydantic_filters/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

from typing_extensions import TypeAlias, Union

if sys.version_info >= (3, 10):
if sys.version_info >= (3, 10): # pragma: no cover
from types import NoneType
else:
else: # pragma: no cover
NoneType = type(None)

if sys.version_info >= (3, 9):
if sys.version_info >= (3, 9): # pragma: no cover
from types import GenericAlias
else:
else: # pragma: no cover
GenericAlias = type(List[int])


Expand Down
17 changes: 7 additions & 10 deletions pydantic_filters/drivers/sqlalchemy/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@

from pydantic_filters import (
BaseFilter,
BasePagination,
BaseSort,
PaginationInterface,
SortByOrder,
)

from ._exceptions import AttributeNotFoundSaDriverError
from ._mapping import filter_to_column_clauses, filter_to_column_options

_Filter = TypeVar("_Filter", bound=BaseFilter)
_Pagination = TypeVar("_Pagination", bound=BasePagination)
_Pagination = TypeVar("_Pagination", bound=PaginationInterface)
_Sort = TypeVar("_Sort", bound=BaseSort)
_Model = TypeVar("_Model", bound=so.DeclarativeBase)
_T = TypeVar("_T")
Expand All @@ -41,14 +41,11 @@ def append_pagination_to_statement(
pagination: _Pagination,
) -> sa.Select[_T]:

if pagination.limit is not None:
statement = (
statement
.limit(pagination.limit)
.offset(pagination.offset)
)

return statement
return (
statement
.limit(pagination.get_limit())
.offset(pagination.get_offset())
)


def append_sort_to_statement(
Expand Down
6 changes: 5 additions & 1 deletion pydantic_filters/pagination/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .base import BasePagination
from ._base import (
OffsetPagination,
PagePagination,
PaginationInterface,
)
36 changes: 36 additions & 0 deletions pydantic_filters/pagination/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from abc import ABC, abstractmethod

from pydantic import BaseModel, Field


class PaginationInterface(BaseModel, ABC):

@abstractmethod
def get_limit(self) -> int:
...

@abstractmethod
def get_offset(self) -> int:
...


class OffsetPagination(PaginationInterface):
limit: int = Field(100, ge=1)
offset: int = Field(0, ge=0)

def get_limit(self) -> int:
return self.limit

def get_offset(self) -> int:
return self.offset


class PagePagination(PaginationInterface):
page: int = Field(1, ge=1)
per_page: int = Field(100, ge=1)

def get_limit(self) -> int:
return self.per_page

def get_offset(self) -> int:
return (self.page - 1) * self.per_page
10 changes: 0 additions & 10 deletions pydantic_filters/pagination/base.py

This file was deleted.

12 changes: 5 additions & 7 deletions pydantic_filters/plugins/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, TypeVar
from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar

from pydantic_filters.filter._base import BaseFilter

Expand Down Expand Up @@ -31,8 +31,7 @@ def squash_filter(
filter_: Type[_Filter],
prefix: str,
delimiter: str,
converter: Callable[["FieldInfo"], _T],
) -> Dict[str, _T]:
) -> Dict[str, "FieldInfo"]:
"""
**Example:**
Expand All @@ -44,23 +43,22 @@ def squash_filter(
>>> class MyFilter(BaseFilter):
... a: int
... b: NestedFilter
>>> squash_filter(MyFilter, "", "__", lambda f: None)
{"a": None, "b__c": None, "b__d__e": None}
>>> squash_filter(MyFilter, "", "__")
{"a": ..., "b__c": ..., "b__d__e": ...}
"""

squashed = {}

for key in chain(filter_.filter_fields, filter_.search_fields):
field_info = filter_.model_fields[key]
squashed[add_prefix(key, prefix, delimiter)] = converter(field_info)
squashed[add_prefix(key, prefix, delimiter)] = field_info

for key, nested_filter in filter_.nested_filters.items():
squashed.update(
squash_filter(
filter_=nested_filter,
prefix=add_prefix(key, prefix, delimiter),
delimiter=delimiter,
converter=converter,
),
)

Expand Down
53 changes: 33 additions & 20 deletions pydantic_filters/plugins/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from copy import deepcopy
from inspect import Parameter, signature
from typing import Any, List, Tuple, Type, TypeVar
from typing import Any, List, Type, TypeVar

from fastapi import Depends, Query
from fastapi import params as fastapi_params
from pydantic import BaseModel
from pydantic.fields import FieldInfo

from pydantic_filters import BaseFilter, BasePagination, BaseSort
from pydantic_filters import BaseFilter, BaseSort, PaginationInterface

from ._utils import inflate_filter, squash_filter

_Filter = TypeVar("_Filter", bound=BaseFilter)
_Pagination = TypeVar("_Pagination", bound=BasePagination)
_Pagination = TypeVar("_Pagination", bound=PaginationInterface)
_Sort = TypeVar("_Sort", bound=BaseSort)
_PydanticModel = TypeVar("_PydanticModel", bound=BaseModel)


def _field_info_to_query(
Expand Down Expand Up @@ -42,32 +44,28 @@ def _get_custom_params(
delimiter: str,
) -> List[Parameter]:

def _converter(f: FieldInfo) -> Tuple[FieldInfo, fastapi_params.Query]:
return f, _field_info_to_query(f)

squashed = squash_filter(
filter_=filter_,
prefix=prefix,
delimiter=delimiter,
converter=_converter,
)

return [
Parameter(
name=key,
kind=Parameter.KEYWORD_ONLY,
default=fastapi_query,
default=_field_info_to_query(field_info),
annotation=field_info.annotation,
)
for key, (field_info, fastapi_query) in squashed.items()
for key, field_info in squashed.items()
]


def FilterDepends( # noqa: N802
filter_: Type[_Filter],
prefix: str = "",
delimiter: str = "__",
) -> _Filter:
) -> _Filter: # pragma: no cover

def _depends(**kwargs: Any) -> _Filter: # noqa: ANN401
"""Signature of this function is replaced with Query parameters,
Expand All @@ -81,28 +79,43 @@ def _depends(**kwargs: Any) -> _Filter: # noqa: ANN401
data=kwargs,
)

# Переопределяем то, что функция принимает на вход
_depends.__signature__ = signature(_depends).replace(
parameters=_get_custom_params(filter_, prefix, delimiter),
)

return Depends(_depends)


def PaginationDepends(pagination: Type[_Pagination]) -> _Pagination:
limit_field = pagination.model_fields["limit"]
offset_field = pagination.model_fields["offset"]
def _PydanticModelAsDepends(pydantic_model: Type[_PydanticModel]) -> _PydanticModel: # pragma: no cover
def _depends(**kwargs: Any) -> _Filter: # noqa: ANN401
return pydantic_model.model_construct(**kwargs)

custom_params = []
for key, field_info in pydantic_model.model_fields.items():
if issubclass(field_info.annotation, BaseModel):
continue

custom_params.append(
Parameter(
name=key,
kind=Parameter.KEYWORD_ONLY,
default=_field_info_to_query(field_info),
annotation=field_info.annotation,
),
)

def _depends(
limit: limit_field.annotation = _field_info_to_query(limit_field),
offset: offset_field.annotation = _field_info_to_query(offset_field),
) -> _Filter:
return pagination.model_construct(limit=limit, offset=offset)
_depends.__signature__ = signature(_depends).replace(
parameters=custom_params,
)

return Depends(_depends)


def SortDepends(sort: Type[_Sort]) -> _Sort:
def PaginationDepends(pagination: Type[_Pagination]) -> _Pagination: # pragma: no cover
return _PydanticModelAsDepends(pagination)


def SortDepends(sort: Type[_Sort]) -> _Sort: # pragma: no cover
sort_by_field = sort.model_fields["sort_by"]
sort_by_order_field = sort.model_fields["sort_by_order"]

Expand Down
23 changes: 14 additions & 9 deletions tests/drivers/sqlalchemy/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import sqlalchemy.orm as so
from sqlalchemy.dialects.sqlite import dialect as sa_sqlite_dialect

from pydantic_filters import BaseFilter, BasePagination, BaseSort, SortByOrder
from pydantic_filters import (
BaseFilter,
OffsetPagination,
BaseSort,
SortByOrder,
PaginationInterface,
PagePagination,
)
from pydantic_filters.drivers.sqlalchemy._main import (
append_filter_to_statement,
append_pagination_to_statement,
Expand Down Expand Up @@ -67,15 +74,13 @@ def test_append_filter_to_statement() -> None:
@pytest.mark.parametrize(
"pagination, expected_stmt",
[
(BasePagination(),
"SELECT a.id, a.b_id FROM a"),
(BasePagination(limit=10),
"SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 0"),
(BasePagination(limit=10, offset=20),
"SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 20"),
(OffsetPagination(limit=10), "SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 0"),
(OffsetPagination(limit=10, offset=20), "SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 20"),
(PagePagination(page=1, per_page=10), "SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 0"),
(PagePagination(page=10, per_page=10), "SELECT a.id, a.b_id FROM a LIMIT 10 OFFSET 90"),
]
)
def test_append_pagination_to_statement(pagination: BasePagination, expected_stmt: str) -> None:
def test_append_pagination_to_statement(pagination: PaginationInterface, expected_stmt: str) -> None:
stmt = append_pagination_to_statement(
statement=sa.select(AModel),
pagination=pagination,
Expand Down Expand Up @@ -109,7 +114,7 @@ def test_append_to_statement() -> None:
statement=sa.select(AModel),
model=AModel,
filter_=AFilter(id=1, b=BFilter(id=2)),
pagination=BasePagination(limit=10, offset=20),
pagination=OffsetPagination(limit=10, offset=20),
sort=BaseSort(sort_by="id", sort_by_order=SortByOrder.desc)
)
expected_stmt = (
Expand Down
31 changes: 31 additions & 0 deletions tests/pagination/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from pydantic_filters.pagination import PaginationInterface, PagePagination, OffsetPagination


class TestOffsetPagination:

@pytest.mark.parametrize(
"obj, limit, offset",
[
(OffsetPagination(limit=1, offset=1), 1, 1),
(OffsetPagination(limit=10, offset=0), 10, 0),
]
)
def test_get_limit_get_offset(self, obj: PaginationInterface, limit: int, offset: int) -> None:
assert obj.get_limit() == limit
assert obj.get_offset() == offset


class TestPagePagination:

@pytest.mark.parametrize(
"obj, limit, offset",
[
(PagePagination(page=1, per_page=10), 10, 0),
(PagePagination(page=12, per_page=34), 34, 374),
]
)
def test_get_limit_get_offset(self, obj: PaginationInterface, limit: int, offset: int) -> None:
assert obj.get_limit() == limit
assert obj.get_offset() == offset
8 changes: 5 additions & 3 deletions tests/plugins/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Dict, Any
from unittest import mock

import pytest
from pydantic.fields import FieldInfo

from pydantic_filters import BaseFilter
from pydantic_filters.plugins._utils import (
Expand Down Expand Up @@ -48,7 +50,8 @@ class FilterTest(BaseFilter):
a: int
b: NestedFilter



@mock.patch.object(FieldInfo, "__eq__", new=lambda *_: True)
@pytest.mark.parametrize(
"prefix, delimiter, res",
[
Expand All @@ -61,8 +64,7 @@ def test_squash_filter(prefix: str, delimiter: str, res: Dict[str, Any]):
assert squash_filter(
filter_=FilterTest,
prefix=prefix,
delimiter=delimiter,
converter=lambda f: None,
delimiter=delimiter,
) == res


Expand Down

0 comments on commit 09a60be

Please sign in to comment.