Created
August 9, 2019 23:36
-
-
Save ksivaman/237ca738a1f74a78e2b6b6b130b27ae5 to your computer and use it in GitHub Desktop.
extract features from the vgg19 network
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_features(image, model, layers=None): | |
""" Run an image forward through a model and get the features for | |
a set of layers. Default layers are for VGGNet matching Gatys et al (2016) | |
""" | |
## Need the layers for the content and style representations of an image | |
if layers is None: | |
layers = {'0': 'conv1_1', | |
'5': 'conv2_1', | |
'10': 'conv3_1', | |
'19': 'conv4_1', | |
'21': 'conv4_2', ## content representation | |
'28': 'conv5_1'} | |
features = {} | |
x = image | |
# model._modules is a dictionary holding each module in the model | |
for name, layer in model._modules.items(): | |
x = layer(x) | |
if name in layers: | |
features[layers[name]] = x | |
return features |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment