Created
January 23, 2017 14:57
-
-
Save psycharo-zz/58717872a3a00284fbbcd9575d265785 to your computer and use it in GitHub Desktop.
example of an efficient and simple input pipeline in tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
def _int64_feature(value): | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | |
def _bytes_feature(value): | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def _convert_example(rgb_path, label_path): | |
# rgb_png = tf.gfile.GFile(rgb_path, 'rb').read() | |
# label_png = tf.gfile.GFile(label_path, 'rb').read() | |
rgb_png = open(rgb_path, 'rb').read() | |
label_png = open(label_path, 'rb').read() | |
example = tf.train.Example(features=tf.train.Features(feature={ | |
'rgb_png': _bytes_feature(rgb_png), | |
'label_png': _bytes_feature(label_png) | |
})) | |
return example.SerializeToString() | |
def _convert_dataset_shard(filenames, output_path): | |
"""A per-thread unit of work for dataset processing | |
Args: | |
filenames: a list of (fname0, fname1) tuples | |
output_path: where to store the records | |
""" | |
writer = tf.python_io.TFRecordWriter(output_path) | |
for rgb_path, label_path in filenames: | |
writer.write(_convert_example(rgb_path, label_path)) | |
writer.close() | |
def _to_filenames(raw_data_dir, tag, city, fid): | |
"""returns a tuple of filenames""" | |
rgb_path = ('%s/leftImg8bit/%s/%s/%s_leftImg8bit.png' % | |
(raw_data_dir, tag, city, fid)) | |
label_path = ('%s/gtFine/%s/%s/%s_gtFine_labelIds.png' % | |
(raw_data_dir, tag, city, fid)) | |
return rgb_path, label_path | |
def convert_dataset(raw_data_dir, processed_dir, tag, num_threads=1, | |
max_num_examples=10): | |
"""Converts the dataset into TFRecords | |
Args: | |
raw_data_dir: directory with the unprocessed dataset | |
processed_dir: where to store TFRecords | |
tag: "train"|"test"|"val" | |
num_threads: number of threads to use in parallel | |
max_num_examples: maximum number of examples to load | |
""" | |
cities = sorted(os.listdir('%s/leftImg8bit/%s' % (raw_data_dir, tag))) | |
fids = [(city, p.rsplit('_', 1)[0]) | |
for city in cities | |
for p in os.listdir('%s/leftImg8bit/%s/%s' % | |
(raw_data_dir, tag, city))] | |
filenames = [_to_filenames(raw_data_dir, tag, city, fid) | |
for (city, fid) in fids[:max_num_examples]] | |
filenames_sliced = [] | |
slices = np.linspace(0, len(filenames), num_threads+1).astype(np.int32) | |
for i in range(num_threads): | |
filenames_sliced.append(filenames[slices[i]:slices[i+1]]) | |
coord = tf.train.Coordinator() | |
threads = [] | |
for i in range(num_threads): | |
output_path = (processed_dir + | |
'/%s-%02d-of-%02d.tfrecord' % (tag, i, num_threads)) | |
args = (filenames_sliced[i], output_path) | |
thread = threading.Thread(target=_convert_dataset_shard, args=args) | |
thread.start() | |
threads.append(thread) | |
coord.join(threads) | |
# inputs | |
def read_and_decode(filename_queue, height=1024, width=2048): | |
"""parse TFRecord example""" | |
reader = tf.TFRecordReader() | |
_, serialized_example = reader.read(filename_queue) | |
features = tf.parse_single_example(serialized_example, | |
features={ 'rgb_png': tf.FixedLenFeature([], tf.string), | |
'label_png': tf.FixedLenFeature([], tf.string) }) | |
rgb = tf.image.decode_png(features['rgb_png'], channels=3) | |
label = tf.image.decode_png(features['label_png'], channels=1) | |
rgb.set_shape([height, width, 3]) | |
label.set_shape([height, width, 1]) | |
return rgb, label | |
def input_pipeline(file_pattern, batch_size, min_values_dequeue, | |
num_epochs=None, num_reader_threads=1): | |
"""Creates input pipeline | |
Args: | |
file_pattern: pattern for input files, e.g. 'train-??-of-10.tfrecord' | |
batch_size: used to determine buffer sizes | |
min_values_dequeue: the size of buffer: the more, the better shuffling | |
num_epochs: how many times to go through all the data} | |
""" | |
# TODO: is_training flag? | |
filenames = tf.gfile.Glob(file_pattern) | |
filename_queue = tf.train.string_input_producer(filenames, num_epochs, | |
shuffle=False, | |
name='filename_queue') | |
example_list = [read_and_decode(filename_queue) | |
for t in range(num_reader_threads)] | |
capacity = min_values_dequeue + 64 * batch_size | |
rgb_batch, label_batch = tf.train.batch_join(example_list, batch_size, capacity) | |
return rgb_batch, label_batch | |
tf.reset_default_graph() | |
train_file_pattern = os.path.join(processed_data_dir, 'train-??-of-??.tfrecord') | |
rgb_batch, label_batch = input_pipeline(train_file_pattern, | |
batch_size=8, | |
min_values_dequeue=3200, | |
num_epochs=1) | |
init_op = tf.local_variables_initializer() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment