Skip to content

Instantly share code, notes, and snippets.

View dbpprt's full-sized avatar

Dennis Bappert dbpprt

  • Amazon Web Services
  • Mainz
View GitHub Profile
import argparse
import logging
import os
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.data.distributed
from torch.nn.utils import clip_grad_norm_
@dbpprt
dbpprt / oom.md
Last active September 6, 2021 13:27
A simple helper function to handle OOM errors while training with PyTorch. On my Windows system I sometimes get strange OutOfMemory errors in the middle of a training job. This wrapper tries to recover by freeing up as much memory as possible and splits the batches into half.

Usage

optimizer.zero_grad()

def criterion(output, target, steps, batch_size):
    loss = F.cross_entropy(output, target)
    loss.backward()
    return loss