Skip to content

Instantly share code, notes, and snippets.

@ajskateboarder
Created June 12, 2024 19:46
Show Gist options
  • Save ajskateboarder/2285fa9d50bf612f87643e4df290eb43 to your computer and use it in GitHub Desktop.
Save ajskateboarder/2285fa9d50bf612f87643e4df290eb43 to your computer and use it in GitHub Desktop.
Rust-like enums in Python, featuring Result kinda
from string import ascii_lowercase, ascii_uppercase
from itertools import product
from dataclasses import dataclass
import traceback
import inspect
class UnwrapError(Exception): pass
UNWRAP_FUNCS = ("unwrap", "unwrap_or", "unwrap_or_else", "and_then", "__class__")
RESULT_TYPES = ("_err_tuple", "_err", "_ok", "_ok_tuple")
@dataclass(slots=True)
class Base:
def __post_init__(self, **args):
for k, v in self.__annotations__.items():
if not isinstance(getattr(self, k), v):
raise TypeError(f"Incorrect types passed to {self.__class__.__name__} struct")
def __getattribute__(self, name):
stack = inspect.stack()
caller = inspect.getframeinfo(stack[1][0])
cla = object.__getattribute__(self, "__class__")
# verify that you didn't try to open a result value before unwrap
if cla.__name__ in RESULT_TYPES and caller.function == "<module>" and name not in UNWRAP_FUNCS:
raise Exception("Cannot access properties until result is unwrapped")
try:
x = object.__getattribute__(self, "_d")
return object.__getattribute__(x, name)
except AttributeError:
return object.__getattribute__(self, name)
def unwrap(self):
x = self.__class__.__name__
if x == "_err_tuple":
raise UnwrapError(self._a[0] if len(self._a) == 1 else self._a)
elif x == "_err":
raise UnwrapError(repr(self).replace("_err", "", 1))
elif x == "_ok_tuple":
return self._a
elif x == "_ok":
return self._d
else:
raise NotImplementedError("Result struct members must use .ok and .err to unwrap")
def unwrap_or(self, value):
try:
return self.unwrap()
except UnwrapError:
return value
def unwrap_or_else(self, cb):
try:
return self.unwrap()
except UnwrapError:
return cb()
def and_then(self, cb):
try:
return cb(self.unwrap())
except UnwrapError:
pass
def struct(name, **args):
_n = "\n "
assert all([inspect.isclass(k) for k in args.values()]), "Invalid parameter types"
mapped = [": ".join([p, t.__name__]) for p, t in list(args.items())]
expected_types = f"[{','.join([t.__name__ for t in args.values()])}]"
wrapper = f"""class {name}(Base):
def __init__(self, **args):
class {name}_t(Base):\n {_n.join(mapped)}
self._d = dataclass({name}_t)(**args)
def __getitem__(self, i):
return self._d[i]
"""
exec(wrapper)
return locals().pop(name)
struct.ok = lambda **args: struct("_ok", **args)
struct.err = lambda **args: struct("_err", **args)
def tstruct(name, *arg_types):
params = ["".join(i) for _, i in zip(arg_types, product(ascii_lowercase+ascii_uppercase, repeat=3))]
expected_types = f"[{','.join([t.__name__ for t in arg_types])}]"
function = f"""class {name}_tuple(Base):
def __init__(self, {'=None,'.join(params)}=None):
if not any([type(k) == t for t, k in list(zip({expected_types}, [{','.join(params)}]))]):
raise TypeError("Incorrect types passed to {name} tuple struct")
if not all([e is not None for e in [{','.join(params)}]]):
raise TypeError("{name} member is missing required arguments")
self._a = ({','.join(params)},)
def __getitem__(self, i):
return self._a[i]
"""
exec(function)
return locals().pop(name + "_tuple")
tstruct.ok = lambda *args: tstruct("_ok", *args)
tstruct.err = lambda *args: tstruct("_err", *args)
@Supernoodles99
Copy link

ok

@EngineerRunner
Copy link

qhar

@SoupleCodes
Copy link

cul

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment