Last active
January 16, 2022 09:37
-
-
Save etienne87/0e56cd7ccf6c684407bcb6d3ce6e1eb3 to your computer and use it in GitHub Desktop.
understanding spatial transform in pytorch (simulate 2 vessels)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
plt.switch_backend('tkagg') | |
import torch as th | |
import torch.nn.functional as F | |
import numpy as np | |
import cv2 | |
from scipy.spatial.transform import Rotation | |
def compute_meshgrid(b,d,h,w): | |
theta = th.eye(4)[:3][None].repeat(b,1,1) | |
grid = F.affine_grid(theta, (b,1,d,h,w)) | |
return grid | |
def affine(vecs, trans, scale, rot, order='fw'): | |
out = vecs.clone() | |
b = len(out) | |
scale = scale.view((b,)+(1,)*(vecs.ndim-1)) | |
trans = trans.view((b,)+(1,)*(vecs.ndim-2)+(3,)) | |
if order == 'bw': | |
out = th.einsum('b...j,bjk->b...k', out, rot.permute(0,2,1)) | |
out = out + trans | |
out = out * scale | |
else: | |
out = out / scale | |
out = out - trans | |
out = th.einsum('b...j,bjk->b...k', out, rot) | |
return out | |
def stn(x, trans, scale, rot, padding_mode='zeros'): | |
b,_,d,h,w = x.shape | |
# my method | |
mgrid = compute_meshgrid(b,d,h,w) | |
grid = affine(mgrid, trans, scale, rot, 'fw') | |
# alternative using direct composition | |
# t = compose_tsr(trans, scale, rot, 'bw') | |
# grid = F.affine_grid(t, x.size()) | |
out = F.grid_sample(x, grid, padding_mode=padding_mode, align_corners=False) | |
return out | |
def generate_random_affine(num): | |
Rot = Rotation.random(num) | |
R = Rot.as_matrix().astype(np.float32) | |
R = th.from_numpy(R) | |
T = th.zeros((num,3), dtype=th.float32).uniform_(-0.7,0.7) | |
S = th.zeros((num,1), dtype=th.float32).uniform_(0.7,1.3)#.repeat(1,3) | |
return T,S,R | |
def compose_tsr(trans, scale, rot, order='bw'): | |
b = len(trans) | |
for i in range(3): | |
rot[:,i,i] *= scale[:,0] | |
T = th.zeros((b,4,4), dtype=th.float32) | |
T[:,:3,:3] = rot | |
T[:,:3,3] = trans | |
T[:,3,3] = 1 | |
if order == 'fw': | |
T = th.linalg.inv(T) | |
return T[:,:3] | |
def test_stn_affine(d=64,h=64,w=64,radius=4,num=1): | |
# A Volume | |
vol = th.zeros((d,h,w),dtype=th.float32) | |
# Draw a cylinder | |
offx = -5 | |
offy = -4 | |
centers = [] | |
for i in range(d): | |
offx += (i*0.005)**2 | |
offy += (i*0.003)**2 | |
center = (int(w//2+offx),int(h//2+offy)) | |
centers.append(center) | |
cv2.circle(vol[i].numpy(), center, radius, 1, 0) | |
# 3 centers | |
center0 = centers[0] | |
center1 = centers[d//2] | |
center2 = centers[-1] | |
cx,cy,cz = center1[0],center1[1], d//2 | |
orig = th.LongTensor([cx,cy,cz]) | |
size = th.LongTensor([w//2,h//2,d//2]) | |
px = center2[0]-center1[0] | |
py = center2[1]-center1[1] | |
pz = d//2 | |
s = 0 | |
vec = th.LongTensor([px,py,pz]) | |
pt = orig + vec | |
prx = center0[0]-center1[0] | |
pry = center0[1]-center1[1] | |
prz = -d//2 | |
vecpr = th.LongTensor([prx,pry,prz]) | |
pt_pr = orig + vecpr | |
#Rot = Rotation.from_euler('xyz', [45,0,0], degrees=True) | |
T,S,R = generate_random_affine(num) | |
vol1 = vol[None,None].repeat(num,1,1,1,1) | |
vol2 = stn(vol1, T, S, R).squeeze() | |
z,y,x = th.where(vol > 0) | |
z2,y2,x2 = th.where(vol2 > 0) | |
# Find Direction using scaled Translation! | |
T2 = T*size | |
vec2 = affine(vec[None].float(), T2, S, R, 'fw') | |
vx2,vy2,vz2 = vec2.squeeze().numpy().tolist() | |
vec3 = affine(vecpr[None].float(), T2, S, R, 'fw') | |
vx3,vy3,vz3 = vec3.squeeze().numpy().tolist() | |
# Plot volume, transformed volume, direction & transformed direction | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection='3d') | |
ax.scatter(x,y,z, marker='.', alpha=0.2, color='gray', s=3, label='volume') | |
ax.scatter(x2,y2,z2, marker='.', alpha=0.1, color='gray', s=3, label='volume aug') | |
ax.quiver(cx,cy,cz,px,py,pz, linewidths=3, color=['blue'], label='direction_gt') | |
ax.quiver(cx,cy,cz,prx,pry,prz, linewidths=3, color=['orange'], label='direction_gt_prev') | |
ax.quiver(cx,cy,cz,vx2,vy2,vz2, linewidths=3, color=['blue'], label='direction_gt_aug') | |
ax.quiver(cx,cy,cz,vx3,vy3,vz3, linewidths=3, color=['orange'], label='direction_gt_prev_aug') | |
ax.set_xlim3d(0, w) | |
ax.set_ylim3d(0, h) | |
ax.set_zlim3d(0, d) | |
ax.set_xlabel('X Label') | |
ax.set_ylabel('Y Label') | |
ax.set_zlabel('Z Label') | |
ax.legend() | |
plt.show() | |
if __name__ == '__main__': | |
import fire;fire.Fire(test_stn_affine) |
Author
etienne87
commented
Jan 16, 2022
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment