Skip to content

Instantly share code, notes, and snippets.

@Sachaa-Thanasius
Last active June 18, 2024 02:16
Show Gist options
  • Save Sachaa-Thanasius/571cafc9087a4de1c5b865079753da29 to your computer and use it in GitHub Desktop.
Save Sachaa-Thanasius/571cafc9087a4de1c5b865079753da29 to your computer and use it in GitHub Desktop.
Proto-draft implementation of a potential Refinement typeform in Python. Not official, just my first pass at it for fun. Reference: https://discuss.python.org/t/pep-746-typedmetadata-for-type-checking-of-pep-593-annotated/53834
# pyright: enableExperimentalFeatures=true
# PEP 712 is only active for pyright with the "enableExperimentalFeatures" setting enabled.
import operator
import re
import sys
from typing import (
TYPE_CHECKING,
Callable,
ClassVar,
Final,
Generic,
NoReturn,
Pattern,
Protocol,
Tuple,
TypeVar,
Union,
final,
)
import attrs
if sys.version_info >= (3, 11):
from typing import ParamSpec, Self, reveal_type
else:
from typing_extensions import ParamSpec, Self, reveal_type
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
if sys.version_info >= (3, 9):
from typing import Annotated
else:
from typing_extensions import Annotated
_MISSING = object()
T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")
P2 = ParamSpec("P2")
base_model = attrs.define
adapter = attrs.field
# =====================================================================================================================
# ==== Implementation of Refined.
# =====================================================================================================================
if TYPE_CHECKING:
Refined = Annotated
else:
import operator
from typing import _GenericAlias, _tp_cache, _type_check, _type_repr
if sys.version_info >= (3, 12):
from typing import Unpack
else:
from typing_extensions import Unpack
if sys.version_info >= (3, 10):
from typing import get_origin
else:
from typing_extensions import get_origin
# Almost an exact reimplementation of Annotated.
@final
class _RefinedGenericAlias(_GenericAlias, _root=True):
if TYPE_CHECKING:
__origin__: type
__refinements__: Tuple[object, ...]
def __init__(self, origin: type, refinements: Tuple[object, ...]):
if isinstance(origin, _RefinedGenericAlias):
refinements = origin.__refinements__ + refinements
origin = origin.__origin__
super().__init__(origin, origin)
self.__refinements__ = refinements
def copy_with(self, params: Tuple[object, ...]):
if len(params) != 1:
raise AssertionError
new_type = params[0]
return _RefinedGenericAlias(new_type, self.__refinements__)
def __repr__(self):
return f"Refined[{_type_repr(self.__origin__)}, {', '.join(repr(r) for r in self.__refinements__)}]"
def __reduce__(self):
return operator.getitem, (Refined, (self.__origin__, *self.__refinements__))
def __eq__(self, other: object, /):
if isinstance(other, type(self)):
if self.__origin__ != other.__origin__:
return False
return self.__refinements__ == other.__refinements__
return NotImplemented
def __hash__(self):
return hash((self.__origin__, self.__refinements))
@final
class Refined:
__slots__ = ()
def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
raise TypeError("Type Refined cannot be instantiated.")
def __init_subclass__(cls, *args: object, **kwargs: object) -> NoReturn:
raise TypeError(f"Cannot subclass {cls.__module__}.Refined")
def __class_getitem__(
cls,
params: Tuple[type, Unpack[Tuple[Union["TypeRefinement", "ValueRefinement"], ...]]],
) -> _RefinedGenericAlias:
if not isinstance(params, tuple):
params = (params,)
return cls._class_getitem_inner(cls, *params)
@_tp_cache(typed=True)
def _class_getitem_inner(
cls,
*params: Unpack[Tuple[type, Unpack[Tuple[Union["TypeRefinement", "ValueRefinement"], ...]]]],
) -> _RefinedGenericAlias:
if len(params) < 2:
raise TypeError("Refined[...] should be used with at least two arguments (a type and an annotation).")
if (not isinstance(params[0], type)) and getattr(params[0], "__typing_is_unpacked_typevartuple__", False):
raise TypeError("Refined[...] should not be used with an unpacked TypeVarTuple.")
allowed_special_forms = {ClassVar, Final}
if get_origin(params[0]) in allowed_special_forms:
origin = params[0]
else:
msg = "Refined[t, ...]: t must be a type."
origin = _type_check(params[0], msg)
refinements = tuple(params[1:])
return _RefinedGenericAlias(origin, refinements)
class TypeRefinement(Protocol):
def __supports_type__(self, t: type) -> bool: ...
class ValueRefinement(Protocol):
def __supports_value__(self, o: object) -> bool: ...
class NumCmp:
_op_map: ClassVar = {
"eq": operator.eq,
"ne": operator.ne,
"gt": operator.gt,
"ge": operator.ge,
"lt": operator.lt,
"le": operator.le,
}
def __init__(
self,
eq: object = _MISSING,
ne: object = _MISSING,
gt: object = _MISSING,
ge: object = _MISSING,
lt: object = _MISSING,
le: object = _MISSING,
):
self.eq = eq
self.ne = ne
self.gt = gt
self.ge = ge
self.lt = lt
self.le = le
def __supports_value__(self, o: object) -> bool:
cond = True
for cmp_name, cmp_op in self._op_map.items():
if (cmp_val := getattr(self, cmp_name)) is not _MISSING:
cond &= cmp_op(o, cmp_val)
return cond
class RePtrn:
def __init__(self, pattern: Union[str, Pattern[str]]):
self.pattern = pattern if isinstance(pattern, Pattern) else re.compile(pattern)
def __supports_value__(self, o: str) -> bool:
return self.pattern.match(o) is not None
# =====================================================================================================================
# ==== Implementation of parse to superficially match the semantics of Pydantic's thing/use case.
# =====================================================================================================================
class ValidationError(Exception):
pass
@final
class Parser(Generic[P, T]):
__slots__ = ("typer",)
def __init__(self, typer: Callable[P, T]):
self.typer = typer
def __init_subclass__(cls, *args: object, **kwargs: object) -> NoReturn:
raise TypeError(f"Cannot subclass {cls.__module__}.Parser")
def __or__(self, other: "Parser[P2, T]", /) -> "Parser[P2, T]":
if not isinstance(other, Parser): # pyright: ignore [reportUnnecessaryIsInstance]
return NotImplemented
def temp(*args: P2.args, **kwargs: P2.kwargs) -> T:
result = object()
for typer in (self.typer, other.typer):
try:
result = typer(*args, **kwargs)
except ValidationError:
print(f"Failed to parse {(args, kwargs)} with {typer}. Attempting next.") # noqa: T201
else:
return result
raise ValidationError
return Parser(temp)
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
try:
return self.typer(*args, **kwargs)
except Exception as exc: # noqa: BLE001
raise ValidationError from exc
def _convert(self, fn: Callable[[T], U]) -> "Parser[P, U]":
def temp(*args: P.args, **kwargs: P.kwargs) -> U:
return fn(self.typer(*args, **kwargs))
return Parser(temp)
def transform(self, fn: Callable[[T], U]) -> "Parser[P, U]":
return self._convert(fn)
def parse(self, fn: Callable[[T], U]) -> "Parser[P, U]":
return self._convert(fn)
def ge(self, floor: int) -> Self:
# XXX: Nonfunctional placeholder.
return self
def lt(self, ceil: int) -> Self:
# XXX: Nonfunctional placeholder.
return self
def parse(tp: Callable[P, T]) -> Parser[P, T]:
return Parser(tp)
# =====================================================================================================================
# ==== Attempt at an example using the above.
# =====================================================================================================================
# Pretend these classes are subclasses of pydantic.BaseModel instead of fresh classes being wrapped by class decorators.
# This is what Pydantic wants their transformers and validators to look like.
@base_model
class Before:
username: Annotated[str, parse(str).transform(str.lower)]
birthday: Annotated[int, (parse(int) | parse(str).transform(str.strip).parse(int)).ge(0).lt(512)]
age: Annotated[int, parse(int)]
# This is an attrs class with PEP 712 active, and imo looks like a better alternative.
@base_model
class After:
username: str = adapter(converter=parse(str).transform(str.lower))
birthday: Refined[int, NumCmp(ge=0, lt=512)] = adapter(converter=(parse(int) | parse(str).transform(str.strip).parse(int)))
age: int = adapter(converter=parse(int))
def test() -> None:
reveal_type(After.__init__)
# Type of "After.__init__" is "(self: After, username: object, birthday: object, age: str | Buffer | SupportsInt | SupportsIndex | SupportsTrunc) -> None"
ex = After(10, "1010", 1.0)
reveal_type(ex.username) # Type of "ex.username" is "str"
print(ex.username)
reveal_type(ex.birthday) # Type of "ex.birthday" is "int"
print(ex.birthday)
reveal_type(ex.age) # Type of "ex.age" is "int"
print(ex.age)
if __name__ == "__main__":
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment