Skip to content

Instantly share code, notes, and snippets.

@mrtpk
Created September 25, 2019 12:19
Show Gist options
  • Save mrtpk/33293e834be648a9a1d3fbffaedef54a to your computer and use it in GitHub Desktop.
Save mrtpk/33293e834be648a9a1d3fbffaedef54a to your computer and use it in GitHub Desktop.
Network inspired from VGG and UNET
# Network inspired from VGG and UNET
# ref: https://towardsdatascience.com/understanding-semantic-segmentation-with-unet-6be4f42d4b47
# ref: https://tuatini.me/practical-image-segmentation-with-unet/
# ref: https://github.com/zhixuhao/unet/blob/master/model.py
# import dependencies
from keras.models import Model
import keras.layers as layer
def basic_block(input_tensor, bloc_name, n_convs, n_filters):
'''
Returns basic block of convolution for the network
'''
x = input_tensor
name = "{}_conv{}"
for i in range(0, n_convs):
x = layer.Conv2D(filters=n_filters, kernel_size=(3, 3),
activation='relu', padding='same', name=name.format(bloc_name, i+1))(x)
# TODO: add batchnorm layer
return x
def get_unet(input_shape=(224, 224, 3), classes=2):
'''
Returns a model inspired from VGG with UNET architecture
'''
input_layer = layer.Input(input_shape, name='input_layer')
# Contracting Path
block1 = basic_block(input_layer, "block1", n_convs=2, n_filters=64)
block1_pool = layer.MaxPool2D(pool_size=(2, 2), strides=(2,2), name="block1_pool")(block1)
block2 = basic_block(block1_pool, "block2", n_convs=2, n_filters=128)
block2_pool = layer.MaxPool2D(pool_size=(2, 2), strides=(2,2), name="block2_pool")(block2)
block3 = basic_block(block2_pool, "block3", n_convs=3, n_filters=256)
block3_pool = layer.MaxPool2D(pool_size=(2, 2), strides=(2,2), name="block3_pool")(block3)
block4 = basic_block(block3_pool, "block4", n_convs=3, n_filters=512)
block4_pool = layer.MaxPool2D(pool_size=(2, 2), strides=(2,2), name="block4_pool")(block4)
block5 = basic_block(block4_pool, "block5", n_convs=3, n_filters=512)
# block5_pool = layer.MaxPool2D(pool_size=(2, 2), strides=(2,2), name="block5_pool")(block5)
# Expansive Path
block6_deconv = layer.Conv2DTranspose(filters=512, kernel_size=(3,3), strides=(2, 2),
padding='same', name="block6_deconv1")(block5)
block6_concat = layer.concatenate([block6_deconv, block4])
block6_basic_block = basic_block(block6_concat, "block6_basic", n_convs=3, n_filters=256)
block7_deconv = layer.Conv2DTranspose(filters=256, kernel_size=(3,3), strides=(2, 2),
padding='same', name="block7_deconv1")(block6_basic_block)
block7_concat = layer.concatenate([block7_deconv, block3])
block7_basic_block = basic_block(block7_concat, "block7_basic", n_convs=3, n_filters=128)
block8_deconv = layer.Conv2DTranspose(filters=128, kernel_size=(3,3), strides=(2, 2),
padding='same', name="block8_deconv1")(block7_basic_block)
block8_concat = layer.concatenate([block8_deconv, block2])
block8_basic_block = basic_block(block8_concat, "block8_basic", n_convs=2, n_filters=64)
block9_deconv = layer.Conv2DTranspose(filters=64, kernel_size=(3,3), strides=(2, 2),
padding='same', name="block9_deconv1")(block8_basic_block)
block9_concat = layer.concatenate([block9_deconv, block1])
block9_basic_block = basic_block(block9_concat, "block9_basic", n_convs=2, n_filters=64)
# segmentation head
block10 = basic_block(block9_basic_block, "block10", n_convs=2, n_filters=3)
final_layer = layer.Conv2D(filters=classes, kernel_size=(1, 1),
activation='sigmoid', padding='same', name="final_layer")(block10)
return Model(inputs=input_layer, outputs=final_layer, name="vgg inspired unet")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment