Skip to content

Instantly share code, notes, and snippets.

@dbpprt
Last active September 6, 2021 13:27
Show Gist options
  • Save dbpprt/f24643654a9a13f0137a28b9754001a9 to your computer and use it in GitHub Desktop.
Save dbpprt/f24643654a9a13f0137a28b9754001a9 to your computer and use it in GitHub Desktop.
A simple helper function to handle OOM errors while training with PyTorch. On my Windows system I sometimes get strange OutOfMemory errors in the middle of a training job. This wrapper tries to recover by freeing up as much memory as possible and splits the batches into half.

Usage

optimizer.zero_grad()

def criterion(output, target, steps, batch_size):
    loss = F.cross_entropy(output, target)
    loss.backward()
    return loss

output, loss, oom = utils.train_step(model, image, target,
                                criterion_handler=criterion,
                                handle_oom=True,
                                handle_out_dict=False)

optimizer.step()
def free_up_memory(reset_counters=False):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if reset_counters:
torch.cuda.reset_peak_memory_stats()
gc.collect()
def merge_dicts(*dicts):
res = collections.defaultdict(list)
for d in dicts:
for k, v in d.items():
if len(res[k]) == 0:
res[k] = v
else:
res[k] = torch.cat((res[k], v))
return res
# this function is designed to handle and recover from OOM errors while training
# it does so by freeing up as much memory as possible and splits the batch before retrying
# all gradients get accumulated while splitting the batches
# this is mainly a workaround for gpu vram fragmentation and helps to recover from OOM while starting the training
# until the memory consumption stabilizes or helps to handle random OOMs while training.
def train_step(model, input, target, criterion_handler, forward_handler=None, stack_loss=True, cat_out=True,
handle_out_dict=True, handle_oom=True,
steps=1):
oom = False
try:
if steps == 1:
if forward_handler is not None:
out = forward_handler(model, input)
else:
out = model(input)
if criterion_handler is not None:
loss = criterion_handler(out, target, steps, len(input))
else:
loss = None
return out, loss, False
batches = torch.split(input, len(input) // steps)
targets = torch.split(target, len(target) // steps)
print('If this error persists, you should lower the batch_size. The training will try to recover from OOM '
'errors, however it is highly inefficient!')
print(f'Retrying with batch size {len(target) // steps}')
free_up_memory()
results = []
losses = []
out = None
for i, mini_batch in enumerate(batches):
try:
if forward_handler is not None:
out = forward_handler(model, mini_batch)
else:
out = model(mini_batch)
if criterion_handler is not None:
loss = criterion_handler(out, targets[i], steps, len(mini_batch))
losses.append(loss)
except Exception as e:
if out is not None:
del out
del results
raise e
results.append(out)
if stack_loss:
losses = torch.stack(losses).mean()
if cat_out:
if handle_out_dict and isinstance(results[0], dict):
results = merge_dicts(*results)
else:
results = torch.cat(results)
return results, losses, steps > 1
except RuntimeError as e:
if "out of memory" in str(e):
if not handle_oom:
raise e
# this construct releases the exception which contains a stack frame
# otherwise we would always get OOM
print(f'Exception: {str(e)}')
print(
f'OOM occurred while trying to pass {len(input) // steps} into the model '
f'on device {next(model.parameters()).device}')
oom = True
del e
for p in model.parameters():
p.grad = None
else:
raise e
if oom:
return train_step(model, input, target, criterion_handler, forward_handler, stack_loss, cat_out,
handle_out_dict, handle_oom, steps * 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment