Skip to content

Instantly share code, notes, and snippets.

@simgt
Created June 23, 2016 12:43
Show Gist options
  • Save simgt/c990c755d7b97f75a677cca259238f3c to your computer and use it in GitHub Desktop.
Save simgt/c990c755d7b97f75a677cca259238f3c to your computer and use it in GitHub Desktop.
local torch =require 'torch'
local nn =require 'nn'
local rnn =require 'rnn'
local gnuplot = require 'gnuplot'
torch.setnumthreads(4)
print('number of threads: ' .. torch.getnumthreads())
batchSize = 16
rho = 16 -- sequence length
lr = 0.01
--
-- Model
--
inputSize = 1
hiddenSize = 32
outputSize = 1
local r = nn.Recurrent(
hiddenSize, -- start
nn.Linear(inputSize, hiddenSize), -- input
nn.Linear(hiddenSize, hiddenSize), -- feedback
nn.Sigmoid(), -- transfer
rho
)
local rnn = nn.Sequential()
:add(r)
:add(nn.Linear(hiddenSize, outputSize))
rnn = nn.Sequencer(rnn)
-- load a model previously trained
--rnn = torch.load('sine-waves-model.dat', 'ascii', true)
print(rnn)
criterion = nn.SequencerCriterion(nn.MSECriterion())
--
-- Dataset
--
local numSamples = 1024
local numPeriods = 10
local t = torch.linspace(0, numPeriods * 2 * math.pi, numSamples)
local input = torch.Tensor(numSamples, inputSize)
local output = torch.Tensor(numSamples, outputSize)
input:select(2, 1):copy(torch.sin(t))
--input:select(2, 2):copy(torch.sin(t/2))
output:select(2, 1):copy(torch.sin(t/2))
--output:select(2, 2):copy(torch.sin(t*2))
--
-- Training
--
local it = 1
while true do
offsets = torch.LongTensor(batchSize)
for i=1,batchSize do
offsets[i] = math.ceil(math.random()*input:size(1))
end
for a = 1, 2000 do
-- create a batch of sequences of rho time-steps
local x, y = {}, {}
for step = 1, rho do
x[step] = input:index(1, offsets)
y[step] = output:index(1, offsets)
-- incement indices
offsets = offsets + 1
for j = 1, batchSize do
if offsets[j] > numSamples then
offsets[j] = 1
end
end
end
-- forward the sequence
local z = rnn:forward(x)
local err = criterion:forward(z, y)
print(string.format("[%d] err = %f", it, err / rho))
-- backward the sequence (i.e. BPTT) in reverse order of forward calls
rnn:zeroGradParameters()
local gz = criterion:backward(z, y)
rnn:backward(x, gz)
-- update
rnn:updateParameters(lr)
it = it + 1
end
-- save the model
print("Saving...")
torch.save('sine-waves-model.dat', rnn, 'ascii', true)
-- test on the full sequence
local z = rnn:forward(input)
gnuplot.pngfigure('sine-waves-test.png')
gnuplot.plot(
{'input', t, input:select(2, 1), '-'},
{'truth', t, output:select(2, 1), '-'},
{'estimate', t, z:select(2, 1), '-'}
)
gnuplot.plotflush()
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment