Last active
June 3, 2019 13:48
-
-
Save CrackerHax/06025c08bddf277696e26979b0b93e5d to your computer and use it in GitHub Desktop.
Create tfrecords labeled by category from directories of images
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from random import shuffle | |
import glob | |
import sys | |
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
name = 'mountains' # name of your project directory where all images are | |
# should be under ./images/train (or change the path variable below) | |
image_size = 256 # size of images - should be square images (i.e. 256x256) | |
# this reads files organized by label to save in tfrecord form | |
# directories should look like this: | |
# images/train/portraits/male/old | |
# images/train/portraits/male/young | |
# images/train/portraits/female/old | |
# images/train/portraits/female/young | |
# leave these alone | |
path = 'images/train/' | |
addrs = [] | |
labels = [] | |
all_categories = [] | |
def _int64_feature(value): | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) | |
def _bytes_feature(value): | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def load_image(addr): | |
# cv2 load images as BGR, convert it to RGB | |
img = cv2.imread(addr) | |
if img is None: | |
return None | |
img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_CUBIC) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
return img | |
def createDataRecord(out_filename, addrs, labels): | |
# open the TFRecords file | |
writer = tf.python_io.TFRecordWriter(out_filename) | |
for i in range(len(addrs)): | |
print('Train data: {}/{}'.format(i, len(addrs)-1)) | |
print('--- Path:'+addrs[i]+' Labels:'+str(labels[i])) | |
sys.stdout.flush() | |
# Load the image | |
img = load_image(addrs[i]) | |
label = labels[i] | |
if img is None: | |
print("Error: no image") | |
continue | |
# Create a feature | |
feature = { | |
'image': _bytes_feature(img.tostring()), | |
'label': _int64_feature(label) | |
} | |
# Create an example protocol buffer | |
example = tf.train.Example(features=tf.train.Features(feature=feature)) | |
# Serialize to string and write on the file | |
writer.write(example.SerializeToString()) | |
writer.close() | |
sys.stdout.flush() | |
#-- | |
#index the labels | |
image_categories = [f.name for f in os.scandir(path+name) if f.is_dir() ] | |
for category in image_categories: | |
image_subcategories = [f.name for f in os.scandir(path+name+'/'+category) if f.is_dir() ] | |
all_categories = image_categories + image_subcategories | |
print(all_categories) | |
images = [] | |
# get image paths and the (1 or 2) labels for each image | |
# there's probably a better way to do this recursively | |
if len(image_subcategories) == 0: | |
for category in image_categories: | |
file_path = path+name+'/'+category | |
images = [f.name for f in os.scandir(file_path) if f.is_file() ] | |
for f in images: | |
label = np.zeros((len(all_categories))).astype(int) | |
label[all_categories.index(category)] = 1 | |
labels += [label] | |
addrs += [file_path+'/'+f] | |
else: | |
for category in image_categories: | |
for subcategory in image_subcategories: | |
file_path = path+name+'/'+category+'/'+subcategory | |
images = [f.name for f in os.scandir(file_path) if f.is_file() ] | |
for f in images: | |
label = np.zeros((len(all_categories))).astype(int) | |
label[all_categories.index(category)] = 1 | |
label[all_categories.index(subcategory)] = 1 | |
addrs += [file_path+'/'+f] | |
labels += [label] | |
# to shuffle data | |
c = list(zip(addrs, labels)) | |
shuffle(c) | |
addrs, labels = zip(*c) | |
if not os.path.exists('datasets/'+name+'/'): | |
os.mkdir('datasets/'+name+'/') | |
createDataRecord('datasets/'+name+'/'+name+'_train.tfrecords', addrs, labels) | |
print('saved in ./datasets/'+name+'/') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment