Last active
September 16, 2023 02:15
-
-
Save tljstewart/29803625455a4a6df3ae760fe655cef7 to your computer and use it in GitHub Desktop.
Testing the concept of zero init cnn layers from control net paper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
# Generate some random data | |
batch_size = 32 | |
channels = 3 | |
height = 32 | |
width = 32 | |
x_train = torch.randn(batch_size, channels, height, width) | |
y_train = torch.randn(batch_size, 10) | |
# Define the original model | |
class OriginalCNN(nn.Module): | |
def __init__(self): | |
super(OriginalCNN, self).__init__() | |
self.conv1 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1) | |
self.fc1 = nn.Linear(16*32*32, 10) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = x.view(x.size(0), -1) | |
x = self.fc1(x) | |
return x | |
def insert_optimizer_params(optimizer, new_layer): | |
for param_group in optimizer.param_groups: | |
param_group['params'].extend(new_layer.parameters()) | |
# Create and train the original model | |
model = OriginalCNN() | |
optimizer = optim.SGD(model.parameters(), lr=0.01) | |
# print(optimizer) | |
criterion = nn.MSELoss() | |
for epoch in range(100): | |
optimizer.zero_grad() | |
output = model(x_train) | |
loss = criterion(output, y_train) | |
loss.backward() | |
optimizer.step() | |
print("Weights of the original convolutional layer:") | |
print(model.conv1.weight.data[0]) | |
# Add a new convolutional layer with zero-initialized weights | |
model.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1) | |
model.conv2.weight.data.fill_(0.0) | |
model.conv2.bias.data.fill_(0.0) | |
# print("Weights of the new convolutional layer BEFORE train:\n") | |
# print(model.conv2.weight.data[0]) | |
# Update forward method to use the new layer | |
def new_forward(self, x): | |
x = self.conv1(x) | |
x = self.conv2(x) | |
x = x.view(x.size(0), -1) | |
x = self.fc1(x) | |
return x | |
# replace forward method and add new layer parameters to optimizer else weights will remain 0 | |
setattr(model, "forward", new_forward.__get__(model)) | |
REUSEOPTIMIZER = True | |
if REUSEOPTIMIZER: insert_optimizer_params(optimizer, model.conv2) | |
# print(optimizer.state_dict()) | |
else: optimizer = optim.SGD(model.parameters(), lr=0.01) | |
# print(optimizer.state_dict()) | |
# Train the new model | |
for epoch in range(100): | |
optimizer.zero_grad() | |
output = model(x_train) | |
loss = criterion(output, y_train) | |
loss.backward() | |
optimizer.step() | |
# Print the weights of the new convolutional layer | |
print("Weights of the original convolutional layer AFTER train:\n") | |
print(model.conv1.weight.data[0]) | |
print("Weights of the new convolutional layer AFTER train:\n") | |
print(model.conv2.weight.data[0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment