The random pose is parameterized as a perturbation of the true pose. Angular perturbations are uniformly sampled from [-π/4, π/4] and translational perturbations are uniformly sampled from [-30, 30].
If the negative normalized cross-correlation is greater than 0.999, we say the target and moving DRR have converged.
4. Backpropogate the loss to the moving DRR parameters
We also use this example to show how different optimizers affect the outcome of registration. The parameters we tweak are
lr_rotations: learning rate for rotation parameters
lr_translations: learning rate for translation parameters
momentum: momentum for stochastic gradient descent
dampening: dampening for stochastic gradient descent
A basic implementation of an optimization loop is provided below:
def optimize( reg: Registration, ground_truth, lr_rotations=5e-2, lr_translations=1e2, momentum=0, dampening=0, n_itrs=500, optimizer="sgd", # 'sgd' or `adam`):# Initialize an optimizer with different learning rates# for rotations and translations since they have different scalesif optimizer =="sgd": optim = torch.optim.SGD( [ {"params": [reg._rotation], "lr": lr_rotations}, {"params": [reg._translation], "lr": lr_translations}, ], momentum=momentum, dampening=dampening, maximize=True, ) optimizer = optimizer.upper()elif optimizer =="adam": optim = torch.optim.Adam( [ {"params": [reg._rotation], "lr": lr_rotations}, {"params": [reg._translation], "lr": lr_translations}, ], maximize=True, ) optimizer = optimizer.title()else:raiseValueError(f"Unrecognized optimizer {optimizer}") params = [] losses = [criterion(ground_truth, reg()).item()]for itr in (pbar := tqdm(range(n_itrs), ncols=100)):# Save the current set of parameters alpha, beta, gamma = reg.rotation.squeeze().tolist() bx, by, bz = reg.translation.squeeze().tolist() params.append([i for i in [alpha, beta, gamma, bx, by, bz]])# Run the optimization loop optim.zero_grad() estimate = reg() loss = criterion(ground_truth, estimate) loss.backward() optim.step() losses.append(loss.item()) pbar.set_description(f"NCC = {loss.item():06f}")# Stop the optimization if the estimated and ground truth images are 99.9% correlatedif loss >0.999:if momentum !=0: optimizer +=" + momentum"if dampening !=0: optimizer +=" + dampening" tqdm.write(f"{optimizer} converged in {itr +1} iterations")break# Save the final estimated pose alpha, beta, gamma = reg.rotation.squeeze().tolist() bx, by, bz = reg.translation.squeeze().tolist() params.append([i for i in [alpha, beta, gamma, bx, by, bz]]) df = pd.DataFrame(params, columns=["alpha", "beta", "gamma", "bx", "by", "bz"]) df["loss"] = lossesreturn df
The PyTorch implementation of L-BFGS has a different API to many other optimizers in the library. Notably, it requires a closure function to evaluate the model multiple times before taking a step. Also, it does not accept per-parameter learning rates nor a maximize flag. Below is an implementation of L-BFGS for DiffDRR.
Code
def optimize_lbfgs( reg: Registration, ground_truth, lr, line_search_fn=None, n_itrs=500,):# Initialize the optimizer and define the closure function optim = torch.optim.LBFGS(reg.parameters(), lr, line_search_fn=line_search_fn)def closure():if torch.is_grad_enabled(): optim.zero_grad() estimate = reg() loss =-criterion(ground_truth, estimate)if loss.requires_grad: loss.backward()return loss params = [] losses = [closure().abs().item()]for itr in (pbar := tqdm(range(n_itrs), ncols=100)):# Save the current set of parameters alpha, beta, gamma = reg.rotation.squeeze().tolist() bx, by, bz = reg.translation.squeeze().tolist() params.append([i for i in [alpha, beta, gamma, bx, by, bz]])# Run the optimization loop optim.step(closure)with torch.no_grad(): loss = closure().abs().item() losses.append(loss) pbar.set_description(f"NCC = {loss:06f}")# Stop the optimization if the estimated and ground truth images are 99.9% correlatedif loss >0.999:if line_search_fn isnotNone: method =f"L-BFGS + strong Wolfe conditions"else: method ="L-BFGS" tqdm.write(f"{method} converged in {itr +1} iterations")break# Save the final estimated pose alpha, beta, gamma = reg.rotation.squeeze().tolist() bx, by, bz = reg.translation.squeeze().tolist() params.append([i for i in [alpha, beta, gamma, bx, by, bz]]) df = pd.DataFrame(params, columns=["alpha", "beta", "gamma", "bx", "by", "bz"]) df["loss"] = lossesreturn df
5. Run the optimization algorithm
Below, we compare the following gradient-based iterative optimization methods:
SGD
SGD + momentum
SGD + momentum + dampening
Adam
L-BFGS
L-BFGS + line search
Tip
For 2D/3D registration with Siddon’s method, we don’t need gradients calculated through the grid_sample (which uses nearest neighbors and therefore has gradients of zero w.r.t. the grid points). To avoid computing these gradients, which improves rendering speed, you can set stop_gradients_through_grid_sample=True.
Visualizing the loss curves allows us to interpret interesting dynamics during optimization:
SGD and its variants all arrive at a local maximum around NCC = 0.95, and take differing numbers of iterations to escape the local maximum
While Adam arrives at the answer much faster, its loss curve is not monotonically increasing, which we will visualize in the next section
L-BFGS without line search is slow and each iteration takes much longer than first-order methods
L-BFGS with line seach is highly efficient in terms of number of iterations required, and it runs in roughly the same time as the best first-order gradient-based method
Visualize the parameter updates
Note that differences that between different optimization algorithms can be seen in the motion in the DRRs!