Last active
February 1, 2022 21:38
-
-
Save jerber/22c22d949d553be804e75a7703e17c2d to your computer and use it in GitHub Desktop.
Getting Strawberry to work with nested Pydantic Models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import typing as T | |
from enum import Enum | |
from pydantic import BaseModel | |
from pydantic.fields import ModelField | |
import strawberry | |
from strawberry.field import StrawberryField | |
from strawberry.types.types import TypeDefinition | |
def add_back_fields_to_straw_cls( | |
straw_cls: strawberry.type, model_fields: T.List[ModelField] | |
) -> strawberry.type: | |
td: TypeDefinition = straw_cls._type_definition | |
for model_field in model_fields: | |
print(f"{model_field=}") | |
straw_field = StrawberryField( | |
python_name=model_field.name, | |
graphql_name=model_field.name, | |
type_annotation=straw_cls, | |
default=model_field.default, | |
default_factory=model_field.default_factory, | |
) | |
print(straw_field) | |
td.fields.append(straw_field) | |
return straw_cls | |
def to_strawberry( | |
pydantic_model: T.Type[BaseModel], | |
name: str, | |
input: bool = False, | |
fields: T.Set[str] = None, | |
) -> strawberry.type: | |
removed_field_models: T.List[ModelField] = [] | |
fields = list(fields) if fields else list(pydantic_model.__fields__.keys()) | |
field_names = list(pydantic_model.__fields__.keys()) | |
for field_name in field_names: | |
field_model = pydantic_model.__fields__[field_name] | |
t = field_model.type_ | |
print(f"{field_name=}, {field_model=}") | |
if not isinstance(t, type): | |
continue | |
if issubclass(t, Enum): | |
strawberry.enum(t) | |
if issubclass(t, BaseModel): | |
if t == pydantic_model: | |
print(f"THIS IS RECURSIVE, {pydantic_model=}, {t=}, {field_name=}") | |
removed_field_models.append(field_model) | |
del pydantic_model.__fields__[field_name] | |
continue | |
straw = to_strawberry(pydantic_model=t, name=t.__name__, input=input) | |
field_model.type_ = straw | |
pyd = strawberry.experimental.pydantic | |
decorator = pyd.type if input is False else pyd.input | |
cls = type(f"{name}", (object,), {}) | |
new_cls = decorator(model=pydantic_model, fields=fields)(cls) | |
if removed_field_models: | |
# TODO add_back_fields_to_straw_cls | |
new_cls = add_back_fields_to_straw_cls( | |
straw_cls=new_cls, model_fields=removed_field_models | |
) | |
return new_cls | |
# example cases | |
from pydantic import BaseModel, Field | |
class Teacher(BaseModel): | |
name: str | |
# students: T.List[Student] = Field(default_factory=list) | |
class Student(BaseModel): | |
name: str | |
friends: T.List[Student] = Field(default_factory=list) | |
teacher: Teacher = None | |
best_friend: T.Optional[Student] = None | |
Student.update_forward_refs() | |
Teacher.update_forward_refs() | |
def test_existing_pydantic_func(): | |
""" | |
This errors... You cannot have nested BaseModels since those too have to be converted into strawberry types. | |
So, I created a function that recursively does that, called to_strawberry. | |
""" | |
@strawberry.experimental.pydantic.type(Student, fields=[*Student.__fields__.keys()]) | |
class StudentType: | |
pass | |
def test_custom_pydantic_func(): | |
""" | |
This tests my attempt to get it to work. It's not working as intended. We need a more fundamental solution to | |
going from pydantic to strawberry models. | |
It breaks from infinite recursion when Teacher.students is uncommented. | |
""" | |
StudentType = to_strawberry(pydantic_model=Student, name="StudentType", input=False) | |
assert "friends" in list(StudentType.__dataclass_fields__.keys()) # this fails | |
if __name__ == "__main__": | |
test_existing_pydantic_func() | |
test_custom_pydantic_func() | |
"""Starting a server with it... fails when you try to query Student.friends.""" | |
StudentType = to_strawberry(pydantic_model=Student, name="StudentType", input=False) | |
@strawberry.type | |
class Query: | |
@strawberry.field | |
def get_student(self) -> StudentType: | |
student_type = StudentType(name="Jon") | |
student_type.friends = [] | |
return student_type | |
schema = strawberry.Schema(query=Query) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment