Last active
October 23, 2020 18:54
-
-
Save ahendriksen/a3ffa953ae15c5db3eea154bcc6ccbb5 to your computer and use it in GitHub Desktop.
An implementation of the MS-D network without custom convolutions. Allows comparing custom convolutions with PyTorch provided convolutions.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Requirements: msd_pytorch | |
Installation instruction can be found at https://github.com/ahendriksen/msd_pytorch | |
""" | |
import torch | |
import numpy as np | |
from msd_pytorch.msd_block import MSDBlockImpl2d | |
class MSDBlock2d(torch.nn.Module): | |
def __init__(self, in_channels, dilations, width=1, vanilla=False): | |
"""Multi-scale dense block | |
Parameters | |
---------- | |
in_channels : int | |
Number of input channels | |
dilations : tuple of int | |
Dilation for each convolution-block | |
width : int | |
Number of channels per convolution. | |
Notes | |
----- | |
The number of output channels is in_channels + depth * width | |
""" | |
super().__init__() | |
self.kernel_size = (3, 3) | |
self.width = width | |
self.dilations = dilations | |
self.vanilla = vanilla | |
depth = len(self.dilations) | |
self.bias = torch.nn.Parameter(torch.Tensor(depth * width)) | |
self.weights = [] | |
for i in range(depth): | |
n_in = in_channels + width * i | |
weight = torch.nn.Parameter(torch.Tensor(width, n_in, *self.kernel_size)) | |
self.register_parameter("weight{}".format(i), weight) | |
self.weights.append(weight) | |
self.reset_parameters() | |
def reset_parameters(self): | |
for weight in self.weights: | |
torch.nn.init.kaiming_uniform_(weight, a=np.sqrt(5)) | |
if self.bias is not None: | |
# TODO: improve | |
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weights[0]) | |
bound = 1 / np.sqrt(fan_in) | |
torch.nn.init.uniform_(self.bias, -bound, bound) | |
def forward(self, input): | |
# We need to obtain weights in this way, because self.weights | |
# may become obsolete when used in multi-gpu settings when the | |
# weights are automatically transferred (by, e.g., | |
# torch.nn.DataParallel). In that case, self.weights may | |
# continue to point to the weight parameters on the original | |
# device, even when the weight parameters have been | |
# transferred to a different gpu. | |
# | |
# To be compatible with torch.nn.utils.prune, we obtain the | |
# weights using attributes. Previously, we used | |
# `self.parameters()`, but this returns the original | |
# (unmasked) parameters. | |
bias = self.bias | |
weights = (getattr(self, "weight{}".format(i)) for i in range(len(self.weights))) | |
if self.vanilla: | |
return MSDVanillaBlockImpl2d.apply(input, self.dilations, bias, *weights) | |
else: | |
return MSDBlockImpl2d.apply(input, self.dilations, bias, *weights) | |
class MSDVanillaBlockImpl2d(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, input, dilations, bias, *weights): | |
depth = len(dilations) | |
assert depth == len(weights), "number of weights does not match depth" | |
num_out_channels = sum(w.shape[0] for w in weights) | |
assert ( | |
len(bias) == num_out_channels | |
), "number of biases does not match number of output channels from weights" | |
ctx.dilations = dilations | |
ctx.depth = depth | |
result = input.new_empty( | |
input.shape[0], input.shape[1] + num_out_channels, *input.shape[2:] | |
) | |
# Copy input into result buffer | |
result[:, : input.shape[1]] = input | |
result_start = input.shape[1] | |
bias_start = 0 | |
for i in range(depth): | |
# Extract variables | |
sub_input = result[:, :result_start] | |
sub_weight = weights[i] | |
blocksize = sub_weight.shape[0] | |
sub_bias = bias[bias_start : bias_start + blocksize] | |
sub_result = result[:, result_start : result_start + blocksize] | |
dilation = ctx.dilations[i] | |
# Compute convolution with relu | |
sub_result = torch.relu(torch.conv2d(sub_input, sub_weight, stride=1, dilation=dilation, padding=dilation)) | |
sub_result += sub_bias.view(1, len(sub_bias), 1, 1) | |
# Update steps etc | |
result_start += blocksize | |
bias_start += blocksize | |
ctx.save_for_backward(bias, result, *weights) | |
return result | |
@staticmethod | |
def backward(ctx, grad_output): | |
bias, result, *weights = ctx.saved_tensors | |
depth = ctx.depth | |
grad_bias = torch.zeros_like(bias) | |
# XXX: Could we just overwrite grad_output instead of clone? | |
gradients = grad_output.clone() | |
grad_weights = [] | |
result_end = result.shape[1] | |
bias_end = len(bias) | |
for i in range(depth): | |
idx = depth - 1 - i | |
# Get subsets | |
sub_weight = weights[idx] | |
blocksize = sub_weight.shape[0] | |
result_start = result_end - blocksize | |
bias_start = bias_end - blocksize | |
sub_grad_output = gradients[:, result_start:result_end] | |
sub_grad_input = gradients[:, :result_start] | |
sub_result = result[:, result_start:result_end] | |
sub_input = result[:, :result_start] | |
dilation = ctx.dilations[idx] | |
# Gradient w.r.t. input: conv_relu_backward_x computes the | |
# gradient wrt sub_input and adds the gradient to | |
# sub_grad_input. | |
sub_grad_output *= (sub_result > 0.0).to(dtype=torch.float) | |
sub_grad_input += torch.conv_transpose2d(sub_grad_output, sub_weight, stride=1, dilation=dilation, padding=dilation) | |
# cc.conv_relu_backward_x( | |
# sub_result, sub_grad_output, sub_weight, sub_grad_input, dilation | |
# ) | |
# Gradient w.r.t weights | |
IDX_WEIGHT_START = 3 # The first weight has index 3 in the forward pass. | |
if ctx.needs_input_grad[i + IDX_WEIGHT_START]: | |
c_out, c_in = blocksize, result_start | |
sub_grad_weight = torch.nn.grad.conv2d_weight(sub_input, (c_out, c_in, 3, 3), sub_grad_output, stride=1, dilation=dilation, padding=dilation) | |
grad_weights.insert(0, sub_grad_weight) | |
else: | |
grad_weights.insert(0, None) | |
# Gradient of Bias | |
if ctx.needs_input_grad[2]: | |
sub_grad_bias = grad_bias[bias_start:bias_end] | |
sub_grad_bias += sub_grad_output.sum(dim=(0, 2, 3)) | |
# cc.conv_relu_backward_bias( | |
# sub_result, sub_grad_output, sub_grad_bias | |
# ) | |
# Update positions etc | |
result_end -= blocksize | |
bias_end -= blocksize | |
grad_input = gradients[:, : weights[0].shape[1]] | |
return (grad_input, None, grad_bias, *grad_weights) | |
from timeit import default_timer as timer | |
def time(N, n, vanilla=False, backward=True): | |
torch.cuda.synchronize() | |
start = timer() | |
# Constant time (to be ignored) | |
dilations = [1,2,3,4,5,6,7,8,9] * 10 # 100-layer deep network | |
net = MSDBlock2d(1, dilations, width=1, vanilla=vanilla).cuda() | |
x = torch.randn(1, 1, N, N).cuda() | |
tgt = torch.randn(1, len(dilations) + 1, N, N).cuda() | |
# Variable time (variable of interest) | |
for _ in range(n): | |
y = net(x) | |
loss = torch.nn.functional.mse_loss(y, tgt) | |
if backward: | |
loss.backward() | |
torch.cuda.synchronize() | |
return timer() - start | |
def print_args(args): | |
assert isinstance(args, Namespace), "expected namespace" | |
args_dict = dict(args._get_kwargs()) | |
print("Parameters:") | |
for k, v in args_dict.items(): | |
print(f" {k:<20}: {v}") | |
print() | |
if __name__ == '__main__': | |
from argparse import ArgumentParser, Namespace | |
parser = ArgumentParser() | |
parser.add_argument('--vanilla', dest='vanilla', action='store_true') | |
parser.add_argument('--custom', dest='vanilla', action='store_false') | |
parser.set_defaults(vanilla=False) | |
parser.add_argument('--forward', dest='backward', action='store_false') | |
parser.add_argument('--backward', dest='backward', action='store_true') | |
parser.set_defaults(backward=False) | |
parser.add_argument('--N', type=int, default=100) | |
args = parser.parse_args() | |
print_args(args) | |
print("time: ", | |
time(args.N, 1, vanilla=args.vanilla, backward=args.backward) | |
) | |
print("peak memory usage: ", torch.cuda.max_memory_allocated() / 1e9, "GB") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment