Created
June 12, 2024 19:46
-
-
Save ajskateboarder/2285fa9d50bf612f87643e4df290eb43 to your computer and use it in GitHub Desktop.
Rust-like enums in Python, featuring Result kinda
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from string import ascii_lowercase, ascii_uppercase | |
from itertools import product | |
from dataclasses import dataclass | |
import traceback | |
import inspect | |
class UnwrapError(Exception): pass | |
UNWRAP_FUNCS = ("unwrap", "unwrap_or", "unwrap_or_else", "and_then", "__class__") | |
RESULT_TYPES = ("_err_tuple", "_err", "_ok", "_ok_tuple") | |
@dataclass(slots=True) | |
class Base: | |
def __post_init__(self, **args): | |
for k, v in self.__annotations__.items(): | |
if not isinstance(getattr(self, k), v): | |
raise TypeError(f"Incorrect types passed to {self.__class__.__name__} struct") | |
def __getattribute__(self, name): | |
stack = inspect.stack() | |
caller = inspect.getframeinfo(stack[1][0]) | |
cla = object.__getattribute__(self, "__class__") | |
# verify that you didn't try to open a result value before unwrap | |
if cla.__name__ in RESULT_TYPES and caller.function == "<module>" and name not in UNWRAP_FUNCS: | |
raise Exception("Cannot access properties until result is unwrapped") | |
try: | |
x = object.__getattribute__(self, "_d") | |
return object.__getattribute__(x, name) | |
except AttributeError: | |
return object.__getattribute__(self, name) | |
def unwrap(self): | |
x = self.__class__.__name__ | |
if x == "_err_tuple": | |
raise UnwrapError(self._a[0] if len(self._a) == 1 else self._a) | |
elif x == "_err": | |
raise UnwrapError(repr(self).replace("_err", "", 1)) | |
elif x == "_ok_tuple": | |
return self._a | |
elif x == "_ok": | |
return self._d | |
else: | |
raise NotImplementedError("Result struct members must use .ok and .err to unwrap") | |
def unwrap_or(self, value): | |
try: | |
return self.unwrap() | |
except UnwrapError: | |
return value | |
def unwrap_or_else(self, cb): | |
try: | |
return self.unwrap() | |
except UnwrapError: | |
return cb() | |
def and_then(self, cb): | |
try: | |
return cb(self.unwrap()) | |
except UnwrapError: | |
pass | |
def struct(name, **args): | |
_n = "\n " | |
assert all([inspect.isclass(k) for k in args.values()]), "Invalid parameter types" | |
mapped = [": ".join([p, t.__name__]) for p, t in list(args.items())] | |
expected_types = f"[{','.join([t.__name__ for t in args.values()])}]" | |
wrapper = f"""class {name}(Base): | |
def __init__(self, **args): | |
class {name}_t(Base):\n {_n.join(mapped)} | |
self._d = dataclass({name}_t)(**args) | |
def __getitem__(self, i): | |
return self._d[i] | |
""" | |
exec(wrapper) | |
return locals().pop(name) | |
struct.ok = lambda **args: struct("_ok", **args) | |
struct.err = lambda **args: struct("_err", **args) | |
def tstruct(name, *arg_types): | |
params = ["".join(i) for _, i in zip(arg_types, product(ascii_lowercase+ascii_uppercase, repeat=3))] | |
expected_types = f"[{','.join([t.__name__ for t in arg_types])}]" | |
function = f"""class {name}_tuple(Base): | |
def __init__(self, {'=None,'.join(params)}=None): | |
if not any([type(k) == t for t, k in list(zip({expected_types}, [{','.join(params)}]))]): | |
raise TypeError("Incorrect types passed to {name} tuple struct") | |
if not all([e is not None for e in [{','.join(params)}]]): | |
raise TypeError("{name} member is missing required arguments") | |
self._a = ({','.join(params)},) | |
def __getitem__(self, i): | |
return self._a[i] | |
""" | |
exec(function) | |
return locals().pop(name + "_tuple") | |
tstruct.ok = lambda *args: tstruct("_ok", *args) | |
tstruct.err = lambda *args: tstruct("_err", *args) |
qhar
cul
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
ok