Last active
July 28, 2020 12:57
-
-
Save sunnychugh/046280a39eb9a685090d77c675b20a2a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import os | |
import time | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
def load_data_using_keras(folders): | |
""" | |
Load the images in batches using Keras. | |
Shuffle images (for training set only) using keras. | |
Returns: | |
Data Generator to be used while training the model. | |
Note: Keras might need 'pillow' library to be installed. Use- | |
# pip install pillow | |
""" | |
image_generator = {} | |
data_generator = {} | |
for x in folders: | |
image_generator[x] = ImageDataGenerator(rescale=1./255) | |
shuffle_images = True if x == 'train' else False | |
data_generator[x] = image_generator[x].flow_from_directory( | |
batch_size=batch_size, | |
directory=os.path.join(dir_path, x), | |
shuffle=shuffle_images, | |
target_size=(img_dims[0], img_dims[1]), | |
class_mode='categorical') | |
return data_generator | |
def load_data_using_tfdata(folders): | |
""" | |
Load the images in batches using Tensorflow (tfdata). | |
Cache can be used to speed up the process. | |
Faster method in comparison to image loading using Keras. | |
Returns: | |
Data Generator to be used while training the model. | |
""" | |
def parse_image(file_path): | |
# convert the path to a list of path components | |
parts = tf.strings.split(file_path, os.path.sep) | |
class_names = np.array(os.listdir(dir_path + '/train')) | |
# The second to last is the class-directory | |
label = parts[-2] == class_names | |
# load the raw data from the file as a string | |
img = tf.io.read_file(file_path) | |
# convert the compressed string to a 3D uint8 tensor | |
img = tf.image.decode_jpeg(img, channels=3) | |
# Use `convert_image_dtype` to convert to floats in the [0,1] range | |
img = tf.image.convert_image_dtype(img, tf.float32) | |
# resize the image to the desired size. | |
img = tf.image.resize(img, [img_dims[0], img_dims[1]]) | |
return img, label | |
def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000): | |
# If a small dataset, only load it once, and keep it in memory. | |
# use `.cache(filename)` to cache preprocessing work for datasets | |
# that don't fit in memory. | |
if cache: | |
if isinstance(cache, str): | |
ds = ds.cache(cache) | |
else: | |
ds = ds.cache() | |
ds = ds.shuffle(buffer_size=shuffle_buffer_size) | |
# Repeat forever | |
ds = ds.repeat() | |
ds = ds.batch(batch_size) | |
# `prefetch` lets the dataset fetch batches in the background | |
# while the model is training. | |
ds = ds.prefetch(buffer_size=AUTOTUNE) | |
return ds | |
data_generator = {} | |
for x in folders: | |
dir_extend = dir_path + '/' + x | |
list_ds = tf.data.Dataset.list_files(str(dir_extend+'/*/*')) | |
AUTOTUNE = tf.data.experimental.AUTOTUNE | |
# Set `num_parallel_calls` so that multiple images are | |
# processed in parallel | |
labeled_ds = list_ds.map( | |
parse_image, num_parallel_calls=AUTOTUNE) | |
# cache = True, False, './file_name' | |
# If the dataset doesn't fit in memory use a cache file, | |
# eg. cache='./data.tfcache' | |
data_generator[x] = prepare_for_training( | |
labeled_ds, cache='./data.tfcache') | |
return data_generator | |
def timeit(ds, steps=1000): | |
""" | |
Check performance/speed for loading images using Keras or tfdata. | |
""" | |
start = time.time() | |
it = iter(ds) | |
for i in range(steps): | |
next(it) | |
print(' >> ', i, '/1000', end='\r') | |
duration = time.time()-start | |
print(f'''{steps} batches: ''' | |
f'''{datetime.timedelta(seconds=int(duration))}''') | |
print(f'{round(batch_size*steps/duration)} Images/s') | |
if __name__ == '__main__': | |
# Need to change this w.r.t data | |
dir_path = '/home/sun/data/dog_vs_cat' | |
folders = ['train', 'val'] | |
load_data_using = 'tfdata' | |
batch_size = 32 | |
img_dims = [256, 256] | |
if load_data_using == 'keras': | |
data_generator = load_data_using_keras(folders) | |
elif load_data_using == 'tfdata': | |
data_generator = load_data_using_tfdata(folders) | |
timeit(data_generator['train']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment