Skip to content

Instantly share code, notes, and snippets.

@AlessandroMondin
Created December 23, 2022 10:20
Show Gist options
  • Save AlessandroMondin/9c6ca94b57ab5f5b0af41d5e59477d9d to your computer and use it in GitHub Desktop.
Save AlessandroMondin/9c6ca94b57ab5f5b0af41d5e59477d9d to your computer and use it in GitHub Desktop.
class Training_Dataset(Dataset):
"""COCO 2017 dataset constructed using the PyTorch built-in functionalities"""
def __init__(self,
num_classes,
root_directory=config.ROOT_DIR,
transform=None,
train=True,
rect_training=False,
default_size=640,
bs=64,
bboxes_format="coco",
ultralytics_loss=False,
):
assert bboxes_format in ["coco", "yolo"], 'bboxes_format must be either "coco" or "yolo"'
self.bs = bs
self.batch_range = 64 if bs < 64 else 128
self.bboxes_format = bboxes_format
self.ultralytics_loss = ultralytics_loss
self.root_directory = root_directory
self.nc = num_classes
self.transform = transform
self.rect_training = rect_training
self.default_size = default_size
self.train = train
if train:
fname = 'images/train'
annot_file = "annot_train.csv"
self.annot_folder = "train"
else:
fname = 'images/val'
annot_file = "annot_val.csv"
self.annot_folder = "val"
self.fname = fname
try:
self.annotations = pd.read_csv(os.path.join(root_directory, "labels", annot_file),
header=None, index_col=0).sort_values(by=[0])
self.annotations = self.annotations.head((len(self.annotations)-1)) # just removes last line
except FileNotFoundError:
annotations = []
for img_txt in os.listdir(os.path.join(self.root_directory, "labels", self.annot_folder)):
img = img_txt.split(".txt")[0]
try:
w, h = imagesize.get(os.path.join(self.root_directory, "images", self.annot_folder, f"{img}.jpg"))
except FileNotFoundError:
continue
annotations.append([str(img) + ".jpg", h, w])
self.annotations = pd.DataFrame(annotations)
self.annotations.to_csv(os.path.join(self.root_directory, "labels", annot_file))
self.len_ann = len(self.annotations)
if rect_training:
self.annotations = self.adaptive_shape(self.annotations)
def __len__(self):
return len(self.annotations)
def __getitem__(self, idx):
img_name = self.annotations.iloc[idx, 0]
tg_height = self.annotations.iloc[idx, 1] if self.rect_training else 640
tg_width = self.annotations.iloc[idx, 2] if self.rect_training else 640
# img_name[:-4] to remove the .jpg or .png which are coco img formats
label_path = os.path.join(self.root_directory, "labels", self.annot_folder, img_name[:-4] + ".txt")
# to avoid an annoying "UserWarning: loadtxt: Empty input file"
with warnings.catch_warnings():
warnings.simplefilter("ignore")
labels = np.loadtxt(fname=label_path, delimiter=" ", ndmin=2)
# removing annotations with negative values
labels = labels[np.all(labels >= 0, axis=1), :]
# to avoid negative values
labels[:, 3:5] = np.floor(labels[:, 3:5] * 1000) / 1000
img = np.array(Image.open(os.path.join(self.root_directory, self.fname, img_name)).convert("RGB"))
if self.bboxes_format == "coco":
labels[:, -1] -= 1 # 0-indexing the classes of coco labels (1-80 --> 0-79)
labels = np.roll(labels, axis=1, shift=1)
# normalized coordinates are scale invariant, hence after resizing the img we don't resize labels
labels[:, 1:] = coco_to_yolo_tensors(labels[:, 1:], w0=img.shape[1], h0=img.shape[0])
img = resize_image(img, (int(tg_width), int(tg_height)))
if self.transform:
batch_n = idx // self.bs
if batch_n % 2 == 0:
self.transform[1].p = 1
else:
self.transform[1].p = 0
# albumentations requires bboxes to be (x,y,w,h,class_idx)
augmentations = self.transform(image=img,
bboxes=np.roll(labels, axis=1, shift=4)
)
img = augmentations["image"]
labels = np.array(augmentations["bboxes"])
if len(labels):
labels = np.roll(labels, axis=1, shift=1)
if self.ultralytics_loss:
labels = torch.from_numpy(labels)
out_bboxes = torch.zeros((labels.shape[0], 6))
if len(labels):
out_bboxes[..., 1:] = labels
img = img.transpose((2, 0, 1))
img = np.ascontiguousarray(img)
return torch.from_numpy(img), out_bboxes if self.ultralytics_loss else labels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment