Created
June 13, 2024 00:20
-
-
Save alisterburt/3e66afe30e83ae4a1774f0ff623a407e to your computer and use it in GitHub Desktop.
3D correlations from multiple 2D correlations for Will Wan
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import einops | |
import numpy as np | |
import torch | |
import mrcfile | |
import napari | |
from torch_fourier_slice import project_3d_to_2d | |
from torch_image_lerp import sample_image_2d | |
from torch_grid_utils import coordinate_grid | |
from scipy.spatial.transform import Rotation as R | |
# read in volume file | |
volume_file = '/Users/burta2/data/4v6x_bin4.mrc' | |
reference_volume = torch.tensor(mrcfile.read(volume_file)) | |
# make projections of a shifted copy of the volume | |
# this is our 'experimental particle' | |
# do the whole expriment a bunch of times to compare results | |
for i in range(100): | |
experimental_particle_3d = torch.zeros_like(reference_volume) | |
true_shift = np.random.randint(low=1, high=10, size=(3, )) | |
while np.linalg.norm(true_shift) > 30: | |
true_shift = np.random.randint(low=1, high=10, size=(3, )) | |
print(f'true shift: {true_shift}') | |
sx, sy, sz = true_shift | |
experimental_particle_3d[sx:, sy:, sz:] = reference_volume[:-sx, :-sy, :-sz] | |
# vis | |
# viewer = napari.Viewer(ndisplay=3) | |
# viewer.add_image(reference_volume.numpy()) | |
# viewer.add_image(experimental_particle_3d.numpy()) | |
# napari.run() | |
# simulate a -60 to +60 tilt series of the reference | |
# rotation matrices rotate internal coordinate system (intrinsic rather than extrinsic) | |
tilt_angles = np.linspace(-60, 60, 41, endpoint=True) | |
rotation_matrices = R.from_euler(angles=tilt_angles, seq='y', degrees=True).as_matrix() | |
rotation_matrices = torch.tensor(rotation_matrices).float() | |
reference_projections = project_3d_to_2d(reference_volume, rotation_matrices=rotation_matrices) | |
# vis | |
# viewer = napari.Viewer() | |
# viewer.add_image(projections.numpy()) | |
# napari.run() | |
# simulate experimental projections from experimental particle | |
# (these would be extracted from TS in actual processing) | |
experimental_projections = project_3d_to_2d(experimental_particle_3d, rotation_matrices=rotation_matrices) | |
# correlate 2D images | |
reference_projections_centered = torch.fft.fftshift(reference_projections, dim=(-2, -1)) | |
reference_projections_dft = torch.fft.rfftn(reference_projections_centered, dim=(-2, -1)) | |
experimental_projections_dft = torch.fft.rfftn(experimental_projections, dim=(-2, -1)) | |
correlations_2d_dft = reference_projections_dft * experimental_projections_dft | |
correlations_2d = torch.fft.irfftn(correlations_2d_dft, dim=(-2, -1)) | |
image_center_2d = torch.tensor(reference_projections.shape[-2:]) // 2 | |
# vis | |
# viewer = napari.Viewer() | |
# viewer.add_image(reference_projections.numpy()) | |
# viewer.add_image(experimental_projections.numpy()) | |
# viewer.add_image(correlations_2d.numpy()) | |
# viewer.add_points(image_center_2d.numpy(), face_color='red', size=5) | |
# napari.run() | |
# grid of possible 3D shift values | |
shift_grid = coordinate_grid( | |
image_shape=reference_volume.shape, | |
center=np.array(reference_volume.shape) // 2, | |
) | |
# define a restricted region of valid correlations, here a sphere of radius 5 | |
correlation_mask_3d = torch.linalg.norm(shift_grid, dim=-1) <= 35 # (d, d, d) | |
# get the 3D xyz shifts for each point within the correlation mask | |
valid_shifts_zyx = shift_grid[correlation_mask_3d, :] # (b, 3) array of zyx shifts | |
valid_shifts_xyz = torch.flip(valid_shifts_zyx, dims=(-1, )) | |
# project these 3D shifts into 2D | |
# shifts are (nshifts, 3) | |
# rotation matrices are (ntilts, 3, 3) | |
# result will be an | |
# - (ntilts, nshifts, 2) array of xy shifts | |
valid_shifts_xyz = einops.rearrange(valid_shifts_xyz, 'nshifts xyz -> nshifts xyz 1') | |
rotation_matrices_extrinsic = torch.linalg.inv(rotation_matrices) | |
rotation_matrices_extrinsic = einops.rearrange(rotation_matrices_extrinsic, 'ntilts i j -> ntilts 1 i j') | |
projection_matrices = rotation_matrices_extrinsic[..., :2, :] | |
projected_shifts_xy = projection_matrices @ valid_shifts_xyz # (ntilts, nshifts, 2, 1) | |
projected_shifts_xy = einops.rearrange(projected_shifts_xy, 'ntilts nshifts xy 1 -> ntilts nshifts xy') | |
# after projecting the shifts, let's sample the 2D correlation functions at projected shift positions | |
# remembering to fix xy -> yx for image sampling | |
sampling_positions_yx_2d = torch.flip(projected_shifts_xy, dims=(-1,)) + image_center_2d | |
# visualise sample positions relative to image center | |
# viewer = napari.Viewer() | |
# points_napari = einops.rearrange(sampling_positions_yx_2d[:, 0:5], 'ntilts nshifts yx -> (ntilts nshifts) yx') | |
# viewer.add_points(points_napari, size=0.1, face_color='cornflowerblue') | |
# viewer.add_points(image_center_2d.numpy(), size=0.5, face_color='red') | |
# napari.run() | |
correlation_samples = torch.stack( | |
[ | |
sample_image_2d(tilt_correlation_image, coordinates=per_tilt_samples) | |
for tilt_correlation_image, per_tilt_samples | |
in zip(correlations_2d, sampling_positions_yx_2d) | |
] | |
) # (ntilts, nshifts) array of per-tilt, per-shift correlation values | |
# now we can calculate the 3D correlation value by doing a sum of the 2D correlations | |
weights = torch.ones(size=(experimental_projections.shape[0], 1)) | |
weighted_correlations = weights * correlation_samples | |
per_shift_correlations = einops.reduce( | |
weighted_correlations, 'ntilts nshifts -> nshifts', reduction='sum' | |
) | |
# which shift has the max correlation? | |
best_shift_idx = torch.argmax(per_shift_correlations) | |
best_shift = np.array(valid_shifts_zyx[best_shift_idx]) | |
print(f'best shift: {best_shift}') | |
print(f'difference: {best_shift - true_shift}') | |
print('\n') | |
# something appears systematically off, tendency for z, y shifts to be 1 away from ideal | |
# cc peaks are very diffuse same as we saw in Cambridge - I'm not sure why but maybe you have an idea? | |
# otherwise the core of it is working! | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment