Created
March 19, 2017 00:43
-
-
Save faroit/92ba12373440d092e1096967b530a5b8 to your computer and use it in GitHub Desktop.
threadsafe_generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import threading | |
def threadsafe_generator(lock=None): | |
def wrap(f): | |
"""A decorator that takes a generator function and makes it thread-safe. | |
Taken from | |
http://anandology.com/blog/using-iterators-and-generators/ | |
""" | |
def g(*a, **kw): | |
return threadsafe_iter(f(*a, **kw), lock=lock) | |
return g | |
return wrap | |
class threadsafe_iter: | |
"""Takes an iterator/generator and makes it thread-safe by | |
serializing call to the `next` method of given iterator/generator. | |
""" | |
def __init__(self, it, lock=None): | |
self.it = it | |
if lock is None: | |
lock = threading.Lock() | |
self.lock = lock | |
def __iter__(self): | |
return self | |
def next(self): | |
with self.lock: | |
return self.it.next() | |
@threadsafe_generator(lock=None) | |
def DataGenerator( | |
model, | |
shuffle=False, | |
seed=None, | |
subset='train', | |
validation=False, | |
split=0.1 | |
): | |
# set seed | |
if seed is not None: | |
np.random.seed(seed) | |
# create an index array from all samples in the dataset | |
index_array = np.arange(dataset.shape[0]) | |
# split index array into dev/valid with the purpose of | |
# __train__ being shuffled each epoch and | |
# __valid__ not being shuffled | |
if shuffle: | |
# shuffle all batches once | |
index_array = np.random.permutation(index_array) | |
# split index into train and validation where validation is the | |
# latter self.split percent of the training data. | |
train_index, val_index = train_valid_split(index_array, split) | |
# for validation use the validation index (and shuffe=False) | |
if validation: | |
index = val_index | |
else: | |
index = train_index | |
# infinite generator | |
while True: | |
# shuffle each epoch | |
if shuffle and not validation: | |
index = np.random.permutation(index) | |
for batch_indices in batcher(index): | |
X, y = load_from_h5(indices=batch_indices) | |
yield X, y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment