import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import torch
from tqdm import tqdm
from diffdrr.drr import DRR
from diffdrr.data import load_example_ct
from diffdrr.metrics import XCorr2
from diffdrr.visualization import plot_drr
= torch.device("cuda" if torch.cuda.is_available() else "cpu") device
Visualize registration loss landscapes
Utility functions for the simulation
- Generate ground truth DRR
- Function for generating estimated DRRs
- Functions for scoring (Negative NCC and L2-norm)
# Read in the volume
= load_example_ct()
volume, spacing
# Get parameters for the detector
= torch.tensor(volume.shape) * torch.tensor(spacing) / 2
bx, by, bz = torch.pi, 0.0, torch.pi / 2
theta, phi, gamma = {
detector_kwargs "sdr" : 300.0,
"theta" : theta,
"phi" : phi,
"gamma" : gamma,
"bx" : bx,
"by" : by,
"bz" : bz,
}
# Make the DRR
= DRR(volume, spacing, height=200, delx=4.0).to("cuda" if torch.cuda.is_available() else "cpu")
drr = drr(**detector_kwargs)
target_drr = plot_drr(target_drr)
ax plt.show()
# Scoring functions
= XCorr2(zero_mean_normalized=True)
xcorr2
def get_normxcorr2(theta, phi, gamma, bx, by, bz, sdr=300.0):
= drr(sdr, theta, phi, gamma, bx, by, bz)
moving_drr return xcorr2(target_drr, moving_drr)
Negative Normalized XCorr
# NCC for the XYZs
= torch.arange(-15., 16.)
xs = torch.arange(-15., 16.)
ys = torch.arange(-15., 16.)
zs
# Get coordinate-wise correlations
= []
xy_corrs for x in tqdm(xs):
for y in ys:
= get_normxcorr2(theta, phi, gamma, bx+x, by+y, bz)
xcorr -xcorr)
xy_corrs.append(= torch.tensor(xy_corrs).reshape(len(xs), len(ys))
XY
= []
xz_corrs for x in tqdm(xs):
for z in zs:
= get_normxcorr2(theta, phi, gamma, bx+x, by, bz+z)
xcorr -xcorr)
xz_corrs.append(= torch.tensor(xz_corrs).reshape(len(xs), len(zs))
XZ
= []
yz_corrs for y in tqdm(ys):
for z in zs:
= get_normxcorr2(theta, phi, gamma, bx, by+y, bz+z)
xcorr -xcorr)
yz_corrs.append(= torch.tensor(yz_corrs).reshape(len(ys), len(zs)) YZ
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [00:34<00:00, 1.10s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [00:34<00:00, 1.10s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [00:34<00:00, 1.12s/it]
# NCC for the angles
= torch.arange(-torch.pi/4, torch.pi/4, step=.05)
t_angles = torch.arange(-torch.pi/4, torch.pi/4, step=.05)
p_angles = torch.arange(-torch.pi/8, torch.pi/8, step=.05)
g_angles
# Get coordinate-wise correlations
= []
tp_corrs for t in tqdm(t_angles):
for p in p_angles:
= get_normxcorr2(theta+t, phi+p, gamma, bx, by, bz)
xcorr -xcorr)
tp_corrs.append(= torch.tensor(tp_corrs).reshape(len(t_angles), len(p_angles))
TP
= []
tg_corrs for t in tqdm(t_angles):
for g in g_angles:
= get_normxcorr2(theta+t, phi, gamma+g, bx, by, bz)
xcorr -xcorr)
tg_corrs.append(= torch.tensor(tg_corrs).reshape(len(t_angles), len(g_angles))
TG
= []
pg_corrs for p in tqdm(p_angles):
for g in g_angles:
= get_normxcorr2(theta, phi+p, gamma+g, bx, by, bz)
xcorr -xcorr)
pg_corrs.append(= torch.tensor(pg_corrs).reshape(len(p_angles), len(g_angles)) PG
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:41<00:00, 1.30s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:20<00:00, 1.56it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:19<00:00, 1.67it/s]
# Make the plots
# XYZ
= torch.meshgrid(xs, ys)
xyx, xyy = torch.meshgrid(xs, zs)
xzx, xzz = torch.meshgrid(ys, zs)
yzy, yzz
= plt.figure(figsize=3*plt.figaspect(1.2/1), dpi=300)
fig
= fig.add_subplot(1, 3, 1, projection='3d')
ax ="z", offset=-1, cmap=plt.get_cmap('rainbow'), alpha=0.5)
ax.contourf(xyx, xyy, XY, zdir=1, cstride=1, cmap=plt.get_cmap('rainbow'))
ax.plot_surface(xyx, xyy, XY, rstride'ΔX (mm)')
ax.set_xlabel('ΔY (mm)')
ax.set_ylabel(-1., -0.825)
ax.set_zlim3d(
= fig.add_subplot(1, 3, 2, projection='3d')
ax ="z", offset=-1, cmap=plt.get_cmap('rainbow'), alpha=0.5)
ax.contourf(xzx, xzz, XZ, zdir=1, cstride=1, cmap=plt.get_cmap('rainbow'))
ax.plot_surface(xzx, xzz, XZ, rstride'ΔX (mm)')
ax.set_xlabel('ΔZ (mm)')
ax.set_ylabel(-1., -0.825)
ax.set_zlim3d(
= fig.add_subplot(1, 3, 3, projection='3d')
ax ="z", offset=-1, cmap=plt.get_cmap('rainbow'), alpha=0.5)
ax.contourf(yzy, yzz, YZ, zdir=1, cstride=1, cmap=plt.get_cmap('rainbow'))
ax.plot_surface(yzy, yzz, YZ, rstride'ΔY (mm)')
ax.set_xlabel('ΔZ (mm)')
ax.set_ylabel(-1., -0.825)
ax.set_zlim3d(
# Angles
= torch.meshgrid(t_angles, p_angles)
xyx, xyy = torch.meshgrid(t_angles, g_angles)
xzx, xzz = torch.meshgrid(p_angles, g_angles)
yzy, yzz
= fig.add_subplot(2, 3, 1, projection='3d')
ax ="z", offset=-1, cmap=plt.get_cmap('rainbow'), alpha=0.5)
ax.contourf(xyx, xyy, TP, zdir=1, cstride=1, cmap=plt.get_cmap('rainbow'))
ax.plot_surface(xyx, xyy, TP, rstride'Δθ (radians)')
ax.set_xlabel('Δφ (radians)')
ax.set_ylabel(-1., -0.4)
ax.set_zlim3d(
= fig.add_subplot(2, 3, 2, projection='3d')
ax ="z", offset=-1, cmap=plt.get_cmap('rainbow'), alpha=0.5)
ax.contourf(xzx, xzz, TG, zdir=1, cstride=1, cmap=plt.get_cmap('rainbow'))
ax.plot_surface(xzx, xzz, TG, rstride'Δθ (radians)')
ax.set_xlabel('Δγ (radians)')
ax.set_ylabel(-1., -0.4)
ax.set_zlim3d(
= fig.add_subplot(2, 3, 3, projection='3d')
ax ="z", offset=-1, cmap=plt.get_cmap('rainbow'), alpha=0.5)
ax.contourf(yzy, yzz, PG, zdir=1, cstride=1, cmap=plt.get_cmap('rainbow'))
ax.plot_surface(yzy, yzz, PG, rstride'Δφ (radians)')
ax.set_xlabel('Δγ (radians)')
ax.set_ylabel(-1., -0.4)
ax.set_zlim3d(
plt.show()
/data/vision/polina/users/vivekg/utils/mambaforge/envs/diffdrr/lib/python3.11/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /home/conda/feedstock_root/build_artifacts/pytorch-recipe_1673745441827/work/aten/src/ATen/native/TensorShape.cpp:3190.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]