Skip to content

Instantly share code, notes, and snippets.

@zeakey
Created January 2, 2020 09:32
Show Gist options
  • Save zeakey/3fd5d7c8bf0be8ee4ccd562d84df2c79 to your computer and use it in GitHub Desktop.
Save zeakey/3fd5d7c8bf0be8ee4ccd562d84df2c79 to your computer and use it in GitHub Desktop.
class ExclusiveLinear(nn.Module):
def __init__(self, feat_dim=512, num_class=10572, norm_data=True, radius=20):
super(ExclusiveLinear, self).__init__()
self.num_class = num_class
self.feat_dim = feat_dim
self.norm_data = norm_data
self.radius = float(radius)
self.weight = nn.Parameter(torch.randn(self.num_class, self.feat_dim))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x):
weight_norm = torch.nn.functional.normalize(self.weight, p=2, dim=1)
cos = torch.mm(weight_norm, weight_norm.t())
cos.clamp(-1, 1)
cos1 = cos.detach()
cos1.scatter_(1, torch.arange(self.num_class).view(-1, 1).long().cuda(), -100)
_, indices = torch.max(cos1, dim=0)
mask = torch.zeros((self.num_class, self.num_class)).cuda()
mask.scatter_(1, indices.view(-1, 1).long(), 1)
exclusive_loss = torch.dot(cos.view(cos.numel()), mask.view(mask.numel())) / self.num_class
if self.norm_data:
x = torch.nn.functional.normalize(x, p=2, dim=1)
x = x * self.radius
return torch.nn.functional.linear(x, weight_norm), exclusive_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment