Created
January 23, 2024 10:36
-
-
Save jgs03177/7d9a24a47afd4f10d97246d2d1ddea31 to your computer and use it in GitHub Desktop.
pytorch dataloader example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import errno | |
import cv2 | |
import torch | |
import numpy as np | |
import random | |
from torch.utils.data import Dataset, ConcatDataset, DataLoader | |
from typing import Callable | |
class ImageDataset(Dataset): | |
"""init | |
get dataset from set_dirname. | |
if label is none then folders names are labels. (format: ./labelname/imgname) | |
if label is not none then use labels. (format: ./imgname) | |
if picklename is not none then save/load pickles with picklename. | |
reg: img regularizer | |
crop: img cropping prob. | |
erase: img random erasing prob. | |
mapper: label mapping function. | |
""" | |
def __init__(self, | |
set_dirname: str, | |
picklename: str=None, | |
label: int=None, | |
reg: bool=True, | |
crop: float=0, | |
erase: float=0, | |
mapper: Callable[[int], int] = None): | |
# generate folder | |
if not os.path.exists("./pickles"): | |
try: | |
os.makedirs("./pickles") | |
except OSError as error: | |
if error.errno != errno.EEXIST: | |
raise | |
self.list_img = None | |
self.list_label = None | |
self.list_name = None | |
self.crop = crop | |
self.erase = erase | |
self.mapper = mapper | |
# img regularizer | |
regc = 1 | |
if reg: | |
picklename = "r_" + picklename | |
regc = 256 | |
# get pickle if exists | |
picklename = os.path.join("./pickles/", picklename) | |
if picklename is not None and os.path.isfile(picklename): | |
pickle = torch.load(picklename) | |
self.list_img = pickle["img"] | |
self.list_label = pickle["label"] | |
self.list_name = pickle["name"] | |
else: | |
self.list_img = list() | |
self.list_name = list() | |
self.list_label = list() | |
# read foldernames(=label) and images and imagenames (for csv) | |
set_listdir = os.listdir(set_dirname) | |
if label is None: | |
for folder in set_listdir: | |
if folder[0] == '.': | |
continue | |
classid_label = folder[0:2] | |
subset_dirname = os.path.join(set_dirname, folder) | |
subset_listdir = os.listdir(subset_dirname) | |
for subset_filename in subset_listdir: | |
img_dirname = os.path.join(subset_dirname, subset_filename) | |
img = np.transpose(cv2.imread(img_dirname), (2, 1, 0)).astype(np.float)/regc | |
img = np.pad(img, ((0,), (4,), (4,)), 'constant', constant_values=0) | |
self.list_img.append(img) | |
self.list_label.append(int(classid_label)) | |
self.list_name.append(subset_filename) | |
# read images and imagenames (for csv) | |
else: | |
for set_filename in set_listdir: | |
img_dirname = os.path.join(set_dirname, set_filename) | |
img = np.transpose(cv2.imread(img_dirname), (2, 1, 0)).astype(np.float)/regc | |
img = np.pad(img, ((0,), (4,), (4,)), 'constant', constant_values=0) | |
self.list_img.append(img) | |
self.list_name.append(set_filename) | |
self.list_label.append(label) | |
# generate pickle | |
if picklename is not None: | |
pickle = dict() | |
pickle["img"] = self.list_img | |
pickle["label"] = self.list_label | |
pickle["name"] = self.list_name | |
torch.save(pickle, picklename) | |
self.len = len(self.list_img) | |
# debug | |
# for i in range(self.len): | |
# _, l1, l2 = self.__getitem__(i) | |
# print(l1, l2) | |
print("length: ", self.len) | |
def __getitem__(self, item): | |
output_img = self.list_img[item] | |
output_label = self.list_label[item] if self.list_label is not None else -1 | |
output_name = self.list_name[item] | |
# Data augmentation | |
# horizontal flipping: implemented outside | |
# random erasing: from | |
# Random Erasing Data Augmentation | |
# arXiv:1708.04896 | |
# (not implemented) | |
# 4padding and 32cropping | |
p_padcrop = random.uniform(0, 1) | |
if p_padcrop < self.crop: | |
dw = random.randint(0, 7) | |
dh = random.randint(0, 7) | |
else: | |
dw = 4 | |
dh = 4 | |
output_img = output_img[:, dw:32 + dw, dh:32 + dh] | |
# Label mapping | |
if self.mapper is not None: | |
output_label = self.mapper(output_label) | |
# debug | |
# print(output_img, output_label, output_name) | |
return output_img, output_label, output_name | |
def __len__(self): | |
return self.len | |
# horizontal flipping | |
class ImageDatasetFlipper(Dataset): | |
def __init__(self, imagedataset: ImageDataset): | |
self.original = imagedataset | |
def __getitem__(self, item): | |
original_item = self.original.__getitem__(item) | |
output_img, output_label, output_name = original_item | |
np.flip(output_img, 2) | |
return output_img, output_label, "f" + output_name | |
def __len__(self): | |
return self.original.__len__() | |
# training dataset | |
def get_tr_imgds(classid: int=None, mirror=False, reg=True, crop=0, mapper=None): | |
path = [ | |
"00_cup_n03147509", | |
"01_coffee_n07929519", | |
"02_bed_n02818832", | |
"03_tree_n13104059", | |
"04_bird_n01503061", | |
"05_chair_n03001627", | |
"06_tea_n07933274", | |
"07_bread_n07679356", | |
"08_bicycle_n02834778", | |
"09_sail_n04127904" | |
] | |
if classid is None: | |
data = ImageDataset("./all/tr/", picklename="p_tr.bin", reg=reg, crop=crop, mapper=mapper) | |
else: | |
data = ImageDataset("./all/tr/" + path[classid], | |
picklename="p_tr_" + str(classid) + ".bin", | |
label=classid, reg=reg, crop=crop, mapper=mapper) | |
if mirror: | |
return concatmirror(data) | |
else: | |
return data | |
# test dataset | |
def get_ts_imgds(reg=True): | |
data = ImageDataset("./all/ts/", picklename="p_ts.bin", label=-1, reg=reg) | |
return data | |
# validation dataset | |
def get_val_imgds(classid: int=None, reg=True, mapper=None): | |
path = [ | |
"00_cup_n03147509", | |
"01_coffee_n07929519", | |
"02_bed_n02818832", | |
"03_tree_n13104059", | |
"04_bird_n01503061", | |
"05_chair_n03001627", | |
"06_tea_n07933274", | |
"07_bread_n07679356", | |
"08_bicycle_n02834778", | |
"09_sail_n04127904" | |
] | |
if classid is None: | |
data = ImageDataset("./all/val/", picklename="p_val.bin", reg=reg, mapper=mapper) | |
else: | |
data = ImageDataset("./all/val/" + path[classid], | |
picklename="p_val_" + str(classid) + ".bin", | |
label=classid, reg=reg, mapper=mapper) | |
return data | |
# horizontal filpper generator | |
def concatmirror(data: ImageDataset): | |
mirror = ImageDatasetFlipper(data) | |
return ConcatDataset([data, mirror]) | |
# dataloader tester | |
if __name__ == '__main__': | |
device = "cuda" | |
mem = (device == "cuda") | |
_tr_dsloader = DataLoader(get_tr_imgds(), batch_size=100, shuffle=True, pin_memory=mem) | |
_ts_dsloader = DataLoader(get_ts_imgds(), batch_size=100, shuffle=False, pin_memory=mem) | |
_val_dsloader = DataLoader(get_val_imgds(1), batch_size=100, shuffle=False, pin_memory=mem) | |
for d in _val_dsloader: | |
img, label, name = d | |
print(img, label, name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment