Trilinear rendering

Timing demonstration for trilinear interpolation
import matplotlib.pyplot as plt
import torch
from IPython.core.magics.execution import _format_time

from diffdrr.data import load_example_ct
from diffdrr.drr import DRR
from diffdrr.pose import convert
from diffdrr.visualization import plot_drr
subject = load_example_ct()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set the camera pose with rotations (yaw, pitch, roll) and translations (x, y, z)
rotations = torch.tensor([[0.0, 0.0, 0.0]], device=device)
translations = torch.tensor([[0.0, 850.0, 0.0]], device=device)

pose = convert(
    rotations,
    translations,
    parameterization="euler_angles",
    convention="ZXY",
)

Siddon’s method

Rendering a standard AP view with Siddon’s method takes ~25 ms. This is slower than trilinear interpolation because Siddon’s method computes the exact intersection of every cast ray with the voxels in the volume.

# Initialize the DRR module for generating synthetic X-rays
drr = DRR(
    subject,
    sdd=1020.0,
    height=200,
    delx=2.0,
).to(device)
_ = drr(pose)  # Initialize drr.density

source, target = drr.detector(pose, calibration=None)
source = drr.affine_inverse(source)
target = drr.affine_inverse(target)
times = %timeit -o drr.renderer(drr.density, source, target)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose)
plot_drr(img, title=f"Siddon ({time})")
plt.show()
24.7 ms ± 18.2 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Trilinear interpolation

Rendering the same view with trilinear interpolation is much faster. The main hyperparameter to control is n_points, which is the number of points to sample per ray. The rendering cost of trilinear interpolation is the same as Siddon’s method when n_points is about 2,000 points.

drr = DRR(
    subject,
    sdd=1020.0,
    height=200,
    delx=2.0,
    renderer="trilinear",  # Switch the rendering mode
).to(device)

source, target = drr.detector(pose, calibration=None)
_ = drr(pose)  # Initialize drr.density
n_points = 25

times = %timeit -o drr.renderer(drr.density, source, target, n_points)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose, n_points=n_points)
plot_drr(img, title=f"Trilinear with {n_points} points ({time})")
plt.show()
737 μs ± 2.87 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

n_points = 50

times = %timeit -o drr.renderer(drr.density, source, target, n_points)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose, n_points=n_points)
plot_drr(img, title=f"Trilinear with {n_points} points ({time})")
plt.show()
1.03 ms ± 1.23 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

n_points = 100

times = %timeit -o drr.renderer(drr.density, source, target, n_points)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose, n_points=n_points)
plot_drr(img, title=f"Trilinear with {n_points} points ({time})")
plt.show()
1.65 ms ± 812 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

n_points = 200

times = %timeit -o drr.renderer(drr.density, source, target, n_points)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose, n_points=n_points)
plot_drr(img, title=f"Trilinear with {n_points} points ({time})")
plt.show()
3.52 ms ± 3.84 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

n_points = 250

times = %timeit -o drr.renderer(drr.density, source, target, n_points)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose, n_points=n_points)
plot_drr(img, title=f"Trilinear with {n_points} points ({time})")
plt.show()
4.75 ms ± 5.13 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

n_points = 500

times = %timeit -o drr.renderer(drr.density, source, target, n_points)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose, n_points=n_points)
plot_drr(img, title=f"Trilinear with {n_points} points ({time})")
plt.show()
7.63 ms ± 935 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)

n_points = 1000

times = %timeit -o drr.renderer(drr.density, source, target, n_points)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose, n_points=n_points)
plot_drr(img, title=f"Trilinear with {n_points} points ({time})")
plt.show()
13.1 ms ± 4.73 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

n_points = 2000

times = %timeit -o drr.renderer(drr.density, source, target, n_points)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose, n_points=n_points)
plot_drr(img, title=f"Trilinear with {n_points} points ({time})")
plt.show()
25 ms ± 7.9 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

n_points = 2500

times = %timeit -o drr.renderer(drr.density, source, target, n_points)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose, n_points=n_points)
plot_drr(img, title=f"Trilinear with {n_points} points ({time})")
plt.show()
32.3 ms ± 10.8 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

n_points = 3750

times = %timeit -o drr.renderer(drr.density, source, target, n_points)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose, n_points=n_points)
plot_drr(img, title=f"Trilinear with {n_points} points ({time})")
plt.show()
48 ms ± 11 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

n_points = 5000

times = %timeit -o drr.renderer(drr.density, source, target, n_points)
time = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"

img = drr(pose, n_points=n_points)
plot_drr(img, title=f"Trilinear with {n_points} points ({time})")
plt.show()
66 ms ± 11.1 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)