Skip to content

Instantly share code, notes, and snippets.

@BlackHC
Created August 4, 2024 11:59
Show Gist options
  • Save BlackHC/30bb8438e7c8c82dcd6b12fbd298e1c0 to your computer and use it in GitHub Desktop.
Save BlackHC/30bb8438e7c8c82dcd6b12fbd298e1c0 to your computer and use it in GitHub Desktop.
ImageNet v2 Loader for PyTorch
import typing
import datasets
import torch
import torch.utils.data
def load_imagenet_v2(
split: typing.Literal[
"threshold0.7", "top-images", "matching-frequency"
] = "threshold0.7"
) -> torch.utils.data.Dataset:
dataset = datasets.load_dataset(
"vaishaal/ImageNetV2",
data_files=[f"imagenetv2-{split}.tar.gz"],
download_config=datasets.DownloadConfig(resume_download=True),
)
all_samples = dataset["train"][0]
meta_keys = ("__key__", "__url__")
labels = [
torch.tensor(int(path.split("/")[1]))
for path in all_samples.keys()
if path not in meta_keys
]
images = [value for key, value in all_samples.items() if key not in meta_keys]
return torch.utils.data.StackDataset(images, labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment