import matplotlib.pyplot as plt
import seaborn as sns
import torch
from diffdrr.data import load_example_ct
from diffdrr.drr import DRR
from diffdrr.visualization import plot_drr
"talk") sns.set_context(
How to use DiffDRR
DRR
module’s functionality
Rendering DRRs
DiffDRR
is implemented as a custom PyTorch module.
All raytracing operations have been formulated in a vectorized function, enabling use of PyTorch’s GPU support and autograd. This also means that X-ray priojection is interoperable as a layer in deep learning frameworks.
Rotations can be parameterized with numerous conventions (not just Euler angles). See diffdrr.DRR
for more details.
# Read in the volume and get its origin and spacing in world coordinates
= load_example_ct()
subject
# Initialize the DRR module for generating synthetic X-rays
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = DRR(
drr # A torchio.Subject object storing the CT volume, origin, and voxel spacing
subject, =1020, # Source-to-detector distance (i.e., the C-arm's focal length)
sdd=200, # Height of the DRR (if width is not seperately provided, the generated image is square)
height=2.0, # Pixel spacing (in mm)
delx
).to(device)
# Set the camera pose with rotations (yaw, pitch, roll) and translations (x, y, z)
= torch.tensor([[0.0, 0.0, 0.0]], device=device)
rotations = torch.tensor([[0.0, 850.0, 0.0]], device=device)
translations = drr(rotations, translations, parameterization="euler_angles", convention="ZXY")
img =False)
plot_drr(img, ticks plt.show()
We demonstrate the speed of DiffDRR
by timing repeated DRR synthesis. Timing results are on a single NVIDIA RTX 2080 Ti GPU.
38.4 ms ± 64 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Rendering multiple DRRs at once
The rotations
tensor is expected to be of the size B D
, where D
is the number of components needed to represent the rotation (e.g., 3
for Euler angles, 4
for quaternions, etc.). The translations
tensor expected to be of the size B D
.
= torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, torch.pi]], device=device)
rotations = torch.tensor([[0.0, 850.0, 0.0], [0.0, 850.0, 0.0]], device=device)
translations = drr(rotations, translations, parameterization="euler_angles", convention="ZXY")
img =False)
plot_drr(img, ticks plt.show()
Note that rendered DRRs have shape B C H W
where - B
is the number of camera poses passed to the renderer - C
is the number of channels in the rendered images - H
is the image height, specified in the constructor of the diffdrr.drr.DRR
object - W
is the image width, which defaults to the height if not otherwise specified
Typically, C = 1
. However, we can have more channels if rendering individual anatomical structures (see the next section).
img.shape
torch.Size([2, 1, 200, 200])
Rendering individual structures in separate channels
If the subject
passed to diffdrr.drr.DRR
also has a mask
attribute (a torchio.LabelMap
), we can use this 3D segmentation map to render individual structures in the DRR.
Method 1
The first way to do this is to set mask_to_channels=True
in DRR.forward
, which will create a new channel for every structure.
Note mask_to_channels
is only an option for the Siddon
renderer (which is the default option).
from diffdrr.pose import convert
# Note that you also have the option to directly pass poses in SE(3) to the renderer
= torch.tensor([[0.0, 0.0, 0.0]], device=device)
rotations = torch.tensor([[0.0, 850.0, 0.0]], device=device)
translations = convert(rotations, translations, parameterization="euler_angles", convention="ZXY")
pose
= drr(pose, mask_to_channels=True) img
We used TotalSegmentator v2 to automatically segment the example CT. This dataset has 118 classes. Therefore, the output image has C = 119
(the zero-th channel is a rendering of the background).
img.shape
torch.Size([1, 119, 200, 200])
We incur a small amount of additional overhead to partition these channels during rendering:
53.8 ms ± 175 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
We can also visualize all of these channels superimposed on the DRR. Note that summing over the channel dimension recapitulates the original DRR.
Code
from diffdrr.visualization import plot_mask
# Relabel classes in the TotalSegmentator dataset
= {
groups "skeleton": "Appendicular Skeleton",
"ribs": "Ribs",
"vertebrae": "Vertebrae",
"cardiac": "Cardiovasculature",
"organs": "Organs",
"muscles": "Muscles",
}
# Plot the segmentation masks
= plt.subplots(
fig, axs =2,
nrows=4,
ncols=(14, 7.75),
figsize=True,
tight_layout=300,
dpi
)
= img.sum(dim=1, keepdim=True)
im =axs[0, 0], ticks=False, title="DRR")
plot_drr(im, axs=axs[1, 0], ticks=False, title="All Segmentations")
plot_drr(im, axs
for (group, title), ax in zip(groups.items(), axs[:, 1:].flatten()):
= subject.structures.query(f"group == '{group}'")["id"].tolist()
jdxs = img[:, jdxs]
im sum(dim=1, keepdim=True), title=title, axs=ax, ticks=False)
plot_drr(im.= plot_mask(im, ax=ax, return_masks=True)
masks for jdx in range(masks.shape[1]):
1, 0].imshow(masks[0, jdx], alpha=0.5)
axs[ plt.show()
Method 2
If we only care about a subset of the structures, we can instead partition the 3D CT prior to rendering. Note that this method is compatible with different rendering backends.
# Only load the bones in the CT (and the costal cartilage, but it looks weird without it)
= ["skeleton", "ribs", "vertebrae"]
structures = subject.structures.query(f"group in {structures}")["id"].tolist()
labels = load_example_ct(labels=labels)
subject = DRR(subject, sdd=1020, height=200, delx=2.0).to(device)
drr
# Set the camera pose with rotations (yaw, pitch, roll) and translations (x, y, z)
= drr(pose)
img =False)
plot_drr(img, ticks plt.show()
Because we are rendering all structures at once, we don’t incur additional overhead.
38.1 ms ± 34.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Rendering sparse DRRs
You can also render random sparse subsets of the pixels in a DRR.
Sparse DRR rendering can be useful in registration and reconstruction tasks when coupled with a pixel-wise loss, such as MSE.
# Make the DRR with 10% of the pixels
= load_example_ct()
subject = DRR(
drr
subject,=1020,
sdd=200,
height=2.0,
delx=0.1, # Set the proportion of pixels that should be rendered
p_subsample=True, # Map rendered pixels back to their location in true space - useful for plotting, but can be disabled if using MSE as a loss function
reshape
).to(device)
# Make the DRR
= drr(pose)
img =False)
plot_drr(img, ticks plt.show()
6.39 ms ± 16.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Using different rendering backends
DiffDRR
can also render synthetic X-rays using trilinear interpolation instead of Siddon’s method. The key argument to pay attention to is n_points
, which controls how many points are sampled along each ray for interpolation. Higher values make more realistic images, at the cost of higher rendering time.
= DRR(
drr
subject,=1020,
sdd=200,
height=2.0,
delx="trilinear", # Set the rendering backend to trilinear
renderer
).to(device)
= []
imgs = [100, 250, 500, 1000]
n_points for n in n_points:
= drr(pose, n_points=n)
img
imgs.append(img)
= plt.subplots(1, 4, figsize=(14, 7), dpi=300, tight_layout=True)
fig, axs = torch.concat(imgs)
img = plot_drr(img, ticks=False, axs=axs)
axs for idx, n in enumerate(n_points):
set(title=f"n_points={n}")
axs[idx]. plt.show()