Last active
July 10, 2024 22:19
-
-
Save gpleiss/d3244db99e6d5adad467291232fbb256 to your computer and use it in GitHub Desktop.
ImageZipDataset for PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import csv | |
from zipfile import ZipFile | |
from contextlib import contextmanager | |
import torch | |
from PIL import Image | |
class _ImageZipDataset(torch.utils.data.Dataset): | |
def __init__(self, zip_file, samples, transform=None, target_transform=None): | |
self.zip_file = zip_file | |
self.samples = samples | |
self.transform = transform | |
self.target_transform = target_transform | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: (sample, target) where target is class_index of the target class. | |
""" | |
path, target = self.samples[index] | |
with self.zip_file.open(path) as f: | |
sample = Image.open(f).convert('RGB') | |
if self.transform is not None: | |
sample = self.transform(sample) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return sample, target | |
def __len__(self): | |
return len(self.samples) | |
def __repr__(self): | |
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |
tmp = ' Transforms (if any): ' | |
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
tmp = ' Target Transforms (if any): ' | |
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
return fmt_str | |
class ImageZipDatasetWrapper(torch.utils.data.Dataset): | |
""" | |
A dataset where images are stored in a zip file: | |
<filename>.zip@/<img_1>.JPEG | |
<filename>.zip@/<img_2>.JPEG | |
<filename>.zip@/<img_3>.JPEG | |
... | |
And the class assignments are stored in a TSV file: | |
.../<filename>.zip@<img_1>.JPEG <class_a> | |
.../<filename>.zip@<img_2>.JPEG <class_b> | |
.../<filename>.zip@<img_3>.JPEG <class_c> | |
Args: | |
zip_path (string): path to zip file | |
info_path (string): filename of TSV file with class assignments | |
transform (callable, optional): transforms to apply to each image | |
target_transform (callable, optional): transforms to apply to each target | |
""" | |
def __init__(self, zip_path, info_path, transform=None, target_transform=None, info_encoding='utf8'): | |
if not os.path.exists(zip_path): | |
raise RuntimeError('%s does not exist' % zip_path) | |
if not os.path.exists(info_path): | |
raise RuntimeError('%s does not exist' % info_path) | |
self.zip_path = zip_path | |
self.info_path = info_path | |
self.transform = transform | |
self.target_transform = target_transform | |
with open(self.info_path, 'r', encoding=info_encoding) as f: | |
reader = csv.reader(f, delimiter='\t') | |
self.samples = sorted( | |
((os.path.basename(row[0]), int(row[1])) for row in reader), | |
key=lambda x: (x[1], x[0]) | |
) | |
@contextmanager | |
def dataset(self): | |
with ZipFile(self.zip_path, 'r') as zip_file: | |
res = _ImageZipDataset( | |
zip_file=zip_file, | |
samples=self.samples, | |
transform=self.transform, | |
target_transform=self.target_transform, | |
) | |
yield res | |
def __len__(self): | |
return len(self.samples) | |
def __repr__(self): | |
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |
fmt_str += ' Zip Location: {}\n'.format(self.zip_path) | |
fmt_str += ' Info Location: {}\n'.format(self.info_path) | |
tmp = ' Transforms (if any): ' | |
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
tmp = ' Target Transforms (if any): ' | |
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |
return fmt_str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision | |
from image_zip_dataset import ImageZipDatasetWrapper | |
dataset_wrapper = ImageZipDatasetWrapper( | |
zip_path='path/to/file.zip', | |
info_path='path/to/info_file.txt', | |
transform=torchvision.transforms.ToTensor(), | |
target_transform=None | |
) | |
with dataset_wrapper.dataset() as dataset: | |
# Do stuff here... | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment