Last active
December 12, 2018 10:29
-
-
Save movefast/b9a774d434c6fd6eb87b2ddcab3e32ba to your computer and use it in GitHub Desktop.
A random sampler weighted on prev batch losses using fastai library
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data.sampler import Sampler | |
from torch.utils.data.sampler import RandomSampler | |
class WeightedLossSampler(Sampler, Callback): | |
def __init__(self, data_source, replacement=True, bs=64, init_fac=1, ls_init_fac=1e-2): | |
self.num_samples = len(data_source) | |
self.weights = to_gpu(torch.ones(self.num_samples)*init_fac) | |
self.replacement = replacement | |
self.i = 0 | |
self.bs = bs | |
self.ls_param = 0 | |
self.init_fac, self.ls_init_fac = init_fac, ls_init_fac | |
def __iter__(self): | |
self.idxes = to_gpu(torch.multinomial(self.weights.add(self.ls_param), self.num_samples, self.replacement)) | |
return iter(self.idxes) | |
def __len__(self): | |
return self.num_samples | |
def set_weights_per_batch(self, wts): | |
end = min(self.i+self.bs, self.num_samples) | |
self.weights[self.idxes[self.i:end]] = wts | |
self.i += self.bs | |
def on_epoch_end(self, metrics): | |
if not hasattr(self, 'prev_loss'): | |
self.prev_loss = metrics[0][0] | |
self.ls_param = self.prev_loss * self.ls_init_fac | |
else: | |
cur_loss = metrics[0][0] | |
# assume normal learning curve | |
ls_fac = np.exp((cur_loss - self.prev_loss) / self.prev_loss) | |
self.ls_param = self.ls_param * ls_fac | |
self.prev_loss = cur_loss | |
self.i = 0 | |
def on_batch_end(self, raw_losses): | |
self.set_weights_per_batch(raw_losses.data) | |
class SortishSampler(Sampler, Callback): | |
def __init__(self, data_source, bs): | |
self.data_source,self.bs = data_source,bs | |
self.i = 0 | |
self.num_samples = len(self.data_source) | |
self.weights = to_gpu(torch.ones(self.num_samples)) | |
def __len__(self): return len(self.data_source) | |
def __iter__(self): | |
idxs = np.random.permutation(len(self.data_source)) | |
sz = self.bs*50 | |
ck_idx = [idxs[i:i+sz] for i in range(0, len(idxs), sz)] | |
weights = to_np(self.weights) | |
sort_idx = sum([sorted(s, key=lambda x: [x], reverse=True) for s in ck_idx], []) | |
sz = self.bs | |
self.idxes = torch.LongTensor(np.array(sort_idx)).cuda() | |
return iter(self.idxes) | |
def set_weights_per_batch(self, wts): | |
end = min(self.i+self.bs, self.num_samples) | |
self.weights[self.idxes[self.i:end]] = wts | |
if end == self.num_samples: | |
self.i = 0 | |
else: | |
self.i += self.bs | |
def on_batch_end(self, raw_losses): | |
self.set_weights_per_batch(raw_losses.data) | |
class ImageClassifierData(ImageData): | |
def __init__(self, path, datasets, bs, num_workers, classes): | |
trn_ds,val_ds,fix_ds,aug_ds,test_ds,test_aug_ds = datasets | |
self.path,self.bs,self.num_workers,self.classes = path,bs,num_workers,classes | |
# self.our_sampler = RandomSampler(trn_ds) | |
self.our_sampler = WeightedLossSampler(trn_ds, replacement=True, bs=64) | |
# self.our_sampler = SortishSampler(trn_ds, bs=64) | |
self.trn_dl = self.get_dl(trn_ds,False, self.our_sampler) | |
self.val_dl, self.fix_dl,self.aug_dl,self.test_dl,self.test_aug_dl = [ | |
self.get_dl(ds,shuf) for ds,shuf in [ | |
(val_ds,False),(fix_ds,False),(aug_ds,False), | |
(test_ds,False),(test_aug_ds,False) | |
] | |
] | |
@classmethod | |
def from_arrays(cls, path, trn, val, bs=64, tfms=(None,None), classes=None, num_workers=4, test=None): | |
datasets = cls.get_ds(ArraysIndexDataset, trn, val, tfms, test=test) | |
return cls(path, datasets, bs, num_workers, classes=classes) | |
def get_dl(self, ds, shuffle, sampler=None): | |
if ds is None: return None | |
return DataLoader(ds, batch_size=self.bs, shuffle=shuffle, | |
num_workers=self.num_workers, pin_memory=False, sampler=sampler) | |
diff --git a/fastai/model.py b/fastai/model.py | |
index a9b41bc..4f82b6d 100644 | |
--- a/fastai/model.py | |
+++ b/fastai/model.py | |
@@ -48,8 +48,8 @@ class Stepper(): | |
output = self.m(*xs) | |
if isinstance(output,tuple): output,*xtra = output | |
if self.fp16: self.m.zero_grad() | |
- else: self.opt.zero_grad() | |
- loss = raw_loss = self.crit(output, y) | |
+ else: self.opt.zero_grad() | |
+ loss = raw_loss = torch.mean(self.crit(output, y)) | |
if self.loss_scale != 1: assert(self.fp16); loss = loss*self.loss_scale | |
if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss) | |
loss.backward() | |
@@ -64,10 +64,32 @@ class Stepper(): | |
torch.cuda.synchronize() | |
return torch_item(raw_loss.data) | |
+ def step_with_raw_loss(self, xs, y, epoch): | |
+ xtra = [] | |
+ output = self.m(*xs) | |
+ if isinstance(output,tuple): output,*xtra = output | |
+ if self.fp16: self.m.zero_grad() | |
+ else: self.opt.zero_grad() | |
+ raw_loss = raw_loss_1 = self.crit(output, y) | |
+ loss = raw_loss = torch.mean(raw_loss) | |
+ if self.loss_scale != 1: assert(self.fp16); loss = loss*self.loss_scale | |
+ if self.reg_fn: loss = self.reg_fn(output, xtra, raw_loss) | |
+ loss.backward() | |
+ if self.fp16: update_fp32_grads(self.fp32_params, self.m) | |
+ if self.loss_scale != 1: | |
+ for param in self.fp32_params: param.grad.data.div_(self.loss_scale) | |
+ if self.clip: # Gradient clipping | |
+ nn.utils.clip_grad_norm(trainable_params_(self.m), self.clip) | |
+ self.opt.step() | |
+ if self.fp16: | |
+ copy_fp32_to_model(self.m, self.fp32_params) | |
+ torch.cuda.synchronize() | |
+ return torch_item(raw_loss.data), raw_loss_1 | |
+ | |
def evaluate(self, xs, y): | |
preds = self.m(*xs) | |
if isinstance(preds,tuple): preds=preds[0] | |
- return preds, self.crit(preds, y) | |
+ return preds, torch.mean(self.crit(preds, y)) | |
def set_train_mode(m): | |
if (hasattr(m, 'running_mean') and (getattr(m,'bn_freeze',False) | |
@@ -125,13 +147,13 @@ def fit(model, data, n_epochs, opt, crit, metrics=None, callbacks=None, stepper= | |
for (*x,y) in t: | |
batch_num += 1 | |
for cb in callbacks: cb.on_batch_begin() | |
- loss = model_stepper.step(V(x),V(y), epoch) | |
+ loss, raw_losses = model_stepper.step_with_raw_loss(V(x),V(y), epoch) | |
avg_loss = avg_loss * avg_mom + loss * (1-avg_mom) | |
debias_loss = avg_loss / (1 - avg_mom**batch_num) | |
t.set_postfix(loss=debias_loss) | |
stop=False | |
los = debias_loss if not all_val else [debias_loss] + validate_next(model_stepper,metrics, val_iter) | |
- for cb in callbacks: stop = stop or cb.on_batch_end(los) | |
+ for cb in callbacks: stop = stop or cb.on_batch_end(raw_losses) | |
if stop: return | |
if batch_num >= cnt_phases[phase]: | |
for cb in callbacks: cb.on_phase_end() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment