Skip to content

Instantly share code, notes, and snippets.

@oliland
Last active June 10, 2021 21:49
Show Gist options
  • Save oliland/f29f4014369d23ba7d2549e95bb8a51e to your computer and use it in GitHub Desktop.
Save oliland/f29f4014369d23ba7d2549e95bb8a51e to your computer and use it in GitHub Desktop.
class Sensor:
def __init__(self, psf):
# helper methods for fourier transforms of a PSF
size = psf.shape[-2:]
# pad image bounds to double PSF size
pad_size = [s * 2 for s in size]
# find coordinates of center in padded image
ys = (pad_size[0] - size[0]) // 2
ye = ys + size[0]
xs = (pad_size[1] - size[1]) // 2
xe = xs + size[1]
# store dimensions for methods
self.axes = (-2, -1)
self.size = tuple(size)
self.pad_size = tuple(pad_size)
self.y_center = slice(ys, ye)
self.x_center = slice(xs, xe)
# store psf and fourier transform
# for solvers, forward & backward models
self.psf = psf
self.h = self.fft(self.pad(self.psf))
self.h_conj = self.h.conj()
def pad(self, image):
if image.shape[-2:] != self.size:
raise ValueError("image shape does not match psf shape")
shape = image.shape[:-2] + self.pad_size
output = np.zeros(shape, dtype=np.complex64)
output[..., self.y_center, self.x_center] = image
return output
def crop(self, image):
if image.shape[-2:] != self.pad_size:
raise ValueError("padded image shape does not match padded psf shape")
return image[..., self.y_center, self.x_center]
def fft(self, image):
return fft.fft2(fft.ifftshift(image, axes=self.axes), axes=self.axes, norm='ortho')
def ifft(self, image):
return fft.fftshift(fft.ifft2(image, axes=self.axes, norm='ortho'), axes=self.axes)
def autocorrelation(self):
image = self.ifft(self.h * self.h)
return image.real
def convolve(self, image):
# punt to fourier, convolve, and back!
image = self.ifft(self.h * self.fft(image))
# real values only please
return image.real
def convolve_adj(self, image):
# punt to fourier
image = self.ifft(self.h_conj * self.fft(image))
# real values only please
return image.real
def forward(self, image):
# takes a picture through the diffuser
output = image
# pad the image
output = self.pad(output)
# punt to frequency domain and convolve with psf
output = self.convolve(output)
# our sensor has an aperture, so we need to crop
output = self.crop(output)
return output
def adjoint(self, image):
# send errors back through the diffuser
output = image
# pad
output = self.pad(output)
# punt to fourier and convolve with adjoint
output = self.convolve_adj(output)
# crop
output = self.crop(output)
return output
class FISTA:
def __init__(self, sensor):
self.sensor = sensor
def solve(self, image, iters):
vk = np.zeros_like(image)
xk_prev = np.zeros_like(image)
xk_next = np.zeros_like(image)
tk_prev = 1
tk_next = 1
alpha = 1.8 / np.max(np.real(self.sensor.h_conj * self.sensor.h))
for i in range(0, iters):
xk_prev = xk_next
forward = self.sensor.forward(vk)
error = forward - image
gradient = self.sensor.adjoint(error)
vk = vk - alpha * gradient
xk_next = np.maximum(vk, 0)
tk_prev = (1 + np.sqrt(1 + 4 * tk_next ** 2)) / 2
vk = xk_next + (tk_next - 1) / tk_prev * (xk_next - xk_prev)
tk_next = tk_prev
return vk
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment