Last active
November 6, 2020 17:35
-
-
Save antonagestam/bc437a063536704eca3d60166fd65e32 to your computer and use it in GitHub Desktop.
A (very naïve) implementation of dependent parsers for Python dataclasses.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import dataclasses | |
from typing import Annotated | |
from typing import Callable | |
from typing import Generic | |
from typing import Iterable | |
from typing import Mapping | |
from typing import Tuple | |
from typing import Type | |
from typing import TypeVar | |
from typing import get_type_hints | |
from phantom.ext.iso3166 import CountryCode | |
from phantom.ext.phonenumbers import FormattedPhoneNumber | |
T = TypeVar("T") | |
F = TypeVar("F", bound=Callable) | |
class Parser(Generic[F]): | |
def __init__(self, parser: F) -> None: | |
self.parser = parser | |
def __class_getitem__(cls, item: F) -> Parser[F]: | |
return cls(item) | |
def _get_parsers(dcls: type) -> Iterable[Tuple[str, Callable]]: | |
assert dataclasses.is_dataclass(dcls) | |
class_annotations = get_type_hints(dcls, include_extras=True) | |
for field_name, annotation in class_annotations.items(): | |
metadata = getattr(annotation, "__metadata__", ()) | |
for metadatum in metadata: | |
if isinstance(metadatum, Parser): | |
# Make sure the parser's return type is compatible with the field. | |
parser_annotations = get_type_hints(metadatum.parser) | |
parser_ret = parser_annotations["return"] | |
assert parser_ret is annotation.__origin__ or issubclass( | |
parser_ret, annotation.__origin__ | |
) | |
# TODO: Cache the dependencies to reuse when resolving. | |
# TODO: Iterate and transform less. | |
# Make sure every dependency of the parser is a supertype of the | |
# corresponding field on the dataclass. | |
dependencies = { | |
dependency_name: parser_annotations[dependency_name] | |
for dependency_name in set(parser_annotations.keys()) | |
- {"data", "return"} | |
} | |
for dependency_name, dependency_type in dependencies.items(): | |
assert class_annotations[ | |
dependency_name | |
] is dependency_type or issubclass( | |
class_annotations[dependency_name], dependency_type | |
) | |
yield field_name, metadatum.parser | |
def from_dict(dcls: Type[T], data: Mapping[str, object]) -> T: | |
fields = dataclasses.fields(dcls) | |
# TODO: Raise error if parser annotation is not identical to field annotation | |
# TODO: Raise error if parser return type is not identical to field annotation | |
parsers = dict(_get_parsers(dcls)) | |
annotations = get_type_hints(dcls) | |
def resolve(field: dataclasses.Field) -> object: | |
value = data[field.name] | |
if not isinstance(value, annotations[field.name]): | |
return annotations[field.name](value) | |
return value | |
# Make a first pass resolving all fields without dependencies | |
resolved = { | |
field.name: resolve(field) for field in fields if field.name not in parsers | |
} | |
# This is inefficient and obviously dangerous, but proves my point :) | |
while unresolved := frozenset( | |
{field for field in fields if field.name not in resolved} | |
): | |
for field_name, parser in parsers.items(): | |
dependencies = set(get_type_hints(parser).keys()) - {"data", "return"} | |
parsable = not (dependencies - resolved.keys()) | |
if parsable: | |
resolved[field_name] = parser( | |
data, **{k: v for k, v in resolved.items() if k in dependencies} | |
) | |
return dcls(**resolved) # type: ignore[call-arg] | |
# --- | |
def parse_phone_number( | |
data: Mapping[str, object], country: CountryCode | |
) -> FormattedPhoneNumber: | |
return FormattedPhoneNumber.parse(data["phone_number"], country) | |
@dataclasses.dataclass(frozen=True) | |
class User: | |
country: CountryCode | |
phone_number: Annotated[FormattedPhoneNumber, Parser[parse_phone_number]] | |
assert from_dict(User, {"country": "SE", "phone_number": "0701234567"}) == User( | |
country=CountryCode.parse("SE"), | |
phone_number=FormattedPhoneNumber.parse("+46701234567"), | |
) | |
assert from_dict(User, {"country": "dk", "phone_number": "86180311"}) == User( | |
country=CountryCode.parse("DK"), | |
phone_number=FormattedPhoneNumber.parse("+4586180311"), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment