Skip to content

Instantly share code, notes, and snippets.

@huchenxucs
Created August 4, 2020 07:07
Show Gist options
  • Save huchenxucs/e0c70624862b05c25b3b8766c6c2213c to your computer and use it in GitHub Desktop.
Save huchenxucs/e0c70624862b05c25b3b8766c6c2213c to your computer and use it in GitHub Desktop.
Masked Conv1d
class Conv1dWithMask(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, w_init_gain='linear'):
super(Conv1dWithMask, self).__init__()
assert kernel_size > 1, f"Conv1dWithMask kernel size must greater than 1"
self.kernel_size = kernel_size
self.out_channels = out_channels
self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, bias=bias)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, x, mask=None):
"""
:param x: [B, H, T]
:param mask: [B, T, T], e.g.:
tensor([[[1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1.]], ...])
:return: [B, H', T]
"""
if isinstance(x, list):
assert len(x) == 2
x, mask = x[0], x[1]
assert mask is not None
x = x.permute(0, 2, 1) # [B, H, T] -> [B, T, H]
kernel_size = self.kernel_size
B, T, H = x.shape
mask_pad = F.pad(mask, [kernel_size // 2, kernel_size // 2])
mask_pad_shift = torch.cat([mask_pad[:, :, :-1].reshape(B, -1), mask_pad[:, :, -1]], -1)
mask_pad_shift = mask_pad_shift.reshape(B, T, -1)[:, :, :kernel_size]
mask_pad_shift = mask_pad_shift.reshape(-1, 1, kernel_size).float() # [B*T, 1, K]
x_pad = F.pad(x, [0, 0, kernel_size // 2, kernel_size // 2], value=0) # [B, T+K-1, H]
x_unfold = x_pad.unfold(1, kernel_size, 1) # [B, T, H, K]
x_unfold = x_unfold.reshape(-1, H, kernel_size) # [B*T, H, K]
x_conv = self.conv(x_unfold * mask_pad_shift) # [B*T, H', 1]
x_conv = x_conv.reshape(B, T, self.out_channels) # [B, T, H']
x_conv = x_conv.permute(0, 2, 1) # [B, H', T]
return x_conv
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment