Skip to content

Instantly share code, notes, and snippets.

@bhpfelix
Last active October 27, 2020 04:10
Show Gist options
  • Save bhpfelix/8001f2e2c4770655e23ad0c1900f1f15 to your computer and use it in GitHub Desktop.
Save bhpfelix/8001f2e2c4770655e23ad0c1900f1f15 to your computer and use it in GitHub Desktop.
Code snippet for porting TensorFlow trained model to PyTorch
import numpy as np
from PIL import Image
np.random.seed(2)
import torchvision
import torch
# torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import tensorflow as tf
from common_layers import Stage
import matplotlib.pyplot as plt
slim = tf.contrib.slim
def read_images_from_disk(filename):
img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32)
img_contents = tf.read_file(filename)
img = tf.image.decode_image(img_contents, channels=3)
img.set_shape((None, None, 3))
# bgr
img_r, img_g, img_b = tf.split(axis=2, num_or_size_splits=3, value=img)
img = tf.cast(tf.concat(axis=2, values=[img_b, img_g, img_r]), dtype=tf.float32)
# Subtract mean.
img -= img_mean
img = tf.expand_dims(img, 0)
return img
def pt_read_im(filename):
image = Image.open(filename)
image = np.array(image)
# image = image[:, :, ::-1]
print(image.shape)
to_tensor = transforms.ToTensor()
# normalize = transforms.Normalize((104.00698793, 116.66876762, 122.67891434), (1., 1., 1.))
normalize = transforms.Normalize((122.67891434, 116.66876762, 104.00698793), (1., 1., 1.))
return normalize(to_tensor(image - 255.).float() + 255.).unsqueeze(0).float()
def vgg_16_deeplab_st(inputs,
num_classes=21,
is_training=True,
dropout_keep_prob=0.5,
scope='vgg_16'):
"""VGG-16 Deeplab lfov model for single task.
Args:
inputs: a tensor of size [batch_size, height, width, channels].
num_classes: number of predicted classes.
is_training: whether or not the model is being trained.
dropout_keep_prob: the probability that activations are kept in the dropout
layers during training.
scope: Optional scope for the variables.
Returns:
the last op containing the log predictions and end_points dict.
"""
with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
end_points_collection = sc.name + '_end_points'
# Collect outputs for conv2d, fully_connected and max_pool2d.
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
outputs_collections=end_points_collection):
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [3, 3], stride=2, padding='SAME', scope='pool1')
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [3, 3], stride=2, padding='SAME', scope='pool2')
net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [3, 3], stride=2, padding='SAME', scope='pool3')
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
net = slim.max_pool2d(net, [3, 3], stride=1, padding='SAME', scope='pool4')
# net = slim.repeat(net, 3, conv2d_same, 512, [3, 3], stride=1, rate=2, scope='conv5')
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], rate=2, scope='conv5')
net = slim.max_pool2d(net, [3, 3], stride=1, padding='SAME', scope='pool5_max')
net = slim.avg_pool2d(net, [3, 3], stride=1, padding='SAME', scope='pool5_avg')
# Use conv2d instead of fully_connected layers.
rate = 12
net = slim.conv2d(net, 1024, [3, 3], rate=rate, padding='SAME', scope='fc6')
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
scope='dropout6')
net = slim.conv2d(net, 1024, [1, 1], padding='SAME', scope='fc7')
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
scope='dropout7')
net = slim.conv2d(net, num_classes, [1, 1],
activation_fn=None,
normalizer_fn=None,
scope='fc8_voc12')
# Convert end_points_collection into a end_point dict.
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
return net, end_points
class DeepLabLargeFOV(nn.Module):
def __init__(self, in_dim, out_dim, weights='ImageNet', *args, **kwargs):
super(DeepLabLargeFOV, self).__init__(*args, **kwargs)
self.stages = []
layers = []
stage = [
nn.Conv2d(in_dim, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.ConstantPad2d((0, 1, 0, 1), 0), # TensorFlow 'SAME' behavior
nn.MaxPool2d(3, stride=2)
]
layers += stage
self.stages.append(Stage(64, stage))
stage = [
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.ConstantPad2d((0, 1, 0, 1), 0), # TensorFlow 'SAME' behavior
nn.MaxPool2d(3, stride=2)
]
layers += stage
self.stages.append(Stage(128, stage))
stage = [
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.ConstantPad2d((0, 1, 0, 1), 0), # TensorFlow 'SAME' behavior
nn.MaxPool2d(3, stride=2)
]
layers += stage
self.stages.append(Stage(256, stage))
stage = [
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, stride=1, padding=1)
]
layers += stage
self.stages.append(Stage(512, stage))
stage = [
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, stride=1, padding=1)
]
layers += stage
self.stages.append(Stage(512, stage))
self.stages = nn.ModuleList(self.stages)
self.features = nn.Sequential(*layers)
head = [
# must use count_include_pad=False to make sure result is same as TF
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=12, dilation=12),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Conv2d(1024, out_dim, kernel_size=1)
]
self.head = nn.Sequential(*head)
self.weights = weights
self.init_weights()
def forward(self, x):
N, C, H, W = x.size()
for stage in self.stages:
x = stage(x)
x = self.head(x)
# x = F.interpolate(x, (H, W), mode='bilinear', align_corners=True)
return x
def _forward(self, x):
x = self.stages[0](x)
x = self.stages[1](x)
x = self.stages[2](x)
return x
def init_weights(self):
for layer in self.head.children():
if isinstance(layer, nn.Conv2d):
nn.init.kaiming_normal_(layer.weight, a=1)
nn.init.constant_(layer.bias, 0)
if self.weights == 'ImageNet':
vgg = torchvision.models.vgg16(pretrained=True)
state_vgg = vgg.features.state_dict()
self.features.load_state_dict(state_vgg)
elif self.weights == 'DeepLab':
pretrained_dict = torch.load('weights/vgg_deeplab_lfov/model_final.pkl')
model_dict = self.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k.replace('classifier', 'head'): v for k, v in pretrained_dict.items()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'head.7' not in k}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
self.load_state_dict(model_dict)
elif self.weights == 'TFDeepLab':
# TODO: Check with a sample input, TF feedforward vs PyTorch feedforward
checkpoint_path = tf.train.latest_checkpoint('weights/nyu_v2_tf/slim_finetune_seg')
# checkpoint_path = 'weights/vgg_deeplab_lfov_tf/model.ckpt-init-slim'
tf_input = read_images_from_disk('0002.png')
# tf_input = tf.convert_to_tensor(input_image, dtype=tf.float32)
net, end_points = vgg_16_deeplab_st(tf_input, num_classes=40, is_training=False, dropout_keep_prob=1.0)
# Which variables to load.
restore_var = tf.global_variables()
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
sess.run(tf.local_variables_initializer())
tf.train.Saver(var_list=restore_var).restore(sess, checkpoint_path)
print("Restored model parameters from {}".format(checkpoint_path))
tf_vars = tf.trainable_variables()
pt_vars = list(self.named_parameters())
for tf_var, (pt_var_k, pt_var_v) in zip(tf_vars, pt_vars):
if 'weight' in tf_var.name:
weight = tf_var.eval(session=sess)
weight = weight.transpose((3, 2, 0, 1))
if 'conv1_1' in tf_var.name:
# Flip weight of first conv layer because TF model is trained on BGR data, but we want RGB in PyTorch
print("flipping weights")
weight = np.flip(weight, axis=1).copy()
print(weight.shape)
weight = torch.from_numpy(weight).float()
pt_var_v.data = weight
else:
assert 'bias' in tf_var.name
bias = tf_var.eval(session=sess)
bias = torch.from_numpy(bias).float()
pt_var_v.data = bias
# print(tf_var.name, pt_var_k, pt_var_v.data.size())
# Check if weights are correct
self.eval()
# pt_input = tf_input.eval(session=sess)
# pt_input = torch.tensor(pt_input.transpose((0, 3, 1, 2))).float()
pt_input = pt_read_im('0002.png')
# print(torch.abs(pt_input2 - pt_input).mean())
tf_result = end_points['vgg_16/fc8_voc12'].eval(session=sess)
# tf_result = end_points['vgg_16/pool3'].eval(session=sess)
pt_result = self.forward(pt_input).detach().numpy().transpose((0, 2, 3, 1))
# pt_result = self._forward(pt_input).detach().numpy().transpose((0, 2, 3, 1))
# # Compare final results
# print(tf_result.squeeze().shape)
# tf_pred = tf_result.squeeze().argmax(axis=2)
# pt_pred = pt_result.squeeze().argmax(axis=2)
# diff = np.dstack([tf_pred, pt_pred, np.zeros_like(tf_pred)]).astype('float')
# diff /= diff.max()
# mask = (tf_pred - pt_pred) != 0
# tf_pred[mask] = -10
# pt_pred[mask] = -10
# plt.matshow(tf_pred)
# plt.colorbar()
# plt.title('tf_pred')
# plt.matshow(pt_pred)
# plt.colorbar()
# plt.title('pt_pred')
# plt.matshow(mask.astype('int'))
# plt.colorbar()
# # plt.imshow(diff)
# plt.title('diff')
print(tf_result.shape)
print(tf_result.max(), pt_result.max())
print(tf_result.mean(), pt_result.mean())
diff = np.abs(tf_result - pt_result).squeeze()
# plt.matshow(tf_result.squeeze().mean(axis=2))
# plt.colorbar()
# plt.title('tf')
# plt.matshow(pt_result.squeeze().mean(axis=2))
# plt.colorbar()
# plt.title('pt')
# plt.matshow(diff.mean(axis=2))
# plt.colorbar()
# plt.title('diff')
# plt.show()
print(diff.max())
elif self.weights == '':
pass
else:
raise NotImplementedError
if __name__ == "__main__":
net = DeepLabLargeFOV(3, 40, weights='TFDeepLab')
# print(net)
in_ten = torch.randn(1, 3, 321, 321)
torch.save(net.state_dict(), 'weights/nyu_v2/tf_finetune_seg.pth')
# torch.save(net.state_dict(), 'weights/vgg_deeplab_lfov/tf_deeplab.pth')
# print(out.size())
# print(net.stages[1])
@bhpfelix
Copy link
Author

Messy example for converting the TensorFlow trained models from NDDR-CNN to corresponding PyTorch models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment