Last active
April 24, 2019 19:07
-
-
Save AruniRC/b841917fb1e5196fd9df750e52d60631 to your computer and use it in GitHub Desktop.
Allow shifts and scales of Poincare distance which usually lies on the unit disc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import itertools | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
from torch.autograd import Variable | |
from torch.autograd import Function | |
from scipy.spatial.distance import pdist | |
from core.config import cfg | |
import nn as mynn | |
import utils.net as net_utils | |
import numpy as np | |
import math | |
import gc | |
from joblib import Parallel,delayed | |
from modeling.sparse_activations import Sparsemax | |
from poincare_embeddings.hype import poincare | |
from hyperbolic_cones import my_poincare_model as mpm | |
DEBUG = False | |
class GradScaler(Function): | |
""" | |
Gradient scaler layer | |
Based off: | |
https://discuss.pytorch.org/t/solved-reverse-gradients-in-backward-pass/3589/4 | |
""" | |
def __init__(self, scaler=0.0): | |
self.scaler = scaler | |
def forward(self, x): | |
return x.view_as(x) | |
def backward(self, grad_output): | |
return (grad_output * self.scaler) | |
def grad_scale(x): | |
return GradScaler()(x) | |
class Poincare(nn.Module): | |
def __init__(self): | |
super(Poincare, self).__init__() | |
self.eps = 1e-5 | |
def forward(self, u, v): | |
eps = self.eps | |
squnorm = torch.clamp(torch.sum(u * u, dim=-1), 0, 1 - eps) | |
sqvnorm = torch.clamp(torch.sum(v * v, dim=-1), 0, 1 - eps) | |
sqdist = torch.sum(torch.pow(u - v, 2), dim=-1) | |
#ctx.eps = eps | |
#ctx.save_for_backward(u, v, squnorm, sqvnorm, sqdist) | |
x = sqdist / ((1 - squnorm) * (1 - sqvnorm)) * 2 + 1 | |
# arcosh | |
z = torch.sqrt(torch.pow(x, 2) - 1) | |
return torch.log(x + z) | |
##### Self-Attention Relation Networks Recreation ###### | |
class SelfAttnMat(nn.Module): | |
""" Visual appearance features to compute Self-attention Matrix | |
""" | |
def __init__(self, feat_dim=2048, proj_dim=256, T=1.0, use_poincare=False): | |
super(SelfAttnMat, self).__init__() | |
self.proj_dim = proj_dim | |
self.T = T | |
self.sparsemax = Sparsemax(dim=2) | |
self.d_k_sqrt = math.sqrt(self.proj_dim) | |
self.proj_w1 = nn.Conv2d(feat_dim, self.proj_dim, 1, stride=1) # [feat_dim x proj_dim] | |
self.proj_w2 = nn.Conv2d(feat_dim, self.proj_dim, 1, stride=1) | |
self._init_weights() | |
self.use_poincare = use_poincare | |
self.pm = poincare.PoincareManifold() | |
self.my_pc = Poincare() | |
# EDIT: scale and shift the distance on poincare disc | |
self.poinc_scale = nn.Parameter(torch.tensor([1.0])) | |
self.poinc_shift = nn.Parameter(torch.tensor([0.0])) | |
def _init_weights(self): | |
mynn.init.XavierFill(self.proj_w1.weight) | |
init.constant_(self.proj_w1.bias, 0) | |
mynn.init.XavierFill(self.proj_w2.weight) | |
init.constant_(self.proj_w2.bias, 0) | |
def forward(self, region_feature, num_imgs, iou_mat=[]): | |
""" Return adjacency matrix as scaled dot-product self-attention """ | |
# Send scaled (or zero) gradients to rest of net | |
region_feature = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(region_feature) | |
# Project down the region features [n_img*n_region x n_dim x 1 x 1] | |
feat_key = self.proj_w1(region_feature) | |
feat_query = self.proj_w2(region_feature) | |
# Reshape from (n_img*n_region, n_dim, 1, 1) to (n_img, n_region, n_dim, 1, 1) | |
sz = feat_key.shape | |
feat_key = feat_key.view(num_imgs, int(sz[0]/num_imgs), sz[1], sz[2], sz[3]) | |
feat_query = feat_query.view(num_imgs, int(sz[0]/num_imgs), sz[1], sz[2], sz[3]) | |
use_poincare = self.use_poincare | |
if use_poincare: | |
import time;start = time.time() | |
device_id = feat_key.get_device() | |
n_img = feat_key.shape[0] | |
n_region = feat_key.shape[1] | |
R_new = [] | |
#A = torch.zeros((n_region,n_region)) | |
for im in range(n_img): | |
u = feat_key[im].squeeze(-1).squeeze(-1) # [n_region x n_dim x 1 x 1] -> [n_region x n_dim] | |
v = feat_query[im].squeeze(-1).squeeze(-1) # [n_region x n_dim x 1 x 1] -> [n_region x n_dim] | |
# normalize to unit ball | |
u = F.normalize(u,p=1,dim=1) | |
v = F.normalize(v,p=1,dim=1) | |
# slow version: iterate through regions | |
for i in range(n_region): | |
A[i,:] = self.pm.distance(u[i,:].unsqueeze(0).expand(n_region,u.shape[1]),v) | |
# slow version sped up with multiprocessing -- pytorch DataLoader threads cry | |
#pool = Parallel(n_jobs=2)( | |
# delayed(self.pm.distance)(u[i,:].unsqueeze(0).expand(n_region,u.shape[1]),v) for i in range(n_region) | |
# ) | |
#A = [(i,self.pm.distance(u[i,:].unsqueeze(0).expand(n_region,u.shape[1]),v)) for i in range(n_region)] | |
#import pdb; pdb.set_trace(); | |
# broadcast version -- poincare grad throws an error | |
#A = self.pm.distance(u.unsqueeze(1),v.unsqueeze(1).transpose(0,1)) | |
# broadcast with my Poincare (above): directly uses torch autograd, not the poincare grad | |
#A = self.my_pc(u.unsqueeze(1),v.unsqueeze(1).transpose(0,1)) | |
R_new.append(A.unsqueeze(0)) | |
del A | |
R_new = torch.cat(R_new).cuda(device_id) # [n_img x n_region x n_region] | |
R_new = (1 - R_new) | |
end = time.time(); print('==',end - start) | |
else: | |
feat_key = feat_key.unsqueeze(2) # [n_img x n_region x 1 x n_dim x 1 x 1] | |
feat_query = feat_query.unsqueeze(2) # [n_img x n_region x 1 x n_dim x 1 x 1] | |
feat_query = feat_query.transpose(1, 2) # [n_img x 1 x n_region x n_dim x 1 x 1] | |
# broadcast: [n_img x n_region x n_region x n_dim x 1 x 1] | |
R_new = feat_key * feat_query | |
R_new = R_new.squeeze(-1).squeeze(-1) # [n_img x n_region x n_region x n_dim] | |
R_new = R_new.sum(3) / self.d_k_sqrt # self-attention/relation network has sqrt(d_k) | |
if cfg.TRAIN.ATTN_IOU_THRESH: | |
# mask out region pairs with IoU > TRAIN.IOU_THRESH | |
assert num_imgs == 1 # TODO - extend to multiple images | |
assert len(iou_mat) > 0 | |
device_id = R_new.get_device() | |
mask = (iou_mat < cfg.TRAIN.IOU_THRESH).astype('float32') | |
np.fill_diagonal(mask, 1.0) | |
mask = Variable(torch.from_numpy(mask), requires_grad=False).cuda(device_id) | |
mask = mask.view(R_new.shape) | |
R_new = R_new * mask | |
R_new = R_new.contiguous() | |
R_new = R_new / (self.T) # softmax temperature | |
if cfg.TRAIN.SPARSEMAX: | |
out = self.sparsemax(R_new) | |
else: | |
if cfg.TRAIN.DROPOUT > 0: | |
R_new = F.dropout(R_new,p=cfg.TRAIN.DROPOUT,inplace=True) | |
if use_poincare: | |
#out = (R_new / R_new.sum(dim=2)) | |
out = (R_new / R_new.sum(dim=2).unsqueeze(2)) | |
#import pdb; pdb.set_trace(); | |
# EDIT: If you are using "softmax" poincare | |
# R_new = (-self.poinc_scale * R_new) + self.poinc_shift | |
# Then do softmax instead of row-sum 1 | |
else: | |
out = F.softmax(R_new, 2) | |
return out | |
class SelfAttn_basic(nn.Module): | |
""" Self Attention Network for visual context """ | |
def __init__(self, num_A=1, feat_dim=2048, input_feat=2048, output_feat=2048, | |
visual_proj_dim=256, combine='add'): | |
""" | |
num_A - Number of adjacency matrices (multi attention heads) | |
feat_dim - Size of "appearance" features for each region (roi) | |
input_feat - Size of ROI-pooled features (can be different from feat_dim) | |
output_features - Size of the output from each attention head | |
""" | |
super(SelfAttn_basic, self).__init__() | |
self.num_A = num_A | |
self.proj_dim = visual_proj_dim | |
self.output_feat = int(output_feat / self.num_A) | |
self.combine = combine | |
if cfg.TRAIN.CONTEXT_BBOX: | |
feat_dim += 4 # bbox coords are appended to visual feat | |
if self.num_A >= 1: | |
assert cfg.TRAIN.ATTN_W # multi-heads need down-projection with W | |
for i in range(self.num_A): | |
module_AdjMat = SelfAttnMat(feat_dim=feat_dim, | |
proj_dim=self.proj_dim, | |
T=cfg.TRAIN.SOFTMAX_T[i], | |
use_poincare=(i in cfg.TRAIN.POINCARE)) | |
self.add_module('compute_AdjMat{}'.format(i), module_AdjMat) | |
linear_out = nn.Conv2d(input_feat, self.output_feat, 1, stride=1) | |
self._init_weights_multi(linear_out) | |
self.add_module('linear_out{}'.format(i), linear_out) | |
else: | |
raise ValueError | |
def _init_weights_multi(self, linear_out): | |
mynn.init.XavierFill(linear_out.weight) | |
init.constant_(linear_out.bias, 0) | |
def forward(self, visual_feature, x, num_imgs, iou_mat=[], bboxes=[]): | |
""" | |
Returns features incorporating visual context from all other rois | |
visual_feature - appearance feature tensor [num_rois, feat_dim, 1, 1] | |
x - region (box) feature [num_rois, feat_dim, 1, 1] | |
num_imgs - number of images per batch | |
iou_mat - (optional) IoU between regions [num_rois, num_rois] | |
Visual feature and "x" can be from same or different CNN layers. | |
Image IDs are typically obtained from rpn_ret['rois'] in model_builder.py | |
""" | |
# Scale-down or zero-out gradients to rest of the (pre-trained) network | |
x = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(x) | |
num_rois = int(visual_feature.shape[0] / num_imgs) | |
if cfg.TRAIN.ATTN_IOU_THRESH: | |
assert len(iou_mat) > 0 | |
if cfg.TRAIN.CONTEXT_BBOX: | |
bboxes = bboxes.unsqueeze(-1).unsqueeze(-1).cuda(visual_feature.get_device()) | |
visual_feature = torch.cat((visual_feature, bboxes), 1) | |
visual_feature = GradScaler(scaler=cfg.TRAIN.GRL_SCALER)(visual_feature) | |
z = [] | |
for i in range(self.num_A): | |
# z = A.X.W | |
A_i = self._modules['compute_AdjMat{}'.format(i)](visual_feature, num_imgs, | |
iou_mat=iou_mat) | |
z_i = torch.bmm(A_i, x.view(num_imgs, num_rois, -1, 1, 1).squeeze(-1).squeeze(-1)) | |
z_i = z_i.view(num_imgs*num_rois, -1) | |
z_i = z_i.unsqueeze(-1).unsqueeze(-1) | |
z_i = self._modules['linear_out{}'.format(i)](z_i) # [n_img*n_region, output_features, 1, 1] | |
z.append(z_i) | |
z = torch.cat(z, 1) # [n_img*n_region, num_A*output_features, 1, 1] | |
if self.combine == 'add': | |
y = x + z | |
elif self.combine == 'concat': | |
y = torch.cat([x,z], 1) | |
else: | |
raise NotImplementedError | |
y = F.relu(y, inplace=True) | |
return y | |
##### END: Self-Attention Relation Networks Recreation ###### | |
def _gen_timing_signal(length, channels=64, min_timescale=1.0, max_timescale=1.0e3): | |
""" | |
Generates a [1, length, channels] timing signal consisting of sinusoids | |
Adapted from: | |
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/layers/common_attention.py | |
""" | |
position = np.arange(length) | |
num_timescales = channels // 2 | |
log_timescale_increment = ( | |
math.log(float(max_timescale) / float(min_timescale)) / | |
(float(num_timescales) - 1)) | |
inv_timescales = min_timescale * np.exp( | |
np.arange(num_timescales).astype(np.float) * -log_timescale_increment) | |
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, 0) | |
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) | |
signal = np.pad(signal, [[0, 0], [0, channels % 2]], | |
'constant', constant_values=[0.0, 0.0]) | |
signal = signal.reshape([1, length, channels]) | |
return torch.from_numpy(signal).type(torch.FloatTensor) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment