Last active July 25, 2024 22:56
TinyJit vis WIP
from tinygrad import Tensor, TinyJit, nn
from tinygrad.helpers import JIT
from tinygrad.nn.optim import SGD
from tinygrad.nn.state import get_parameters
class TinyNet:
def __init__(self):
self.l1 = nn.Linear(784, 128, bias=False)
self.l2 = nn.Linear(128, 10, bias=False)
def __call__(self, x):
x = self.l1(x)
x = x.leakyrelu()
x = self.l2(x)
return x
net = TinyNet()
optim = SGD(get_parameters(net))
JIT.value = 2
def train_step(batch, labels):
x = net(batch)
loss = x.sub(labels).square().mean()
ia = Tensor.randn(64, 784)
ib = Tensor.randn(64, 10)
with Tensor.train():
for _ in range(3):
r = train_step(ia, ib)
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.engine.realize import BufferCopy
from dataclasses import dataclass
from collections import defaultdict
import re
import os
import networkx as nx
def strip_coloring(text):
return re.compile(r'\x1B[@-_][0-?]*[ -/]*[@-~]').sub('', text)
class BufInfo:
loaded: bool = False
stored: bool = False
G = nx.DiGraph()
for ex, ei in enumerate(train_step.jit_cache):
# if isinstance(ei.prg, BufferCopy): continue # TODO:
# ei.prg.p.uops.print()
# Collect bufs load/store info
buf_info = defaultdict(lambda: BufInfo())
for uop in ei.prg.p.uops:
uop: UOp = uop
if uop.op is UOps.LOAD:
if uop.src[0].op is UOps.DEFINE_LOCAL: continue
assert uop.src[0].op is UOps.DEFINE_GLOBAL
buf_info[uop.src[0].arg[0]].loaded = True
if uop.op is UOps.STORE:
if uop.src[0].op is UOps.DEFINE_LOCAL: continue
assert uop.src[0].op is UOps.DEFINE_GLOBAL
buf_info[uop.src[0].arg[0]].stored = True
prg_id = f'{ex} {id(ei.prg)}'
G.add_node(prg_id, label=f'#{ex} ' + strip_coloring(ei.prg.display_name))
for bufx, buf in enumerate(ei.bufs):
if buf is None: continue
# if buf.size == 1: continue
label=str(buf).replace(':', '_'),
if buf_info[bufx].loaded: G.add_edge(id(buf), prg_id, label=str(bufx))
if buf_info[bufx].stored: G.add_edge(prg_id, id(buf), label=str(bufx))
for k, v in train_step.input_replace.items():
if k[0] != ex: continue
G.add_node(f'input_{v}', shape='rectangle')
G.add_edge(f'input_{v}', f'{k[0]} {id(train_step.jit_cache[k[0]].prg)}', label=str(k[1])) # TODO: dir
fn = '_jit'
nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
os.system(f'dot -Tsvg {fn}.dot -o {fn}.svg')
