Last active
September 9, 2022 09:04
-
-
Save antont/9c9e502dac0e8802497b29515d3d3769 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
fastapi-cache decorator modified for SQLModel use. | |
Also, caches and serves the final JSON, instead of Python objects. | |
Thus looses pydantic validation of the response objects up in fastapi, | |
but this cache func does `response_model.from_orm` itself, so runs the same validation. | |
Elsewhere, there's an example that uses unmodified fastapi-cache with SQLModel, that works too | |
https://github.com/jonra1993/fastapi-alembic-sqlmodel-async/blob/main/fastapi-alembic-sqlmodel-async/app/api/v1/endpoints/cache.py | |
""" | |
import os | |
#import asyncio | |
from functools import wraps#, partial | |
from typing import TYPE_CHECKING, Any, Callable, Optional, Type | |
from dataclasses import dataclass | |
from sqlmodel import SQLModel | |
from fastapi.responses import Response, RedirectResponse, JSONResponse | |
from fastapi_cache import FastAPICache | |
from fastapi_cache.coder import Coder | |
NAMESPACE = os.environ.get('GAE_SERVICE') or os.uname().nodename #NOTE: used in appengine, modify for your needs | |
class NoopJSONResponse(JSONResponse): | |
"""JSON Response from data that's already JSON, just to set content-type""" | |
def render(self, content) -> bytes: | |
return content.encode("utf-8") | |
class SQLModelFastapicacheCoder: | |
def encode(ob: SQLModel): | |
return ob.json() | |
#def decode(val): | |
# return val #we wanna return the json, not convert it to ob again | |
if TYPE_CHECKING: | |
import concurrent.futures | |
def user_as_id(kwargs, argname): | |
if argname in kwargs: | |
arg = kwargs[argname] | |
if arg is not None: | |
kwargs[f'{argname}_id'] = arg.id | |
del kwargs[argname] | |
return kwargs | |
return kwargs | |
"""registry of cached routes for pre-populating to run through""" | |
@dataclass | |
class CachedRoute: | |
func: Callable | |
namespace: str | |
inner: Callable | |
cached_routes: dict[str, CachedRoute] = {} | |
def sqlmodel_cache( | |
expire: int = None, | |
coder: Type[Coder] = None, | |
key_builder: Callable = None, | |
namespace: Optional[str] = NAMESPACE, | |
executor: Optional["concurrent.futures.Executor"] = None, | |
path = "", | |
response_model = None, | |
response_model_is_list = False | |
): | |
""" | |
cache all function | |
:param namespace: | |
:param expire: | |
:param coder: | |
:param key_builder: | |
:param executor: | |
:return: | |
""" | |
def wrapper(func): | |
@wraps(func) | |
async def inner(*args, **kwargs): | |
nonlocal coder | |
nonlocal expire | |
nonlocal key_builder | |
copy_kwargs = kwargs.copy() | |
request = copy_kwargs.pop("request") | |
print("--- FASTAPI Cache SQLModel wrapper ---") | |
if (request.headers.get("Cache-Control") == "no-store") \ | |
or not FastAPICache.get_enable() \ | |
or request.method != "GET": | |
return await func(*args, **kwargs) | |
coder = coder or FastAPICache.get_coder() | |
expire = expire or FastAPICache.get_expire() | |
key_builder = key_builder or FastAPICache.get_key_builder() | |
backend = FastAPICache.get_backend() | |
del copy_kwargs['session'] | |
cache_key = key_builder( | |
func, namespace, path, args=args, kwargs=copy_kwargs | |
) | |
print("CACHE KEY:", cache_key) | |
ttl, ret = await backend.get_with_ttl(cache_key) | |
response = None #will be either new or cached | |
if ret is None: | |
ret = await func(*args, **kwargs) | |
if isinstance(ret, RedirectResponse): | |
#TODO: we could serialize the redirect to cache, but how? XXX | |
#await backend.set(cache_key, ret., expire or FastAPICache.get_expire()) | |
return ret | |
if isinstance(ret, Response): | |
if ret.status_code != 200: | |
return ret | |
#typeguard.check_type would work for List[T], but we go with extra param now instead | |
#if isinstance(response_model, List): | |
if response_model_is_list: | |
full_ret_objs = [response_model.from_orm(ob).json() for ob in ret] | |
if len(full_ret_objs) > 0: | |
full_ret_json_elements = ",".join(full_ret_objs) | |
full_ret_json = f"[{full_ret_json_elements}]" | |
print("CACHED LIST JSON:", len(full_ret_json)) #, type(full_ret_objs), full_ret_json_elements) | |
else: | |
full_ret_json = "[]" | |
else: | |
full_ret_json = response_model.from_orm(ret).json() | |
await backend.set(cache_key, full_ret_json, expire or FastAPICache.get_expire()) | |
print("FASTAPI CACHE, store new JSON:", len(full_ret_json)) | |
response = NoopJSONResponse(full_ret_json) | |
if_none_match = request.headers.get("if-none-match") | |
response.headers["Cache-Control"] = f"max-age={ttl}" | |
etag = f"W/{hash(full_ret_json)}" | |
if if_none_match == etag: | |
response.status_code = 304 | |
return response | |
response.headers["ETag"] = etag | |
response.headers["Cached-Server-Response"] = "false" | |
else: | |
response = NoopJSONResponse(ret.decode()) | |
response.headers["Cached-Server-Response"] = "true" | |
#print("RET:", type(ret), ret) | |
#print("RSP:", type(response), response) | |
print("FASTAPI CACHE Return (new or cached:", type(response)) #len(ret)) #, len(response.body)) | |
return response | |
cached_routes[path] = CachedRoute(func, namespace, inner) | |
return inner | |
return wrapper |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment