Created
April 21, 2020 02:24
-
-
Save VyBui/05ec95c67c2af008975170be3fa90365 to your computer and use it in GitHub Desktop.
bps_train
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
import segmentation_models as sm | |
# Segmentation Models: using `keras` framework. | |
import tensorflow as tf | |
from tensorflow.python.client import device_lib | |
from config import cfg | |
from losses import schp_loss | |
from create_tf_records_bps import input_fn | |
from warm_start import get_learning_rate | |
from vgg19 import build_vgg19_model | |
from tensorflow import keras | |
keras.backend.set_image_data_format('channels_last') | |
def get_available_gpus(): | |
local_device_protos = device_lib.list_local_devices() | |
return [x.name for x in local_device_protos if x.device_type == 'GPU'] | |
def train(): | |
""" | |
:return: | |
""" | |
""" | |
This is useful if you want to truly bound the amount of GPU memory available to the TensorFlow process. | |
This is common practice for local development when the GPU is shared with other applications such as a workstation GUI. | |
""" | |
gpus = get_available_gpus() | |
print(gpus) | |
try: | |
for gpu in gpus: | |
print("??") | |
# tf.config.experimental.set_memory_growth(gpu, True) | |
strategy = tf.distribute.MirroredStrategy() | |
BATCH_SIZE_PER_REPLICA = 2 | |
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync | |
print("global batch_size is: {}".format(BATCH_SIZE)) | |
# Get dataset | |
params = {'batch_size': BATCH_SIZE, 'tf_records_path': cfg.TF_RECORD_PATH} | |
train_dataset = input_fn(mode="train", params=params) | |
# test_dataset = input_fn(mode="test", params=params) | |
print("aaaa") | |
train_dist_dataset = strategy.distribute_dataset(input_fn(mode="train", params=params)) | |
# test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset) | |
with strategy.scope(): | |
print("Building Vgg19 model") | |
vgg19 = build_vgg19_model() | |
with strategy.scope(): | |
def train_step(batch, current_epoch): | |
""" | |
:param batch: | |
:param current_epoch: | |
:return: | |
""" | |
with tf.GradientTape() as gen_tape: | |
input_image, label, file_name = batch | |
LEARNING_RATE = get_learning_rate(cfg.max_learning_rate, cfg.min_learning_rate, current_epoch, | |
cfg.EPOCHS) | |
schp_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=0.5) | |
segmentation_model = sm.PSPNet('resnet101', encoder_weights='imagenet') | |
[_, _], D_real_style_steps, D_real_content_steps = vgg19(input_image, training=True) | |
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)] | |
callbacks = [ | |
tf.keras.callbacks.ModelCheckpoint(cfg.checkpoint_dir, save_weights_only=True, | |
save_best_only=True, | |
mode='min'), | |
tf.keras.callbacks.ReduceLROnPlateau(), | |
] | |
dice_loss = sm.losses.DiceLoss() | |
focal_loss = sm.losses.CategoricalFocalLoss() | |
parsing_loss = dice_loss + (1 * focal_loss) | |
loss = schp_loss(loss_edges, parsing_loss, loss_consistent) | |
segmentation_model.compile( | |
schp_optimizer, | |
loss=loss, | |
metrics=metrics, | |
callbacks=callbacks | |
) | |
prediction = segmentation_model.outputs | |
segment_gradients = gen_tape.gradient(loss, | |
segmentation_model.trainable_variables) | |
schp_optimizer.apply_gradients(zip(segment_gradients, | |
segmentation_model.trainable_variables)) | |
# if step % 10 == 0: | |
# with tf.device("cpu:0"): | |
# with summary_writer.as_default(): | |
# tf.summary.scalar('schp loss', loss, step=step, description='schp losses blocks') | |
# tf.summary.scalar('gan_loss', gan_loss, step=step, description='GANs losses blocks') | |
# tf.summary.scalar('gan_l1_loss', gan_l1_loss, step=step, description='GANs losses blocks') | |
return loss | |
def test_step(batch, step): | |
""" | |
:param batch: | |
:param step: | |
:return: | |
""" | |
image, label, imagename = batch # do not need label_non_head | |
gen_output_validation = segmentation_model([image, label, imagename ], training=False) | |
# with tf.device("cpu:0"): | |
# with summary_writer.as_default(): | |
# tf.summary.image("val_output", gen_output_validation[:, :, :, ::-1], step=step) | |
with strategy.scope(): | |
# `experimental_run_v2` replicates the provided computation and runs it | |
# with the distributed input. | |
@tf.function | |
def distributed_train_step(dataset_inputs, epoch): | |
discriminator_per_replica_losses, generator_per_replica_losses = strategy.experimental_run_v2( | |
train_step, | |
args=(dataset_inputs, epoch)) | |
return strategy.reduce(tf.distribute.ReduceOp.SUM, discriminator_per_replica_losses, axis=None), \ | |
strategy.reduce(tf.distribute.ReduceOp.SUM, generator_per_replica_losses, axis=None) | |
# @tf.function | |
def distributed_test_step(dataset_inputs, step): | |
return strategy.experimental_run_v2(test_step, args=(dataset_inputs, step)) | |
def fit(train_dist_dataset, epochs, test_dist_dataset): | |
""" | |
:param train_dist_dataset: | |
:param epochs: | |
:param test_dist_dataset: | |
:return: | |
""" | |
for epoch in range(epochs): | |
# TRAIN LOOP | |
print("Epoch: ", epoch) | |
d_total_loss = 0.0 | |
g_total_loss = 0.0 | |
num_batches = 0 | |
train_iter = iter(train_dist_dataset) | |
total_step = int(cfg.total_tfrecords_for_training / BATCH_SIZE) | |
print("The number of total steps for train: {}".format(total_step)) | |
for num_batches in range(total_step): | |
print('....', end='') | |
step = tf.convert_to_tensor(num_batches, dtype=tf.int64) | |
d_loss, g_loss = distributed_train_step(next(train_iter), epoch) | |
d_total_loss += d_loss | |
g_total_loss += g_loss | |
if num_batches % 10 == 0: | |
step_template = "Step {}, d_Loss: {}, g_Loss: {}" | |
print(step_template.format(num_batches, d_total_loss / num_batches, | |
g_total_loss / num_batches)) | |
train_d_loss = d_total_loss / num_batches | |
train_g_loss = g_total_loss / num_batches | |
template = "Epoch {}, d_Loss: {}, g_Loss: {}" | |
print(template.format(epoch + 1, train_d_loss, train_g_loss)) | |
# saving (checkpoint) the model every epoch | |
# checkpoint.save(file_prefix=checkpoint_prefix) | |
# Validate the dataset every 5 epoch | |
if epoch % 5 == 0: | |
# test_iter = iter(test_dist_dataset) | |
total_test_steps = int(cfg.total_viton_tfrecords_for_testing / BATCH_SIZE) | |
print("The number of total steps for test: {}".format(total_test_steps)) | |
test_step = 0 | |
for x in test_dist_dataset: | |
test_step += 1 | |
test_step = tf.convert_to_tensor(test_step, dtype=tf.int64) | |
distributed_test_step(x, test_step) | |
# summary_writer = tf.summary.create_file_writer( | |
# cfg.log_dir + "fit/" + time.datetime.now().strftime("%Y%m%d-%H%M%S")) | |
print("alo") | |
fit(train_dist_dataset, cfg.EPOCHS, None) | |
except Exception as e: | |
print(e) | |
if __name__ == '__main__': | |
# learning_rate = get_learning_rate(current_epoch=t_epoch, total_epochs=EPOCHS) | |
train() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment