from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from diffdrr.drr import DRR
from diffdrr.data import load_example_ct
from diffdrr.metrics import XCorr2
from diffdrr.visualization import plot_drr
39) np.random.seed(
2D-to-3D Registration
def converged(df):
return df["loss"].iloc[-1] <= -0.999
= ("cuda" if torch.cuda.is_available() else "cpu", torch.float32)
defaults print(defaults)
('cuda', torch.float32)
# Make the ground truth X-ray
= 300.0
SDR = 100
HEIGHT = 8.0
DELX
= load_example_ct()
volume, spacing = np.array(volume.shape) * np.array(spacing) / 2
bx, by, bz = {
true_params "sdr": SDR,
"theta": torch.pi,
"phi": 0,
"gamma": torch.pi / 2,
"bx": bx,
"by": by,
"bz": bz,
}
= DRR(volume, spacing, height=HEIGHT, delx=DELX)
drr = drr(**true_params).to(*defaults)
ground_truth
plot_drr(ground_truth) plt.show()
# Make a random DRR
def get_initial_parameters(true_params):
= true_params["sdr"]
sdr = true_params["theta"] + np.random.uniform(-np.pi / 4, np.pi / 4)
theta = true_params["phi"] + np.random.uniform(-np.pi / 3, np.pi / 3)
phi = true_params["gamma"] + np.random.uniform(-np.pi / 3, np.pi / 3)
gamma = true_params["bx"] + np.random.uniform(-30.0, 31.0)
bx = true_params["by"] + np.random.uniform(-30.0, 31.0)
by = true_params["bz"] + np.random.uniform(-30.0, 31.0)
bz return torch.tensor([[sdr, theta, phi, gamma, bx, by, bz]])
= get_initial_parameters(true_params)
params = DRR(volume, spacing, height=HEIGHT, delx=DELX, params=params)
drr = drr()
est
plot_drr(est) plt.show()
def optimize(
drr,
ground_truth,=5.3e-2,
lr_rotations=7.5e1,
lr_translations=0,
momentum=0,
dampening=250
n_itrs
):= XCorr2(zero_mean_normalized=True)
criterion = torch.optim.SGD(
optimizer
["params": [drr.rotations], "lr": lr_rotations},
{"params": [drr.translations], "lr": lr_translations},
{
],=momentum,
momentum=dampening,
dampening
)
= []
params for itr in tqdm(range(n_itrs)):
= drr()
estimate = drr.rotations.squeeze()
theta, phi, gamma = drr.translations.squeeze()
bx, by, bz for i in [theta, phi, gamma, bx, by, bz]])
params.append([i.item() = -criterion(ground_truth, estimate)
loss
optimizer.zero_grad()=True)
loss.backward(retain_graph
optimizer.step()
if loss < -0.999:
f"Converged in {itr} iterations")
tqdm.write(break
return pd.DataFrame(params, columns=["theta", "phi", "gamma", "bx", "by", "bz"])
# Base SGD
= DRR(volume, spacing, height=HEIGHT, delx=DELX, params=params).to(*defaults)
drr = optimize(drr, ground_truth)
params_base del drr
# SGD + momentum
= DRR(volume, spacing, height=HEIGHT, delx=DELX, params=params).to(*defaults)
drr = optimize(drr, ground_truth, momentum=0.9)
params_momentum del drr
# SGD + momentum + dampening
= DRR(volume, spacing, height=HEIGHT, delx=DELX, params=params).to(*defaults)
drr = optimize(drr, ground_truth, momentum=0.9, dampening=0.1)
params_momentum_dampen del drr
55%|███████████████████████████████████████████▌ | 138/250 [00:11<00:08, 12.50it/s]
Converged in 138 iterations
26%|████████████████████▍ | 64/250 [00:01<00:03, 49.45it/s]
Converged in 64 iterations
23%|██████████████████▌ | 58/250 [00:01<00:03, 49.35it/s]
Converged in 58 iterations
from IPython.display import display, HTML
from base64 import b64encode
from diffdrr.visualization import animate
def animate_in_browser(df):
= animate("<bytes>", df, SDR, drr, ground_truth, verbose=True, extension=".webp", duration=30)
out f"""<img src='{"data:img/gif;base64," + b64encode(out).decode()}'>""")) display(HTML(
= DRR(volume, spacing, height=HEIGHT, delx=DELX).to(*defaults) drr
animate_in_browser(params_base)
Precomputing DRRs: 100%|██████████████████████████████████████████████████████████████| 70/70 [00:13<00:00, 5.12it/s]
animate_in_browser(params_momentum)
Precomputing DRRs: 100%|██████████████████████████████████████████████████████████████| 33/33 [00:06<00:00, 4.88it/s]
animate_in_browser(params_momentum_dampen)
Precomputing DRRs: 100%|██████████████████████████████████████████████████████████████| 30/30 [00:06<00:00, 4.86it/s]