seaborn.set_theme(context= "notebook" , style= "ticks" )
def plot(idx, zmin= None , zmax= None ):
if idx == 2 or idx == 3 :
multiplier = - 1
else :
multiplier = 1
### 3D
fig = plt.figure(figsize= (10 , 6.5 ), dpi= 300 )
axs = []
# Angles
xyx, xyy = torch.meshgrid(t_angles, p_angles, indexing= "ij" )
xzx, xzz = torch.meshgrid(t_angles, g_angles, indexing= "ij" )
yzy, yzz = torch.meshgrid(p_angles, g_angles, indexing= "ij" )
ax = fig.add_subplot(2 , 3 , 1 , projection= "3d" )
ax.contourf(
xyx.numpy(),
xyy.numpy(),
multiplier * TP[..., idx].numpy(),
zdir= "z" ,
offset= (multiplier * TP[..., idx]).min (),
cmap= plt.get_cmap("rainbow" ),
alpha= 0.5 ,
)
ax.plot_surface(
xyx.numpy(),
xyy.numpy(),
multiplier * TP[..., idx].numpy(),
rstride= 1 ,
cstride= 1 ,
cmap= plt.get_cmap("rainbow" ),
linewidth= 0.0 ,
)
ax.set_xlabel("Δα (radians)" )
ax.set_ylabel("Δβ (radians)" )
ax.set_zlim3d(zmin, zmax)
axs.append(ax)
ax = fig.add_subplot(2 , 3 , 2 , projection= "3d" )
plt.title(
[
"Gradient NCC" ,
"Local NCC" ,
"-MAE" ,
"-MSE" ,
"Global NCC" ,
"PSNR" ,
"SSIM" ,
"mNCC" ,
"mSSIM" ,
][idx]
)
ax.contourf(
xzx.numpy(),
xzz.numpy(),
multiplier * TG[..., idx].numpy(),
zdir= "z" ,
offset= (multiplier * TG[..., idx]).min (),
cmap= plt.get_cmap("rainbow" ),
alpha= 0.5 ,
)
ax.plot_surface(
xzx.numpy(),
xzz.numpy(),
multiplier * TG[..., idx].numpy(),
rstride= 1 ,
cstride= 1 ,
cmap= plt.get_cmap("rainbow" ),
linewidth= 0.0 ,
)
ax.set_xlabel("Δα (radians)" )
ax.set_ylabel("Δγ (radians)" )
ax.set_zlim3d(zmin, zmax)
axs.append(ax)
ax = fig.add_subplot(2 , 3 , 3 , projection= "3d" )
ax.contourf(
yzy.numpy(),
yzz.numpy(),
multiplier * PG[..., idx].numpy(),
zdir= "z" ,
offset= (multiplier * PG[..., idx]).min (),
cmap= plt.get_cmap("rainbow" ),
alpha= 0.5 ,
)
ax.plot_surface(
yzy.numpy(),
yzz.numpy(),
multiplier * PG[..., idx].numpy(),
rstride= 1 ,
cstride= 1 ,
cmap= plt.get_cmap("rainbow" ),
linewidth= 0.0 ,
)
ax.set_xlabel("Δβ (radians)" )
ax.set_ylabel("Δγ (radians)" )
ax.set_zlim3d(zmin, zmax)
axs.append(ax)
# Angles
xyx, xyy = torch.meshgrid(xs, ys, indexing= "ij" )
xzx, xzz = torch.meshgrid(xs, zs, indexing= "ij" )
yzy, yzz = torch.meshgrid(ys, zs, indexing= "ij" )
ax = fig.add_subplot(2 , 3 , 4 , projection= "3d" )
ax.contourf(
xyx.numpy(),
xyy.numpy(),
multiplier * XY[..., idx],
zdir= "z" ,
offset= (multiplier * XY[..., idx]).min (),
cmap= plt.get_cmap("rainbow" ),
alpha= 0.5 ,
)
ax.plot_surface(
xyx.numpy(),
xyy.numpy(),
multiplier * XY[..., idx].numpy(),
rstride= 1 ,
cstride= 1 ,
cmap= plt.get_cmap("rainbow" ),
linewidth= 0.0 ,
)
ax.set_xlabel("ΔX (mm)" )
ax.set_ylabel("ΔY (mm)" )
ax.set_zlim3d(zmin, zmax)
axs.append(ax)
ax = fig.add_subplot(2 , 3 , 5 , projection= "3d" )
ax.contourf(
xzx.numpy(),
xzz.numpy(),
multiplier * XZ[..., idx].numpy(),
zdir= "z" ,
offset= (multiplier * XZ[..., idx]).min (),
cmap= plt.get_cmap("rainbow" ),
alpha= 0.5 ,
)
ax.plot_surface(
xzx.numpy(),
xzz.numpy(),
multiplier * XZ[..., idx].numpy(),
rstride= 1 ,
cstride= 1 ,
cmap= plt.get_cmap("rainbow" ),
linewidth= 0.0 ,
)
ax.set_xlabel("ΔX (mm)" )
ax.set_ylabel("ΔZ (mm)" )
ax.set_zlim3d(zmin, zmax)
axs.append(ax)
ax = fig.add_subplot(2 , 3 , 6 , projection= "3d" )
ax.contourf(
yzy.numpy(),
yzz.numpy(),
multiplier * YZ[..., idx].numpy(),
zdir= "z" ,
offset= (multiplier * YZ[..., idx]).min (),
cmap= plt.get_cmap("rainbow" ),
alpha= 0.5 ,
)
ax.plot_surface(
yzy.numpy(),
yzz.numpy(),
multiplier * YZ[..., idx].numpy(),
rstride= 1 ,
cstride= 1 ,
cmap= plt.get_cmap("rainbow" ),
linewidth= 0.0 ,
)
ax.set_xlabel("ΔY (mm)" )
ax.set_ylabel("ΔZ (mm)" )
ax.set_zlim3d(zmin, zmax)
axs.append(ax)
return fig, axs