Created
December 18, 2018 11:40
-
-
Save sidgairo18/dca347edd4588484237a231d7dab9a63 to your computer and use it in GitHub Desktop.
Siamese Network with Triplet Loss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# Author: Siddhartha Gairola (siddhartha dot gairola at iiit dot ac dot in) | |
from __future__ import division | |
import tensorflow as tf | |
tf.set_random_seed(1) | |
from keras.applications.inception_v3 import InceptionV3 | |
from keras.models import Model, load_model | |
from keras.layers import BatchNormalization, Activation, Dense, Dropout, Flatten, Input, Lambda | |
from keras.layers import Conv2D, MaxPooling2D, ZeroPadding2D | |
from keras.regularizers import l2 | |
from keras.layers.merge import dot, multiply | |
from keras.callbacks import TensorBoard, ModelCheckpoint | |
from keras import optimizers | |
import os | |
import pickle | |
import pdb | |
import json | |
import pickle as pkl | |
import numpy as np | |
from tqdm import tqdm | |
import random | |
from my_dataloader_for_triplet import DataGenerator | |
from utils import read_my_image | |
#from my_dataloader import DataGenerator | |
#Alpha : The Triplet Loss Parameter | |
def triplet_loss(x, ALPHA=0.2): | |
anchor, positive, negative = x | |
#Modifying the triplet loss | |
anchor = 2*anchor | |
pos_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, positive)), 1) | |
neg_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, negative)), 1) | |
basic_loss = tf.add(tf.subtract(pos_dist, neg_dist), ALPHA) | |
loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0) | |
loss = tf.divide(tf.maximum(loss, 0.0), 2.0) | |
#Return max(0, loss) | |
return loss | |
class StyleNet(): | |
def __init__(self, input_shape_x, input_shape_y, input_shape_z, n_classes, reg_lambda): | |
self.input_shape_x = input_shape_x | |
self.input_shape_y = input_shape_y | |
self.input_shape_z = input_shape_z | |
self.n_classes = n_classes | |
self.reg_lambda = reg_lambda | |
def create_model(self): | |
anchor_example = Input(shape=(self.input_shape_x, self.input_shape_y, self.input_shape_z), name='input_1') | |
positive_example = Input(shape=(self.input_shape_x, self.input_shape_y, self.input_shape_z), name='input_2') | |
negative_example = Input(shape=(self.input_shape_x, self.input_shape_y, self.input_shape_z), name='input_3') | |
input_image = Input(shape=(self.input_shape_x, self.input_shape_y, self.input_shape_z)) | |
base_inception = InceptionV3(input_tensor = input_image, input_shape=(self.input_shape_x, self.input_shape_y, self.input_shape_z), weights=None, include_top=False, pooling='avg') | |
base_pool5 = base_inception.output | |
##############Adding the Bottleneck layer Here####################################################### | |
bottleneck_layer = Dense(256, kernel_regularizer=l2(self.reg_lambda), name='bottleneck_layer')(base_pool5) | |
bottleneck_norm = BatchNormalization(name='bottleneck_norm')(bottleneck_layer) | |
#bottleneck_relu = Activation('relu', name='bottleneck_relu')(bottleneck_norm) | |
#bottleneck_drop = Dropout(0.5)(bottleneck_relu) | |
fin = Dense(self.n_classes)(bottleneck_norm) | |
fin_norm = BatchNormalization(name='fin_norm')(fin) | |
fin_softmax = Activation('softmax')(fin_norm) | |
###################################################################################################### | |
###########Triplet Model Which learns the embedding layer relu6#################### | |
self.triplet_model = Model(input_image, bottleneck_norm) | |
positive_embedding = self.triplet_model(positive_example) | |
negative_embedding = self.triplet_model(negative_example) | |
anchor_embedding = self.triplet_model(anchor_example) | |
###########Triplet Model Which learns the embedding layer relu6#################### | |
adam_opt = optimizers.Adam(lr=0.001, clipnorm = 1.0, amsgrad=False) | |
#self.style_net_classification_model = Model(inputs = base_inception.input, outputs = fin_softmax) | |
#self.style_net_classification_model.compile(optimizer=adam_opt, loss='categorical_crossentropy', metrics=['accuracy']) | |
#The Triplet Model which optimizes over the triplet loss. | |
loss = Lambda(triplet_loss, output_shape=(1,))([anchor_embedding, positive_embedding, negative_embedding]) | |
self.triplet_model_worker = Model(inputs=[anchor_example, positive_example, negative_example], outputs = loss) | |
self.triplet_model_worker.compile(loss='mean_absolute_error', optimizer=adam_opt) | |
''' | |
adam_opt = optimizers.Adam(lr=0.00001, amsgrad=False) | |
self.classification_model = Model(input_image, fin_softmax) | |
self.triplet_model_worker.compile(loss='mean_absolute_error', optimizer=adam_opt) | |
self.classification_model.compile(optimizer=adam_opt, loss='categorical_crossentropy', metrics=['accuracy']) | |
#print (self.classification_model.summary()) | |
print (self.classification_model.summary()) | |
''' | |
print (self.triplet_model_worker.summary()) | |
def fit_model(self, pathname='./models/'): | |
if not os.path.exists(pathname): | |
os.makedirs(pathname) | |
if not os.path.exists(pathname+'/weights'): | |
os.makedirs(pathname+'/weights') | |
if not os.path.exists(pathname+'/tb'): | |
os.makedirs(pathname+'/tb') | |
filepath=pathname+"weights/{epoch:02d}.hdf5" | |
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=False, mode='auto') | |
tensorboard = TensorBoard(log_dir=pathname+'/tb', write_graph=True, write_images=True) | |
callbacks_list = [checkpoint, tensorboard] | |
#Parameter | |
params = {'dim': (224, 224), 'batch_size':32, 'n_classes':11, 'n_channels':3, 'shuffle':True} | |
#Datasets | |
#partition = pickle.load(open('../../../data/bam_2_partition.pkl', 'rb')) | |
#labels = pickle.load(open('../../../data/bam_2_labels.pkl', 'rb')) | |
partition = pickle.load(open('../../../data/bam_2_partition_triplet.pkl', 'rb')) | |
labels = pickle.load(open('../../../data/bam_2_labels_triplet.pkl', 'rb')) | |
print ("Size of Partition and Labels", len(partition['train']), len(labels.keys())) | |
#partition2 = pickle.load(open('../../data/bam_2_partition_triplet_val.pkl', 'rb')) | |
#labels2 = pickle.load(open('../../data/bam_2_partition_triplet_val_labels.pkl', 'rb')) | |
#Generators | |
#training_generator = DataGenerator(partition['train'], labels, 128, (224, 224), 3, 11, True) | |
#validation_generator = DataGenerator(partition2['validation'], labels2, 4, (224, 224), 3, 11, True) | |
#validation_generator = DataGenerator(partition['validation'], labels, 64, (224, 224), 3, 11, True) | |
training_generator = DataGenerator(partition['train'], labels, **params) | |
#validation_generator = DataGenerator(partition['validation'], labels, **params) | |
#self.classification_model.fit(inputs, output, validation_split=0.2, epochs=50, batch_size=128, callbacks=callbacks_list, verbose=1) | |
#self.triplet_model_worker.fit_generator(generator = training_generator, validation_data = validation_generator, epochs = 30, use_multiprocessing=True, workers = 10, callbacks = callbacks_list, verbose = 1) | |
self.triplet_model_worker.fit_generator(generator = training_generator, epochs = 60, use_multiprocessing=True, workers = 10, callbacks = callbacks_list, verbose = 1) | |
#self.triplet_model_worker.fit_generator( generator = image_generator(partition['train'], labels, 128, (224,224,3)), steps_per_epoch=len(partition['train']) // 128, epochs = 50, use_multiprocessing=False, callbacks = callbacks_list, verbose = 1) | |
#self.style_net_classification_model.fit_generator(generator = training_generator, validation_data = validation_generator, epochs = 120, use_multiprocessing=True, workers = 10, callbacks = callbacks_list, verbose = 1) | |
if __name__ == "__main__": | |
m = StyleNet(224, 224, 3, 11, 0.) | |
m.create_model() | |
m.triplet_model_worker.load_weights('/scratch/models_inception_stage1/yo/weights/20.hdf5', by_name=True) | |
m2 = m.triplet_model_worker.get_layer('model_1') | |
for layer in m2.layers: | |
print ("Yes", layer.name) | |
weights = layer.get_weights() | |
print (weights) | |
exit() | |
#m.create_model() | |
#m = load_model('models_orig_classification/yo/weights/24.hdf5') | |
#print (m.triplet_model_worker.summary()) | |
#m.fit_model('/scratch/models_inception_stage2/yo/') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Why are you multiplying the anchor by 2 in triplet loss?