Spherical Coordinates

import matplotlib.pyplot as plt
import numpy as np
import torch

from diffdrr.drr import DRR
Code
sdr = 1.0
theta, gamma, phi = 0, 0, 0
bx, by, bz = 10, -10, -40
params = torch.Tensor([[sdr, theta, phi, gamma, bx, by, bz]])

drr = DRR(
    volume=np.zeros([512, 512, 133]),
    spacing=[1, 1, 1],
    height=5,
    delx=0.75,
    params=params,
).to("cuda" if torch.cuda.is_available() else "cpu")
fig = plt.figure()
for i in range(8):
    theta = (i / 8) * torch.pi
    phi = (0 / 8) * torch.pi
    gamma = (0 / 8) * torch.pi
    params = torch.Tensor([[sdr, theta, phi, gamma, bx, by, bz]])
    
    drr._update_params(params)
    source, rays = drr.detector.make_xrays(drr.sdr, drr.rotations, drr.translations)
    source_ = source.detach().cpu()
    rays_ = rays.permute(2, 0, 1).detach().cpu()

    ax = fig.add_subplot(2, 4, i+1, projection='3d')
    ax.set(title=f"$\\theta=\\frac{{{i}}}{{8}}\pi$")
    ax.scatter(source_[0, 0, 0]  , source_[0, 0, 1]  , source_[0, 0, 2]  , c="black")
    ax.scatter(rays_[0].flatten(), rays_[1].flatten(), rays_[2].flatten(), c=torch.arange(25), cmap="jet")

    xs, ys, zs = rays_.reshape(3, -1)
    for x, y, z in zip(xs, ys, zs):
        ax.plot([source_[0, 0, 0], x], [source_[0, 0, 1], y], [source_[0, 0, 2], z], "k", alpha=0.2)

    ax.set(xlabel="x", ylabel="y", zlabel="z")
    ax.set(xlim=[8,12], ylim=[-12,-8], zlim=[-42,-38])

plt.tight_layout()
plt.show()

fig = plt.figure()
for i in range(8):
    theta = (0 / 8) * torch.pi
    phi = (i / 8) * torch.pi
    gamma = (0 / 8) * torch.pi
    params = torch.Tensor([[sdr, theta, phi, gamma, bx, by, bz]])
    
    drr._update_params(params)
    source, rays = drr.detector.make_xrays(drr.sdr, drr.rotations, drr.translations)
    source_ = source.detach().cpu()
    rays_ = rays.permute(2, 0, 1).detach().cpu()

    ax = fig.add_subplot(2, 4, i+1, projection='3d')
    ax.set(title=f"$\\phi=\\frac{{{i}}}{{8}}\pi$")
    ax.scatter(source_[0, 0, 0]  , source_[0, 0, 1]  , source_[0, 0, 2]  , c="black")
    ax.scatter(rays_[0].flatten(), rays_[1].flatten(), rays_[2].flatten(), c=torch.arange(25), cmap="jet")

    xs, ys, zs = rays_.reshape(3, -1)
    for x, y, z in zip(xs, ys, zs):
        ax.plot([source_[0, 0, 0], x], [source_[0, 0, 1], y], [source_[0, 0, 2], z], "k", alpha=0.2)

    ax.set(xlabel="x", ylabel="y", zlabel="z")
    ax.set(xlim=[8,12], ylim=[-12,-8], zlim=[-42,-38])

plt.tight_layout()
plt.show()

fig = plt.figure()
for i in range(8):
    theta = (0 / 8) * torch.pi
    phi = (0 / 8) * torch.pi
    gamma = (i / 8) * torch.pi
    params = torch.Tensor([[sdr, theta, phi, gamma, bx, by, bz]])
    
    drr._update_params(params)
    source, rays = drr.detector.make_xrays(drr.sdr, drr.rotations, drr.translations)
    source_ = source.detach().cpu()
    rays_ = rays.permute(2, 0, 1).detach().cpu()

    ax = fig.add_subplot(2, 4, i+1, projection='3d')
    ax.set(title=f"$\\gamma=\\frac{{{i}}}{{8}}\pi$")
    ax.scatter(source_[0, 0, 0]  , source_[0, 0, 1]  , source_[0, 0, 2]  , label="Source", c="black")
    ax.scatter(rays_[0].flatten(), rays_[1].flatten(), rays_[2].flatten(), label="Targets", c=torch.arange(25), cmap="jet")

    xs, ys, zs = rays_.reshape(3, -1)
    for x, y, z in zip(xs, ys, zs):
        ax.plot([source_[0, 0, 0], x], [source_[0, 0, 1], y], [source_[0, 0, 2], z], "k", alpha=0.2)

    ax.set(xlabel="x", ylabel="y", zlabel="z")
    ax.set(xlim=[8,12], ylim=[-12,-8], zlim=[-42,-38])

plt.tight_layout()
plt.show()