Skip to content

Instantly share code, notes, and snippets.

@dongzhuoyao
Last active March 4, 2020 17:54
Show Gist options
  • Save dongzhuoyao/9d4f0e75904f47b941bef958080d3220 to your computer and use it in GitHub Desktop.
Save dongzhuoyao/9d4f0e75904f47b941bef958080d3220 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from torch.nn import functional as F
import math
class LayerNorm(nn.Module):
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return (x - mean) / (std + self.eps)
class BasicBlock(nn.Module):
def __init__(
self,
in_channels,
inter_channels=None,
sub_sample=False,
bn_layer=True,
dropout=0.2,
out_channels=512,
):
super(BasicBlock, self).__init__()
self.sub_sample = sub_sample
self.in_channels = in_channels
self.out_channels = out_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels # // 2,TODO
if self.inter_channels == 0:
self.inter_channels = 1
conv_nd = nn.Linear(self.in_channels, self.inter_channels)
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
self.g = nn.Linear(self.in_channels, self.inter_channels)
if bn_layer:
self.W = nn.Linear(self.inter_channels, out_channels)
else:
self.W = conv_nd(
in_channels=self.inter_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0,
)
nn.init.constant(self.W.weight, 0)
nn.init.constant(self.W.bias, 0)
self.theta = nn.Linear(self.in_channels, self.inter_channels)
self.phi = nn.Linear(self.in_channels, self.inter_channels)
self.dropout = nn.Dropout(dropout)
self.ln = LayerNorm()
if sub_sample:
self.g = nn.Sequential(self.g, max_pool_layer)
self.phi = nn.Sequential(self.phi, max_pool_layer)
def forward(self, query, support):
"""
:param query: (b, q, c), support: (b,s,c)
:return: (b,q,c)
"""
assert query.size(2) == support.size(2)
assert query.size(0) == support.size(0)
g_x = self.g(support)
theta_x = self.theta(query)
phi_x = self.phi(support)
phi_x = phi_x.permute(0, 2, 1)
f = torch.matmul(theta_x, phi_x) # Bxqxs
f = f / math.sqrt(self.in_channels) # rescale
f_div_C = F.softmax(f, dim=-1)
y = torch.matmul(f_div_C, g_x) # BxqxC
y = self.ln(y) # layer normalization in last dim
tmp = F.relu(y)
tmp = self.W(tmp)
W_y = self.dropout(tmp)
z = W_y + query
return z
if __name__ == "__main__":
from torch.autograd import Variable
import torch
sub_sample = True
bn_layer = False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment