Skip to content

Instantly share code, notes, and snippets.

@nrbnlulu
Last active September 16, 2024 09:09
Show Gist options
  • Save nrbnlulu/29737080b82b1abb4a1bec0df4cc10be to your computer and use it in GitHub Desktop.
Save nrbnlulu/29737080b82b1abb4a1bec0df4cc10be to your computer and use it in GitHub Desktop.
Strawberry-GraphQL Node type with support to foreign node fields and custom id types.
from __future__ import annotations
import base64
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING, Annotated, Any, Self
import strawberry
from strawberry.annotation import StrawberryAnnotation
from strawberry.types.field import StrawberryField
from strawberry.types.private import StrawberryPrivate
from gql_.context import Info
@dataclass(slots=True)
class LazyRef[T]:
store: dict[str, T]
id_: str
def resolve(self) -> T:
return self.store[self.id_]
# interfaces can't have generics see https://github.com/strawberry-graphql/strawberry/issues/3602
# also I'm not sure it is possible to have a private generic class.
# this is the cause of the various hacks in this file
@strawberry.interface
class NodeV2:
id_p: strawberry.Private[Any]
def __init_subclass__(cls) -> None:
node_registry[cls.__name__] = cls
for k, v in cls.__annotations__.copy().items():
if k.endswith("_p"):
annotation, node_ref = _resolve_field_type(v)
lazy_field = LazyIdField(
k, name=k.removesuffix("_p"), annotation=annotation, node_ref=node_ref
)
setattr(cls, lazy_field.name, lazy_field)
@classmethod
def _create_id(cls, id_) -> GlobalID:
raise NotImplementedError
@strawberry.field
def id(self) -> strawberry.ID:
return strawberry.ID(self._create_id(self.id_p).to_base64())
def _resolve_field_type(
annotation: str | Any,
) -> tuple[StrawberryAnnotation, LazyRef[type[NodeV2]]]:
if isinstance(annotation, str):
is_optional = annotation.startswith(LazyIDOpt.__name__)
assert (
annotation.startswith(LazyID.__name__) or is_optional
), f"Annotation {annotation} is not a LazyID"
_, node_type_name = annotation.split("[")[1].split(",")
node_type_name = node_type_name.replace(" ", "").removesuffix("|").removesuffix("]")
return (
StrawberryAnnotation(strawberry.ID | None if is_optional else strawberry.ID),
LazyRef(node_registry, node_type_name),
)
# this is an annotation
raise NotImplementedError
node_registry: dict[str, type[NodeV2]] = {}
@strawberry.interface
class IntNodeV2(NodeV2):
if TYPE_CHECKING:
id_p: int
@classmethod
def _create_id(cls, id_: int) -> IntGlobalID:
return IntGlobalID(cls.__name__, id_)
class GlobalID[T]:
__slots__ = ("type_name", "id_")
type_name: str
id_: T
def __init__(self, type_name: str, id_: T) -> None:
assert type_name in node_registry, f"{type_name} is not in the node registry"
self.type_name = type_name
self.id_ = id_
def create_id(self) -> str:
return f"{self.type_name}:{self.id_}"
@classmethod
def id_from_str(cls, id_str: str) -> T:
raise NotImplementedError
def to_base64(self) -> str:
return base64.b64encode(f"{self.type_name}:{self.id_}".encode()).decode()
@classmethod
def from_base64(cls, base64_id: str) -> Self:
type_name, id_ = base64.b64decode(base64_id).decode().split(":")
return cls(
type_name=type_name,
id_=cls.id_from_str(id_),
)
def get_type(self) -> type[NodeV2]:
return node_registry[self.type_name]
class IntGlobalID(GlobalID[int]):
@classmethod
def id_from_str(cls, id_str: str) -> int:
return int(id_str)
if TYPE_CHECKING:
type LazyID[T, R: NodeV2] = Annotated[T, StrawberryPrivate(), R]
type LazyIDOpt[T, R: NodeV2] = Annotated[T | None, StrawberryPrivate(), R | None]
else:
class LazyID[T, R](StrawberryPrivate):
def __class_getitem__(cls, item):
return Annotated[Any, StrawberryPrivate(), item]
LazyIDOpt = LazyID
class LazyIdField(StrawberryField):
def __init__(
self,
private_field_name: str,
name: str,
annotation: StrawberryAnnotation,
node_ref: LazyRef[type[NodeV2]],
) -> None:
super().__init__()
self.private_field_name = private_field_name
self.name = name
self.type_annotation = annotation
self.node_ref = node_ref
def foo() -> strawberry.ID: ...
# set the resolver in order not to create a dataclass field
self(foo)
@cached_property
def node_type(self) -> type[NodeV2]:
return self.node_ref.resolve()
def get_result(
self, source: strawberry.auto, info: Info | None, args, kwargs
) -> strawberry.ID | None:
if origin := getattr(source, self.private_field_name):
return strawberry.ID(self.node_type._create_id(origin).to_base64())
return None
__all__ = ["NodeV2", "IntNodeV2", "LazyID", "LazyIDOpt", "IntGlobalID"]
import strawberry
from gql_.node import IntNodeV2, LazyID
@strawberry.type
class Apple(IntNodeV2):
color: str
@strawberry.type
class Worm(IntNodeV2):
length: int
apple_id: LazyID[int, Apple]
@strawberry.type
class Query:
def worm(self) -> Worm:
return Worm(id_p=1, length=1, apple_id=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment