Skip to content

Instantly share code, notes, and snippets.

@ahendriksen
Last active October 23, 2020 18:54
Show Gist options
  • Save ahendriksen/a3ffa953ae15c5db3eea154bcc6ccbb5 to your computer and use it in GitHub Desktop.
Save ahendriksen/a3ffa953ae15c5db3eea154bcc6ccbb5 to your computer and use it in GitHub Desktop.
An implementation of the MS-D network without custom convolutions. Allows comparing custom convolutions with PyTorch provided convolutions.
#!/usr/bin/env python3
"""
Requirements: msd_pytorch
Installation instruction can be found at https://github.com/ahendriksen/msd_pytorch
"""
import torch
import numpy as np
from msd_pytorch.msd_block import MSDBlockImpl2d
class MSDBlock2d(torch.nn.Module):
def __init__(self, in_channels, dilations, width=1, vanilla=False):
"""Multi-scale dense block
Parameters
----------
in_channels : int
Number of input channels
dilations : tuple of int
Dilation for each convolution-block
width : int
Number of channels per convolution.
Notes
-----
The number of output channels is in_channels + depth * width
"""
super().__init__()
self.kernel_size = (3, 3)
self.width = width
self.dilations = dilations
self.vanilla = vanilla
depth = len(self.dilations)
self.bias = torch.nn.Parameter(torch.Tensor(depth * width))
self.weights = []
for i in range(depth):
n_in = in_channels + width * i
weight = torch.nn.Parameter(torch.Tensor(width, n_in, *self.kernel_size))
self.register_parameter("weight{}".format(i), weight)
self.weights.append(weight)
self.reset_parameters()
def reset_parameters(self):
for weight in self.weights:
torch.nn.init.kaiming_uniform_(weight, a=np.sqrt(5))
if self.bias is not None:
# TODO: improve
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weights[0])
bound = 1 / np.sqrt(fan_in)
torch.nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input):
# We need to obtain weights in this way, because self.weights
# may become obsolete when used in multi-gpu settings when the
# weights are automatically transferred (by, e.g.,
# torch.nn.DataParallel). In that case, self.weights may
# continue to point to the weight parameters on the original
# device, even when the weight parameters have been
# transferred to a different gpu.
#
# To be compatible with torch.nn.utils.prune, we obtain the
# weights using attributes. Previously, we used
# `self.parameters()`, but this returns the original
# (unmasked) parameters.
bias = self.bias
weights = (getattr(self, "weight{}".format(i)) for i in range(len(self.weights)))
if self.vanilla:
return MSDVanillaBlockImpl2d.apply(input, self.dilations, bias, *weights)
else:
return MSDBlockImpl2d.apply(input, self.dilations, bias, *weights)
class MSDVanillaBlockImpl2d(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dilations, bias, *weights):
depth = len(dilations)
assert depth == len(weights), "number of weights does not match depth"
num_out_channels = sum(w.shape[0] for w in weights)
assert (
len(bias) == num_out_channels
), "number of biases does not match number of output channels from weights"
ctx.dilations = dilations
ctx.depth = depth
result = input.new_empty(
input.shape[0], input.shape[1] + num_out_channels, *input.shape[2:]
)
# Copy input into result buffer
result[:, : input.shape[1]] = input
result_start = input.shape[1]
bias_start = 0
for i in range(depth):
# Extract variables
sub_input = result[:, :result_start]
sub_weight = weights[i]
blocksize = sub_weight.shape[0]
sub_bias = bias[bias_start : bias_start + blocksize]
sub_result = result[:, result_start : result_start + blocksize]
dilation = ctx.dilations[i]
# Compute convolution with relu
sub_result = torch.relu(torch.conv2d(sub_input, sub_weight, stride=1, dilation=dilation, padding=dilation))
sub_result += sub_bias.view(1, len(sub_bias), 1, 1)
# Update steps etc
result_start += blocksize
bias_start += blocksize
ctx.save_for_backward(bias, result, *weights)
return result
@staticmethod
def backward(ctx, grad_output):
bias, result, *weights = ctx.saved_tensors
depth = ctx.depth
grad_bias = torch.zeros_like(bias)
# XXX: Could we just overwrite grad_output instead of clone?
gradients = grad_output.clone()
grad_weights = []
result_end = result.shape[1]
bias_end = len(bias)
for i in range(depth):
idx = depth - 1 - i
# Get subsets
sub_weight = weights[idx]
blocksize = sub_weight.shape[0]
result_start = result_end - blocksize
bias_start = bias_end - blocksize
sub_grad_output = gradients[:, result_start:result_end]
sub_grad_input = gradients[:, :result_start]
sub_result = result[:, result_start:result_end]
sub_input = result[:, :result_start]
dilation = ctx.dilations[idx]
# Gradient w.r.t. input: conv_relu_backward_x computes the
# gradient wrt sub_input and adds the gradient to
# sub_grad_input.
sub_grad_output *= (sub_result > 0.0).to(dtype=torch.float)
sub_grad_input += torch.conv_transpose2d(sub_grad_output, sub_weight, stride=1, dilation=dilation, padding=dilation)
# cc.conv_relu_backward_x(
# sub_result, sub_grad_output, sub_weight, sub_grad_input, dilation
# )
# Gradient w.r.t weights
IDX_WEIGHT_START = 3 # The first weight has index 3 in the forward pass.
if ctx.needs_input_grad[i + IDX_WEIGHT_START]:
c_out, c_in = blocksize, result_start
sub_grad_weight = torch.nn.grad.conv2d_weight(sub_input, (c_out, c_in, 3, 3), sub_grad_output, stride=1, dilation=dilation, padding=dilation)
grad_weights.insert(0, sub_grad_weight)
else:
grad_weights.insert(0, None)
# Gradient of Bias
if ctx.needs_input_grad[2]:
sub_grad_bias = grad_bias[bias_start:bias_end]
sub_grad_bias += sub_grad_output.sum(dim=(0, 2, 3))
# cc.conv_relu_backward_bias(
# sub_result, sub_grad_output, sub_grad_bias
# )
# Update positions etc
result_end -= blocksize
bias_end -= blocksize
grad_input = gradients[:, : weights[0].shape[1]]
return (grad_input, None, grad_bias, *grad_weights)
from timeit import default_timer as timer
def time(N, n, vanilla=False, backward=True):
torch.cuda.synchronize()
start = timer()
# Constant time (to be ignored)
dilations = [1,2,3,4,5,6,7,8,9] * 10 # 100-layer deep network
net = MSDBlock2d(1, dilations, width=1, vanilla=vanilla).cuda()
x = torch.randn(1, 1, N, N).cuda()
tgt = torch.randn(1, len(dilations) + 1, N, N).cuda()
# Variable time (variable of interest)
for _ in range(n):
y = net(x)
loss = torch.nn.functional.mse_loss(y, tgt)
if backward:
loss.backward()
torch.cuda.synchronize()
return timer() - start
def print_args(args):
assert isinstance(args, Namespace), "expected namespace"
args_dict = dict(args._get_kwargs())
print("Parameters:")
for k, v in args_dict.items():
print(f" {k:<20}: {v}")
print()
if __name__ == '__main__':
from argparse import ArgumentParser, Namespace
parser = ArgumentParser()
parser.add_argument('--vanilla', dest='vanilla', action='store_true')
parser.add_argument('--custom', dest='vanilla', action='store_false')
parser.set_defaults(vanilla=False)
parser.add_argument('--forward', dest='backward', action='store_false')
parser.add_argument('--backward', dest='backward', action='store_true')
parser.set_defaults(backward=False)
parser.add_argument('--N', type=int, default=100)
args = parser.parse_args()
print_args(args)
print("time: ",
time(args.N, 1, vanilla=args.vanilla, backward=args.backward)
)
print("peak memory usage: ", torch.cuda.max_memory_allocated() / 1e9, "GB")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment