Skip to content

Instantly share code, notes, and snippets.

@vabh
Created November 20, 2017 14:52
Show Gist options
  • Save vabh/2210ac54bc5cb202bfb1133df48ae58b to your computer and use it in GitHub Desktop.
Save vabh/2210ac54bc5cb202bfb1133df48ae58b to your computer and use it in GitHub Desktop.
class ConvLSTMCell(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=1):
super(ConvLSTMCell, self).__init__()
self.k = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.padding = padding
self.w_i = nn.Parameter(torch.Tensor(4*out_channels, in_channels, kernel_size, kernel_size))
self.w_h = nn.Parameter(torch.Tensor(4*out_channels, in_channels, kernel_size, kernel_size))
self.w_c = nn.Parameter(torch.Tensor(3*out_channels, in_channels, kernel_size, kernel_size))
# TODO include bias terms
self.reset_parameters()
def reset_parameters(self):
n = 4 * self.in_channels * self.k * self.k
stdv = 1. / math.sqrt(n)
self.w_i.data.uniform_(-stdv, stdv)
self.w_h.data.uniform_(-stdv, stdv)
self.w_c.data.uniform_(-stdv, stdv)
def forward(self, x, hx):
h, c = hx
wx = F.conv2d(x, self.w_i, padding=self.padding)
wh = F.conv2d(h, self.w_h, padding=self.padding)
wc = F.conv2d(c, self.w_c, padding=self.padding)
i = F.sigmoid(wx[:, :self.out_channels] + wh[:, :self.out_channels] + wc[:, :self.out_channels])
f = F.sigmoid(wx[:, self.out_channels:2*self.out_channels] + wh[:, self.out_channels:2*self.out_channels]
+ wc[:, self.out_channels:2*self.out_channels])
g = F.tanh(wx[:, 2*self.out_channels:3*self.out_channels] + wh[:, 2*self.out_channels:3*self.out_channels])
c_t = f * c + i * g
o_t = F.sigmoid(wx[:, 3*self.out_channels:] + wh[:, 3*self.out_channels:]
+ wc[:, 2*self.out_channels: ]*c_t)
h_t = o_t * F.tanh(c_t)
return h_t, (h_t, c_t)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment