Skip to content

Instantly share code, notes, and snippets.

@avostryakov
Last active May 12, 2017 09:10
Show Gist options
  • Save avostryakov/2e6d2b0055bc0d6fdbbf6d8d5b4ec303 to your computer and use it in GitHub Desktop.
Save avostryakov/2e6d2b0055bc0d6fdbbf6d8d5b4ec303 to your computer and use it in GitHub Desktop.
Implementation of Recurrent Batch Normalization article in Lasagne
from lasagne import nonlinearities, init
from lasagne.layers.normalization import BatchNormLayer
from lasagne.layers.recurrent import Gate, Layer, MergeLayer, LSTMLayer
from lasagne.utils import unroll_scan
import numpy as np
import theano
import theano.tensor as T
class BatchNormalizedLSTMLayer(LSTMLayer):
def __init__(self, incoming, num_units,
ingate=Gate(),
forgetgate=Gate(),
cell=Gate(W_cell=None, nonlinearity=nonlinearities.tanh),
outgate=Gate(),
nonlinearity=nonlinearities.tanh,
cell_init=init.Constant(0.),
hid_init=init.Constant(0.),
backwards=False,
learn_init=False,
peepholes=True,
gradient_steps=-1,
grad_clipping=0,
unroll_scan=False,
precompute_input=True,
mask_input=None,
only_return_final=False,
batch_axes=(0,),
**kwargs):
# Initialize parent layer
super(BatchNormalizedLSTMLayer, self).__init__(incoming, num_units,
ingate, forgetgate, cell, outgate,
nonlinearity, cell_init, hid_init,
backwards, learn_init, peepholes,
gradient_steps, grad_clipping,
unroll_scan, precompute_input, mask_input,
only_return_final, **kwargs)
input_shape = self.input_shapes[0]
# create BN layer with input shape (n_steps, batch_size, 4*num_units) and given axes
self.bn_input = BatchNormLayer((input_shape[1], input_shape[0], 4*self.num_units), beta=None,
gamma=init.Constant(0.1), axes=batch_axes)
self.params.update(self.bn_input.params) # make BN params your params
# create batch normalization parameters for hidden units; the shape is (time_steps, num_units)
self.epsilon = np.float32(1e-4)
self.alpha = theano.shared(np.float32(0.1))
shape = (input_shape[1], 4*num_units)
self.gamma = self.add_param(init.Constant(0.1), shape, 'gamma', trainable=True, regularizable=True)
self.running_mean = self.add_param(init.Constant(0), (input_shape[1], 4*num_units,), 'running_mean',
trainable=False, regularizable=False)
self.running_inv_std = self.add_param(init.Constant(1), (input_shape[1], 4*num_units,), 'running_inv_std',
trainable=False, regularizable=False)
self.running_mean_clone = theano.clone(self.running_mean, share_inputs=False)
self.running_inv_std_clone = theano.clone(self.running_inv_std, share_inputs=False)
self.running_mean_clone.default_update = self.running_mean_clone
self.running_inv_std_clone.default_update = self.running_inv_std_clone
def get_output_for(self, inputs, deterministic=False, **kwargs):
# Retrieve the layer input
input = inputs[0]
# Retrieve the mask when it is supplied
mask = None
hid_init = None
cell_init = None
if self.mask_incoming_index > 0:
mask = inputs[self.mask_incoming_index]
if self.hid_init_incoming_index > 0:
hid_init = inputs[self.hid_init_incoming_index]
if self.cell_init_incoming_index > 0:
cell_init = inputs[self.cell_init_incoming_index]
# Treat all dimensions after the second as flattened feature dimensions
if input.ndim > 3:
input = T.flatten(input, 3)
# Because scan iterates over the first dimension we dimshuffle to
# (n_time_steps, n_batch, n_features)
input = input.dimshuffle(1, 0, 2)
seq_len, num_batch, _ = input.shape
# Stack input weight matrices into a (num_inputs, 4*num_units)
# matrix, which speeds up computation
W_in_stacked = T.concatenate(
[self.W_in_to_ingate, self.W_in_to_forgetgate,
self.W_in_to_cell, self.W_in_to_outgate], axis=1)
# Same for hidden weight matrices
W_hid_stacked = T.concatenate(
[self.W_hid_to_ingate, self.W_hid_to_forgetgate,
self.W_hid_to_cell, self.W_hid_to_outgate], axis=1)
# Stack biases into a (4*num_units) vector
b_stacked = T.concatenate(
[self.b_ingate, self.b_forgetgate,
self.b_cell, self.b_outgate], axis=0)
input = self.bn_input.get_output_for(T.dot(input, W_in_stacked)) + b_stacked
# At each call to scan, input_n will be (n_time_steps, 4*num_units).
# We define a slicing function that extract the input to each LSTM gate
def slice_w(x, n):
return x[:, n*self.num_units:(n+1)*self.num_units]
# Create single recurrent computation step function
# input_n is the n'th vector of the input
def step(input_n, gamma, time_step, cell_previous, hid_previous, *args):
hidden = T.dot(hid_previous, W_hid_stacked)
# batch normalization of hidden states
if deterministic:
mean = self.running_mean[time_step]
inv_std = self.running_inv_std[time_step]
else:
mean = hidden.mean(0)
inv_std = T.inv(T.sqrt(hidden.var(0) + self.epsilon))
self.running_mean_clone.default_update = \
T.set_subtensor(self.running_mean_clone.default_update[time_step],
(1-self.alpha) * self.running_mean_clone.default_update[time_step] + self.alpha * mean)
self.running_inv_std_clone.default_update = \
T.set_subtensor(self.running_inv_std_clone.default_update[time_step],
(1-self.alpha) * self.running_inv_std_clone.default_update[time_step] + self.alpha * inv_std)
mean += 0 * self.running_mean_clone[time_step]
inv_std += 0 * self.running_inv_std_clone[time_step]
gamma = gamma.dimshuffle('x', 0)
mean = mean.dimshuffle('x', 0)
inv_std = inv_std.dimshuffle('x', 0)
# normalize
normalized = (hidden - mean) * (gamma * inv_std)
# Calculate gates pre-activations and slice
gates = input_n + normalized
# Clip gradients
if self.grad_clipping:
gates = theano.gradient.grad_clip(
gates, -self.grad_clipping, self.grad_clipping)
# Extract the pre-activation gate values
ingate = slice_w(gates, 0)
forgetgate = slice_w(gates, 1)
cell_input = slice_w(gates, 2)
outgate = slice_w(gates, 3)
if self.peepholes:
# Compute peephole connections
ingate += cell_previous*self.W_cell_to_ingate
forgetgate += cell_previous*self.W_cell_to_forgetgate
# Apply nonlinearities
ingate = self.nonlinearity_ingate(ingate)
forgetgate = self.nonlinearity_forgetgate(forgetgate)
cell_input = self.nonlinearity_cell(cell_input)
# Compute new cell value
cell = forgetgate*cell_previous + ingate*cell_input
if self.peepholes:
outgate += cell*self.W_cell_to_outgate
outgate = self.nonlinearity_outgate(outgate)
# Compute new hidden unit activation
hid = outgate*self.nonlinearity(cell)
return [cell, hid]
def step_masked(input_n, mask_n, gamma, time_step, cell_previous, hid_previous, *args):
cell, hid = step(input_n, gamma, time_step, cell_previous, hid_previous, *args)
# Skip over any input with mask 0 by copying the previous
# hidden state; proceed normally for any input with mask 1.
cell = T.switch(mask_n, cell, cell_previous)
hid = T.switch(mask_n, hid, hid_previous)
return [cell, hid]
if mask is not None:
# mask is given as (batch_size, seq_len). Because scan iterates
# over first dimension, we dimshuffle to (seq_len, batch_size) and
# add a broadcastable dimension
mask = mask.dimshuffle(1, 0, 'x')
sequences = [input, mask]
step_fun = step_masked
else:
sequences = [input, ]
step_fun = step
time_steps = np.asarray(np.arange(self.input_shapes[0][1]), dtype=np.int32)
sequences.extend([self.gamma, time_steps])
ones = T.ones((num_batch, 1))
if not isinstance(self.cell_init, Layer):
# Dot against a 1s vector to repeat to shape (num_batch, num_units)
cell_init = T.dot(ones, self.cell_init)
if not isinstance(self.hid_init, Layer):
# Dot against a 1s vector to repeat to shape (num_batch, num_units)
hid_init = T.dot(ones, self.hid_init)
# The hidden-to-hidden weight matrix is always used in step
non_seqs = [W_hid_stacked]
# The "peephole" weight matrices are only used when self.peepholes=True
if self.peepholes:
non_seqs += [self.W_cell_to_ingate,
self.W_cell_to_forgetgate,
self.W_cell_to_outgate]
non_seqs += [self.running_mean, self.running_inv_std]
if self.unroll_scan:
# Retrieve the dimensionality of the incoming layer
input_shape = self.input_shapes[0]
# Explicitly unroll the recurrence instead of using scan
cell_out, hid_out = unroll_scan(
fn=step_fun,
sequences=sequences,
outputs_info=[cell_init, hid_init],
go_backwards=self.backwards,
non_sequences=non_seqs,
n_steps=input_shape[1])
else:
# Scan op iterates over first dimension of input and repeatedly
# applies the step function
cell_out, hid_out = theano.scan(
fn=step_fun,
sequences=sequences,
outputs_info=[cell_init, hid_init],
go_backwards=self.backwards,
truncate_gradient=self.gradient_steps,
non_sequences=non_seqs,
strict=True)[0]
# When it is requested that we only return the final sequence step,
# we need to slice it out immediately after scan is applied
if self.only_return_final:
hid_out = hid_out[-1]
else:
# dimshuffle back to (n_batch, n_time_steps, n_features))
hid_out = hid_out.dimshuffle(1, 0, 2)
# if scan is backward reverse the output
if self.backwards:
hid_out = hid_out[:, ::-1]
return hid_out
@webeng
Copy link

webeng commented Jan 11, 2017

Thanks for implementing it.

I'm getting an error when using this implementation:

Traceback (most recent call last): File "qann.py", line 660, in <module> qann.build_model() File "qann.py", line 334, in build_model network_output = lasagne.layers.get_output(network_start) File "/Applications/MAMP/htdocs/qann/env/lib/python2.7/site-packages/lasagne/layers/helper.py", line 191, in get_output all_outputs[layer] = layer.get_output_for(layer_inputs, **kwargs) File "/Applications/MAMP/htdocs/qann/models/batch_normalized_lstm_layer.py", line 236, in get_output_for strict=True)[0] File "/Applications/MAMP/htdocs/qann/env/lib/python2.7/site-packages/theano/scan_module/scan.py", line 557, in scan scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs] IndexError: failed to coerce slice entry of type TensorVariable to integer

Do you have any idea why it occurs?

@nikostr
Copy link

nikostr commented May 12, 2017

@webeng, in the discussion at Lasagne/Lasagne#577 the author mentions that the code works when unroll_scan=True. We got that exact error when attempting to run the above code without unrolling the scan.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment