Last active
February 10, 2023 10:25
-
-
Save SF-300/243b73a2260632bcddf1d99af37e43f9 to your computer and use it in GitHub Desktop.
Standalone FastAPI-like dependency injector POC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import contextlib | |
import inspect | |
from inspect import Parameter, Signature | |
from dataclasses import dataclass | |
from contextlib import AsyncExitStack, AbstractContextManager, AbstractAsyncContextManager | |
from typing import Callable, ParamSpec, TypeVar | |
__all__ = "Depends", "inject_and_run" | |
_P = ParamSpec("_P") | |
_R = TypeVar("_R") | |
@dataclass(frozen=True) | |
class Depends: | |
provider: callable | |
def _get_params_subset(s: Signature, orig_args: list, orig_kwargs: dict): | |
args, kwargs = [], dict() | |
for p in s.parameters.values(): | |
try: | |
if p.kind is Parameter.POSITIONAL_ONLY: | |
args.append(orig_args.pop()) | |
elif p.kind is Parameter.KEYWORD_ONLY: | |
kwargs[p.name] = orig_kwargs.pop(p.name) | |
elif p.kind is Parameter.POSITIONAL_OR_KEYWORD: | |
try: | |
kwargs[p.name] = orig_kwargs.pop(p.name) | |
except KeyError: | |
if p.default is Parameter.empty: | |
args.append(orig_args.pop()) | |
else: | |
return orig_args, orig_kwargs | |
except (KeyError, IndexError): | |
continue | |
return args, kwargs | |
async def inject_and_run(root: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: | |
cache = dict() | |
async with AsyncExitStack() as deffer: | |
async def get_dependency(provider): | |
if provider not in cache: | |
dep = await resolve(provider) | |
if isinstance(dep, AbstractContextManager): | |
dep = deffer.enter_context(dep) | |
elif isinstance(dep, AbstractAsyncContextManager): | |
dep = await deffer.enter_async_context(dep) | |
elif inspect.isawaitable(dep): | |
dep = await dep | |
cache[provider] = dep | |
return cache[provider] | |
async def resolve(func): | |
signature = inspect.signature(func) | |
args_subset, kwargs_subset = _get_params_subset(signature, list(args), dict(kwargs)) | |
bound_params = signature.bind(*args_subset, **kwargs_subset) | |
bound_params.apply_defaults() | |
for name, value in bound_params.arguments.items(): | |
if not isinstance(value, Depends): | |
continue | |
bound_params.arguments[name] = await get_dependency(value.provider) | |
return func(*bound_params.args, **bound_params.kwargs) | |
return await resolve(root) | |
if __name__ == '__main__': | |
async def main(): | |
def connection_settings() -> str: | |
print("connection_settings!") | |
return "Settings!" | |
@contextlib.asynccontextmanager | |
async def connection(omg: int = None, settings: str = Depends(connection_settings)) -> int: | |
print("Victory!") | |
yield 42 | |
print("Success!") | |
def endpoint(conn1: int = Depends(connection), conn2: int = Depends(connection), settings: str = Depends(connection_settings), test=None): | |
print(f"{conn1}, {conn2}, {settings}") | |
return "test" | |
wtf = await inject_and_run(endpoint, test=22) | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment