Last active
June 10, 2021 21:49
-
-
Save oliland/f29f4014369d23ba7d2549e95bb8a51e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Sensor: | |
def __init__(self, psf): | |
# helper methods for fourier transforms of a PSF | |
size = psf.shape[-2:] | |
# pad image bounds to double PSF size | |
pad_size = [s * 2 for s in size] | |
# find coordinates of center in padded image | |
ys = (pad_size[0] - size[0]) // 2 | |
ye = ys + size[0] | |
xs = (pad_size[1] - size[1]) // 2 | |
xe = xs + size[1] | |
# store dimensions for methods | |
self.axes = (-2, -1) | |
self.size = tuple(size) | |
self.pad_size = tuple(pad_size) | |
self.y_center = slice(ys, ye) | |
self.x_center = slice(xs, xe) | |
# store psf and fourier transform | |
# for solvers, forward & backward models | |
self.psf = psf | |
self.h = self.fft(self.pad(self.psf)) | |
self.h_conj = self.h.conj() | |
def pad(self, image): | |
if image.shape[-2:] != self.size: | |
raise ValueError("image shape does not match psf shape") | |
shape = image.shape[:-2] + self.pad_size | |
output = np.zeros(shape, dtype=np.complex64) | |
output[..., self.y_center, self.x_center] = image | |
return output | |
def crop(self, image): | |
if image.shape[-2:] != self.pad_size: | |
raise ValueError("padded image shape does not match padded psf shape") | |
return image[..., self.y_center, self.x_center] | |
def fft(self, image): | |
return fft.fft2(fft.ifftshift(image, axes=self.axes), axes=self.axes, norm='ortho') | |
def ifft(self, image): | |
return fft.fftshift(fft.ifft2(image, axes=self.axes, norm='ortho'), axes=self.axes) | |
def autocorrelation(self): | |
image = self.ifft(self.h * self.h) | |
return image.real | |
def convolve(self, image): | |
# punt to fourier, convolve, and back! | |
image = self.ifft(self.h * self.fft(image)) | |
# real values only please | |
return image.real | |
def convolve_adj(self, image): | |
# punt to fourier | |
image = self.ifft(self.h_conj * self.fft(image)) | |
# real values only please | |
return image.real | |
def forward(self, image): | |
# takes a picture through the diffuser | |
output = image | |
# pad the image | |
output = self.pad(output) | |
# punt to frequency domain and convolve with psf | |
output = self.convolve(output) | |
# our sensor has an aperture, so we need to crop | |
output = self.crop(output) | |
return output | |
def adjoint(self, image): | |
# send errors back through the diffuser | |
output = image | |
# pad | |
output = self.pad(output) | |
# punt to fourier and convolve with adjoint | |
output = self.convolve_adj(output) | |
# crop | |
output = self.crop(output) | |
return output | |
class FISTA: | |
def __init__(self, sensor): | |
self.sensor = sensor | |
def solve(self, image, iters): | |
vk = np.zeros_like(image) | |
xk_prev = np.zeros_like(image) | |
xk_next = np.zeros_like(image) | |
tk_prev = 1 | |
tk_next = 1 | |
alpha = 1.8 / np.max(np.real(self.sensor.h_conj * self.sensor.h)) | |
for i in range(0, iters): | |
xk_prev = xk_next | |
forward = self.sensor.forward(vk) | |
error = forward - image | |
gradient = self.sensor.adjoint(error) | |
vk = vk - alpha * gradient | |
xk_next = np.maximum(vk, 0) | |
tk_prev = (1 + np.sqrt(1 + 4 * tk_next ** 2)) / 2 | |
vk = xk_next + (tk_next - 1) / tk_prev * (xk_next - xk_prev) | |
tk_next = tk_prev | |
return vk |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment