Last active
October 25, 2020 10:12
-
-
Save RaphaelMeudec/74d7889e0dea467b0d8107c64792ce8d to your computer and use it in GitHub Desktop.
Create a simple tf.data Dataset for an image deblurring task
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import tensorflow as tf | |
def select_patch(sharp, blur, patch_size_x, patch_size_y): | |
""" | |
Select a patch on both sharp and blur images at the same localization. | |
Args: | |
sharp (tf.Tensor): Tensor for the sharp image | |
blur (tf.Tensor): Tensor for the blur image | |
patch_size_x (int): Size of patch along x axis | |
patch_size_y (int): Size of patch along y axis | |
Returns: | |
Tuple[tf.Tensor, tf.Tensor]: Tuple of tensors with shape (patch_size_x, patch_size_y, 3) | |
""" | |
stack = tf.stack([sharp, blur], axis=0) | |
patches = tf.image.random_crop(stack, size=[2, patch_size_x, patch_size_y, 3]) | |
return (patches[0], patches[1]) | |
class TensorflowDatasetLoader: | |
def __init__(self, dataset_path, batch_size=4, patch_size=(256, 256), n_epochs=10, n_images=None): | |
# List all images paths | |
sharp_images_paths = [str(path) for path in Path(dataset_path).glob("*/sharp/*.png")] | |
if n_images is not None: | |
sharp_images_paths = sharp_images_paths[0:n_images] | |
# Generate corresponding blurred images paths | |
blur_images_paths = [path.replace("sharp", "blur") for path in sharp_images_paths] | |
# Load sharp and blurred images | |
sharp_dataset = tf.data.Dataset.from_tensor_slices(sharp_images_paths).map( | |
lambda path: self.load_image(path, dtype), | |
) | |
blur_dataset = tf.data.Dataset.from_tensor_slices(blur_images_paths).map( | |
lambda path: self.load_image(path, dtype), | |
) | |
dataset = tf.data.Dataset.zip((sharp_dataset, blur_dataset)) | |
# Select the same patch on the sharp image and its corresponding blurred | |
dataset = dataset.map( | |
lambda sharp_image, blur_image: select_patch( | |
sharp_image, blur_image, patch_size[0], patch_size[1] | |
) | |
) | |
# Define dataset characteristics (batch_size, number_of_epochs, shuffling) | |
dataset = dataset.batch(batch_size) | |
dataset = dataset.shuffle(buffer_size=50) | |
dataset = dataset.repeat() | |
self.dataset = dataset | |
@staticmethod | |
def load_image(image_path, dtype): | |
image = tf.io.read_file(image_path) | |
image = tf.image.decode_png(image, channels=3) | |
image = tf.image.convert_image_dtype(image, dtype) | |
image = (image - 0.5) * 2 | |
return image |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment