Skip to content

Instantly share code, notes, and snippets.

@ksivaman
Created August 9, 2019 23:36
Show Gist options
  • Save ksivaman/237ca738a1f74a78e2b6b6b130b27ae5 to your computer and use it in GitHub Desktop.
Save ksivaman/237ca738a1f74a78e2b6b6b130b27ae5 to your computer and use it in GitHub Desktop.
extract features from the vgg19 network
def get_features(image, model, layers=None):
""" Run an image forward through a model and get the features for
a set of layers. Default layers are for VGGNet matching Gatys et al (2016)
"""
## Need the layers for the content and style representations of an image
if layers is None:
layers = {'0': 'conv1_1',
'5': 'conv2_1',
'10': 'conv3_1',
'19': 'conv4_1',
'21': 'conv4_2', ## content representation
'28': 'conv5_1'}
features = {}
x = image
# model._modules is a dictionary holding each module in the model
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment