Last active
April 30, 2018 16:39
-
-
Save Kaapp/abdb54b232eb7f07b87955d9a18df57d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import hashlib | |
import re | |
import numpy as np | |
import tensorflow as tf | |
from PIL import Image | |
from tensorflow.python.util import compat | |
class DataHandler: | |
# 112120 images | |
# 70% training, 10% validation, 20% testing | |
# ~78484 training, ~11212 validation, ~22424 training | |
# 75712 training, 10812 validation, 25596 training <- actual splits. | |
def __init__(self, multi_label=True): | |
self.MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 #~134M | |
self.TOTAL_IMAGES = 112120 | |
self.training_percentage = 70 | |
self.validation_percentage = 10 | |
self.testing_percentage = 20 | |
self.multi_label = multi_label | |
if multi_label: | |
self.GROUND_TRUTHS = ['Cardiomegaly','Emphysema','Effusion','Hernia','Infiltration', | |
'Mass','Nodule','Atelectasis','Pneumothorax','Pleural_Thickening', | |
'Pneumonia','Fibrosis','Edema','Consolidation'] | |
self.image_list = self.create_multilabel_label_dict() | |
else: | |
self.GROUND_TRUTHS = ['Pathology', 'No Pathology'] | |
self.image_list = self.create_singlelabel_label_dict() | |
return None | |
def create_multilabel_label_dict(self): | |
''' | |
1. create mapping filename -> dataset using the txt file so x = { "001.png": "testing", etc } O(n) | |
2. create normal list by iterating the csv line by line but check mapping to tell which data set. O(n) | |
3. for train/val set we need to hash to get approx split. -> O(2n) creation. | |
''' | |
image_list = { | |
'training': [], | |
'validation': [], | |
'testing': [] | |
} | |
file_mapping = {} | |
with open('./train_val_list.txt') as file: | |
train_files = file.read().splitlines() | |
for file_name in train_files: | |
file_mapping[file_name] = 1 | |
with open('./test_list.txt') as file: | |
test_files = file.read().splitlines() | |
for file_name in test_files: | |
file_mapping[file_name] = 0 | |
first_line = True | |
with open('../data/Data_Entry_2017.csv', 'r') as csvfile: | |
reader = csv.reader(csvfile) | |
for row in reader: | |
if first_line: | |
first_line = False | |
continue | |
# row[0] = filename | |
# row[1] = ground truths | |
file_name = row[0] | |
try: | |
if file_mapping[file_name] == 1: | |
# Train/validation set, need to hash to split | |
percentage_hash = self.get_percentage_hash(row[0]) | |
if percentage_hash < 12.5: # 10% of total data is 12.5% of remaining data | |
image_list['validation'].append((file_name, self.new_y_array(row[1]))) | |
else: | |
image_list['training'].append((file_name, self.new_y_array(row[1]))) | |
else: | |
image_list['testing'].append((file_name, self.new_y_array(row[1]))) | |
except KeyError: | |
pass | |
return image_list | |
def create_singlelabel_label_dict(self): | |
return [] | |
def get_percentage_hash(self, file_name): | |
# Hash only the patient number so that multiple images from the same patient | |
# compute the same hash so they will be placed in the same subset. | |
file_name = re.sub("_[0-9]{3}\.png", "", file_name) | |
file_name_hashed = hashlib.sha1(compat.as_bytes(file_name)).hexdigest() | |
percentage_hash = ((int(file_name_hashed, 16) % | |
(self.MAX_NUM_IMAGES_PER_CLASS + 1)) * | |
(100.0 / self.MAX_NUM_IMAGES_PER_CLASS)) | |
return percentage_hash | |
def new_y_array(self, truth_string): | |
array = np.zeros(len(self.GROUND_TRUTHS), dtype=np.float32) | |
if self.multi_label: | |
labels_array = truth_string.split('|') | |
for label in labels_array: | |
try: | |
label_index = self.GROUND_TRUTHS.index(label) | |
array[label_index] = 1 | |
except ValueError: | |
pass #do nothing, it's No Finding which we encode as all zeros | |
return array | |
def image_parse_function(self, filename, label): | |
image_string = tf.read_file('../data/images/multi-label/' + filename) | |
image_decoded = tf.image.decode_png(image_string, channels=1) | |
image_resized = tf.image.resize_images(image_decoded, [256,256]) | |
image_cropped = tf.image.crop_to_bounding_box(image_resized, 16, 16, 224, 224) | |
return image_cropped, label | |
def get_dataset(self, data_type='training', num_examples=0): | |
if num_examples < 0: | |
raise ValueError('Invalid num_examples: %d' % num_examples) | |
size = len(self.image_list[data_type]) | |
features = [] | |
labels = [] | |
if num_examples == 0 or num_examples >= size: | |
for feature, label in self.image_list[data_type]: | |
features.append(feature) | |
labels.append(label) | |
else: | |
for index in range(num_examples): | |
feature, label = self.image_list[data_type][index] | |
features.append(feature) | |
labels.append(label) | |
return features, labels | |
def get_pathology_counts(self, data_type='validation'): | |
image_dict = {} | |
pathology_dict = { | |
'multi-label': [] | |
} | |
with open('./' + data_type + '_images.txt') as file: | |
images = file.read().splitlines() | |
for image in images: | |
image_dict[image] = 1 | |
with open('../data/Data_Entry_2017.csv') as file: | |
first_line = True | |
reader = csv.reader(file) | |
for row in reader: | |
if first_line: | |
first_line = False | |
continue | |
# row[0] = filename | |
# row[1] = ground truths | |
if row[0] in image_dict: | |
labels = row[1].split('|') | |
if len(labels) > 1: | |
pathology_dict['multi-label'].append(row[0]) | |
else: | |
if labels[0] not in pathology_dict: | |
pathology_dict[labels[0]] = [] | |
pathology_dict[labels[0]].append(row[0]) | |
return pathology_dict | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
import numpy as np | |
# See: https://arxiv.org/pdf/1409.4842.pdf | |
class GoogLeNet: | |
def __init__(self): | |
self.num_labels = 14 | |
self.NAME = "GoogLeNet" | |
#training params | |
self.batch_size = 64 | |
self.learning_rate = 0.1 | |
self.weight_decay = 0.001 | |
return None | |
def construct_graph(self, x, y): | |
self.graph = tf.get_default_graph() | |
self.lr = tf.placeholder(tf.float32, shape=[], name='LR') | |
self.keep_prob = tf.placeholder(tf.float32, shape=[], name='keep_prob') | |
self.is_training = tf.placeholder(tf.bool, shape=[], name='is_training') | |
model = self.conv(x, filters=64, kernel_size=7, stride=2, name='conv1_k7_s2') | |
model = self.max_pool(model, pool_size=3, stride=2, name="maxpool1_p3_s2") | |
model = tf.nn.local_response_normalization(input=model, alpha=0.0001, beta=0.75) | |
model = self.conv(model, filters=64, kernel_size=1, stride=1, name='conv2_k1_s1') | |
model = self.conv(model, filters=192, kernel_size=3, stride=1, name='conv2_k3_s1') | |
model = tf.nn.local_response_normalization(input=model, alpha=0.0001, beta=0.75) | |
model = self.max_pool(model, pool_size=3, stride=2, name='maxpool2_p3_s2') | |
model = self._inception_module(model, filters=[64, 96, 128, 16, 32, 32], | |
name='inception3a') | |
model = self._inception_module(model, filters=[128, 128, 192, 32, 96, 64], | |
name='inception3b') | |
model = self.max_pool(model, pool_size=3, stride=2, name='maxpool3_p3_s2') | |
model = self._inception_module(model, filters=[192, 96, 208, 16, 48, 64], | |
name='inception4a') | |
model = self._inception_module(model, filters=[160, 112, 224, 24, 64, 64], | |
name='inception4b') | |
model = self._inception_module(model, filters=[128, 128, 256, 24, 64, 64], | |
name='inception4c') | |
model = self._inception_module(model, filters=[112, 144, 288, 32, 64, 64], | |
name='inception4d') | |
model = self._inception_module(model, filters=[256, 160, 320, 32, 128, 128], | |
name='inception4e') | |
model = self.max_pool(model, pool_size=3, stride=2, name='maxpool4_p3_s2') | |
model = self._inception_module(model, filters=[256, 160, 320, 32, 128, 128], | |
name='inception5a') | |
model = self._inception_module(model, filters=[384, 192, 384, 48, 128, 128], | |
name='inception5b') | |
model = self.avg_pool(model, pool_size=7, stride=1, name='avgpool5_p7_s1') | |
#model = tf.reshape(model, [-1, 7 * 7 * 1024]) | |
logits = self.fully_connected(model) | |
self.ys_pred = tf.nn.sigmoid(logits, name='prediction') | |
with tf.name_scope('loss'): | |
total_labels = tf.to_float(tf.multiply(self.batch_size, self.num_labels)) | |
num_positive_labels = tf.count_nonzero(y, dtype=tf.float32) | |
num_negative_labels = tf.subtract(total_labels, num_positive_labels) | |
Bp = tf.divide(total_labels, num_positive_labels) | |
Bn = tf.divide(total_labels, num_negative_labels) | |
cross_entropy = -tf.reduce_sum((tf.multiply(Bp, y * tf.log(self.ys_pred + 1e-9))) + | |
(tf.multiply(Bn, (1-y) * tf.log(1-self.ys_pred + 1e-9))), | |
name="cross_entropy") | |
l2 = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()]) | |
# The loss function | |
self.loss = cross_entropy + l2 * self.weight_decay | |
# Training the network with Adam using standard parameters. | |
#self.train_step = tf.train.AdamOptimizer( | |
# learning_rate=self.lr, | |
# beta1=0.9, | |
# beta2=0.999).minimize(self.loss) | |
self.train_step = tf.train.AdagradOptimizer(learning_rate=self.lr).minimize(self.loss) | |
# Define some wrapper functions for brevity/readability | |
def conv(self, inputs, filters, kernel_size, stride, name, padding='SAME', | |
activation=tf.nn.relu): | |
return tf.layers.conv2d( | |
inputs=inputs, | |
filters=filters, | |
kernel_size=[kernel_size, kernel_size], | |
strides=stride, | |
padding=padding, | |
activation=activation, | |
kernel_initializer=tf.truncated_normal_initializer(stddev=0.001), | |
name=name) | |
def max_pool(self, inputs, pool_size, stride, name): | |
return tf.layers.max_pooling2d( | |
inputs=inputs, | |
pool_size=[pool_size, pool_size], | |
strides=stride, | |
padding='SAME', | |
name=name) | |
def avg_pool(self, inputs, pool_size, stride, name): | |
return tf.layers.average_pooling2d( | |
inputs=inputs, | |
pool_size=[pool_size, pool_size], | |
strides=stride, | |
padding='VALID', | |
name=name) | |
def fully_connected(self, inputs): | |
dropout = tf.layers.dropout(inputs, rate=1 - self.keep_prob, training=self.is_training) | |
# Need to reshape dropout to 2D tensor for FC layer, multiply the dimensions excluding | |
# batch size | |
new_shape = int(np.prod(self._get_tensor_shape(dropout)[1:])) | |
dropout = tf.reshape(dropout, [-1, new_shape]) | |
return tf.layers.dense(dropout, self.num_labels) | |
def _get_tensor_shape(self, tensor): | |
return tensor.get_shape().as_list() | |
def _inception_module(self, inputs, filters, name): | |
if len(filters) != 6: | |
raise ValueError('Invalid filters input') | |
# From left to right in the graph @ https://arxiv.org/pdf/1409.4842.pdf fig.3 | |
with tf.name_scope(name): | |
conv1_k1_s1 = self.conv(inputs, filters=filters[0], kernel_size=1, stride=1, | |
name=name + '_conv1_k1_s1') | |
conv2_k1_s1 = self.conv(inputs, filters=filters[1], kernel_size=1, stride=1, | |
name=name + '_conv2_k1_s1') | |
conv3_k3_s1 = self.conv(conv2_k1_s1, filters=filters[2], kernel_size=3, stride=1, | |
name=name + '_conv3_k3_s1') | |
conv4_k1_s1 = self.conv(inputs, filters=filters[3], kernel_size=1, stride=1, | |
name=name + '_conv4_k1_s1') | |
conv5_k5_s1 = self.conv(conv4_k1_s1, filters=filters[4], kernel_size=5, stride=1, | |
name=name + '_conv5_k5_s1') | |
pool1_p3_s1 = self.max_pool(inputs, pool_size=3, stride=1, name=name + '_pool1_p3_s1') | |
conv6_k1_s1 = self.conv(pool1_p3_s1, filters=filters[5], kernel_size=1, stride=1, | |
name=name + '_conv6_k1_s1') | |
tensor_list = [conv1_k1_s1, conv3_k3_s1, conv5_k5_s1, conv6_k1_s1] | |
return tf.concat(tensor_list, axis=3, name=name + '_merge') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
import numpy as np | |
class NaiveCNN: | |
def __init__(self): | |
self.num_labels = 14 | |
self.NAME = "NaiveCNN" | |
#training params | |
self.batch_size = 64 | |
self.learning_rate = 0.001 | |
self.weight_decay = 0.0001 | |
return None | |
def construct_graph(self, x, y): | |
self.graph = tf.get_default_graph() | |
# 224x224xGrayscale input data, cropped from 256x256 8bit greyscale PNG | |
#self.xs = tf.placeholder(tf.float32, shape=[None, 224, 224, 1]) | |
# 14 possible pathologies | |
#self.ys = tf.placeholder(tf.float32, shape=[None, self.num_labels]) | |
self.lr = tf.placeholder(tf.float32, shape=[]) | |
self.keep_prob = tf.placeholder(tf.float32, shape=[]) | |
self.is_training = tf.placeholder(tf.bool, shape=[]) | |
model = tf.layers.conv2d( | |
inputs=x, | |
filters=64, # number of outputs | |
kernel_size=[7,7], | |
strides=2, | |
padding='SAME', | |
activation=tf.nn.relu, | |
kernel_initializer=tf.truncated_normal_initializer(stddev=0.001), | |
name="conv_1_7_2") # conv num 1, size 7, stride 2 | |
model = tf.layers.max_pooling2d( | |
inputs=model, | |
pool_size=[3,3], | |
strides=2, | |
name="pool_1_3_2") # pool num 1, size 3, stride 2 | |
model = tf.layers.conv2d( | |
inputs=model, | |
filters=64, | |
kernel_size=[3,3], | |
strides=1, | |
padding='SAME', | |
kernel_initializer=tf.truncated_normal_initializer(stddev=0.001), | |
activation=tf.nn.relu, | |
name="conv_2_3_1") # conv num 2, size 3, stride 1 | |
model = tf.layers.max_pooling2d( | |
inputs=model, | |
pool_size=[3,3], | |
strides=2, | |
name="pool_2_3_2") | |
# flatten input tensor before dense layer | |
model = tf.reshape(model, [-1, 27 * 27 * 64]) | |
model = tf.layers.dense(inputs=model, units=1024, activation=tf.nn.relu) | |
model = tf.layers.dropout(inputs=model, rate=1 - self.keep_prob, training=self.is_training) | |
model = tf.layers.dense(inputs=model, units=self.num_labels) | |
# The layer used to get predictions from the network | |
# We will use this to calculate AUROC in testing | |
self.ys_pred = tf.nn.sigmoid(model, name="prediction") | |
# OLD | |
#with tf.name_scope('loss'): | |
# cross_entropy = -tf.reduce_sum((y * tf.log(self.ys_pred + 1e-9)) + ((1-y) * tf.log(1-self.ys_pred + 1e-9)), name="cross_entropy") | |
# l2 = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()]) | |
# # The loss function, an element-wise sigmoid non-linearity | |
# self.loss = cross_entropy + l2 * self.weight_decay | |
with tf.name_scope('loss'): | |
total_labels = tf.to_float(tf.multiply(self.batch_size, self.num_labels)) | |
num_positive_labels = tf.count_nonzero(y, dtype=tf.float32) | |
num_negative_labels = tf.subtract(total_labels, num_positive_labels) | |
Bp = tf.divide(total_labels, num_positive_labels) | |
Bn = tf.divide(total_labels, num_negative_labels) | |
cross_entropy = -tf.reduce_sum((tf.multiply(Bp, y * tf.log(self.ys_pred + 1e-9))) + | |
(tf.multiply(Bn, (1-y) * tf.log(1-self.ys_pred + 1e-9))), | |
name="cross_entropy") | |
l2 = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()]) | |
# The loss function | |
self.loss = cross_entropy + l2 * self.weight_decay | |
# Training the network with Adam using standard parameters. | |
#self.train_step = tf.train.AdamOptimizer( | |
# learning_rate=self.lr, | |
# beta1=0.9, | |
# beta2=0.999).minimize(self.loss) | |
self.train_step = tf.train.AdagradOptimizer(learning_rate=self.lr).minimize(self.loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
#import tensorflow as tf | |
import numpy as np | |
import GoogLeNet | |
import DataHandler | |
import tensorflow as tf | |
import os.path | |
import re | |
NUM_EPOCHS = 30 | |
VALIDATION_SET_SIZE = 10000 | |
def get_num_trainable_params(): | |
total_parameters = 0 | |
for variable in tf.trainable_variables(): | |
shape = variable.get_shape() | |
variable_parametes = 1 | |
for dim in shape: | |
variable_parametes *= dim.value | |
total_parameters += variable_parametes | |
return total_parameters | |
def add_summary_ops(ground_truth): | |
# We will round our networks predictions such that >50% presence is a positive, <=50% presence is negative | |
thresholds = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5] | |
# For inference, we will display the actual percentages | |
p, _ = tf.metrics.precision_at_thresholds(labels=ground_truth, predictions=network.ys_pred, thresholds=thresholds) | |
r, _ = tf.metrics.recall_at_thresholds(labels=ground_truth, predictions=network.ys_pred, thresholds=thresholds) | |
# Using F1 because false negative and false positive are equally bad in medicine | |
precision = tf.reduce_mean(p) | |
recall = tf.reduce_mean(r) | |
f1 = 2 * precision * recall / (precision + recall) | |
with tf.name_scope("summaries"): | |
tf.summary.scalar("loss", network.loss) | |
# Plotting learning rate forces us to feed learning rate even when we don't train. | |
tf.summary.scalar("learning_rate", network.lr) | |
tf.summary.scalar("precision", precision) | |
tf.summary.scalar("recall", recall) | |
tf.summary.scalar("f1_score", f1) | |
network.summary_op = tf.summary.merge_all() | |
return p, r, f1 | |
# Initialise network values | |
network = GoogLeNet.GoogLeNet() | |
# Get our list of files and their labels, and create our placeholders to feed | |
data = DataHandler.DataHandler() | |
train_features, train_labels = data.get_dataset('training') | |
val_features, val_labels = data.get_dataset('validation') | |
VALIDATION_SET_SIZE = len(val_features) | |
features_placeholder = tf.placeholder(tf.string, shape=[None]) | |
labels_placeholder = tf.placeholder(tf.float32, shape=[None, len(data.GROUND_TRUTHS)]) | |
# Create a dataset from our placeholders | |
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder)) | |
# Map the filenames to the actual image data | |
dataset = dataset.map(data.image_parse_function) | |
# Split the dataset into batches depending on the network's specified batch size. | |
dataset = dataset.batch(network.batch_size) | |
# Create an iterator for our datasets | |
handle = tf.placeholder(tf.string, shape=[]) | |
iterator = tf.data.Iterator.from_string_handle( | |
handle, dataset.output_types, dataset.output_shapes) | |
train_iterator = dataset.make_initializable_iterator() | |
val_iterator = dataset.make_initializable_iterator() | |
# Get our final image data and label from the iterator, pass it to the network and let | |
# the network build it's graph, followed by the summary ops | |
(x, y) = iterator.get_next() | |
network.construct_graph(x, y) | |
p, r, f1 = add_summary_ops(y) | |
# Create our summary file writer so we can track our progress on TensorBoard | |
train_writer = tf.summary.FileWriter('./train_logs/' + network.NAME + '/train', network.graph) | |
val_writer = tf.summary.FileWriter('./train_logs/' + network.NAME + '/val', network.graph) | |
# Start a session | |
with tf.Session(graph=network.graph) as sess: | |
# Create a saver so we can save/load model checkpoints after epochs | |
saver = tf.train.Saver() | |
batches_completed = 0 | |
epochs_completed = 0 | |
# Look for existing ckpt file else initialise! | |
available_ckpts = [int(re.match(r"(?:[a-zA-Z]*_)([0-9]*)(?:\.ckpt\.txt)", f).group(1)) | |
for f in os.listdir('./checkpoints/' + network.NAME + '/') | |
if f.endswith('.ckpt.txt')] | |
if len(available_ckpts) > 0: | |
# Sort the list of checkpoint numbers in descending order so first entry is latest | |
available_ckpts.sort(reverse=True) | |
print('Restoring from epoch {0}'.format(available_ckpts[0])) | |
saver.restore(sess, './checkpoints/{0}/{0}_{1}.ckpt'.format(network.NAME, available_ckpts[0])) | |
# load epoch and batch values from old model | |
with open('./checkpoints/{0}/{0}_{1}.ckpt.txt'.format(network.NAME, available_ckpts[0])) as info_file: | |
values = info_file.read().splitlines() | |
if len(values) == 4: | |
batches_completed = int(values[1]) | |
epochs_completed = int(values[3]) | |
else: | |
# Initialise our global vars (W and b) | |
sess.run(tf.global_variables_initializer()) | |
# Initialise our local vars (for calculating training/validation precision/recall/f1) | |
sess.run(tf.local_variables_initializer()) | |
# Print the current models number of training params | |
print("Total training params: %.1fM" % (get_num_trainable_params() / 1e6)) | |
# Get the iterator handles to feed for train/val/test | |
train_handle = sess.run(train_iterator.string_handle()) | |
val_handle = sess.run(val_iterator.string_handle()) | |
#for each batch --- learning rate drops to 0.01 at 150 epoch and 0.001 at 225 epoch? | |
no_improvement_last_epoch = False | |
old_loss = 2**32 - 1 # A large number in case this is our first run | |
# Compute for NUM_EPOCHS | |
while epochs_completed < NUM_EPOCHS: | |
# Initialise our iterators with data (this also resets them to the beginning of their dataset) | |
sess.run(train_iterator.initializer, feed_dict={features_placeholder: train_features, labels_placeholder: train_labels}) | |
sess.run(val_iterator.initializer, feed_dict={features_placeholder: val_features, labels_placeholder: val_labels}) | |
while True: | |
try: | |
# Every 1000 batches, also trace runtime statistics for debugging memory usage/compute time | |
if batches_completed % 1000 == 0: | |
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) | |
run_metadata = tf.RunMetadata() | |
_, loss, prediction, summary, _x, _y, _p, _r, _f1 = sess.run([network.train_step, network.loss, network.ys_pred, network.summary_op, x, y, p, r, f1 ], | |
feed_dict={ | |
handle: train_handle, | |
network.lr: network.learning_rate, | |
network.is_training: True, | |
network.keep_prob: 0.8 | |
}, | |
options=run_options, | |
run_metadata=run_metadata) | |
train_writer.add_run_metadata(run_metadata, 'batch{0}'.format(batches_completed)) | |
train_writer.add_summary(summary, global_step=batches_completed) | |
# else just train normally | |
else: | |
_, loss, prediction, summary, _x, _y, _p, _r, _f1 = sess.run([network.train_step, network.loss, network.ys_pred, network.summary_op, x, y, p, r, f1 ], | |
feed_dict={ | |
handle: train_handle, | |
network.lr: network.learning_rate, | |
network.is_training: True, | |
network.keep_prob: 0.8 | |
}) | |
train_writer.add_summary(summary, global_step=batches_completed) | |
# Also run a validation batch every 20 batches for TensorBoard | |
if batches_completed % 20 == 0: | |
loss, prediction, summary, _x, _y, _p, _r, _f1 = sess.run([network.loss, network.ys_pred, network.summary_op, x, y, p, r, f1 ], | |
feed_dict={ | |
handle: val_handle, | |
network.lr: network.learning_rate, | |
network.is_training: False, | |
network.keep_prob: 1.0 | |
}) | |
val_writer.add_summary(summary, global_step=batches_completed) | |
batches_completed = batches_completed + 1 | |
# If we ran out of data, that's the end of our epoch | |
except tf.errors.OutOfRangeError: | |
break | |
# After our epoch, calculate mean loss over full validation set | |
sess.run(val_iterator.initializer, feed_dict={ features_placeholder: val_features, labels_placeholder: val_labels }) | |
total_loss = 0 | |
while True: | |
try: | |
loss, _preds, _y, _x, _p, _r, _f1 = sess.run([network.loss, network.ys_pred, y, x, p, r, f1], | |
feed_dict={ | |
handle: val_handle, | |
network.lr: network.learning_rate, | |
network.is_training: False, | |
network.keep_prob: 1.0 | |
}) | |
total_loss += loss | |
# run predictions until validation set is exhausted | |
except tf.errors.OutOfRangeError: | |
break | |
# Compare the test to the previous models test, either drop learning rate or stop early if no improvement | |
mean_loss = total_loss / VALIDATION_SET_SIZE | |
try: | |
# Try to read old loss from previous checkpoint | |
with open('./checkpoints/' + network.NAME + '/' + network.NAME + '_%d.ckpt.txt' % (epochs_completed - 1), mode='r') as file: | |
data = file.read().splitlines() | |
old_loss = float(data[0]) | |
old_learning_rate = float(data[2]) | |
except: | |
# Must be first checkpoint | |
pass | |
# If we didn't improve | |
if mean_loss >= old_loss: | |
# and we just dropped the learning rate last epoch | |
if no_improvement_last_epoch: | |
# Stop training early | |
print("We're done! Best model was after {0} epochs at {1} mean loss.".format((epochs_completed - 2), old_loss)) | |
break | |
else: # Decay learning rate by factor of 10, and take the previous weights | |
network.learning_rate = network.learning_rate * 0.1 | |
saver.restore(sess, './checkpoints/' + network.NAME + '/' + network.NAME + '_%d.ckpt' % (epochs_completed - 1)) | |
mean_loss = old_loss | |
# If we still don't improve next time after lowering learning rate | |
no_improvement_last_epoch = True | |
else: | |
no_improvement_last_epoch = False | |
# Save this model as a new checkpoint | |
file_name = './checkpoints/' + network.NAME + '/' + network.NAME + '_%d.ckpt' % epochs_completed | |
save_path = saver.save(sess, file_name) | |
# also save current learning rate and global step in an associated text file! | |
with open('./checkpoints/' + network.NAME + '/' + network.NAME + '_%d.ckpt.txt' % epochs_completed, mode='w') as out_file: | |
out_file.write('{0}\n{1}\n{2}\n{3}'.format(mean_loss, batches_completed, network.learning_rate, epochs_completed)) | |
epochs_completed = epochs_completed + 1 | |
train_writer.close() | |
val_writer.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Really nice implementation!