Skip to content

Instantly share code, notes, and snippets.

@patrickmineault
Created December 28, 2020 19:37
Show Gist options
  • Save patrickmineault/21b8d78f423ac8ea4b006f9ec1a1a1a7 to your computer and use it in GitHub Desktop.
Save patrickmineault/21b8d78f423ac8ea4b006f9ec1a1a1a7 to your computer and use it in GitHub Desktop.
Downsample a stack of 2d images in PyTorch
def downsample_2d(X, sz):
"""
Downsamples a stack of square images.
Args:
X: a stack of images (batch, channels, ny, ny).
sz: the desired size of images.
Returns:
The downsampled images, a tensor of shape (batch, channel, sz, sz)
"""
kernel = torch.tensor([[.25, .5, .25],
[.5, 1, .5],
[.25, .5, .25]], device=X.device).reshape(1, 1, 3, 3)
kernel = kernel.repeat((X.shape[1], 1, 1, 1))
while sz < X.shape[-1] / 2:
# Downsample by a factor 2 with smoothing
mask = torch.ones(1, *X.shape[1:])
mask = F.conv2d(mask, kernel, groups=X.shape[1], stride=2, padding=1)
X = F.conv2d(X, kernel, groups=X.shape[1], stride=2, padding=1)
# Normalize the edges and corners.
X = X = X / mask
return F.interpolate(X, size=sz, mode='bilinear')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment