2D-to-3D Registration

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from diffdrr.drr import DRR
from diffdrr.data import load_example_ct
from diffdrr.metrics import XCorr2
from diffdrr.visualization import plot_drr

np.random.seed(39)
def converged(df):
    return df["loss"].iloc[-1] <= -0.999
defaults = ("cuda" if torch.cuda.is_available() else "cpu", torch.float32)
print(defaults)
('cuda', torch.float32)
# Make the ground truth X-ray
SDR = 300.0
HEIGHT = 100
DELX = 8.0

volume, spacing = load_example_ct()
bx, by, bz = np.array(volume.shape) * np.array(spacing) / 2
true_params = {
    "sdr": SDR,
    "theta": torch.pi,
    "phi": 0,
    "gamma": torch.pi / 2,
    "bx": bx,
    "by": by,
    "bz": bz,
}

drr = DRR(volume, spacing, height=HEIGHT, delx=DELX)
ground_truth = drr(**true_params).to(*defaults)

plot_drr(ground_truth)
plt.show()

# Make a random DRR
def get_initial_parameters(true_params):
    sdr = true_params["sdr"]
    theta = true_params["theta"] + np.random.uniform(-np.pi / 4, np.pi / 4)
    phi = true_params["phi"] + np.random.uniform(-np.pi / 3, np.pi / 3)
    gamma = true_params["gamma"] + np.random.uniform(-np.pi / 3, np.pi / 3)
    bx = true_params["bx"] + np.random.uniform(-30.0, 31.0)
    by = true_params["by"] + np.random.uniform(-30.0, 31.0)
    bz = true_params["bz"] + np.random.uniform(-30.0, 31.0)
    return torch.tensor([[sdr, theta, phi, gamma, bx, by, bz]])


params = get_initial_parameters(true_params)
drr = DRR(volume, spacing, height=HEIGHT, delx=DELX, params=params)
est = drr()

plot_drr(est)
plt.show()

def optimize(
    drr,
    ground_truth,
    lr_rotations=5.3e-2,
    lr_translations=7.5e1,
    momentum=0,
    dampening=0,
    n_itrs=250
):
    criterion = XCorr2(zero_mean_normalized=True)
    optimizer = torch.optim.SGD(
        [
            {"params": [drr.rotations], "lr": lr_rotations},
            {"params": [drr.translations], "lr": lr_translations},
        ],
        momentum=momentum,
        dampening=dampening,
    )
    
    params = []
    for itr in tqdm(range(n_itrs)):
        estimate = drr()
        theta, phi, gamma = drr.rotations.squeeze()
        bx, by, bz = drr.translations.squeeze()
        params.append([i.item() for i in [theta, phi, gamma, bx, by, bz]])
        loss = -criterion(ground_truth, estimate)
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        
        if loss < -0.999:
            tqdm.write(f"Converged in {itr} iterations")
            break
        
    return pd.DataFrame(params, columns=["theta", "phi", "gamma", "bx", "by", "bz"])
# Base SGD
drr = DRR(volume, spacing, height=HEIGHT, delx=DELX, params=params).to(*defaults)
params_base = optimize(drr, ground_truth)
del drr

# SGD + momentum
drr = DRR(volume, spacing, height=HEIGHT, delx=DELX, params=params).to(*defaults)
params_momentum = optimize(drr, ground_truth, momentum=0.9)
del drr

# SGD + momentum + dampening
drr = DRR(volume, spacing, height=HEIGHT, delx=DELX, params=params).to(*defaults)
params_momentum_dampen = optimize(drr, ground_truth, momentum=0.9, dampening=0.1)
del drr
 55%|███████████████████████████████████████████▌                                   | 138/250 [00:11<00:08, 12.50it/s]
Converged in 138 iterations
 26%|████████████████████▍                                                           | 64/250 [00:01<00:03, 49.45it/s]
Converged in 64 iterations
 23%|██████████████████▌                                                             | 58/250 [00:01<00:03, 49.35it/s]
Converged in 58 iterations
from IPython.display import display, HTML
from base64 import b64encode

from diffdrr.visualization import animate
def animate_in_browser(df):
    out = animate("<bytes>", df, SDR, drr, ground_truth, verbose=True, extension=".webp", duration=30)
    display(HTML(f"""<img src='{"data:img/gif;base64," + b64encode(out).decode()}'>"""))
drr = DRR(volume, spacing, height=HEIGHT, delx=DELX).to(*defaults)
animate_in_browser(params_base)
Precomputing DRRs: 100%|██████████████████████████████████████████████████████████████| 70/70 [00:13<00:00,  5.12it/s]
animate_in_browser(params_momentum)
Precomputing DRRs: 100%|██████████████████████████████████████████████████████████████| 33/33 [00:06<00:00,  4.88it/s]
animate_in_browser(params_momentum_dampen)
Precomputing DRRs: 100%|██████████████████████████████████████████████████████████████| 30/30 [00:06<00:00,  4.86it/s]