Skip to content

Instantly share code, notes, and snippets.

@jgs03177
Created January 23, 2024 10:36
Show Gist options
  • Save jgs03177/7d9a24a47afd4f10d97246d2d1ddea31 to your computer and use it in GitHub Desktop.
Save jgs03177/7d9a24a47afd4f10d97246d2d1ddea31 to your computer and use it in GitHub Desktop.
pytorch dataloader example
import os
import errno
import cv2
import torch
import numpy as np
import random
from torch.utils.data import Dataset, ConcatDataset, DataLoader
from typing import Callable
class ImageDataset(Dataset):
"""init
get dataset from set_dirname.
if label is none then folders names are labels. (format: ./labelname/imgname)
if label is not none then use labels. (format: ./imgname)
if picklename is not none then save/load pickles with picklename.
reg: img regularizer
crop: img cropping prob.
erase: img random erasing prob.
mapper: label mapping function.
"""
def __init__(self,
set_dirname: str,
picklename: str=None,
label: int=None,
reg: bool=True,
crop: float=0,
erase: float=0,
mapper: Callable[[int], int] = None):
# generate folder
if not os.path.exists("./pickles"):
try:
os.makedirs("./pickles")
except OSError as error:
if error.errno != errno.EEXIST:
raise
self.list_img = None
self.list_label = None
self.list_name = None
self.crop = crop
self.erase = erase
self.mapper = mapper
# img regularizer
regc = 1
if reg:
picklename = "r_" + picklename
regc = 256
# get pickle if exists
picklename = os.path.join("./pickles/", picklename)
if picklename is not None and os.path.isfile(picklename):
pickle = torch.load(picklename)
self.list_img = pickle["img"]
self.list_label = pickle["label"]
self.list_name = pickle["name"]
else:
self.list_img = list()
self.list_name = list()
self.list_label = list()
# read foldernames(=label) and images and imagenames (for csv)
set_listdir = os.listdir(set_dirname)
if label is None:
for folder in set_listdir:
if folder[0] == '.':
continue
classid_label = folder[0:2]
subset_dirname = os.path.join(set_dirname, folder)
subset_listdir = os.listdir(subset_dirname)
for subset_filename in subset_listdir:
img_dirname = os.path.join(subset_dirname, subset_filename)
img = np.transpose(cv2.imread(img_dirname), (2, 1, 0)).astype(np.float)/regc
img = np.pad(img, ((0,), (4,), (4,)), 'constant', constant_values=0)
self.list_img.append(img)
self.list_label.append(int(classid_label))
self.list_name.append(subset_filename)
# read images and imagenames (for csv)
else:
for set_filename in set_listdir:
img_dirname = os.path.join(set_dirname, set_filename)
img = np.transpose(cv2.imread(img_dirname), (2, 1, 0)).astype(np.float)/regc
img = np.pad(img, ((0,), (4,), (4,)), 'constant', constant_values=0)
self.list_img.append(img)
self.list_name.append(set_filename)
self.list_label.append(label)
# generate pickle
if picklename is not None:
pickle = dict()
pickle["img"] = self.list_img
pickle["label"] = self.list_label
pickle["name"] = self.list_name
torch.save(pickle, picklename)
self.len = len(self.list_img)
# debug
# for i in range(self.len):
# _, l1, l2 = self.__getitem__(i)
# print(l1, l2)
print("length: ", self.len)
def __getitem__(self, item):
output_img = self.list_img[item]
output_label = self.list_label[item] if self.list_label is not None else -1
output_name = self.list_name[item]
# Data augmentation
# horizontal flipping: implemented outside
# random erasing: from
# Random Erasing Data Augmentation
# arXiv:1708.04896
# (not implemented)
# 4padding and 32cropping
p_padcrop = random.uniform(0, 1)
if p_padcrop < self.crop:
dw = random.randint(0, 7)
dh = random.randint(0, 7)
else:
dw = 4
dh = 4
output_img = output_img[:, dw:32 + dw, dh:32 + dh]
# Label mapping
if self.mapper is not None:
output_label = self.mapper(output_label)
# debug
# print(output_img, output_label, output_name)
return output_img, output_label, output_name
def __len__(self):
return self.len
# horizontal flipping
class ImageDatasetFlipper(Dataset):
def __init__(self, imagedataset: ImageDataset):
self.original = imagedataset
def __getitem__(self, item):
original_item = self.original.__getitem__(item)
output_img, output_label, output_name = original_item
np.flip(output_img, 2)
return output_img, output_label, "f" + output_name
def __len__(self):
return self.original.__len__()
# training dataset
def get_tr_imgds(classid: int=None, mirror=False, reg=True, crop=0, mapper=None):
path = [
"00_cup_n03147509",
"01_coffee_n07929519",
"02_bed_n02818832",
"03_tree_n13104059",
"04_bird_n01503061",
"05_chair_n03001627",
"06_tea_n07933274",
"07_bread_n07679356",
"08_bicycle_n02834778",
"09_sail_n04127904"
]
if classid is None:
data = ImageDataset("./all/tr/", picklename="p_tr.bin", reg=reg, crop=crop, mapper=mapper)
else:
data = ImageDataset("./all/tr/" + path[classid],
picklename="p_tr_" + str(classid) + ".bin",
label=classid, reg=reg, crop=crop, mapper=mapper)
if mirror:
return concatmirror(data)
else:
return data
# test dataset
def get_ts_imgds(reg=True):
data = ImageDataset("./all/ts/", picklename="p_ts.bin", label=-1, reg=reg)
return data
# validation dataset
def get_val_imgds(classid: int=None, reg=True, mapper=None):
path = [
"00_cup_n03147509",
"01_coffee_n07929519",
"02_bed_n02818832",
"03_tree_n13104059",
"04_bird_n01503061",
"05_chair_n03001627",
"06_tea_n07933274",
"07_bread_n07679356",
"08_bicycle_n02834778",
"09_sail_n04127904"
]
if classid is None:
data = ImageDataset("./all/val/", picklename="p_val.bin", reg=reg, mapper=mapper)
else:
data = ImageDataset("./all/val/" + path[classid],
picklename="p_val_" + str(classid) + ".bin",
label=classid, reg=reg, mapper=mapper)
return data
# horizontal filpper generator
def concatmirror(data: ImageDataset):
mirror = ImageDatasetFlipper(data)
return ConcatDataset([data, mirror])
# dataloader tester
if __name__ == '__main__':
device = "cuda"
mem = (device == "cuda")
_tr_dsloader = DataLoader(get_tr_imgds(), batch_size=100, shuffle=True, pin_memory=mem)
_ts_dsloader = DataLoader(get_ts_imgds(), batch_size=100, shuffle=False, pin_memory=mem)
_val_dsloader = DataLoader(get_val_imgds(1), batch_size=100, shuffle=False, pin_memory=mem)
for d in _val_dsloader:
img, label, name = d
print(img, label, name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment