import matplotlib.pyplot as plt
import torch
from diffdrr.drr import DRR
from diffdrr.visualization import plot_drr
from diffpose.deepfluoro import DeepFluoroDataset, Transforms
from diffpose.registration import SparseRegistration, vector_to_img
Sparse rendering
Hacking DiffDRR
to render sparse subsets of image patches
= DeepFluoroDataset(1)
specimen = 256
height = (1536 - 100) / height
subsample = 0.194 * subsample
delx
= specimen[0]
_, pose = pose.cuda()
pose
= DRR(
drr
specimen.volume,
specimen.spacing,=specimen.focal_len / 2,
sdr=height,
height=delx,
delx=specimen.x0,
x0=specimen.y0,
y0=True,
reverse_x_axis=2.5,
bone_attenuation_multiplier"cuda")
).to(
= SparseRegistration(drr, pose, parameterization="se3_log_map") registration
# Generate images with different numbers of patches
= []
imgs = [None, 100, 250, 500, 750]
n_patches for n in n_patches:
= registration(n, patch_size=13)
img, mask if n is not None:
= vector_to_img(img, mask)
img
imgs.append(img)
# Plot the images with various levels of sparsity
= plot_drr(torch.concat(imgs))
axs for ax, n in zip(axs, n_patches):
if n is None:
= "full image"
n set(xlabel=n)
ax. plt.show()
# Full image with SparseRegistration
29.5 ms ± 375 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
# 100 patches
8.57 ms ± 51.3 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
# 250 patches
15.3 ms ± 112 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
# 500 patches
22.1 ms ± 141 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
# 750 patches
25.7 ms ± 157 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
# Full image with DiffDRR
29.7 ms ± 203 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)
from diffdrr.metrics import MultiscaleNormalizedCrossCorrelation2d
from diffpose.registration import VectorizedNormalizedCrossCorrelation2d
= MultiscaleNormalizedCrossCorrelation2d(patch_sizes=[None, 13], patch_weights=[0.5, 0.5])
mncc = VectorizedNormalizedCrossCorrelation2d() smncc
1.51 ms ± 1.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
= registration(n_patches=1000, patch_size=13) pred_img, mask
3.27 ms ± 8.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
img.shape
torch.Size([1, 1, 256, 256])
pred_img.shape
torch.Size([1, 14675])
1 / 30e-3
33.333333333333336
1/ ((30+1.5) * 1e-3)
31.746031746031747