Last active
November 26, 2019 17:32
-
-
Save Solonarv/ca97f4a7e37eb99d1250db7ee4758e73 to your computer and use it in GitHub Desktop.
Augmented switch statements in python, using a hybrid context manager/decorator approach.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABCMeta, abstractmethod | |
def do_nothing(f): | |
"""A decorator that swallows the function and does nothing with it.""" | |
pass | |
def instantiate(callable): | |
"""A decorator that replaces a class definition with an instance of the class. | |
Useful for making opaque singleton objects. | |
>>> @instantiate | |
... class foo: pass | |
>>> foo | |
<__main__.foo object at 0xdeadbeef> | |
>>> type(foo) | |
<class 'foo'> | |
""" | |
return callable() | |
@instantiate | |
class fallthrough: | |
"""Special sentinel value to indicate that the switch() | |
statement should fall through to the next case.""" | |
pass | |
class SwitchFinished(Exception): | |
"""Internal exception to indicate that a matching case | |
alternative has been found and the switch() statement should end.""" | |
def __init__(self, val=None): | |
super().__init__("switch finished - you should never see this exception") | |
self.val=val | |
def _just_run_it(body): | |
ret = body() | |
if ret is not fallthrough: | |
raise SwitchFinished(ret) | |
class PatternFailedToMatch(Exception): | |
"""Indicates that no matching pattern was found. | |
Most pattern combinators will swallow, re-raise, | |
or otherwise interact with this. | |
Used to decide whether a case alternative's body | |
should run. | |
""" | |
pass | |
class Pattern(metaclass=ABCMeta): | |
"""A Pattern may be matched against a scrutinee, | |
returning some values if successful and throwing | |
PatternFailedToMatch otherwise. | |
""" | |
@abstractmethod | |
def __init__(self): pass | |
@abstractmethod | |
def match(self, scrutinee): pass | |
class Eq(Pattern): | |
"""A simple Pattern that checks whether the scrutinee | |
is equal to a given reference value. | |
""" | |
def __init__(self, reference): | |
self.reference = reference | |
def match(self, scrutinee): | |
if scrutinee != self.reference: | |
raise PatternFailedToMatch() | |
class Test(Pattern): | |
"""A more general form of Eq that applies an arbitrary | |
(boolean) test. | |
""" | |
def __init__(self, test): | |
self.test = test | |
def match(self, scrutinee): | |
if not self.test(scrutinee): | |
raise PatternFailedToMatch | |
class Type(Pattern): | |
"""Checks whether the scrutinee is an instance of the given type(s).""" | |
def __init__(self, *types): | |
if len(types) == 1: | |
types = types[0] | |
self.types = types | |
def match(self, scrutinee): | |
if isinstance(scrutinee, self.types): | |
return (scrutinee,) | |
else: | |
raise PatternFailedToMatch() | |
class Trivial(Pattern): | |
"""Trivial pattern that always matches.""" | |
def match(self, scrutinee): | |
return (scrutinee,) | |
class Apply(Pattern): | |
"""Applies some converter function to the scrutinee, matching if the | |
function returns without an exception. | |
Some exceptions thrown by the converter will be re-raised as | |
PatternFailedToMatch. This is intended to catch e.g. a ValueError | |
arising from matching int() against "spam". | |
""" | |
def __init__(self, func, swallow=Exception, *args, **kwargs): | |
self.func = func | |
self.swallow = swallow | |
self.args = args | |
self.kwargs = kwargs | |
def match(self, scrutinee): | |
try: | |
return (self.func(scrutinee),) | |
except self.swallow as exc: | |
raise PatternFailedToMatch from exc | |
class PatternCombinator(Pattern): | |
"""Base class for patterns that combine other sub-patterns somehow.""" | |
def __init__(self, *patterns): | |
self.patterns = patterns | |
class All(PatternCombinator): | |
"""Matches only if all sub-patterns match, returning a tuple of their values.""" | |
def match(self, scrutinee): | |
return tuple(pat.match(scrutinee) for pat in self.patterns) | |
class Any(PatternCombinator): | |
"""Matches if any of the sub-patterns matches, returning the first matching | |
sub-pattern's result.""" | |
def match(self, scrutinee): | |
for pat in self.patterns: | |
try: | |
return pat.match(scrutinee) | |
except PatternFailedToMatch: | |
continue | |
raise PatternFailedToMatch | |
class Tuple(PatternCombinator): | |
"""Matches a tuple of patterns against a tuple of values. The match is successful | |
only if the tuples are the same length; use lax=True to ignore lengths.""" | |
def __init__(self, *patterns, lax=False): | |
self.patterns = patterns | |
def match(self, scrutinee): | |
if len(self.patterns) != len(scrutinee) and not lax: | |
raise PatternFailedToMatch() | |
return tuple(pat.match(x) for pat,x in zip(self.patterns, scrutinee)) | |
class Chain(PatternCombinator): | |
"""Chains a number of patterns one after another.""" | |
def __init__(self, *patterns): | |
self.patterns = patterns | |
def match(self, scrutinee): | |
vals = (scrutinee,) | |
for pat in self.patterns: | |
vals = pat.match(*vals) | |
return vals | |
class switch: | |
"""C-style 'switch' statement, augmented with pattern matching. | |
Usage example: | |
>>> with switch(input()) as case: | |
... @case(Chain, Apply(int), Eq(1)) | |
... def _(): | |
... print("one") | |
... | |
... @case(int) | |
... def _(ival): # here ival is an int | |
... print(ival*ival) | |
... | |
... @case() # default case | |
... def _(): | |
... print("I didn't understand the input ;(") | |
While the need for dummy functions is unfortunate, it | |
can't be avoided: PEP 377, which proposes allowing | |
context managers to skip execution of the 'with' block, | |
was rejected and is unlikely to be implemented. Though | |
there are hacks to achieve the same behavior, they are | |
not portable. | |
""" | |
def __init__(self, scrutinee): | |
self.scrut = scrutinee | |
self.val = None | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_value, trace): | |
if isinstance(exc_value, SwitchFinished): | |
self.val = exc_value.val | |
return True | |
def __call__(self, pat=None, *args, **kwargs): | |
if pat is None: | |
return _just_run_it | |
elif isinstance(pat, type) and issubclass(pat, Pattern): | |
pat = pat(*args, **kwargs) | |
elif isinstance(pat, Pattern): | |
pass # don't need to update pat | |
elif callable(pat): | |
pat = Apply(pat, *args, **kwargs) | |
else: | |
pat = Eq(pat) | |
try: | |
vals = pat.match(self.scrut) | |
if vals is None: | |
return _just_run_it | |
def run_alt(body): | |
ret = body(*vals) | |
if ret is not fallthrough: | |
raise SwitchFinished(ret) | |
return run_alt | |
except PatternFailedToMatch: | |
return do_nothing | |
# Example usage | |
if __name__ == '__main__': | |
with switch(input()) as case: | |
@case(Chain, Apply(int), Eq(1)) | |
def _(): | |
print("one") | |
@case(int) | |
def _(x): | |
print(x*x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment