Last active
April 8, 2023 19:59
-
-
Save coinconclusive/d07bded9a0bac7c3db7a9ec7a2e8b75f to your computer and use it in GitHub Desktop.
lambda calculus in python 3.10
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# by coinconclusive. in public domain. | |
from __future__ import annotations | |
from dataclasses import dataclass | |
from typing import ClassVar | |
@dataclass | |
class Lam: | |
name: str | |
body: Expr | |
def __str__(self): | |
assert self.body != self | |
return f'λ{self.name}.{self.body}' | |
def __repr__(self): return f'Expr({str(self)})' | |
@dataclass | |
class Var: | |
name: str | |
def __str__(self): return f'{self.name}' | |
def __repr__(self): return f'Expr({str(self)})' | |
@dataclass | |
class Keep: | |
str_expand: ClassVar[bool] = False | |
name: str | |
expr: Expr | |
def __str__(self): | |
if Keep.str_expand: return f'[{self.name} = {str(self.expr)}]' | |
return f'[{self.name}]' | |
def __repr__(self): return f'Expr({str(self)})' | |
@dataclass | |
class App: | |
func: Expr | |
subs: Expr | |
def __str__(self): | |
assert self.func != self | |
assert self.subs != self | |
parens = lambda x, t: \ | |
f'({str(x)})' if isinstance(x, t) else str(x) | |
return f'({parens(self.func, Lam)} {parens(self.subs, Lam)})' | |
def __repr__(self): return f'Expr({str(self)})' | |
Expr = Lam | Var | App | Keep | |
import sys | |
def read_name(s: str) -> tuple[str, str]: | |
sys.stdout.flush() | |
i = 0 | |
while s[i] not in 'λ\\.():=' and not s[i].isspace(): i += 1 | |
if i == 0: raise ValueError('expected name.') | |
return s[:i], s[i:] | |
def parse_atom(s: str) -> tuple[Expr, str]: | |
s = s.lstrip() | |
if s[0] == 'λ' or s[0] == '\\': | |
name, s = read_name(s[1:]) | |
s = s.lstrip() | |
if s[0] != '.': raise ValueError('expected "." after lambda.') | |
e, s = parse_expr(s[1:]) | |
return Lam(name, e), s | |
if s[0] == '(': | |
e, s = parse_expr(s[1:]) | |
s = s.lstrip() | |
if s[0] != ')': raise ValueError('expected matching ")" after "(".') | |
return e, s[1:] | |
name, s = read_name(s) | |
return Var(name), s | |
def parse_expr(s: str) -> tuple[Expr, str]: | |
a, s = parse_atom(s) | |
while s: | |
try: | |
b, s = parse_atom(s) | |
a = App(a, b) | |
except: break | |
return a, s | |
@dataclass | |
class PragmaStmt: | |
name: str | |
params: list[Expr] | |
@dataclass | |
class AssignStmt: | |
name: str | |
value: Expr | |
Stmt = PragmaStmt | |
def parse_stmt(s: str) -> tuple[Stmt | AssignStmt, str]: | |
s = s.lstrip() | |
if s[0] == ':': | |
name, s = read_name(s[1:].lstrip()) | |
s = s.lstrip() | |
params: list[Expr] = [] | |
while s[0] != '.': | |
e, s = parse_atom(s) | |
params.append(e) | |
s = s.lstrip() | |
return PragmaStmt(name, params), s | |
else: | |
name, s = read_name(s) | |
s = s.lstrip() | |
if s[0] == ':': | |
s = s[1:] | |
keep = True | |
else: keep = False | |
if s[0] != '=': raise ValueError('expected "=" after name.') | |
expr, s = parse_expr(s[1:]) | |
return AssignStmt(name, Keep(name, expr) if keep else expr), s | |
def parse_prog(s: str) -> tuple[dict[str, Expr], list[Stmt]]: | |
if not s.strip(): return {}, [] | |
stmts: list[Stmt] = [] | |
names: dict[str, Expr] = {} | |
while True: | |
try: | |
stmt, s = parse_stmt(s) | |
match stmt: | |
case AssignStmt(name, value): names[name] = value | |
case _: stmts.append(stmt) | |
s = s.lstrip() | |
if s[0] != '.': raise ValueError('expected "." after statement.') | |
s = s[1:] | |
except: break | |
return names, stmts | |
def replace(e: Expr, n: str, s: Expr) -> Expr: | |
match e: | |
case Var(name) if name == n: return s | |
case Lam(name, _) if name == n: return e | |
case Var(_): return e | |
case Lam(name, body): | |
return Lam(name, replace(body, n, s)) | |
case Keep(name, _) if name == n: return s | |
case Keep(_, expr): | |
r = replace(expr, n, s) | |
if r == expr: return e | |
return r | |
case App(func, subs): | |
return App(replace(func, n, s), | |
replace(subs, n, s)) | |
@dataclass | |
class Context: | |
@dataclass | |
class NatLitInfo: | |
zero: Keep | |
succ: Keep | |
names: dict[str, Expr] | |
natlit: NatLitInfo | None = None | |
def nat_literal(i: int, info: Context.NatLitInfo) -> Expr: | |
if i == 0: return Keep('0', info.zero) | |
return Keep(str(i), App(info.succ.expr, nat_literal(i - 1, info))) | |
def resolve_var(name: str, ctx: Context) -> Expr: | |
if name in ctx.names: | |
return run(ctx.names[name], ctx) | |
if name.isdigit() and ctx.natlit is not None: | |
return nat_literal(int(name), ctx.natlit) | |
return Var(name) | |
def is_transitive_keep(e: Expr, ctx: Context) -> bool: | |
match e: | |
case Keep(_, _): return True | |
case Var(name): | |
return is_transitive_keep(resolve_var(name, ctx), ctx) | |
case _: return False | |
def run(e: Expr, ctx: Context, indent: int = 0) -> Expr: | |
match e: | |
case Var(name): return resolve_var(name, ctx) | |
case App(func, subs): | |
print(' ' * indent + 'apply', func, '<-', subs) | |
f = run(func, ctx, indent + 1) | |
print(' ' * indent + ' -1->', f, '<-', subs) | |
while isinstance(f, Keep): | |
if ctx.natlit is not None and is_transitive_keep(subs, ctx) \ | |
and f.name == ctx.natlit.succ.name and subs.name.isdigit(): | |
print(' ' * indent + 'succ', subs.name, end=' ') | |
r = nat_literal(int(subs.name) + 1, ctx.natlit) | |
print('=', r) | |
return r | |
f = run(f.expr, ctx, indent + 1) | |
assert isinstance(f, Lam) | |
print(' ' * indent + ' -2->', f, '<-', subs) | |
r = run(replace(f.body, f.name, subs), ctx, indent + 1) | |
print(' ' * indent + ' -3->', r) | |
return r | |
case e: return e | |
if __name__ == '__main__': | |
ctx = Context({}) | |
ctx.names, stmts = parse_prog(''' | |
id = λx.x. | |
flip := λx.λy.y x. | |
zero := λz.λs.z. | |
succ := λx.λz.λs.s x. | |
:natlit zero succ. | |
add := λa. a id λa'.λb. add a' (succ b). | |
main = add 1 1. | |
''') | |
for name, expr in ctx.names.items(): | |
print(name, '=', expr) | |
for stmt in stmts: | |
match stmt: | |
case PragmaStmt('natlit', [zero, succ]): | |
zero = run(zero, ctx) | |
succ = run(succ, ctx) | |
assert isinstance(zero, Keep) | |
assert isinstance(succ, Keep) | |
ctx.natlit = Context.NatLitInfo(zero, succ) | |
case PragmaStmt(name, _): | |
raise ValueError(f'unknown pragma: "{name}"') | |
print(run(ctx.names['main'], ctx)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment