Skip to content

Instantly share code, notes, and snippets.

@qexat
Last active September 7, 2024 20:23
Show Gist options
  • Save qexat/6b04fc28146feabcbe18e1190371607b to your computer and use it in GitHub Desktop.
Save qexat/6b04fc28146feabcbe18e1190371607b to your computer and use it in GitHub Desktop.
make pyright cook your CPU speedrun any%
# ruff: noqa: DOC201, DOC501
"""
Combinatorial arithmetic
"""
from __future__ import annotations
import abc
import typing
import attrs
class TypeVisitor[R_co](typing.Protocol):
"""
Represents a visitor of the Type tree.
"""
def visit_zero_type(self, typ: Zero) -> R_co:
"""
Visit the Zero type.
"""
...
def visit_unit_type(self, typ: Unit) -> R_co:
"""
Visit the Unit type.
"""
...
def visit_atomic_type(self, typ: Atomic[typing.LiteralString]) -> R_co:
"""
Visit an Atomic type.
"""
...
def visit_negation_type(self, typ: Negation[Type]) -> R_co:
"""
Visit a Negation type.
"""
...
def visit_product_type(self, typ: Product[Type, Type]) -> R_co:
"""
Visit a Product type.
"""
...
def visit_sum_type(self, typ: Sum[Type, Type]) -> R_co:
"""
Visit a Sum type.
"""
...
@attrs.frozen
class TypeBase(abc.ABC):
"""
Base of the type tree.
"""
def __neg__[T: Type](self: T) -> Negation[T]: # pyright: ignore[reportGeneralTypeIssues]
return Negation(self)
def __add__[T0: Type, T1: Type](self: T0, other: T1, /) -> Sum[T0, T1]: # pyright: ignore[reportGeneralTypeIssues]
return Sum(self, other)
def __mul__[T0: Type, T1: Type](self: T0, other: T1, /) -> Product[T0, T1]: # pyright: ignore[reportGeneralTypeIssues]
return Product(self, other)
@abc.abstractmethod
def accept[R](self, visitor: TypeVisitor[R]) -> R:
"""
Accept of a type visitor and return the result.
"""
@attrs.frozen(init=False)
@typing.final
class Zero(TypeBase):
"""
Represents the Zero type.
"""
def __init__(self) -> None:
message = "the zero type is uninstantiable"
raise TypeError(message)
@typing.override
def accept[R](self, visitor: TypeVisitor[R]) -> R:
return visitor.visit_zero_type(self)
# TODO: make it a singleton
@attrs.frozen
@typing.final
class Unit(TypeBase):
"""
Represents the Unit type.
"""
@typing.override
def accept[R](self, visitor: TypeVisitor[R]) -> R:
return visitor.visit_unit_type(self)
@attrs.frozen
@typing.final
class Atomic[Name: typing.LiteralString](TypeBase):
"""
Represents an Atomic type.
"""
name: Name
@typing.override
def accept[R](self, visitor: TypeVisitor[R]) -> R:
return visitor.visit_atomic_type(self)
@attrs.frozen
@typing.final
class Negation[T: Type](TypeBase):
"""
Represents a Negation type.
"""
typ: T
@typing.override
def accept[R](self, visitor: TypeVisitor[R]) -> R:
return visitor.visit_negation_type(self)
@attrs.frozen
@typing.final
class Product[T0: Type, T1: Type](TypeBase):
"""
Represents a Product type.
"""
first: T0
second: T1
@typing.override
def accept[R](self, visitor: TypeVisitor[R]) -> R:
return visitor.visit_product_type(
typing.cast(Product[Type, Type], self),
)
@attrs.frozen
@typing.final
class Sum[T0: Type, T1: Type](TypeBase):
"""
Represents a Sum type.
"""
first: T0
second: T1
@typing.override
def accept[R](self, visitor: TypeVisitor[R]) -> R:
return visitor.visit_sum_type(typing.cast(Sum[Type, Type], self))
type Type = (
Zero
| Unit
| Atomic[typing.LiteralString]
| Negation["Type"]
| Product["Type", "Type"]
| Sum["Type", "Type"]
)
def _unsafe_zero() -> Zero:
"""
Instantiate a Zero through unsound means.
Be EXTREMELY CAREFUL with using this instance.
Returns
-------
Zero
"""
return TypeBase.__new__(Zero)
def identity[T: Type](typ: T) -> T:
"""
Identity combinator.
"""
return typ
def sum_id_left_intro[T: Type](typ: T) -> Sum[Zero, T]:
"""
Rewriting rule to introduce the sum left identity.
"""
return _unsafe_zero() + typ
def sum_id_left_elim[T: Type](typ: Sum[Zero, T]) -> T:
"""
Rewriting rule to eliminate the sum left identity.
"""
match typ:
case Sum(Zero(), a):
return a
def sum_comm[T0: Type, T1: Type](typ: Sum[T0, T1]) -> Sum[T1, T0]:
"""
Commutativity sum combinator.
"""
match typ:
case Sum(a, b):
return b + a
def sum_assoc_left[T0: Type, T1: Type, T2: Type](
typ: Sum[T0, Sum[T1, T2]],
) -> Sum[Sum[T0, T1], T2]:
"""
Left-associativity sum combinator.
"""
match typ:
case Sum(a, Sum(b, c)):
return (a + b) + c
def sum_assoc_right[T0: Type, T1: Type, T2: Type](
typ: Sum[Sum[T0, T1], T2],
) -> Sum[T0, Sum[T1, T2]]:
"""
Right-associativity sum combinator.
"""
match typ:
case Sum(Sum(a, b), c):
return a + (b + c)
def sum_eta[T: Type](typ: Zero) -> Sum[T, Negation[T]]: # noqa: ARG001
"""
Eta sum combinator.
"""
raise RuntimeError("unreachable")
def sum_eps[T: Type](typ: Sum[T, Negation[T]]) -> Zero: # noqa: ARG001
"""
Eps sum combinator.
"""
return _unsafe_zero()
def prod_id_left_intro[T: Type](typ: T) -> Product[Unit, T]:
"""
Rewriting rule to introduce the product left identity.
"""
return Unit() * typ
def prod_id_left_elim[T: Type](typ: Product[Unit, T]) -> T:
"""
Rewriting rule to eliminate the product left identity.
"""
match typ:
case Product(Unit(), a):
return a
def prod_comm[T0: Type, T1: Type](typ: Product[T0, T1]) -> Product[T1, T0]:
"""
Commutativity product combinator.
"""
match typ:
case Product(a, b):
return b * a
def prod_assoc_left[T0: Type, T1: Type, T2: Type](
typ: Product[T0, Product[T1, T2]],
) -> Product[Product[T0, T1], T2]:
"""
Left-associative product combinator.
"""
match typ:
case Product(a, Product(b, c)):
return (a * b) * c
def prod_assoc_right[T0: Type, T1: Type, T2: Type](
typ: Product[Product[T0, T1], T2],
) -> Product[T0, Product[T1, T2]]:
"""
Right-associative product combinator.
"""
match typ:
case Product(Product(a, b), c):
return a * (b * c)
def prod_zero_left_intro[T: Type](typ: Zero) -> Product[Zero, T]: # noqa: ARG001
"""
Rewriting rule to introduce a left-zero-product term.
"""
raise RuntimeError("unreachable")
def prod_zero_left_elim[T: Type](typ: Product[Zero, T]) -> Zero: # noqa: ARG001
"""
Rewriting rule to eliminate a left-zero-product term.
"""
return Zero()
def sum_distrib_prod_left[T0: Type, T1: Type, T2: Type](
typ: Product[Sum[T0, T1], T2],
) -> Sum[Product[T0, T2], Product[T1, T2]]:
"""
Rewriting rule to distribute a sum product.
"""
match typ:
case Product(Sum(a, b), c):
return (a * c) + (b * c)
class PropVisitor[R_co](typing.Protocol):
"""
Represents a visitor of the proposition tree.
"""
def visit_equality(self, prop: Equality[Type, Type]) -> R_co:
"""
Visit an equality proposition.
"""
...
def visit_conjunction(self, prop: Conjunction[Prop, Prop]) -> R_co:
"""
Visit a proposition conjunction.
"""
...
def visit_disjunction(self, prop: Disjunction[Prop, Prop]) -> R_co:
"""
Visit a proposition disjunction.
"""
...
def visit_implication(self, prop: Implication[Prop, Prop]) -> R_co:
"""
Visit a proposition implication.
"""
...
@attrs.frozen
class PropBase(abc.ABC):
"""
Base of the proposition tree.
"""
def __and__[P0: Prop, P1: Prop](
self: P0, # pyright: ignore[reportGeneralTypeIssues]
other: P1,
/,
) -> Conjunction[P0, P1]:
return Conjunction(self, other)
def __or__[P0: Prop, P1: Prop](
self: P0, # pyright: ignore[reportGeneralTypeIssues]
other: P1,
/,
) -> Disjunction[P0, P1]:
return Disjunction(self, other)
def __rshift__[P0: Prop, P1: Prop](
self: P0, # pyright: ignore[reportGeneralTypeIssues]
other: P1,
/,
) -> Implication[P0, P1]:
return Implication(self, other)
@abc.abstractmethod
def accept[R](self, visitor: PropVisitor[R]) -> R:
"""
Accept a proposition visitor and return the result.
"""
@attrs.frozen
class Equality[T0: Type, T1: Type](PropBase):
"""
Represents an equality.
"""
left: T0
right: T1
@typing.override
def accept[R](self, visitor: PropVisitor[R]) -> R:
return visitor.visit_equality(self)
@attrs.frozen
class Conjunction[P0: Prop, P1: Prop](PropBase):
"""
Represents a proposition conjunction.
"""
left: P0
right: P1
@typing.override
def accept[R](self, visitor: PropVisitor[R]) -> R:
return visitor.visit_conjunction(self)
@attrs.frozen
class Disjunction[P0: Prop, P1: Prop](PropBase):
"""
Represents a proposition disjunction.
"""
left: P0
right: P1
@typing.override
def accept[R](self, visitor: PropVisitor[R]) -> R:
return visitor.visit_disjunction(
typing.cast(Disjunction[Prop, Prop], self),
)
@attrs.frozen
class Implication[P0: Prop, P1: Prop](PropBase):
"""
Represents a proposition implication.
"""
left: P0
right: P1
@typing.override
def accept[R](self, visitor: PropVisitor[R]) -> R:
return visitor.visit_implication(
typing.cast(Implication[Prop, Prop], self),
)
type Prop = (
Equality[Type, Type]
| Conjunction[Prop, Prop]
| Disjunction[Prop, Prop]
| Implication[Prop, Prop]
)
def sum_eq_inj[T0: Type, T1: Type, T2: Type, T3: Type](
left: Sum[T0, T1],
right: Sum[T2, T3],
) -> Implication[
Conjunction[Equality[T0, T2], Equality[T1, T3]],
Equality[Sum[T0, T1], Sum[T2, T3]],
]:
r"""
Sum equality injection.
∀ a, b, c, d : Type, a = c /\ b = d -> a + b = c + d
"""
match (left, right):
case Sum(a, b), Sum(c, d):
return Implication(
Conjunction(Equality(a, c), Equality(b, d)),
Equality(left, right),
)
def main() -> None:
"""
Entry point of the program.
"""
foo = Atomic("foo")
zero = sum_eps(foo + (-foo))
typing.reveal_type(zero) # Revealed type is "Zero"
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment