Skip to content

Instantly share code, notes, and snippets.

@Alwinfy
Created August 31, 2024 04:15
Show Gist options
  • Save Alwinfy/a778ef86cdad44e0bebbf6e1304b8bbf to your computer and use it in GitHub Desktop.
Save Alwinfy/a778ef86cdad44e0bebbf6e1304b8bbf to your computer and use it in GitHub Desktop.
def ziprec(parms, vals):
if isinstance(parms, list):
for p, v in zip(parms, vals):
yield from ziprec(p, v)
else: yield parms, vals
class Env(dict):
__slots__ = ["parent"]
def __init__(self, l, parent):
super().__init__(l)
self.parent = parent
def find(self, n):
return self[n] if n in self else self.parent.find(n)
class Lam:
__slots__ = ["args", "body", "env"]
def __init__(self, env, args):
self.args = args[1]
self.body = args[2]
self.env = env
send = lambda self, _: self.__call__
def __call__(self, args):
bind = ziprec(self.args, args)
frame = Env(bind, self.env)
return ev(frame, self.body)
bind1 = type((lambda: None).__get__(0))
def coerce_corot(co):
return lambda *args: (None, co(*args).send)
def ev_(env, dat):
if isinstance(dat, str): return env.find(dat), None
if callable(dat):
return lambda args: (dat(*args), None), None
if isinstance(dat, list):
head = dat[0]
if isinstance(head, str) and head in sp_forms:
return sp_forms[head](env, dat)
fn = yield ev(env, head)
acc = []
l, i = len(dat), 1
while i < l:
val = yield ev(env, dat[i])
acc.append(val)
i += 1
return fn(acc)
return dat, None
ev = coerce_corot(ev_)
def ev_cond(env, clauses):
l = len(clauses) - 1
i = 1
while i < l:
cond = yield ev(env, clauses[i])
if cond:
return ev(env, clauses[i + 1])
i += 2
if i == l: return ev(env, clauses[i])
return None, None
def ev_do(env, clauses):
l = len(clauses) - 1
i = 1
while i < l:
_ = yield ev(env, clauses[i])
i += 1
return ev(env, clauses[i])
def ev_quote(env, val):
yield from ()
return val[1], None
sp_forms = {
"lambda": lambda env, dat: (Lam(env, dat), None),
"cond": coerce_corot(ev_cond),
"do": coerce_corot(ev_do),
"let": lambda env, args: ev(env, [["lambda", args[1], args[3]]] + args[2]),
"quote": lambda env, dat: (dat[1], None),
}
def trampoline(val):
dat, first = val
stack = [first]
while stack:
try:
dat, next_cmp = stack[-1](dat)
stack.append(next_cmp)
except StopIteration as si:
stack.pop()
dat, next_cmp = si.value
if next_cmp is not None:
stack.append(next_cmp)
return dat
def evaluate(atom):
return trampoline(ev(Env({
"apply": lambda args: args[0](args[1])
}, None), atom))
def wrap_fun(lam):
# Convert a Lam to python calling convention
def call(*args):
return trampoline(lam(args))
return call
def foreach(fn, l):
for i in l:
fn(i)
def foldl(fn, acc, it):
for i in it:
acc = fn(acc, i)
return acc
def join(d1, d2):
k1, k2, v = d2
d1.setdefault(k1, {})[k2] = v
return d1
# accessors for perf
pair = lambda x, y: (x, y)
snd = lambda x: x[1]
code = \
["let",
["Ymemo", "rev_graph", "inf", "="],
[["lambda", ["f"],
["let", ["table"], [[dict]],
[["lambda", ["m"], ["m", "m"]],
["lambda", ["m"], ["f", ["lambda", "args",
["let", ["hashable"], [[tuple, "args"]],
["cond",
[dict.__contains__, "table", "hashable"], [dict.__getitem__, "table", "hashable"],
["let", ["result"], [["apply", ["m", "m"], "args"]],
["do",
[dict.__setitem__, "table", "hashable", "result"],
"result"]]]]]]]]]],
["lambda", ["graph"],
["let", ["rev_graph"], [[dict]],
["do",
[foreach,
[wrap_fun, ["lambda", [["node", "connections"]],
[foreach,
[wrap_fun, ["lambda", [["neighbor", "weight"]],
[dict.__setitem__,
[dict.setdefault, "rev_graph", "neighbor", [dict]],
"node",
"weight"]]],
[dict.items, "connections"]]]],
[dict.items, "graph"]],
"rev_graph"]]],
1e309,
lambda x, y: x == y],
["lambda", ["graph", "root"],
["let", ["rgraph"], [["rev_graph", "graph"]],
["let", ["walk"],
[["Ymemo", ["lambda", ["walk"],
["lambda", ["node", "odd"],
["cond", ["=", "node", "root"],
["cond", "odd", [pair, 0, "inf"], [pair, 1, 0]],
["let", ["neighbors"],
[[list,
[map,
[wrap_fun, ["lambda", [["neighbor", "weight"]],
["let", [["count", "distance"]], [["walk", "neighbor", [bool, [(1).__and__, [int.__add__, "weight", "odd"]]]]],
[pair, "count", [sum, [pair, "weight", "distance"]]]]]],
[dict.items, [dict.__getitem__, "rgraph", "node"]]]]],
["cond", "neighbors",
["let", ["min_dist"], [[min, [map, ["quote", snd], "neighbors"]]],
[pair,
[sum, [map,
[wrap_fun, ["lambda", [["count", "dist"]],
["cond", ["=", "dist", "min_dist"],
"count",
0]]],
"neighbors"]],
"min_dist"]],
[pair, 0, "inf"]]]]]]]],
[dict,
[map,
[wrap_fun, ["lambda", ["key"],
[pair, "key", [tuple.__getitem__, ["walk", "key", False], 0]]]],
"graph"]]]]]]
num_opt_even_weight_paths = wrap_fun(evaluate(code))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment