Last active
November 24, 2016 03:07
-
-
Save xlvector/370cb20c62a9ca16d5d9fea43a27f33c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//used for training | |
def bi_lstm_unroll(seq_len, input_size,num_hidden, num_embed, num_label, dropout=0.): | |
embed_weight = mx.sym.Variable("embed_weight") | |
cls_weight = mx.sym.Variable("cls_weight") | |
cls_bias = mx.sym.Variable("cls_bias") | |
last_states = [] | |
last_states.append(LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h"))) | |
last_states.append(LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h"))) | |
forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"), | |
i2h_bias=mx.sym.Variable("l0_i2h_bias"), | |
h2h_weight=mx.sym.Variable("l0_h2h_weight"), | |
h2h_bias=mx.sym.Variable("l0_h2h_bias")) | |
backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"), | |
i2h_bias=mx.sym.Variable("l1_i2h_bias"), | |
h2h_weight=mx.sym.Variable("l1_h2h_weight"), | |
h2h_bias=mx.sym.Variable("l1_h2h_bias")) | |
data = mx.sym.Variable('data') | |
label = mx.sym.Variable('softmax_label') | |
embed = mx.sym.Embedding(data=data, input_dim=input_size, | |
weight=embed_weight, output_dim=num_embed, name='embed') | |
wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1) | |
forward_hidden = [] | |
for seqidx in range(seq_len): | |
hidden = wordvec[seqidx] | |
next_state = lstm(num_hidden, indata=hidden, | |
prev_state=last_states[0], | |
param=forward_param, | |
seqidx=seqidx, layeridx=0, dropout=dropout) | |
hidden = next_state.h | |
last_states[0] = next_state | |
forward_hidden.append(hidden) | |
backward_hidden = [] | |
for seqidx in range(seq_len): | |
k = seq_len - seqidx - 1 | |
hidden = wordvec[k] | |
next_state = lstm(num_hidden, indata=hidden, | |
prev_state=last_states[1], | |
param=backward_param, | |
seqidx=k, layeridx=1,dropout=dropout) | |
hidden = next_state.h | |
last_states[1] = next_state | |
backward_hidden.insert(0, hidden) | |
hidden_all = [] | |
for i in range(seq_len): | |
hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1)) | |
hidden_concat = mx.sym.Concat(*hidden_all, dim=0) | |
pred = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label, | |
weight=cls_weight, bias=cls_bias, name='pred') | |
label = mx.sym.transpose(data=label) | |
label = mx.sym.Reshape(data=label, target_shape=(0,)) | |
sm = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax') | |
return sm | |
//used for inference | |
class BiLSTMInferenceModel(object): | |
def __init__(self, | |
seq_len, | |
input_size, | |
num_hidden, | |
num_embed, | |
num_label, | |
arg_params, | |
ctx=mx.cpu(), | |
dropout=0.): | |
self.sym = bi_lstm_inference_symbol(input_size, seq_len, | |
num_hidden, | |
num_embed, | |
num_label, | |
dropout) | |
print "input size: ", input_size | |
batch_size = 1 | |
init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(2)] | |
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(2)] | |
data_shape = [("data", (batch_size,))] | |
input_shapes = dict(init_c + init_h + data_shape) | |
print input_shapes | |
self.executor = self.sym.simple_bind(ctx=mx.cpu(), **input_shapes) | |
for key in self.executor.arg_dict.keys(): | |
if key in arg_params: | |
print key, arg_params[key].shape, self.executor.arg_dict[key].shape | |
arg_params[key].copyto(self.executor.arg_dict[key]) | |
#self.executor.arg_dict[key] = arg_params[key] | |
state_name = [] | |
for i in range(2): | |
state_name.append("l%d_init_c" % i) | |
state_name.append("l%d_init_h" % i) | |
self.states_dict = dict(zip(state_name, self.executor.outputs[1:])) | |
self.input_arr = mx.nd.zeros(data_shape[0][1]) | |
def forward(self, input_data, new_seq=False): | |
if new_seq == True: | |
for key in self.states_dict.keys(): | |
self.executor.arg_dict[key][:] = 0. | |
print input_data | |
self.executor.arg_dict["data"] = input_data | |
#input_data.copyto(self.executor.arg_dict["data"]) | |
self.executor.forward() | |
for key in self.states_dict.keys(): | |
print key | |
self.states_dict[key].copyto(self.executor.arg_dict[key]) | |
prob = self.executor.outputs[0].asnumpy() | |
return prob | |
def bi_lstm_inference_symbol(input_size, seq_len, | |
num_hidden, num_embed, num_label, dropout=0.): | |
seqidx = 0 | |
embed_weight=mx.sym.Variable("embed_weight") | |
cls_weight = mx.sym.Variable("cls_weight") | |
cls_bias = mx.sym.Variable("cls_bias") | |
last_states = [LSTMState(c = mx.sym.Variable("l0_init_c"), h = mx.sym.Variable("l0_init_h")), | |
LSTMState(c = mx.sym.Variable("l1_init_c"), h = mx.sym.Variable("l1_init_h"))] | |
forward_param = LSTMParam(i2h_weight=mx.sym.Variable("l0_i2h_weight"), | |
i2h_bias=mx.sym.Variable("l0_i2h_bias"), | |
h2h_weight=mx.sym.Variable("l0_h2h_weight"), | |
h2h_bias=mx.sym.Variable("l0_h2h_bias")) | |
backward_param = LSTMParam(i2h_weight=mx.sym.Variable("l1_i2h_weight"), | |
i2h_bias=mx.sym.Variable("l1_i2h_bias"), | |
h2h_weight=mx.sym.Variable("l1_h2h_weight"), | |
h2h_bias=mx.sym.Variable("l1_h2h_bias")) | |
data = mx.sym.Variable("data") | |
embed = mx.sym.Embedding(data=data, input_dim=input_size, | |
weight=embed_weight, output_dim=num_embed, name='embed') | |
wordvec = mx.sym.SliceChannel(data=embed, num_outputs=seq_len, squeeze_axis=1) | |
forward_hidden = [] | |
for seqidx in range(seq_len): | |
next_state = lstm(num_hidden, indata=wordvec[seqidx], | |
prev_state=last_states[0], | |
param=forward_param, | |
seqidx=seqidx, layeridx=0, dropout=0.0) | |
hidden = next_state.h | |
last_states[0] = next_state | |
forward_hidden.append(hidden) | |
backward_hidden = [] | |
for seqidx in range(seq_len): | |
k = seq_len - seqidx - 1 | |
next_state = lstm(num_hidden, indata=wordvec[k], | |
prev_state=last_states[1], | |
param=backward_param, | |
seqidx=k, layeridx=1, dropout=0.0) | |
hidden = next_state.h | |
last_states[1] = next_state | |
backward_hidden.insert(0, hidden) | |
hidden_all = [] | |
for i in range(seq_len): | |
hidden_all.append(mx.sym.Concat(*[forward_hidden[i], backward_hidden[i]], dim=1)) | |
hidden_concat = mx.sym.Concat(*hidden_all, dim=0) | |
fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=num_label, | |
weight=cls_weight, bias=cls_bias, name='pred') | |
sm = mx.sym.SoftmaxOutput(data=fc, name='softmax') | |
output = [sm] | |
for state in last_states: | |
output.append(state.c) | |
output.append(state.h) | |
return mx.sym.Group(output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment