Last active
January 8, 2019 05:50
-
-
Save davidbau/67cd2b0a7e7438262e3a7879b269bc34 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch.utils.data as data | |
from torchvision.datasets.folder import default_loader, is_image_file | |
from PIL import Image | |
def grayscale_loader(path): | |
with open(path, 'rb') as f: | |
return Image.open(f).convert('L') | |
class FeatureFolder(data.Dataset): | |
""" | |
A data loader that looks for parallel image filenames | |
photo/park/004234.jpg | |
photo/park/004236.jpg | |
photo/park/004237.jpg | |
feature/park/004234.png | |
feature/park/004236.png | |
feature/park/004237.png | |
""" | |
def __init__(self, source_root, target_root, | |
source_transform=None, target_transform=None, | |
source_loader=default_loader, target_loader=grayscale_loader): | |
self.imagepairs = make_feature_dataset(source_root, target_root) | |
if len(self.imagepairs) == 0: | |
raise RuntimeError("Found 0 images within: %s" % source_root) | |
self.root = source_root | |
self.target_root = target_root | |
self.source_transform = source_transform | |
self.target_transform = target_transform | |
self.source_loader = source_loader | |
self.target_loader = target_loader | |
def __getitem__(self, index): | |
path, target_path = self.imagepairs[index] | |
source = self.source_loader(path) | |
target = self.target_loader(target_path) | |
if self.source_transform is not None: | |
source = self.source_transform(source) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return source, target | |
def __len__(self): | |
return len(self.imagepairs) | |
class FeatureAndClassFolder(data.Dataset): | |
""" | |
A data loader that looks for parallel image filenames | |
photo/park/004234.jpg | |
photo/park/004236.jpg | |
photo/park/004237.jpg | |
feature/park/004234.png | |
feature/park/004236.png | |
feature/park/004237.png | |
""" | |
def __init__(self, source_root, target_root, | |
source_transform=None, target_transform=None, | |
source_loader=default_loader, target_loader=grayscale_loader): | |
classes, class_to_idx = find_classes(source_root) | |
self.imagetriples= make_triples(source_root, target_root, class_to_idx) | |
if len(self.imagetriples) == 0: | |
raise RuntimeError("Found 0 images within: %s" % source_root) | |
self.root = source_root | |
self.target_root = target_root | |
self.classes = classes | |
self.class_to_idx = class_to_idx | |
self.source_transform = source_transform | |
self.target_transform = target_transform | |
self.source_loader = source_loader | |
self.target_loader = target_loader | |
def __getitem__(self, index): | |
path, classidx, target_path = self.imagetriples[index] | |
source = self.source_loader(path) | |
target = self.target_loader(target_path) | |
if self.source_transform is not None: | |
source = self.source_transform(source) | |
if self.target_transform is not None: | |
target = self.target_transform(target) | |
return source, (classidx, target) | |
def __len__(self): | |
return len(self.imagetriples) | |
class CachedImageFolder(data.Dataset): | |
""" | |
A version of torchvision.dataset.ImageFolder that takes advantage | |
of cached filename lists. | |
photo/park/004234.jpg | |
photo/park/004236.jpg | |
photo/park/004237.jpg | |
""" | |
def __init__(self, root, | |
transform=None, | |
loader=default_loader): | |
classes, class_to_idx = find_classes(root) | |
self.imgs = make_class_dataset(root, class_to_idx) | |
if len(self.imgs) == 0: | |
raise RuntimeError("Found 0 images within: %s" % root) | |
self.root = root | |
self.classes = classes | |
self.class_to_idx = class_to_idx | |
self.transform = transform | |
self.loader = loader | |
def __getitem__(self, index): | |
path, classidx = self.imgs[index] | |
source = self.loader(path) | |
if self.transform is not None: | |
source = self.transform(source) | |
return source, classidx | |
def __len__(self): | |
return len(self.imgs) | |
class StackFeatureChannels(object): | |
def __init__(self, channels=None, keep_only=None): | |
self.channels = channels | |
self.keep_only = keep_only | |
def __call__(self, tensor): | |
if self.channels: | |
channels = self.channels | |
height = tensor.shape[1] // channels | |
else: | |
height = tensor.shape[2] | |
channels = tensor.shape[1] // height | |
result = tensor.view(channels, height, tensor.shape[2]) | |
if self.keep_only: | |
result = result[:self.keep_only,...] | |
return result | |
class SoftExpScale(object): | |
def __init__(self, alpha=45.0): | |
self.scale = 255.0 / alpha | |
def __call__(self, tensor): | |
return (tensor * self.scale).exp_().sub_(1) | |
def is_npy_file(path): | |
return path.endswith('.npy') or path.endswith('.NPY') | |
def walk_image_files(rootdir): | |
if os.path.isfile('%s.txt' % rootdir): | |
print('Loading file list from %s.txt instead of scanning dir' % rootdir) | |
basedir = os.path.dirname(rootdir) | |
with open('%s.txt' % rootdir) as f: | |
result = sorted([os.path.join(basedir, line.strip()) | |
for line in f.readlines()]) | |
import random | |
random.Random(1).shuffle(result) | |
return result | |
result = [] | |
for dirname, _, fnames in sorted(os.walk(rootdir)): | |
for fname in sorted(fnames): | |
if is_image_file(fname) or is_npy_file(fname): | |
result.append(os.path.join(dirname, fname)) | |
return result | |
def find_classes(dir): | |
classes = [d for d in os.listdir(dir) | |
if os.path.isdir(os.path.join(dir, d))] | |
classes.sort() | |
class_to_idx = {classes[i]: i for i in range(len(classes))} | |
return classes, class_to_idx | |
def make_feature_dataset(source_root, target_root): | |
""" | |
Finds images in the subdirectories under source_root, and looks for | |
similarly-located images (with the same directory structure | |
and base filenames, but with possibly different file extensions) | |
under target_root. Each source image have a corresponding | |
target image. | |
""" | |
source_root = os.path.expanduser(source_root) | |
target_root = os.path.expanduser(target_root) | |
target_images = {} | |
for path in walk_image_files(target_root): | |
key = os.path.splitext(os.path.relpath(path, target_root))[0] | |
target_images[key] = path | |
imagepairs = [] | |
for path in walk_image_files(source_root): | |
key = os.path.splitext(os.path.relpath(path, source_root))[0] | |
if key not in target_images: | |
raise RuntimeError('%s has no matching target %s.*' % | |
(path, os.path.join(target_root, key)) ) | |
imagepairs.append((path, target_images[key])) | |
return imagepairs | |
def make_triples(source_root, target_root, class_to_idx): | |
""" | |
Returns (source, classnum, feature) | |
""" | |
source_root = os.path.expanduser(source_root) | |
target_root = os.path.expanduser(target_root) | |
target_images = {} | |
for path in walk_image_files(target_root): | |
key = os.path.splitext(os.path.relpath(path, target_root))[0] | |
target_images[key] = path | |
imagetriples = [] | |
for path in walk_image_files(source_root): | |
key = os.path.splitext(os.path.relpath(path, source_root))[0] | |
if key not in target_images: | |
raise RuntimeError('%s has no matching target %s.*' % | |
(path, os.path.join(target_root, key)) ) | |
classname = os.path.basename(os.path.dirname(key)) | |
imagetriples.append((path, class_to_idx[classname], target_images[key])) | |
return imagetriples | |
def make_class_dataset(source_root, class_to_idx): | |
""" | |
Returns (source, classnum, feature) | |
""" | |
imagepairs = [] | |
source_root = os.path.expanduser(source_root) | |
for path in walk_image_files(source_root): | |
classname = os.path.basename(os.path.dirname(path)) | |
imagepairs.append((path, class_to_idx[classname])) | |
return imagepairs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment