Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created August 16, 2024 03:03
Show Gist options
  • Save ezyang/f466d9c23452d968c0a3b17eb845d07e to your computer and use it in GitHub Desktop.
Save ezyang/f466d9c23452d968c0a3b17eb845d07e to your computer and use it in GitHub Desktop.
import json
import sys
import re
pattern = re.compile(r'^(\s*)if TEMPLATE[^:]*:\s*\n((?:\1\s{4}.*\n|\n)*?)\1else:\s*\n(?:\1\s{4}.*\n|\n)*?(\1(?=\S|$))', re.MULTILINE)
def replace(match):
indent = match.group(1)
true_branch = match.group(2)
true_branch = re.sub(r'^\s{' + str(len(indent) + 4) + '}', indent, true_branch, flags=re.MULTILINE)
return true_branch + match.group(3)
test = """
@torch.compile(backend="eager", fullgraph=True)
def cf_check(x):
u0, u1 = x.tolist()
if TEMPLATE and False:
pass
else:
torch._check(u0 * 2 == u1 * 3)
# Do not modify the code below here (imagine it's in framework code you can't edit)
# NB: In future exercises, we'll use force_guard as a shorthand for this pattern.
if u0 * 2 == u1 * 3:
return torch.tensor(True)
else:
return torch.tensor(False)
@run_test
def test_check():
assert cf_check(torch.tensor([12, 8])).item()
"""
#print(pattern.sub(replace, test))
#sys.exit(0)
def process_notebook(notebook_path):
with open(notebook_path, 'r') as file:
notebook = json.load(file)
for cell in notebook['cells']:
if cell['cell_type'] == 'code':
source = cell['source']
if isinstance(source, list):
source = ''.join(source)
cell['source'] = pattern.sub(replace, source)
#cell['source'] = source
with open(f"new_{notebook_path}", 'w') as file:
json.dump(notebook, file, indent=1)
# Usage
file_path = 'puzzlers-aug15.ipynb'
process_notebook(file_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment