Created
May 3, 2020 19:41
-
-
Save calebh/fd7ff5014d9925b0e2d7473d9d4d5131 to your computer and use it in GitHub Desktop.
Bizzare PyTorch memory issue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import pickle | |
import os | |
import random | |
import sourcenode | |
import torch.nn as nn | |
import torch.utils.checkpoint | |
import torch.utils.data | |
import math | |
import objprocessor | |
import ctypes | |
import gc | |
# Maximum number of asm instructions: 12216 | |
# Maximum number of c nodes: 74035 | |
if os.name == 'nt': | |
ctypes.cdll.LoadLibrary('caffe2_nvrtc.dll') | |
torch.manual_seed(0) | |
random.seed(0) | |
MAX_INPUT_SIZE = 5000 | |
use_cuda = True | |
if use_cuda: | |
current_device = torch.device('cuda') | |
else: | |
current_device = torch.device('cpu') | |
class PickledDataset(torch.utils.data.Dataset): | |
def __init__(self, pickled_data_directory): | |
self.data_file_paths = [] | |
for dir_name, subdir_list, file_list in os.walk(pickled_data_directory): | |
for file_name in file_list: | |
(name, extension) = os.path.splitext(file_name) | |
if extension == ".pickle": | |
self.data_file_paths.append(os.path.join(dir_name, file_name)) | |
self.data_file_paths.sort() | |
def __getitem__(self, item): | |
with open(self.data_file_paths[item], 'rb') as f: | |
return pickle.load(f) | |
def __len__(self): | |
return len(self.data_file_paths) | |
neg_one = torch.tensor([[-1.0]]) | |
padding_id = torch.tensor([sourcenode.Padding().get_id()], dtype=torch.long) | |
def collator(batch): | |
# an ASM tensor has shape (S, objprocessor.MAX_INSTR_SIZE) | |
asm_tensors = [data[0] for data in batch] | |
max_num_asm = max([ten.shape[0] for ten in asm_tensors]) | |
asm_mask = torch.zeros((len(asm_tensors), max_num_asm), dtype=torch.bool) | |
for i in range(len(asm_tensors)): | |
ten = asm_tensors[i] | |
asm_mask[i, ten.shape[0]:max_num_asm] = True | |
asm_tensors = [torch.cat([ten, neg_one.expand((max_num_asm - ten.shape[0], ten.shape[1]))]) for ten in asm_tensors] | |
# a C tensor has shape (T, sourcenode.NODE_ID_END) | |
c_tensors = [data[1] for data in batch] | |
max_num_c = max([ten.shape[0] for ten in c_tensors]) | |
c_mask = torch.zeros((len(c_tensors), max_num_c), dtype=torch.bool) | |
for i in range(len(c_tensors)): | |
ten = c_tensors[i] | |
c_mask[i, ten.shape[0]:max_num_c] = True | |
c_tensors = [torch.cat([ten, padding_id.expand(max_num_c - ten.shape[0])]) for ten in c_tensors] | |
return (torch.stack(asm_tensors, dim=1), torch.stack(c_tensors, dim=1), asm_mask, c_mask) | |
dataset = PickledDataset("../pickledtrainingdata") | |
ten_percent = int(0.1 * len(dataset)) | |
training_size = len(dataset) - 2 * ten_percent | |
(testing_dataset, validation_dataset, training_dataset) = torch.utils.data.random_split(dataset, [ten_percent, ten_percent, training_size]) | |
batch_size = 4 | |
training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, collate_fn=collator, shuffle=True) | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
x = x + self.pe[:x.size(0), :] | |
return self.dropout(x) | |
class TransformerModel(nn.Module): | |
def __init__(self, nhead=8, dim_feedforward=1024, num_layers=6, dropout=0.1): | |
super(TransformerModel, self).__init__() | |
self.d_model = 512 | |
self.input_embedding1 = nn.Linear(objprocessor.MAX_INSTR_SIZE, self.d_model) | |
self.relu1 = nn.ReLU() | |
self.input_embedding2 = nn.Linear(self.d_model, self.d_model) | |
self.relu2 = nn.ReLU() | |
self.input_embedding3 = nn.Linear(self.d_model, self.d_model) | |
self.output_embedding1 = nn.Linear(sourcenode.NODE_ID_END, self.d_model) | |
self.pos_encoder = PositionalEncoding(self.d_model, dropout, max_len=MAX_INPUT_SIZE) | |
self.transformer = nn.Transformer(d_model=self.d_model, nhead=nhead, dim_feedforward=dim_feedforward, | |
num_encoder_layers=num_layers, num_decoder_layers=num_layers, dropout=dropout) | |
self.output_linear = nn.Linear(self.d_model, sourcenode.NODE_ID_END) | |
# dummy_tensor must have requires_grad=True and is used to fool the checkpoint system into | |
# computing the gradient | |
def forward(self, src, tgt, src_padding_mask=None, tgt_padding_mask=None): | |
# src is a tensor of shape (S, N, objprocessor.MAX_INSTR_SIZE), where N = batch size, | |
# S = sequence length of input, objprocessor.MAX_INSTR_SIZE = channel size | |
# tgt is a tensor of shape (T, N, sourcenode.NODE_ID_END), where N = batch size, | |
# T = sequence length of output, sourcenode.NODE_ID_END = channel size | |
tgt_len = tgt.shape[0] | |
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_len) | |
tgt_mask = tgt_mask.to(tgt.device) | |
src = self.input_embedding1(src) | |
src = self.relu1(src) | |
src = self.input_embedding2(src) | |
src = self.relu2(src) | |
src = self.input_embedding3(src) | |
src = self.pos_encoder(src) | |
tgt = self.output_embedding1(tgt) | |
tgt = self.pos_encoder(tgt) | |
def wrapper(src, tgt, tgt_mask, src_padding_mask, tgt_padding_mask): | |
return self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask, tgt_key_padding_mask=tgt_padding_mask) | |
output = torch.utils.checkpoint.checkpoint(wrapper, src, tgt, tgt_mask, src_padding_mask, tgt_padding_mask) | |
output = self.output_linear(output) | |
return output | |
model = TransformerModel(dropout=0.1) | |
#model = nn.DataParallel(model, dim=1) | |
if use_cuda: | |
model.cuda() | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | |
criterion = nn.CrossEntropyLoss() | |
model.train() | |
MAX_NUM_EPOCHS = 50 | |
num_batches = math.ceil(len(training_dataset) / batch_size) | |
MODEL_SAVE_DIR = "models" | |
#MODEL_SAVE_DIR = "/content/drive/My Drive/models" | |
#torch.save(model.state_dict(), MODEL_SAVE_DIR + "/model_0.pt") | |
for epoch in range(MAX_NUM_EPOCHS): | |
i = 0 | |
loss_sum = 0.0 | |
loss_n = 0 | |
exp_ma_loss = None | |
for (src, tgt_indices, src_padding_mask, tgt_padding_mask) in training_loader: | |
# src has shape (S, N, objprocessor.MAX_INSTR_SIZE) where S = sequence length of input, N = batch size, | |
# objprocessor.MAX_INSTR_SIZE = number of input channels | |
# tgt_indices has shape (T, N) where T = sequence length of output, N = batch size | |
# src_padding mask has shape (N, S) | |
# tgt_padding mask has shape (N, T) | |
if src.shape[0] > MAX_INPUT_SIZE or tgt_indices.shape[0] > MAX_INPUT_SIZE: | |
print("Input " + str(i) + " / " + str(num_batches) + " exceeded maximum input size") | |
continue | |
try: | |
# Really make sure that no local variables escape the scope | |
def run(): | |
global src, tgt_indices, src_padding_mask, tgt_padding_mask, i, exp_ma_loss, loss_sum, loss_n | |
if use_cuda: | |
src = src.cuda() | |
tgt_indices = tgt_indices.cuda() | |
src_padding_mask = src_padding_mask.cuda() | |
tgt_padding_mask = tgt_padding_mask.cuda() | |
# Convert the indices to one hot vectors for use as input in the transformer model | |
tgt = torch.zeros((tgt_indices.shape[0], tgt_indices.shape[1], sourcenode.NODE_ID_END), device=current_device) | |
r1 = torch.arange(0, tgt_indices.shape[0], device=current_device).unsqueeze(1).expand_as(tgt_indices) | |
r2 = torch.arange(0, tgt_indices.shape[1], device=current_device).unsqueeze(0).expand_as(tgt_indices) | |
tgt[r1, r2, tgt_indices] = 1.0 | |
del r1 | |
del r2 | |
optimizer.zero_grad() | |
output = model(src, tgt, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask) | |
del src | |
del tgt | |
del src_padding_mask | |
# Now remove the output that corresponds to padding entries | |
tgt_indices_padding_mask = (~tgt_padding_mask).t() | |
del tgt_padding_mask | |
tgt_indices_no_padding = torch.masked_select(tgt_indices, tgt_indices_padding_mask) | |
del tgt_indices | |
output_padding_mask = tgt_indices_padding_mask.unsqueeze(2) | |
output = torch.masked_select(output, output_padding_mask).view(-1, sourcenode.NODE_ID_END) | |
del output_padding_mask | |
print(torch.argmax(output.view(-1, sourcenode.NODE_ID_END), dim=1)) | |
loss = criterion(output, tgt_indices_no_padding) | |
del tgt_indices_no_padding | |
del output | |
print("Loss " + str(epoch) + " - " + str(i) + " / " + str(num_batches), loss.item()) | |
if exp_ma_loss is None: | |
exp_ma_loss = loss.item() | |
else: | |
coefficient = 0.001 | |
exp_ma_loss = coefficient * loss.item() + (1.0 - coefficient) * exp_ma_loss | |
if i % 2500 == 0: | |
print("Exp MA Loss: " + str(epoch) + " - " + str(i) + " / " + str(num_batches) + " - " + str(exp_ma_loss)) | |
loss_sum += loss.item() | |
loss_n += 1 | |
loss.backward() | |
del loss | |
optimizer.step() | |
if i % 20000 == 0: | |
torch.save(model.state_dict(), MODEL_SAVE_DIR + "/model_" + str(epoch) + "_" + str(i) + ".pt") | |
i += 1 | |
run() | |
gc.collect(generation=0) | |
gc.collect(generation=1) | |
gc.collect(generation=2) | |
torch.cuda.ipc_collect() | |
torch.cuda.synchronize(device=None) | |
if use_cuda: | |
torch.cuda.empty_cache() | |
except RuntimeError as exc: | |
if use_cuda and str(exc).startswith("CUDA out of memory"): | |
# Somehow asking for the memory summary fixes the issue | |
print(torch.cuda.memory_summary(device=None, abbreviated=True)) | |
print("CUDA ran out of memory on " + str(i) + " / " + str(num_batches)) | |
i += 1 | |
else: | |
raise | |
print("Average loss at the end of epoch " + str(epoch) + ": " + str(loss_sum / loss_n)) | |
torch.save(model.state_dict(), MODEL_SAVE_DIR + "/model_" + str(epoch) + ".pt") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment