import matplotlib.pyplot as plt
import torch
from IPython.core.magics.execution import _format_time
from diffdrr.data import load_example_ct
from diffdrr.drr import DRR
from diffdrr.pose import convert
from diffdrr.visualization import plot_drr
Trilinear rendering
Timing demonstration for trilinear interpolation
= load_example_ct()
subject = torch.device("cuda" if torch.cuda.is_available() else "cpu") device
# Set the camera pose with rotations (yaw, pitch, roll) and translations (x, y, z)
= torch.tensor([[0.0, 0.0, 0.0]], device=device)
rotations = torch.tensor([[0.0, 850.0, 0.0]], device=device)
translations
= convert(
pose
rotations,
translations,="euler_angles",
parameterization="ZXY",
convention )
Siddon’s method
Rendering a standard AP view with Siddon’s method takes ~25 ms. This is slower than trilinear interpolation because Siddon’s method computes the exact intersection of every cast ray with the voxels in the volume.
# Initialize the DRR module for generating synthetic X-rays
= DRR(
drr
subject,=1020.0,
sdd=200,
height=2.0,
delx
).to(device)= drr(pose) # Initialize drr.density
_
= drr.detector(pose, calibration=None)
source, target = drr.affine_inverse(source)
source = drr.affine_inverse(target) target
= %timeit -o drr.renderer(drr.density, source, target)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose)
img =f"Siddon ({time})")
plot_drr(img, title plt.show()
24.7 ms ± 18.2 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Trilinear interpolation
Rendering the same view with trilinear interpolation is much faster. The main hyperparameter to control is n_points
, which is the number of points to sample per ray. The rendering cost of trilinear interpolation is the same as Siddon’s method when n_points
is about 2,000 points.
= DRR(
drr
subject,=1020.0,
sdd=200,
height=2.0,
delx="trilinear", # Switch the rendering mode
renderer
).to(device)
= drr.detector(pose, calibration=None)
source, target = drr(pose) # Initialize drr.density _
= 25
n_points
= %timeit -o drr.renderer(drr.density, source, target, n_points)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose, n_points=n_points)
img =f"Trilinear with {n_points} points ({time})")
plot_drr(img, title plt.show()
737 μs ± 2.87 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
= 50
n_points
= %timeit -o drr.renderer(drr.density, source, target, n_points)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose, n_points=n_points)
img =f"Trilinear with {n_points} points ({time})")
plot_drr(img, title plt.show()
1.03 ms ± 1.23 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
= 100
n_points
= %timeit -o drr.renderer(drr.density, source, target, n_points)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose, n_points=n_points)
img =f"Trilinear with {n_points} points ({time})")
plot_drr(img, title plt.show()
1.65 ms ± 812 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
= 200
n_points
= %timeit -o drr.renderer(drr.density, source, target, n_points)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose, n_points=n_points)
img =f"Trilinear with {n_points} points ({time})")
plot_drr(img, title plt.show()
3.52 ms ± 3.84 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
= 250
n_points
= %timeit -o drr.renderer(drr.density, source, target, n_points)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose, n_points=n_points)
img =f"Trilinear with {n_points} points ({time})")
plot_drr(img, title plt.show()
4.75 ms ± 5.13 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
= 500
n_points
= %timeit -o drr.renderer(drr.density, source, target, n_points)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose, n_points=n_points)
img =f"Trilinear with {n_points} points ({time})")
plot_drr(img, title plt.show()
7.63 ms ± 935 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
= 1000
n_points
= %timeit -o drr.renderer(drr.density, source, target, n_points)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose, n_points=n_points)
img =f"Trilinear with {n_points} points ({time})")
plot_drr(img, title plt.show()
13.1 ms ± 4.73 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
= 2000
n_points
= %timeit -o drr.renderer(drr.density, source, target, n_points)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose, n_points=n_points)
img =f"Trilinear with {n_points} points ({time})")
plot_drr(img, title plt.show()
25 ms ± 7.9 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
= 2500
n_points
= %timeit -o drr.renderer(drr.density, source, target, n_points)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose, n_points=n_points)
img =f"Trilinear with {n_points} points ({time})")
plot_drr(img, title plt.show()
32.3 ms ± 10.8 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
= 3750
n_points
= %timeit -o drr.renderer(drr.density, source, target, n_points)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose, n_points=n_points)
img =f"Trilinear with {n_points} points ({time})")
plot_drr(img, title plt.show()
48 ms ± 11 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
= 5000
n_points
= %timeit -o drr.renderer(drr.density, source, target, n_points)
times = f"{_format_time(times.average, times._precision)} ± {_format_time(times.stdev, times._precision)}"
time
= drr(pose, n_points=n_points)
img =f"Trilinear with {n_points} points ({time})")
plot_drr(img, title plt.show()
66 ms ± 11.1 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)