-
Notifications
You must be signed in to change notification settings - Fork 1
/
helpers.py
104 lines (91 loc) · 3.69 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Credit: FastAPI-Utils
Source: https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py
"""
import inspect
from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints
from fastapi import APIRouter, Depends
from pydantic.typing import is_classvar
from starlette.routing import Route, WebSocketRoute
T = TypeVar("T")
CBV_CLASS_KEY = "__cbv_class__"
def class_based_view(router: APIRouter, cls: Type[T]) -> Type[T]:
"""
Replaces any methods of the provided class `cls` that are endpoints of routes in `router` with updated
function calls that will properly inject an instance of `cls`.
"""
_init_cbv(cls)
cbv_router = APIRouter()
function_members = inspect.getmembers(cls, inspect.isfunction)
functions_set = set(func for _, func in function_members)
cbv_routes = [
route
for route in router.routes
if isinstance(route, (Route, WebSocketRoute))
and route.endpoint in functions_set
]
for route in cbv_routes:
router.routes.remove(route)
_update_cbv_route_endpoint_signature(cls, route)
cbv_router.routes.append(route)
router.include_router(cbv_router)
return cls
def _init_cbv(cls: Type[Any]) -> None:
"""
Idempotently modifies the provided `cls`, performing the following modifications:
* The `__init__` function is updated to set any class-annotated dependencies as instance attributes
* The `__signature__` attribute is updated to indicate to FastAPI what arguments should be passed to the initializer
"""
if getattr(cls, CBV_CLASS_KEY, False): # pragma: no cover
return # Already initialized
old_init: Callable[..., Any] = cls.__init__
old_signature = inspect.signature(old_init)
old_parameters = list(old_signature.parameters.values())[
1:
] # drop `self` parameter
new_parameters = [
x
for x in old_parameters
if x.kind
not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
]
dependency_names: List[str] = []
for name, hint in get_type_hints(cls).items():
if is_classvar(hint):
continue
parameter_kwargs = {"default": getattr(cls, name, Ellipsis)}
dependency_names.append(name)
new_parameters.append(
inspect.Parameter(
name=name,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=hint,
**parameter_kwargs,
)
)
new_signature = old_signature.replace(parameters=new_parameters)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
for dep_name in dependency_names:
dep_value = kwargs.pop(dep_name)
setattr(self, dep_name, dep_value)
old_init(self, *args, **kwargs)
setattr(cls, "__signature__", new_signature)
setattr(cls, "__init__", new_init)
setattr(cls, CBV_CLASS_KEY, True)
def _update_cbv_route_endpoint_signature(
cls: Type[Any], route: Union[Route, WebSocketRoute]
) -> None:
"""
Fixes the endpoint signature for a cbv route to ensure FastAPI performs dependency injection properly.
"""
old_endpoint = route.endpoint
old_signature = inspect.signature(old_endpoint)
old_parameters: List[inspect.Parameter] = list(old_signature.parameters.values())
old_first_parameter = old_parameters[0]
new_first_parameter = old_first_parameter.replace(default=Depends(cls))
new_parameters = [new_first_parameter] + [
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY)
for parameter in old_parameters[1:]
]
new_signature = old_signature.replace(parameters=new_parameters)
setattr(route.endpoint, "__signature__", new_signature)