Created
September 27, 2019 11:30
-
-
Save mrtpk/2d0d09a61690d6c872a71639fa502174 to your computer and use it in GitHub Desktop.
Vanilla implementation of squeezenet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ref: https://arxiv.org/pdf/1602.07360.pdf | |
# ref: https://medium.com/@smallfishbigsea/notes-of-squeezenet-4137d51feef4 | |
# ref: https://github.com/cmasch/squeezenet/blob/master/squeezenet.py | |
# ref: https://towardsdatascience.com/review-squeezenet-image-classification-e7414825581a | |
# import dependencies | |
from keras.models import Model | |
import keras.layers as layer | |
def get_fire_module(input_tensor, name, s1x1, use_bypass=False): | |
''' | |
Creates fire module with @param name. It is the basic building block of squeezenet which has two layers | |
- Squeeze layer and expand layer. Squeeze layer is implemented using 1x1 filter. Number of 1x1 filters used | |
is denoted by @param: s1x1. Expand layer has two different filter sizes. To implement this, | |
output of 2 convolution layers with different kernels- 1x1, 3x3 are concatnated. | |
The layer contains @param: e1x1 1x1 filters and @param: e3x3 3x3 filters. | |
In vanilla squeezenet e1x1 and e3x3 is 4 times the s1x1. | |
Bypass connections(like skip connections in Resnet) is created if @param: use_bypass is true. | |
In this implementation, the simple bypass connections is implemented by adding the input tensor @param: input_tensor | |
and output tensor of specific fire module. The complex bypass connection implemented using | |
1x1 convolutions to accomodate different input and output is not implemented. | |
''' | |
name_placeholder = "{}_{}" | |
squeeze_layer = layer.Conv2D(filters=s1x1, kernel_size=(1, 1), | |
activation='relu', padding='same', name=name_placeholder.format(name, "squeeze"))(input_tensor) | |
e1x1 = int(4 * s1x1) | |
expand_layer_1x1 = layer.Conv2D(filters=e1x1, kernel_size=(1, 1), | |
activation='relu', padding='same', name=name_placeholder.format(name, "expand_1x1"))(squeeze_layer) | |
e3x3 = int(4 * s1x1) | |
expand_layer_3x3 = layer.Conv2D(filters=e3x3, kernel_size=(3, 3), | |
activation='relu', padding='same', name=name_placeholder.format(name, "expand_3x3"))(squeeze_layer) | |
expand_layer = layer.concatenate([expand_layer_1x1, expand_layer_3x3], name=name_placeholder.format(name, "expand")) | |
if use_bypass is True: | |
# TODO: implement bypass using 1x1 conv to accomodate different input and output sizes | |
# Simple bypass is done by adding the output of expand layer with the input | |
expand_layer = layer.Add(name=name_placeholder.format(name, "bypass_add"))([expand_layer, input_tensor]) | |
return expand_layer | |
def get_squeezenet(input_shape = (224, 224, 3), nb_classes = 1000): | |
''' | |
Returns squeezenet. | |
For simple bypass squeezenet variant, make @param: use_byepass True for fire3, fire5, fire7, fire9. | |
''' | |
input_layer = layer.Input(input_shape, name='input_layer') | |
# block1 | |
conv1 = layer.Conv2D(filters=96, kernel_size=(7, 7), strides= 2, | |
activation='relu', padding='same', name="conv1")(input_layer) | |
maxpool1 = layer.MaxPool2D(pool_size=(3, 3), strides=(2,2), name="maxpool1")(conv1) | |
# block2 | |
fire2 = get_fire_module(maxpool1, name="fire2", s1x1=16, use_bypass=False) | |
fire3 = get_fire_module(fire2, name="fire3", s1x1=16, use_bypass=False) | |
fire4 = get_fire_module(fire3, name="fire4", s1x1=32, use_bypass=False) | |
maxpool4 = layer.MaxPool2D(pool_size=(3, 3), strides=(2,2), name="maxpool4")(fire4) | |
# block3 | |
fire5 = get_fire_module(maxpool4, name="fire5", s1x1=32, use_bypass=False) | |
fire6 = get_fire_module(fire5, name="fire6", s1x1=48, use_bypass=False) | |
fire7 = get_fire_module(fire6, name="fire7", s1x1=48, use_bypass=False) | |
fire8 = get_fire_module(fire7, name="fire8", s1x1=64, use_bypass=False) | |
maxpool8 = layer.MaxPool2D(pool_size=(3, 3), strides=(2,2), name="maxpool8")(fire8) | |
# block4 | |
fire9 = get_fire_module(maxpool8, name="fire9", s1x1=64, use_bypass=False) | |
fire9_dropout = layer.Dropout(rate=0.5, name="fire9_dropout")(fire9) | |
# classfication head | |
conv10 = layer.Conv2D(filters=nb_classes, kernel_size=(1, 1), strides= 1, | |
activation='relu', padding='same', name="conv10")(fire9_dropout) | |
avgpool10 = layer.GlobalAveragePooling2D(name='avgpool10')(conv10) | |
avgpool10_activation = layer.Activation(name="avgpool10_activation", activation='softmax')(avgpool10) | |
return Model(inputs=input_layer, outputs=avgpool10_activation, name="squeezenet") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment