-
-
Save prostomarkeloff/9c321d9afd25d39e4a70804073e06237 to your computer and use it in GitHub Desktop.
FastAPI CBV
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from typing import Any, Callable, ClassVar, List, Type, TypeVar, Union, get_type_hints | |
from fastapi import APIRouter, Depends | |
from starlette.routing import Route, WebSocketRoute | |
T = TypeVar("T") | |
def cbv(router: APIRouter) -> Callable[[Type[T]], Type[T]]: | |
def decorator(cls: Type[T]) -> Type[T]: | |
return _cbv(router, cls) | |
return decorator | |
def _cbv(router: APIRouter, cls: Type[T]) -> Type[T]: | |
if not issubclass(cls, BaseCBV): | |
# TODO: Don't set up cls as a CBV via inheritance; just do the __init_subclass__ in a normal function? | |
cls = type(cls.__name__, (cls, BaseCBV), {}) | |
cbv_router = APIRouter() | |
functions = inspect.getmembers(cls, inspect.isfunction) | |
routes_by_endpoint = { | |
route.endpoint: route for route in router.routes if isinstance(route, (Route, WebSocketRoute)) | |
} | |
for _, func in functions: | |
route = routes_by_endpoint.get(func) | |
if route is None: | |
continue | |
router.routes.remove(route) | |
_update_cbv_route_endpoint_signature(cls, route) | |
cbv_router.routes.append(route) | |
router.include_router(cbv_router) | |
return cls | |
def _update_cbv_route_endpoint_signature(cls: Type[Any], route: Union[Route, WebSocketRoute]) -> None: | |
old_endpoint = route.endpoint | |
old_signature = inspect.signature(old_endpoint) | |
old_parameters: List[inspect.Parameter] = list(old_signature.parameters.values()) | |
old_first_parameter = old_parameters[0] | |
new_first_parameter = old_first_parameter.replace(default=Depends(cls)) | |
new_parameters = [new_first_parameter] + [ | |
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in old_parameters[1:] | |
] | |
new_signature = old_signature.replace(parameters=new_parameters) | |
setattr(route.endpoint, "__signature__", new_signature) | |
class BaseCBV: | |
def __init_subclass__(cls) -> None: | |
old_init: Callable[..., Any] = cls.__init__ | |
old_signature = inspect.signature(old_init) | |
old_parameters = list(old_signature.parameters.values()) | |
new_parameters = [ | |
x for x in old_parameters if x.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) | |
] | |
dependency_names: List[str] = [] | |
for name, hint in get_type_hints(cls).items(): | |
if getattr(hint, "__origin__", None) is ClassVar: | |
continue | |
value = getattr(cls, name, Ellipsis) | |
parameter_kwargs = {} | |
if value is not Ellipsis: | |
parameter_kwargs["default"] = value | |
dependency_names.append(name) | |
new_parameters.append( | |
inspect.Parameter(name=name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=hint, **parameter_kwargs) | |
) | |
new_signature = old_signature.replace(parameters=new_parameters) | |
def new_init(self: Any, *args: Any, **kwargs: Any) -> None: | |
for dep_name in dependency_names: | |
dep_value = kwargs.pop(dep_name) | |
setattr(self, dep_name, dep_value) | |
old_init(self, *args, **kwargs) | |
setattr(new_init, "__signature__", new_signature) | |
setattr(cls, "__init__", new_init) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import APIRouter, Depends, FastAPI | |
from starlette.testclient import TestClient | |
from fastapi_cbv import cbv | |
router = APIRouter() | |
def dependency() -> int: | |
return 1 | |
@cbv(router) | |
class CBV: | |
x: int = Depends(dependency) | |
def __init__(self, z: int = Depends(dependency)): | |
self.y = 1 | |
self.z = z | |
@router.get("/", response_model=int) | |
def f(self) -> int: | |
return self.x + self.y + self.z | |
app = FastAPI() | |
app.include_router(router) | |
client = TestClient(app) | |
assert client.get("/").content == b"3" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment