Skip to content

Instantly share code, notes, and snippets.

@gbiz123
Last active July 29, 2024 02:46
Show Gist options
  • Save gbiz123/27f77712a7c4d26b000c41ef795aa2c2 to your computer and use it in GitHub Desktop.
Save gbiz123/27f77712a7c4d26b000c41ef795aa2c2 to your computer and use it in GitHub Desktop.
Downsample Binary Pytorch Dataset Down To Size Of Smallest Class
import torch
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Subset
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory
import random
def downsample_balance_binary_dataset(dataset: Dataset) -> Dataset:
class_0_count = len([d for d in dataset if d[1] == 0])
class_1_count = len([d for d in dataset if d[1] == 1])
if class_0_count > class_1_count:
class_0_indeces = [i for i, val in enumerate(dataset) if val[1] == 0]
class_1_indeces = [i for i, val in enumerate(dataset) if val[1] == 1]
downsampled_class_0_indeces = class_0_indeces[:class_1_count]
if len(class_1_indeces) != len(downsampled_class_0_indeces):
raise ValueError("Error during downsampling, class_1_indices was not the same as downsampled_class_0_indeces")
all_indices = downsampled_class_0_indeces + class_1_indeces
print(f"Sampled dataset down to {len(all_indices)} samples")
return Subset(dataset, all_indices)
elif class_1_count > class_0_count:
class_0_indeces = [i for i, val in enumerate(dataset) if val[1] == 0]
class_1_indeces = [i for i, val in enumerate(dataset) if val[1] == 1]
downsampled_class_1_indeces = class_1_indeces[:class_0_count]
all_indices = downsampled_class_1_indeces + class_0_indeces
if len(class_0_indeces) != len(downsampled_class_1_indeces):
raise ValueError("Error during downsampling, class_0_indices was not the same as downsampled_class_0_indeces")
print(f"Sampled dataset down to {len(all_indices)} samples")
return Subset(dataset, all_indices)
else:
return dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment