Skip to content

Instantly share code, notes, and snippets.

@axhiao
Last active July 28, 2020 20:24
Show Gist options
  • Save axhiao/f4d9a7e846fbf204e1a94f0bd3809624 to your computer and use it in GitHub Desktop.
Save axhiao/f4d9a7e846fbf204e1a94f0bd3809624 to your computer and use it in GitHub Desktop.
fusion model
mods = {
"resnet50": (torchvision.models.resnet50, 2048),
"resnet152": (torchvision.models.resnet152, 2048),
"densenet121": (torchvision.models.densenet121, 1024*7*7),
"densenet201": (torchvision.models.densenet201, 1920*7*7),
"vgg16": (torchvision.models.vgg16, 512*7*7),
}
MOD = "densenet201"
class Fusion2Model(nn.Module):
def __init__(self, imgs_model):
super().__init__()
self.imgs_model = nn.Sequential(*list(imgs_model.children())[:-1])
if MOD in("densenet121", "densenet201", "vgg16"):
self.imgs_model.add_module(module = nn.Flatten(), name = "flatten")
input_no = mods[MOD][1]
self.classifier = nn.Sequential(
nn.Linear(input_no, 400),
nn.ReLU(True),
nn.Dropout(p=0.5), # default p=0.5
nn.Linear(400, 100),
nn.ReLU(True),
nn.Dropout(p=0.2),
nn.Linear(100, 2),
)
def forward(self, images):
b, n, _, _, _ = images.shape
arr = []
for i in range(n):
t = self.imgs_model(images[:,i,:,:,:]) # -> [4, 2048, 1, 1]
arr.append(torch.squeeze(t)) # -> [4, 2048]
arr = torch.cat(arr, dim = 0)
brr = []
for i in range(b):
tr = torch.mean(arr[i::b,:], dim = 0, keepdim=True)
brr.append(tr)
imgs_feat = torch.cat(brr, dim = 0)
return self.classifier(imgs_feat)
imgs_model = mods[MOD][0](pretrained=True).to(device)
fusion2 = Fusion2Model(imgs_model).to(device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment