Last active
June 24, 2020 17:44
-
-
Save etienne87/e65b6bb2493213f436bf4a5b43b943ca to your computer and use it in GitHub Desktop.
very scruffy script to show case siren networks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from ranger import Ranger | |
import numpy as np | |
import random | |
import cv2 | |
import math | |
import kornia | |
from functools import partial | |
laplace_filter = partial(kornia.filters.laplacian, kernel_size=3) | |
gradient_filter = kornia.filters.sobel | |
def make_dataset(filname=None, centered=False): | |
img = cv2.imread('dali.jpg') | |
img = cv2.pyrDown(img) | |
height, width, c = img.shape | |
xv, yv = torch.meshgrid([torch.linspace(0,1,height), torch.linspace(0,1,width)]) | |
xv = xv.contiguous() | |
yv = yv.contiguous() | |
y = torch.from_numpy(img)/255.0 | |
g = kornia.spatial_gradient(y.permute(2,0,1)[None]) | |
g = g[0].permute(2,3,0,1).view(-1,2,3).permute(0,2,1) | |
#g = g.view(-1, c) | |
y = y.view(-1, c) | |
# centering | |
# y = (y*2)-1 | |
x = torch.cat([xv.view(-1)[:,None], yv.view(-1)[:,None]], dim=1) | |
return x, y, g, height, width | |
def siren_init(tensor, use_this_fan_in=None): | |
""" | |
Siren initalization of a tensor. To initialize a nn.Module use 'apply_siren_init'. | |
It's equivalent to torch.nn.init.kaiming_uniform_ with mode = 'fan_in' | |
and the same gain as the 'ReLU' nonlinearity | |
""" | |
if use_this_fan_in is not None: | |
fan_in = use_this_fan_in | |
else: | |
fan_in = nn.init._calculate_correct_fan(tensor, "fan_in") | |
bound = math.sqrt(6.0 / fan_in) | |
with torch.no_grad(): | |
return tensor.uniform_(-bound, bound) | |
def apply_siren_init(layer: nn.Module): | |
""" | |
Applies siren initialization to a layer | |
""" | |
siren_init(layer.weight) | |
if layer.bias is not None: | |
fan_in = nn.init._calculate_correct_fan(layer.weight, "fan_in") | |
siren_init(layer.bias, use_this_fan_in=fan_in) | |
class SinusLayer(nn.Module): | |
def __init__(self, cin, cout, w0=1): | |
super(SinusLayer, self).__init__() | |
self.linear = nn.Linear(cin, cout) | |
#special init | |
#wi ∼ U(−c/√n, c/√n) | |
# apply_siren_init(self.linear) | |
fan_in = nn.init._calculate_correct_fan(self.linear.weight, "fan_in") | |
bound = math.sqrt(6.0 / fan_in) | |
self.w0 = torch.nn.Parameter(torch.tensor(w0).float()) | |
torch.nn.init.uniform_(self.linear.weight, -bound, bound) | |
def forward(self, x): | |
return torch.sin(self.linear(self.w0*x)) | |
class Siren(nn.Module): | |
def __init__(self, cin=2, cout=3, hiddens=[128,64,64,32,16]): | |
super(Siren, self).__init__() | |
self.prepare = SinusLayer(cin, hiddens[0], w0=30) | |
self.residuals = nn.ModuleList() | |
last = hiddens[0] | |
for i in range(1,len(hiddens)): | |
v = hiddens[i] | |
self.residuals.append(SinusLayer(last, v)) | |
last = v | |
self.out = nn.Linear(last, cout) | |
# self.out = SinusLayer(last, 3) | |
def forward(self, x): | |
x = self.prepare(x) | |
for res in self.residuals: | |
x = res(x) | |
return torch.sigmoid(self.out(x)) | |
def show_pred(pred, height, width, centered=False): | |
pred = pred.view(height, width, 3) | |
img = pred.data.cpu().numpy() | |
if centered: | |
img += 1 | |
img /= 2 | |
img = (img-img.min())/(img.max()-img.min()) | |
img = (img*255).astype(np.uint8) | |
return img | |
def jacobian_in_batch(y, x): | |
''' | |
Compute the Jacobian matrix in batch form. | |
Return (B, D_y, D_x) | |
''' | |
batch = y.shape[0] | |
single_y_size = np.prod(y.shape[1:]) | |
y = y.view(batch, -1) | |
vector = torch.ones(batch).to(y) | |
# Compute Jacobian row by row. | |
# dy_i / dx -> dy / dx | |
# (B, D) -> (B, 1, D) -> (B, D, D) | |
jac = [torch.autograd.grad(y[:, i], x, | |
grad_outputs=vector, | |
retain_graph=True, | |
create_graph=True)[0].view(batch, -1) | |
for i in range(single_y_size)] | |
jac = torch.stack(jac, dim=1) | |
return jac | |
batch_size = 1024 | |
x, y, gradient, height, width = make_dataset() | |
img_dataset = show_pred(y.clone(), height, width) | |
N = len(x) | |
cuda = 1 | |
net = Siren(cout=3,hiddens=[256,128,64,64,32]) | |
# criterion = torch.nn.MSELoss() | |
criterion = torch.nn.SmoothL1Loss() | |
if cuda: | |
x = x.cuda() | |
y = y.cuda() | |
gradient = gradient.cuda() | |
#gradient = (gradient-gradient.mean())/(gradient.std()+1e-5) | |
#gradient = (gradient-gradient.min())/(gradient.max()-gradient.min()) | |
gradient *= 7 | |
net.cuda() | |
criterion.cuda() | |
net.train() | |
opt = optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-4) | |
# opt = Ranger(net.parameters(), lr=0.001) | |
idx = np.arange(0, len(y)) | |
random.shuffle(idx) | |
probas = [1./N] * N | |
probas = np.array(probas) | |
for epoch in range(100): | |
for i in range(0, N, batch_size): | |
jdx = np.random.choice(np.arange(0, len(y)), size=batch_size, p=probas) | |
#low = i | |
#high = (i+batch_size)%N | |
#jdx = idx[low:high] | |
#jdx = torch.from_numpy(jdx) | |
bx = x[jdx] | |
by = y[jdx] | |
bg = gradient[jdx] #ground truth gradient! | |
opt.zero_grad() | |
bx = Variable(bx, requires_grad=True) | |
out = net(bx) | |
# bx = Variable(x[jdx], requires_grad=True) | |
# numerical gradient (not analytical, really easy to bp through, but not exact) | |
# sizes = [width, height] | |
# grad = [None,None] | |
# for r in [-1,1]: | |
# for dim in [0,1]: | |
# bx2 = bx.clone() | |
# bx2[:,dim] += r/100 | |
# o = net(bx2) | |
# if grad[dim] is None: | |
# grad[dim] = torch.zeros_like(o) | |
# grad[dim] += r * o | |
# grad_errors = ((bg[:,0] - grad[0])**2 + (bg[:,1] - grad[1])**2).mean(dim=1) | |
# jacob = jacobian_in_batch(out, bx) | |
# grad_errors = (jacob - bg)**2 | |
# grad_errors = grad_errors.mean(dim=[1,2]) | |
errors = (out-by)**2 | |
errors = errors.mean(dim=1) | |
loss = errors.mean() | |
probas[jdx] = (errors.data.cpu().numpy())/ len(errors) | |
probas /= probas.sum() | |
loss.backward() | |
opt.step() | |
if i%10 == 0: | |
print('loss1: ', loss.item()) | |
#showcase | |
pred = net(x) | |
img = show_pred(pred, height, width) | |
cv2.imshow('ground_truth', img_dataset) | |
cv2.imshow('prediction', img) | |
cv2.waitKey(5) | |
net.eval() | |
with torch.no_grad(): | |
height2, width2 = height*2, width*2 | |
xv, yv = torch.meshgrid([torch.linspace(0,1,height2), torch.linspace(0,1,width2)]) | |
xv = xv.contiguous() | |
yv = yv.contiguous() | |
x2 = torch.cat([xv.view(-1)[:,None], yv.view(-1)[:,None]], dim=1) | |
x2 = x2.cuda() | |
pred = net(x2) | |
img2 = show_pred(pred, height2, width2) | |
cv2.imshow('ground_truth', img_dataset) | |
cv2.imshow('ground_truth_x2_bicubic', cv2.pyrUp(img_dataset)) | |
cv2.imshow('prediction_x2', img2) | |
cv2.imshow('prediction_x1', img) | |
cv2.waitKey(0) |
Author
etienne87
commented
Jun 24, 2020
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment