Skip to content

Instantly share code, notes, and snippets.

@nilesh0109
Last active February 5, 2023 20:21
Show Gist options
  • Save nilesh0109/f98eed779844c6b570740d5ef78868a3 to your computer and use it in GitHub Desktop.
Save nilesh0109/f98eed779844c6b570740d5ef78868a3 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
## Fisheye Transformation
def get_of_fisheye(height, width, center, magnitude):
xx, yy = torch.linspace(-1, 1, width), torch.linspace(-1, 1, height)
gridy, gridx = torch.meshgrid(yy, xx) #create identity grid
grid = torch.stack([gridx, gridy], dim=-1)
d = center - grid #calculate the distance(cx - x, cy - y)
d_sum = torch.sqrt((d**2).sum(axis=-1)) # sqrt((cx-x)**2 + (cy-y)**2)
grid += d * d_sum.unsqueeze(-1) * magnitude #calculate dx & dy and add to original values
return grid.unsqueeze(0) #unsqueeze(0) since the grid needs to be 4D.
## Horizontal Wave Transformation
def get_of_horizontalwave(height, width, freq, amplitude):
xx, yy = torch.linspace(-1, 1, width), torch.linspace(-1, 1, height)
gridy, gridx = torch.meshgrid(yy, xx) #create identity grid
grid = torch.stack([gridx, gridy], dim=-1)
dy = amplitude * torch.cos(freq * grid[:,:,0]) #calculate dy
grid[:,:,1] += dy
return grid.unsqueeze(0) #unsqueeze(0) since the grid needs to be 4D.
## UTILITY FUNCTIONS
## Create Image Batch
def get_image_batch(img):
transform = transforms.Compose([transforms.ToTensor()])
tfms_img = transform(img)
imgs = torch.unsqueeze(tfms_img, dim=0)
return imgs
def plot(img, fisheye_output, hwave_output):
fisheye_out = fisheye_output[0].numpy()
fisheye_out = np.moveaxis(fisheye_out, 0,-1)
hwave_out = hwave_output[0].numpy()
hwave_out = np.moveaxis(hwave_out, 0,-1)
fig, ax = plt.subplots(1,3, figsize=(16,4))
ax[0].imshow(img)
ax[1].imshow(fisheye_out)
ax[2].imshow(hwave_out)
ax[0].set_title('Input Image(Checkerboard)')
ax[1].set_title('Fisheye')
ax[2].set_title('Horizontal Wave Tfms')
plt.show()
img = Image.open('checkerboard.png')
imgs = get_image_batch(img)
N, C, H, W = imgs.shape
fisheye_grid = get_of_fisheye(H, W, torch.tensor([0,0]), 0.4)
hwave_grid = get_of_horizontalwave(H, W, 10, 0.1)
fisheye_output = F.grid_sample(imgs, fisheye_grid, align_corners=True)
hwave_output = F.grid_sample(imgs, hwave_grid, align_corners=True)
plot(img, fisheye_output, hwave_output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment