Skip to content

Instantly share code, notes, and snippets.

Forked from albanD/
Created April 18, 2023 10:21
Show Gist options
  • Save nyngwang/01c4d16c6a0dfad062332017fdb1aee7 to your computer and use it in GitHub Desktop.
Save nyngwang/01c4d16c6a0dfad062332017fdb1aee7 to your computer and use it in GitHub Desktop.
PyTorch optimizer as hook
import torch
from torch import nn
from torch.optim.sgd import sgd
import gc
import objgraph
import weakref
def all():
# Only a subset of the args you could have
def set_sgd_hook(mod, p, lr, weight_decay, momentum):
buff_list = [None]
acc_grad = p.view_as(p).grad_fn.next_functions[0][0]
# The grad accumulator is a weak ref, so we need to keep it
# alive until the Tensor is alive.
# Store it on the module to avoid uncollectable ref-cycle
if not hasattr(mod, "_acc_grads"):
mod._acc_grads = []
def sgd_hook(*_unused):
# Update the params
sgd([p], [p.grad], buff_list, has_sparse_grad=False, foreach=False,
weight_decay=weight_decay, momentum=momentum, lr=lr, dampening=0,
nesterov=False, maximize=False)
# Free up grad memory
p.grad = None
# We should have an API for post hooks... But we don't have one right now
print("Startup", torch.cuda.memory_allocated())
mod = torch.nn.Linear(4, 1).cuda()
crit = nn.MSELoss()
for p in mod.parameters():
set_sgd_hook(mod, p, lr=.01, weight_decay=0., momentum=0.9)
# Make sure the keepalive works well
inp = torch.rand(10, 4, device="cuda")
target = torch.rand(10, 1, device="cuda")
for i in range(11):
def eval_one():
print(f"It {i}, {torch.cuda.memory_allocated()}")
pred = mod(inp)
loss = crit(pred, target)
print("Before backward", torch.cuda.memory_allocated())
print(f"Loss: {loss.item()}")
if i == 0:
print("No memory decrease due to optimizer state lazy initialization")
print("End of iteration", torch.cuda.memory_allocated())
return weakref.ref(mod.weight)
w = all()
print("Done, final memory", torch.cuda.memory_allocated())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment