Skip to content

Instantly share code, notes, and snippets.

@yaroslavvb
Created December 3, 2017 22:52
Show Gist options
  • Save yaroslavvb/45178ed685f63293104ef54fe61fec76 to your computer and use it in GitHub Desktop.
Save yaroslavvb/45178ed685f63293104ef54fe61fec76 to your computer and use it in GitHub Desktop.
PyTorch tied autoencoder with l-BFGS
import util as u
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
# todo: make images global
step = 0
final_loss = None
def benchmark(batch_size, iters, seed=1, cuda=True, verbose=False):
global step, final_loss
step = 0
final_loss = None
torch.manual_seed(seed)
np.random.seed(seed)
if cuda:
torch.cuda.manual_seed(seed)
visible_size = 28*28
hidden_size = 196
images = torch.Tensor(u.get_mnist_images(batch_size).T)
images = images[:batch_size]
if cuda:
images = images.cuda()
data = Variable(images)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.encoder = nn.Parameter(torch.rand(visible_size, hidden_size))
def forward(self, input):
x = input.view(-1, visible_size)
x = torch.sigmoid(torch.mm(x, self.encoder))
x = torch.sigmoid(torch.mm(x, torch.transpose(self.encoder, 0, 1)))
return x.view_as(input)
# initialize model and weights
model = Net()
model.encoder.data = torch.Tensor(u.ng_init(visible_size,
hidden_size))
if cuda:
model.cuda()
model.train()
optimizer = optim.LBFGS(model.parameters(), max_iter=iters, history_size=100, lr=1.0)
def closure():
global step, final_loss
optimizer.zero_grad()
output = model(data)
loss = F.mse_loss(output, data)
if verbose:
loss0 = loss.data[0]
print("Step %3d loss %6.5f msec %6.3f"%(step, loss0, u.last_time()))
step+=1
if step == iters:
final_loss = loss.data[0]
loss.backward()
u.record_time()
return loss
optimizer.step(closure)
output = model(data)
loss = F.mse_loss(output, data)
loss0 = loss.data[0]
if verbose:
u.summarize_time()
return final_loss
def main():
import common_gd
args = common_gd.args
args.cuda = not args.no_cuda and torch.cuda.is_available()
print(benchmark(batch_size=args.batch_size, iters=args.iters, seed=args.seed, cuda = args.cuda, verbose=True))
if __name__=='__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment