diff --git a/CHANGES.rst b/CHANGES.rst
index a413f53cb..7dc048a0a 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -6,6 +6,11 @@ SCICO Release Notes
Version 0.0.5 (unreleased)
----------------------------
+• New integrated Radon/X-ray transform ``linop.XRayTransform``.
+• Rename modules ``radon_astra`` and ``radon_svmbir`` to ``xray.astra`` and
+ ``xray.svmbir`` respectively, and rename ``TomographicProjector`` classes
+ to ``XRayTransform``.
+• Rename ``AbelProjector`` to ``AbelTransform``.
• Rename ``solver.ATADSolver`` to ``solver.MatrixATADSolver``.
• Support ``jaxlib`` and ``jax`` versions 0.4.3 to 0.4.19.
diff --git a/data b/data
index b63329c3b..23b76fd4f 160000
--- a/data
+++ b/data
@@ -1 +1 @@
-Subproject commit b63329c3b1b89fbebc4cb3ec892badee0b989e40
+Subproject commit 23b76fd4fa092c186689af1dba9d058d6dc433bc
diff --git a/docs/source/examples.rst b/docs/source/examples.rst
index c6cc70e9a..291eeb5a6 100644
--- a/docs/source/examples.rst
+++ b/docs/source/examples.rst
@@ -36,7 +36,8 @@ Computed Tomography
examples/ct_astra_odp_train_foam2
examples/ct_astra_unet_train_foam2
examples/ct_projector_comparison
-
+ examples/ct_multi_cs_tv_admm
+ examples/ct_multi_tv_admm
Deconvolution
^^^^^^^^^^^^^
diff --git a/docs/source/inverse.rst b/docs/source/inverse.rst
index 1542f4f53..696657412 100644
--- a/docs/source/inverse.rst
+++ b/docs/source/inverse.rst
@@ -47,15 +47,15 @@ SCICO provides the :class:`.Operator` and :class:`.LinearOperator`
classes, which may be subclassed by users, in order to implement the
forward operator, :math:`A`. It also has several built-in operators,
most of which are linear, e.g., finite convolutions, discrete Fourier
-transforms, optical propagators, Abel transforms, and Radon
-transforms. For example,
+transforms, optical propagators, Abel transforms, and X-ray transforms
+(the same as Radon transforms in 2D). For example,
.. code:: python
input_shape = (512, 512)
angles = np.linspace(0, 2 * np.pi, 180, endpoint=False)
channels = 512
- A = scico.linop.radon_svmbir.ParallelBeamProjector(input_shape, angles, channels)
+ A = scico.linop.xray.svmbir.XRayTransform(input_shape, angles, channels)
defines a tomographic projection operator.
diff --git a/docs/source/notes.rst b/docs/source/notes.rst
index 08a63be0a..2986f5adc 100644
--- a/docs/source/notes.rst
+++ b/docs/source/notes.rst
@@ -111,13 +111,25 @@ via interfaces to the `bm3d `__ and
when the full benefits of JAX-based code are required.
-Tomographic Projectors
-----------------------
-
-The :class:`.radon_svmbir.TomographicProjector` class is implemented
+Tomographic Projectors/Radon Transforms
+---------------------------------------
+
+Note that the tomographic projections that are frequently referred
+to as Radon transforms are referred to as X-ray transforms in SCICO.
+While the Radon transform is far more well-known than the X-ray
+transform, which is the same as the Radon transform for projections
+in two dimensions, these two transform differ in higher numbers of
+dimensions, and it is the X-ray transform that is the appropriate
+mathematical model for beam attenuation based imaging in three or
+more dimensions.
+
+SCICO includes three different implementations of X-ray transforms.
+Of these, :class:`.linop.XRayTransform` is an integral component of
+SCICO, while the other two depend on external packages.
+The :class:`.xray.svmbir.XRayTransform` class is implemented
via an interface to the `svmbir
`__ package. The
-:class:`.radon_astra.TomographicProjector` class is implemented via an
+:class:`.xray.astra.XRayTransform` class is implemented via an
interface to the `ASTRA toolbox
`__. This toolbox does provide some
GPU acceleration support, but efficiency is expected to be lower than
diff --git a/examples/scripts/README.rst b/examples/scripts/README.rst
index 3910e9671..1cdb0e8e1 100644
--- a/examples/scripts/README.rst
+++ b/examples/scripts/README.rst
@@ -36,8 +36,11 @@ Computed Tomography
`ct_astra_unet_train_foam2.py `_
CT Training and Reconstructions with UNet
`ct_projector_comparison.py `_
- X-ray Projector Comparison
-
+ X-ray Transform Comparison
+ `ct_multi_cs_tv_admm.py `_
+ TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors, Common Sinogram)
+ `ct_multi_tv_admm.py `_
+ TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)
Deconvolution
^^^^^^^^^^^^^
diff --git a/examples/scripts/ct_abel_tv_admm.py b/examples/scripts/ct_abel_tv_admm.py
index 97ca30169..2adc141ce 100644
--- a/examples/scripts/ct_abel_tv_admm.py
+++ b/examples/scripts/ct_abel_tv_admm.py
@@ -26,7 +26,7 @@
import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.examples import create_circular_phantom
-from scico.linop.abel import AbelProjector
+from scico.linop.abel import AbelTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
@@ -40,7 +40,7 @@
"""
Set up the forward operator and create a test measurement.
"""
-A = AbelProjector(x_gt.shape)
+A = AbelTransform(x_gt.shape)
y = A @ x_gt
np.random.seed(12345)
y = y + np.random.normal(size=y.shape).astype(np.float32)
diff --git a/examples/scripts/ct_abel_tv_admm_tune.py b/examples/scripts/ct_abel_tv_admm_tune.py
index ab7ffd18f..c60ade412 100644
--- a/examples/scripts/ct_abel_tv_admm_tune.py
+++ b/examples/scripts/ct_abel_tv_admm_tune.py
@@ -38,7 +38,7 @@
import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.examples import create_circular_phantom
-from scico.linop.abel import AbelProjector
+from scico.linop.abel import AbelTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.ray import tune
@@ -52,7 +52,7 @@
"""
Set up the forward operator and create a test measurement.
"""
-A = AbelProjector(x_gt.shape)
+A = AbelTransform(x_gt.shape)
y = A @ x_gt
np.random.seed(12345)
y = y + np.random.normal(size=y.shape).astype(np.float32)
@@ -84,7 +84,7 @@ def setup(self, config, x_gt, x0, y):
# Get arrays passed by tune call.
self.x_gt, self.x0, self.y = snp.array(x_gt), snp.array(x0), snp.array(y)
# Set up problem to be solved.
- self.A = AbelProjector(self.x_gt.shape)
+ self.A = AbelTransform(self.x_gt.shape)
self.f = loss.SquaredL2Loss(y=self.y, A=self.A)
self.C = linop.FiniteDifference(input_shape=self.x_gt.shape)
self.reset_config(config)
diff --git a/examples/scripts/ct_astra_3d_tv_admm.py b/examples/scripts/ct_astra_3d_tv_admm.py
index 3abb9ae89..bb64ea61b 100644
--- a/examples/scripts/ct_astra_3d_tv_admm.py
+++ b/examples/scripts/ct_astra_3d_tv_admm.py
@@ -15,9 +15,9 @@
$$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
\|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$
-where $A$ is the Radon transform, $\mathbf{y}$ is the sinogram, $C$ is
-a 3D finite difference operator, and $\mathbf{x}$ is the desired
-image.
+where $A$ is the X-ray transform (the CT forward projection operator),
+$\mathbf{y}$ is the sinogram, $C$ is a 3D finite difference operator,
+and $\mathbf{x}$ is the desired image.
"""
@@ -28,7 +28,7 @@
import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.examples import create_tangle_phantom
-from scico.linop.radon_astra import TomographicProjector
+from scico.linop.xray.astra import XRayTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
@@ -43,9 +43,7 @@
n_projection = 10 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
-A = TomographicProjector(
- tangle.shape, [1.0, 1.0], [Nz, max(Nx, Ny)], angles
-) # Radon transform operator
+A = XRayTransform(tangle.shape, [1.0, 1.0], [Nz, max(Nx, Ny)], angles) # CT projection operator
y = A @ tangle # sinogram
diff --git a/examples/scripts/ct_astra_modl_train_foam2.py b/examples/scripts/ct_astra_modl_train_foam2.py
index 4214888ab..66a137e9c 100644
--- a/examples/scripts/ct_astra_modl_train_foam2.py
+++ b/examples/scripts/ct_astra_modl_train_foam2.py
@@ -54,7 +54,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
-from scico.linop.radon_astra import TomographicProjector
+from scico.linop.xray.astra import XRayTransform
"""
Prepare parallel processing. Set an arbitrary processor count (only
@@ -81,12 +81,12 @@
Build CT projection operator.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
-A = TomographicProjector(
+A = XRayTransform(
input_shape=(N, N),
detector_spacing=1,
det_count=N,
angles=angles,
-) # Radon transform operator
+) # CT projection operator
A = (1.0 / N) * A # normalized
diff --git a/examples/scripts/ct_astra_noreg_pcg.py b/examples/scripts/ct_astra_noreg_pcg.py
index fc5dd6f08..9e78f59fd 100644
--- a/examples/scripts/ct_astra_noreg_pcg.py
+++ b/examples/scripts/ct_astra_noreg_pcg.py
@@ -15,8 +15,8 @@
$$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
\|_2^2 \;,$$
-where $A$ is the Radon transform, $\mathbf{y}$ is the sinogram, and
-$\mathbf{x}$ is the reconstructed image.
+where $A$ is the X-ray transform (the CT forward projection operator),
+$\mathbf{y}$ is the sinogram, and $\mathbf{x}$ is the reconstructed image.
"""
from time import time
@@ -29,7 +29,7 @@
from scico import loss, plot
from scico.linop import CircularConvolve
-from scico.linop.radon_astra import TomographicProjector
+from scico.linop.xray.astra import XRayTransform
from scico.solver import cg
"""
@@ -45,7 +45,7 @@
"""
n_projection = N # matches the phantom size so this is not few-view CT
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
-A = 1 / N * TomographicProjector(x_gt.shape, 1, N, angles) # Radon transform operator
+A = 1 / N * XRayTransform(x_gt.shape, 1, N, angles) # CT projection operator
y = A @ x_gt # sinogram
diff --git a/examples/scripts/ct_astra_odp_train_foam2.py b/examples/scripts/ct_astra_odp_train_foam2.py
index 77b241dd2..bb9b1a54b 100644
--- a/examples/scripts/ct_astra_odp_train_foam2.py
+++ b/examples/scripts/ct_astra_odp_train_foam2.py
@@ -58,7 +58,7 @@
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
-from scico.linop.radon_astra import TomographicProjector
+from scico.linop.xray.astra import XRayTransform
"""
Prepare parallel processing. Set an arbitrary processor count (only
@@ -85,12 +85,12 @@
Build CT projection operator.
"""
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
-A = TomographicProjector(
+A = XRayTransform(
input_shape=(N, N),
detector_spacing=1,
det_count=N,
angles=angles,
-) # Radon transform operator
+) # CT projection operator
A = (1.0 / N) * A # normalized
diff --git a/examples/scripts/ct_astra_tv_admm.py b/examples/scripts/ct_astra_tv_admm.py
index 1f12f7ab3..69520f872 100644
--- a/examples/scripts/ct_astra_tv_admm.py
+++ b/examples/scripts/ct_astra_tv_admm.py
@@ -14,9 +14,9 @@
$$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
\|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$
-where $A$ is the Radon transform, $\mathbf{y}$ is the sinogram, $C$ is
-a 2D finite difference operator, and $\mathbf{x}$ is the desired
-image.
+where $A$ is the X-ray transform (the CT forward projection operator),
+$\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and
+$\mathbf{x}$ is the desired image.
"""
import numpy as np
@@ -26,7 +26,7 @@
import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
-from scico.linop.radon_astra import TomographicProjector
+from scico.linop.xray.astra import XRayTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
@@ -44,7 +44,7 @@
"""
n_projection = 45 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
-A = TomographicProjector(x_gt.shape, 1, N, angles) # Radon transform operator
+A = XRayTransform(x_gt.shape, 1, N, angles) # CT projection operator
y = A @ x_gt # sinogram
diff --git a/examples/scripts/ct_astra_weighted_tv_admm.py b/examples/scripts/ct_astra_weighted_tv_admm.py
index 3f14b828d..b3dc439c2 100644
--- a/examples/scripts/ct_astra_weighted_tv_admm.py
+++ b/examples/scripts/ct_astra_weighted_tv_admm.py
@@ -14,11 +14,11 @@
$$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
\|_W^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$
-where $A$ is the Radon transform, $\mathbf{y}$ is the sinogram, the norm
-weighting $W$ is chosen so that the weighted norm is an approximation to
-the Poisson negative log likelihood :cite:`sauer-1993-local`, $C$ is
-a 2D finite difference operator, and $\mathbf{x}$ is the desired
-image.
+where $A$ is the X-ray transform (the CT forward projection),
+$\mathbf{y}$ is the sinogram, the norm weighting $W$ is chosen so that
+the weighted norm is an approximation to the Poisson negative log
+likelihood :cite:`sauer-1993-local`, $C$ is a 2D finite difference
+operator, and $\mathbf{x}$ is the desired image.
"""
import numpy as np
@@ -27,7 +27,7 @@
import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
-from scico.linop.radon_astra import TomographicProjector
+from scico.linop.xray.astra import XRayTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
@@ -51,7 +51,7 @@
𝛼 = 1e-2 # attenuation coefficient
angles = np.linspace(0, 2 * np.pi, n_projection) # evenly spaced projection angles
-A = TomographicProjector(x_gt.shape, 1.0, N, angles) # Radon transform operator
+A = XRayTransform(x_gt.shape, 1.0, N, angles) # CT projection operator
y_c = A @ x_gt # sinogram
diff --git a/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py
index 1e334ada5..80299a1ae 100644
--- a/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py
+++ b/examples/scripts/ct_fan_svmbir_ppp_bm3d_admm_prox.py
@@ -35,7 +35,7 @@
from scico import metric, plot
from scico.functional import BM3D
from scico.linop import Diagonal, Identity
-from scico.linop.radon_svmbir import SVMBIRExtendedLoss, TomographicProjector
+from scico.linop.xray.svmbir import SVMBIRExtendedLoss, XRayTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
@@ -65,7 +65,7 @@
dist_source_detector = 1500.0
magnification = 1.2
-A_fan = TomographicProjector(
+A_fan = XRayTransform(
x_gt.shape,
angles,
num_channels,
@@ -73,7 +73,7 @@
dist_source_detector=dist_source_detector,
magnification=magnification,
)
-A_parallel = TomographicProjector(
+A_parallel = XRayTransform(
x_gt.shape,
angles,
num_channels,
diff --git a/examples/scripts/ct_multi_cs_tv_admm.py b/examples/scripts/ct_multi_cs_tv_admm.py
new file mode 100644
index 000000000..f7fcfcc10
--- /dev/null
+++ b/examples/scripts/ct_multi_cs_tv_admm.py
@@ -0,0 +1,162 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# This file is part of the SCICO package. Details of the copyright
+# and user license can be found in the 'LICENSE.txt' file distributed
+# with the package.
+
+r"""
+TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors, Common Sinogram)
+===================================================================================
+
+This example demonstrates solution of a sparse-view CT reconstruction
+problem with isotropic total variation (TV) regularization
+
+ $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
+ \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$
+
+where $A$ is the X-ray transform (the CT forward projection operator),
+$\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and
+$\mathbf{x}$ is the desired image. The solution is computed and compared
+for all three 2D CT projectors available in scico, using a sinogram
+computed with the svmbir projector.
+"""
+
+import numpy as np
+
+import jax
+
+from xdesign import Foam, discrete_phantom
+
+import scico.numpy as snp
+from scico import functional, linop, loss, metric, plot
+from scico.linop.xray import Parallel2dProjector, XRayTransform, astra, svmbir
+from scico.optimize.admm import ADMM, LinearSubproblemSolver
+from scico.util import device_info
+
+"""
+Create a ground truth image.
+"""
+N = 512 # phantom size
+np.random.seed(1234)
+x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
+x_gt = jax.device_put(x_gt)
+
+
+"""
+Define CT geometry and construct array of (approximately) equivalent projectors.
+"""
+n_projection = 45 # number of projections
+angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
+projectors = {
+ "astra": astra.XRayTransform(x_gt.shape, 1, N, angles - np.pi / 2.0), # astra
+ "svmbir": svmbir.XRayTransform(x_gt.shape, 2 * np.pi - angles, N), # svmbir
+ "scico": XRayTransform(Parallel2dProjector((N, N), angles, det_count=N)), # scico
+}
+
+
+"""
+Compute common sinogram using svmbir projector.
+"""
+A = projectors["svmbir"]
+noise = np.random.normal(size=(n_projection, N)).astype(np.float32)
+y = A @ x_gt + 2.0 * noise
+
+
+"""
+Solve the same problem using the different projectors.
+"""
+print(f"Solving on {device_info()}")
+x_rec, hist = {}, {}
+for p in ("astra", "svmbir", "scico"):
+ print(f"\nSolving with {p} projector")
+
+ # Set up ADMM solver object.
+ λ = 2e0 # L1 norm regularization parameter
+ ρ = 5e0 # ADMM penalty parameter
+ maxiter = 25 # number of ADMM iterations
+ cg_tol = 1e-4 # CG relative tolerance
+ cg_maxiter = 25 # maximum CG iterations per ADMM iteration
+
+ # The append=0 option makes the results of horizontal and vertical
+ # finite differences the same shape, which is required for the L21Norm,
+ # which is used so that g(Cx) corresponds to isotropic TV.
+ C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
+ g = λ * functional.L21Norm()
+ A = projectors[p]
+ f = loss.SquaredL2Loss(y=y, A=A)
+ x0 = snp.clip(A.T(y), 0, 1.0)
+
+ # Set up the solver.
+ solver = ADMM(
+ f=f,
+ g_list=[g],
+ C_list=[C],
+ rho_list=[ρ],
+ x0=x0,
+ maxiter=maxiter,
+ subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
+ itstat_options={"display": True, "period": 5},
+ )
+
+ # Run the solver.
+ solver.solve()
+ hist[p] = solver.itstat_object.history(transpose=True)
+ x_rec[p] = snp.clip(solver.x, 0, 1.0)
+
+
+"""
+Display sinogram.
+"""
+fig, ax = plot.subplots(nrows=1, ncols=1, figsize=(15, 3))
+plot.imview(y, title="sinogram", fig=fig, ax=ax)
+fig.show()
+
+
+"""
+Plot convergence statistics.
+"""
+fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(12, 5))
+plot.plot(
+ np.vstack([hist[p].Objective for p in projectors.keys()]).T,
+ title="Objective function",
+ xlbl="Iteration",
+ ylbl="Functional value",
+ lgnd=projectors.keys(),
+ fig=fig,
+ ax=ax[0],
+)
+plot.plot(
+ np.vstack([hist[p].Prml_Rsdl for p in projectors.keys()]).T,
+ ptyp="semilogy",
+ title="Primal Residual",
+ xlbl="Iteration",
+ fig=fig,
+ ax=ax[1],
+)
+plot.plot(
+ np.vstack([hist[p].Dual_Rsdl for p in projectors.keys()]).T,
+ ptyp="semilogy",
+ title="Dual Residual",
+ xlbl="Iteration",
+ fig=fig,
+ ax=ax[2],
+)
+fig.show()
+
+
+"""
+Show the recovered images.
+"""
+fig, ax = plot.subplots(nrows=1, ncols=4, figsize=(15, 5))
+plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0])
+for n, p in enumerate(projectors.keys()):
+ plot.imview(
+ x_rec[p],
+ title="%s SNR: %.2f (dB)" % (p, metric.snr(x_gt, x_rec[p])),
+ fig=fig,
+ ax=ax[n + 1],
+ )
+fig.show()
+
+
+input("\nWaiting for input to close figures and exit")
diff --git a/examples/scripts/ct_multi_tv_admm.py b/examples/scripts/ct_multi_tv_admm.py
new file mode 100644
index 000000000..8ed284c8d
--- /dev/null
+++ b/examples/scripts/ct_multi_tv_admm.py
@@ -0,0 +1,156 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# This file is part of the SCICO package. Details of the copyright
+# and user license can be found in the 'LICENSE.txt' file distributed
+# with the package.
+
+r"""
+TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors)
+==================================================================
+
+This example demonstrates solution of a sparse-view CT reconstruction
+problem with isotropic total variation (TV) regularization
+
+ $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
+ \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$
+
+where $A$ is the X-ray transform (the CT forward projection operator),
+$\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and
+$\mathbf{x}$ is the desired image. The solution is computed and compared
+for all three 2D CT projectors available in scico.
+"""
+
+import numpy as np
+
+import jax
+
+from xdesign import Foam, discrete_phantom
+
+import scico.numpy as snp
+from scico import functional, linop, loss, metric, plot
+from scico.linop.xray import Parallel2dProjector, XRayTransform, astra, svmbir
+from scico.optimize.admm import ADMM, LinearSubproblemSolver
+from scico.util import device_info
+
+"""
+Create a ground truth image.
+"""
+N = 512 # phantom size
+np.random.seed(1234)
+x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
+x_gt = jax.device_put(x_gt)
+
+
+"""
+Define CT geometry and construct array of (approximately) equivalent projectors.
+"""
+n_projection = 45 # number of projections
+angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
+projectors = {
+ "astra": astra.XRayTransform(x_gt.shape, 1, N, angles - np.pi / 2.0), # astra
+ "svmbir": svmbir.XRayTransform(x_gt.shape, 2 * np.pi - angles, N), # svmbir
+ "scico": XRayTransform(Parallel2dProjector((N, N), angles, det_count=N)), # scico
+}
+
+
+"""
+Solve the same problem using the different projectors.
+"""
+print(f"Solving on {device_info()}")
+y, x_rec, hist = {}, {}, {}
+noise = np.random.normal(size=(n_projection, N)).astype(np.float32)
+for p in ("astra", "svmbir", "scico"):
+ print(f"\nSolving with {p} projector")
+ A = projectors[p]
+ y[p] = A @ x_gt + 2.0 * noise # sinogram
+
+ # Set up ADMM solver object.
+ λ = 2e0 # L1 norm regularization parameter
+ ρ = 5e0 # ADMM penalty parameter
+ maxiter = 25 # number of ADMM iterations
+ cg_tol = 1e-4 # CG relative tolerance
+ cg_maxiter = 25 # maximum CG iterations per ADMM iteration
+
+ # The append=0 option makes the results of horizontal and vertical
+ # finite differences the same shape, which is required for the L21Norm,
+ # which is used so that g(Cx) corresponds to isotropic TV.
+ C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
+ g = λ * functional.L21Norm()
+ f = loss.SquaredL2Loss(y=y[p], A=A)
+ x0 = snp.clip(A.T(y[p]), 0, 1.0)
+
+ # Set up the solver.
+ solver = ADMM(
+ f=f,
+ g_list=[g],
+ C_list=[C],
+ rho_list=[ρ],
+ x0=x0,
+ maxiter=maxiter,
+ subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
+ itstat_options={"display": True, "period": 5},
+ )
+
+ # Run the solver.
+ solver.solve()
+ hist[p] = solver.itstat_object.history(transpose=True)
+ x_rec[p] = snp.clip(solver.x, 0, 1.0)
+
+
+"""
+Compare sinograms.
+"""
+fig, ax = plot.subplots(nrows=3, ncols=1, figsize=(15, 10))
+for idx, name in enumerate(projectors.keys()):
+ plot.imview(y[name], title=f"{name} sinogram", cbar=None, fig=fig, ax=ax[idx])
+fig.show()
+
+
+"""
+Plot convergence statistics.
+"""
+fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(12, 5))
+plot.plot(
+ np.vstack([hist[p].Objective for p in projectors.keys()]).T,
+ title="Objective function",
+ xlbl="Iteration",
+ ylbl="Functional value",
+ lgnd=projectors.keys(),
+ fig=fig,
+ ax=ax[0],
+)
+plot.plot(
+ np.vstack([hist[p].Prml_Rsdl for p in projectors.keys()]).T,
+ ptyp="semilogy",
+ title="Primal Residual",
+ xlbl="Iteration",
+ fig=fig,
+ ax=ax[1],
+)
+plot.plot(
+ np.vstack([hist[p].Dual_Rsdl for p in projectors.keys()]).T,
+ ptyp="semilogy",
+ title="Dual Residual",
+ xlbl="Iteration",
+ fig=fig,
+ ax=ax[2],
+)
+fig.show()
+
+
+"""
+Show the recovered images.
+"""
+fig, ax = plot.subplots(nrows=1, ncols=4, figsize=(15, 5))
+plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0])
+for n, p in enumerate(projectors.keys()):
+ plot.imview(
+ x_rec[p],
+ title="%s SNR: %.2f (dB)" % (p, metric.snr(x_gt, x_rec[p])),
+ fig=fig,
+ ax=ax[n + 1],
+ )
+fig.show()
+
+
+input("\nWaiting for input to close figures and exit")
diff --git a/examples/scripts/ct_projector_comparison.py b/examples/scripts/ct_projector_comparison.py
index 94b8d1d2c..58a31d4cd 100644
--- a/examples/scripts/ct_projector_comparison.py
+++ b/examples/scripts/ct_projector_comparison.py
@@ -6,11 +6,11 @@
r"""
-X-ray Projector Comparison
+X-ray Transform Comparison
==========================
-This example compares SCICO's native X-ray projection algorithm to that
-of the ASTRA Toolbox.
+This example compares SCICO's native X-ray transform algorithm
+to that of the ASTRA toolbox.
"""
import numpy as np
@@ -20,9 +20,9 @@
from xdesign import Foam, discrete_phantom
+import scico.linop.xray.astra as astra
from scico import plot
-from scico.linop import ParallelFixedAxis2dProjector, XRayProject
-from scico.linop.radon_astra import TomographicProjector
+from scico.linop import Parallel2dProjector, XRayTransform
from scico.util import Timer
"""
@@ -30,7 +30,9 @@
"""
N = 512
+
det_count = int(jnp.ceil(jnp.sqrt(2 * N**2)))
+
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = jnp.array(x_gt)
@@ -46,12 +48,12 @@
projectors = {}
timer.start("scico_init")
-projectors["scico"] = XRayProject(ParallelFixedAxis2dProjector((N, N), angles))
+projectors["scico"] = XRayTransform(Parallel2dProjector((N, N), angles))
timer.stop("scico_init")
timer.start("astra_init")
-projectors["astra"] = TomographicProjector(
- (N, N), detector_spacing=1.0, det_count=det_count, angles=angles
+projectors["astra"] = astra.XRayTransform(
+ (N, N), detector_spacing=1.0, det_count=det_count, angles=angles - jnp.pi / 2.0
)
timer.stop("astra_init")
@@ -59,9 +61,10 @@
"""
Time first projector application, which might include JIT overhead.
"""
+
ys = {}
for name, H in projectors.items():
- timer_label = f"{name}_first_proj"
+ timer_label = f"{name}_first_fwd"
timer.start(timer_label)
ys[name] = H @ x_gt
jax.block_until_ready(ys[name])
@@ -71,9 +74,10 @@
"""
Compute average time for a projector application.
"""
+
num_repeats = 3
for name, H in projectors.items():
- timer_label = f"{name}_avg_proj"
+ timer_label = f"{name}_avg_fwd"
timer.start(timer_label)
for _ in range(num_repeats):
ys[name] = H @ x_gt
@@ -82,62 +86,16 @@
timer.td[timer_label] /= num_repeats
-"""
-Display timing results.
-
-On our server, the SCICO projection is more than twice
-as fast as ASTRA when both are run on the GPU, and about
-10% slower when both are run the CPU.
-
-On our server, using the GPU:
-```
-Label Accum. Current
--------------------------------------------
-astra_avg_proj 4.62e-02 s Stopped
-astra_first_proj 6.92e-02 s Stopped
-astra_init 1.36e-03 s Stopped
-scico_avg_proj 1.61e-02 s Stopped
-scico_first_proj 2.95e-02 s Stopped
-scico_init 1.37e+01 s Stopped
-```
-
-Using the CPU:
-```
-Label Accum. Current
--------------------------------------------
-astra_avg_proj 9.11e-01 s Stopped
-astra_first_proj 9.16e-01 s Stopped
-astra_init 1.06e-03 s Stopped
-scico_avg_proj 1.03e+00 s Stopped
-scico_first_proj 1.04e+00 s Stopped
-scico_init 1.00e+01 s Stopped
-```
-"""
-
-print(timer)
-
-
-"""
-Show projections.
-"""
-fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3))
-plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0])
-plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1])
-fig.show()
-
-
"""
Time first back projection, which might include JIT overhead.
"""
-timer = Timer()
-
y = np.zeros(H.output_shape, dtype=np.float32)
y[num_angles // 3, det_count // 2] = 1.0
y = jnp.array(y)
HTys = {}
for name, H in projectors.items():
- timer_label = f"{name}_first_BP"
+ timer_label = f"{name}_first_back"
timer.start(timer_label)
HTys[name] = H.T @ y
jax.block_until_ready(ys[name])
@@ -149,7 +107,7 @@
"""
num_repeats = 3
for name, H in projectors.items():
- timer_label = f"{name}_avg_BP"
+ timer_label = f"{name}_avg_back"
timer.start(timer_label)
for _ in range(num_repeats):
HTys[name] = H.T @ y
@@ -159,41 +117,79 @@
"""
-Display back projection timing results.
+Display timing results.
-On our server, the SCICO back projection is slow the first time it is
-run, probably due to JIT overhead. After the first run, it is an order of
-magnitude faster than ASTRA when both are run on the GPU, and about three
-times faster when both are run on the CPU.
+On our server, the SCICO projection is more than twice as fast as ASTRA
+when both are run on the GPU, and about 10% slower when both are run the
+CPU. The SCICO back projection is slow the first time it is run, probably
+due to JIT overhead. After the first run, it is an order of magnitude
+faster than ASTRA when both are run on the GPU, and about three times
+faster when both are run on the CPU.
On our server, using the GPU:
```
-Label Accum. Current
------------------------------------------
-astra_avg_BP 3.71e-02 s Stopped
-astra_first_BP 4.20e-02 s Stopped
-scico_avg_BP 1.05e-03 s Stopped
-scico_first_BP 7.63e+00 s Stopped
+init astra 1.36e-03 s
+init scico 1.37e+01 s
+
+first fwd astra 6.92e-02 s
+first fwd scico 2.95e-02 s
+
+first back astra 4.20e-02 s
+first back scico 7.63e+00 s
+
+avg fwd astra 4.62e-02 s
+avg fwd scico 1.61e-02 s
+
+avg back astra 3.71e-02 s
+avg back scico 1.05e-03 s
```
Using the CPU:
```
-Label Accum. Current
------------------------------------------
-astra_avg_BP 9.34e-01 s Stopped
-astra_first_BP 9.39e-01 s Stopped
-scico_avg_BP 2.62e-01 s Stopped
-scico_first_BP 1.00e+01 s Stopped
+init astra 1.06e-03 s
+init scico 1.00e+01 s
+
+first fwd astra 9.16e-01 s
+first fwd scico 1.04e+00 s
+
+first back astra 9.39e-01 s
+first back scico 1.00e+01 s
+
+avg fwd astra 9.11e-01 s
+avg fwd scico 1.03e+00 s
+
+avg back astra 9.34e-01 s
+avg back scico 2.62e-01 s
```
"""
-print(timer)
+print(f"init astra {timer.td['astra_init']:.2e} s")
+print(f"init scico {timer.td['scico_init']:.2e} s")
+print("")
+for tstr in ("first", "avg"):
+ for dstr in ("fwd", "back"):
+ for pstr in ("astra", "scico"):
+ print(
+ f"{tstr:5s} {dstr:4s} {pstr} {timer.td[pstr + '_' + tstr + '_' + dstr]:.2e} s"
+ )
+ print()
+
+
+"""
+Show projections.
+"""
+
+fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6))
+plot.imview(ys["scico"], title="SCICO projection", cbar=None, fig=fig, ax=ax[0])
+plot.imview(ys["astra"], title="ASTRA projection", cbar=None, fig=fig, ax=ax[1])
+fig.show()
"""
Show back projections of a single detector element, i.e., a line.
"""
-fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 3))
+
+fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 6))
plot.imview(HTys["scico"], title="SCICO back projection (zoom)", cbar=None, fig=fig, ax=ax[0])
plot.imview(HTys["astra"], title="ASTRA back projection (zoom)", cbar=None, fig=fig, ax=ax[1])
for ax_i in ax:
diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
index d4b2e6050..390925d11 100644
--- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
+++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_cg.py
@@ -31,7 +31,7 @@
from scico import metric, plot
from scico.functional import BM3D, NonNegativeIndicator
from scico.linop import Diagonal, Identity
-from scico.linop.radon_svmbir import SVMBIRSquaredL2Loss, TomographicProjector
+from scico.linop.xray.svmbir import SVMBIRSquaredL2Loss, XRayTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
@@ -53,7 +53,7 @@
num_angles = int(N / 2)
num_channels = N
angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32)
-A = TomographicProjector(x_gt.shape, angles, num_channels)
+A = XRayTransform(x_gt.shape, angles, num_channels)
sino = A @ x_gt
diff --git a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
index 787709b86..a6e663a09 100644
--- a/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
+++ b/examples/scripts/ct_svmbir_ppp_bm3d_admm_prox.py
@@ -41,10 +41,10 @@
from scico import metric, plot
from scico.functional import BM3D, NonNegativeIndicator
from scico.linop import Diagonal, Identity
-from scico.linop.radon_svmbir import (
+from scico.linop.xray.svmbir import (
SVMBIRExtendedLoss,
SVMBIRSquaredL2Loss,
- TomographicProjector,
+ XRayTransform,
)
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
@@ -67,7 +67,7 @@
num_angles = int(N / 2)
num_channels = N
angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32)
-A = TomographicProjector(x_gt.shape, angles, num_channels)
+A = XRayTransform(x_gt.shape, angles, num_channels)
sino = A @ x_gt
diff --git a/examples/scripts/ct_svmbir_tv_multi.py b/examples/scripts/ct_svmbir_tv_multi.py
index 06d99696d..8592b44ff 100644
--- a/examples/scripts/ct_svmbir_tv_multi.py
+++ b/examples/scripts/ct_svmbir_tv_multi.py
@@ -14,7 +14,7 @@
$$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
\|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$
-where $A$ is the Radon transform (implemented using the SVMBIR
+where $A$ is the X-ray transform (implemented using the SVMBIR
:cite:`svmbir-2020` tomographic projection), $\mathbf{y}$ is the sinogram,
$C$ is a 2D finite difference operator, and $\mathbf{x}$ is the desired
image.
@@ -29,7 +29,7 @@
import scico.numpy as snp
from scico import functional, linop, metric, plot
from scico.linop import Diagonal
-from scico.linop.radon_svmbir import SVMBIRSquaredL2Loss, TomographicProjector
+from scico.linop.xray.svmbir import SVMBIRSquaredL2Loss, XRayTransform
from scico.optimize import PDHG, LinearizedADMM
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
@@ -52,7 +52,7 @@
num_angles = int(N / 2)
num_channels = N
angles = snp.linspace(0, snp.pi, num_angles, dtype=snp.float32)
-A = TomographicProjector(x_gt.shape, angles, num_channels)
+A = XRayTransform(x_gt.shape, angles, num_channels)
sino = A @ x_gt
diff --git a/examples/scripts/ct_tv_admm.py b/examples/scripts/ct_tv_admm.py
new file mode 100644
index 000000000..6aa3474a7
--- /dev/null
+++ b/examples/scripts/ct_tv_admm.py
@@ -0,0 +1,139 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# This file is part of the SCICO package. Details of the copyright
+# and user license can be found in the 'LICENSE.txt' file distributed
+# with the package.
+
+r"""
+TV-Regularized Sparse-View CT Reconstruction (Integrated Projector)
+===================================================================
+
+This example demonstrates solution of a sparse-view CT reconstruction
+problem with isotropic total variation (TV) regularization
+
+ $$\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
+ \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,$$
+
+where $A$ is the X-ray transform (the CT forward projection operator),
+$\mathbf{y}$ is the sinogram, $C$ is a 2D finite difference operator, and
+$\mathbf{x}$ is the desired image. This example uses the CT projector
+integrated into scico, while the companion
+[example script](ct_astra_tv_admm.rst) uses the projector provided by
+the astra package.
+"""
+
+import numpy as np
+
+import jax
+
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+from xdesign import Foam, discrete_phantom
+
+import scico.numpy as snp
+from scico import functional, linop, loss, metric, plot
+from scico.linop.xray import Parallel2dProjector, XRayTransform
+from scico.optimize.admm import ADMM, LinearSubproblemSolver
+from scico.util import device_info
+
+"""
+Create a ground truth image.
+"""
+N = 512 # phantom size
+np.random.seed(1234)
+x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
+x_gt = jax.device_put(x_gt) # convert to jax type, push to GPU
+
+
+"""
+Configure CT projection operator and generate synthetic measurements.
+"""
+n_projection = 45 # number of projections
+angles = np.linspace(0, np.pi, n_projection) + np.pi / 2.0 # evenly spaced projection angles
+A = XRayTransform(Parallel2dProjector((N, N), angles)) # CT projection operator
+y = A @ x_gt # sinogram
+
+
+"""
+Set up ADMM solver object.
+"""
+λ = 2e0 # L1 norm regularization parameter
+ρ = 5e0 # ADMM penalty parameter
+maxiter = 25 # number of ADMM iterations
+cg_tol = 1e-4 # CG relative tolerance
+cg_maxiter = 25 # maximum CG iterations per ADMM iteration
+
+# The append=0 option makes the results of horizontal and vertical
+# finite differences the same shape, which is required for the L21Norm,
+# which is used so that g(Cx) corresponds to isotropic TV.
+C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
+g = λ * functional.L21Norm()
+
+f = loss.SquaredL2Loss(y=y, A=A)
+
+x0 = snp.clip(A.T(y), 0, 1.0)
+
+solver = ADMM(
+ f=f,
+ g_list=[g],
+ C_list=[C],
+ rho_list=[ρ],
+ x0=x0,
+ maxiter=maxiter,
+ subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
+ itstat_options={"display": True, "period": 5},
+)
+
+
+"""
+Run the solver.
+"""
+print(f"Solving on {device_info()}\n")
+solver.solve()
+hist = solver.itstat_object.history(transpose=True)
+x_reconstruction = snp.clip(solver.x, 0, 1.0)
+
+
+"""
+Show the recovered image.
+"""
+
+fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(15, 5))
+plot.imview(x_gt, title="Ground truth", cbar=None, fig=fig, ax=ax[0])
+plot.imview(
+ x_reconstruction,
+ title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f"
+ % (metric.snr(x_gt, x_reconstruction), metric.mae(x_gt, x_reconstruction)),
+ fig=fig,
+ ax=ax[1],
+)
+divider = make_axes_locatable(ax[1])
+cax = divider.append_axes("right", size="5%", pad=0.2)
+fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units")
+fig.show()
+
+
+"""
+Plot convergence statistics.
+"""
+fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
+plot.plot(
+ hist.Objective,
+ title="Objective function",
+ xlbl="Iteration",
+ ylbl="Functional value",
+ fig=fig,
+ ax=ax[0],
+)
+plot.plot(
+ snp.vstack((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
+ ptyp="semilogy",
+ title="Residuals",
+ xlbl="Iteration",
+ lgnd=("Primal", "Dual"),
+ fig=fig,
+ ax=ax[1],
+)
+fig.show()
+
+
+input("\nWaiting for input to close figures and exit")
diff --git a/examples/scripts/index.rst b/examples/scripts/index.rst
index f03f8fa28..584711de2 100644
--- a/examples/scripts/index.rst
+++ b/examples/scripts/index.rst
@@ -23,7 +23,8 @@ Computed Tomography
- ct_astra_odp_train_foam2.py
- ct_astra_unet_train_foam2.py
- ct_projector_comparison.py
-
+ - ct_multi_cs_tv_admm.py
+ - ct_multi_tv_admm.py
Deconvolution
^^^^^^^^^^^^^
diff --git a/scico/flax/examples/data_generation.py b/scico/flax/examples/data_generation.py
index b846830c5..5d706aa28 100644
--- a/scico/flax/examples/data_generation.py
+++ b/scico/flax/examples/data_generation.py
@@ -48,7 +48,7 @@
have_astra = True
if have_astra:
- from scico.linop.radon_astra import TomographicProjector
+ from scico.linop.xray.astra import XRayTransform
# Arbitrary process count: only applies if GPU is not available.
@@ -210,7 +210,7 @@ def generate_ct_data(
angles = np.linspace(0, jnp.pi, nproj) # evenly spaced projection angles
gt_sh = (size, size)
detector_spacing = 1
- A = TomographicProjector(gt_sh, detector_spacing, size, angles) # Radon transform operator
+ A = XRayTransform(gt_sh, detector_spacing, size, angles) # Radon transform operator
# Compute sinograms in parallel.
a_map = lambda v: jnp.atleast_3d(A @ v.squeeze())
diff --git a/scico/linop/__init__.py b/scico/linop/__init__.py
index 0c14de950..598a26aa2 100644
--- a/scico/linop/__init__.py
+++ b/scico/linop/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright (C) 2021-2022 by SCICO Developers
+# Copyright (C) 2021-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
@@ -19,7 +19,7 @@
from ._matrix import MatrixOperator
from ._stack import DiagonalStack, VerticalStack
from ._util import jacobian, operator_norm, power_iteration, valid_adjoint
-from ._xray import ParallelFixedAxis2dProjector, XRayProject
+from .xray import Parallel2dProjector, XRayTransform
__all__ = [
"CircularConvolve",
@@ -39,8 +39,8 @@
"Sum",
"Transpose",
"LinearOperator",
- "XRayProject",
- "ParallelFixedAxis2dProjector",
+ "XRayTransform",
+ "Parallel2dProjector",
"ComposedLinearOperator",
"linop_from_function",
"operator_norm",
diff --git a/scico/linop/abel.py b/scico/linop/abel.py
index 9a94b3735..6aa2846ca 100644
--- a/scico/linop/abel.py
+++ b/scico/linop/abel.py
@@ -27,12 +27,12 @@
from scipy.linalg import solve_triangular
-class AbelProjector(LinearOperator):
- r"""Abel transform projector based on `PyAbel `_.
+class AbelTransform(LinearOperator):
+ r"""Abel transform based on `PyAbel `_.
- Perform Abel transform (parallel beam tomographic projection of
- cylindrically symmetric objects) for a 2D image. The input 2D image
- is assumed to be centered and left-right symmetric.
+ Perform Abel transform (parallel beam projection of cylindrically
+ symmetric objects) for a 2D image. The input 2D image is assumed to
+ be centered and left-right symmetric.
"""
def __init__(self, img_shape: Shape):
@@ -78,7 +78,7 @@ def inverse(self, y: jax.Array) -> jax.Array:
def _pyabel_transform(
x: jax.Array, direction: str, proj_mat_quad: jax.Array, symmetry_axis: Optional[list] = None
) -> jax.Array:
- """Perform Abel transformations (forward, inverse and transposed).
+ """Apply Abel transforms (forward, inverse and transposed).
This function contains code copied from `PyAbel `_.
"""
diff --git a/scico/linop/xray/__init__.py b/scico/linop/xray/__init__.py
new file mode 100644
index 000000000..49d5752c5
--- /dev/null
+++ b/scico/linop/xray/__init__.py
@@ -0,0 +1,26 @@
+# -*- coding: utf-8 -*-
+# Copyright (C) 2023 by SCICO Developers
+# All rights reserved. BSD 3-clause License.
+# This file is part of the SCICO package. Details of the copyright and
+# user license can be found in the 'LICENSE' file distributed with the
+# package.
+
+"""X-ray transform classes.
+
+The tomographic projections that are frequently referred to as Radon
+transforms are referred to as X-ray transforms in SCICO. While the Radon
+transform is far more well-known than the X-ray transform, which is the
+same as the Radon transform for projections in two dimensions, these two
+transform differ in higher numbers of dimensions, and it is the X-ray
+transform that is the appropriate mathematical model for beam attenuation
+based imaging in three or more dimensions.
+"""
+
+import sys
+
+from ._xray import Parallel2dProjector, XRayTransform
+
+__all__ = [
+ "XRayTransform",
+ "Parallel2dProjector",
+]
diff --git a/scico/linop/_xray.py b/scico/linop/xray/_xray.py
similarity index 66%
rename from scico/linop/_xray.py
rename to scico/linop/xray/_xray.py
index 40c649cf4..1f4069401 100644
--- a/scico/linop/_xray.py
+++ b/scico/linop/xray/_xray.py
@@ -1,14 +1,13 @@
# -*- coding: utf-8 -*-
-# Copyright (C) 2020-2023 by SCICO Developers
+# Copyright (C) 2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
+"""X-ray transform classes."""
+
-"""
-X-ray projector classes.
-"""
from functools import partial
from typing import Optional
@@ -20,14 +19,18 @@
from scico.typing import Shape
-from ._linop import LinearOperator
+from .._linop import LinearOperator
-class XRayProject(LinearOperator):
- """X-ray projection operator.
+class XRayTransform(LinearOperator):
+ """X-ray transform operator.
- Wraps an X-ray projector object in a SCICO
- :class:`LinearOperator`.
+ Wrap an X-ray projector object in a SCICO :class:`LinearOperator`.
+ **Warning:** Note that the only X-ray projector object currently
+ supported, :class:`.Parallel2dProjector`, is not a very accurate
+ approximation of the integral transform representing real projection
+ imaging, and may therefore not be suitable for real imaging
+ applications.
"""
def __init__(self, projector):
@@ -35,7 +38,7 @@ def __init__(self, projector):
Args:
projector: instance of an X-ray projector object to wrap,
currently the only option is
- :class:`ParallelFixedAxis2dProjector`
+ :class:`Parallel2dProjector`
"""
self._eval = projector.project
@@ -45,22 +48,26 @@ def __init__(self, projector):
)
-class ParallelFixedAxis2dProjector:
+class Parallel2dProjector:
"""Parallel ray, single axis, 2D X-ray projector."""
def __init__(
self,
im_shape: Shape,
angles: ArrayLike,
- det_length: Optional[int] = None,
+ det_count: Optional[int] = None,
dither: bool = True,
):
r"""
Args:
im_shape: Shape of input array.
- angles: (num_angles,) array of angles in radians.
- det_length: Length of detector, in ``None``, defaults to the
- length of diagonal of `im_shape`.
+ angles: (num_angles,) array of angles in radians. Viewing an
+ (M, N) array as a matrix with M rows and N columns, an
+ angle of 0 corresponds to summing rows, an angle of pi/2
+ corresponds to summing columns, and an angle of pi/4
+ corresponds to summing along antidiagonals.
+ det_count: Number of elements in detector. If ``None``,
+ defaults to the size of the diagonal of `im_shape`.
dither: If ``True`` randomly shift pixel locations to
reduce projection artifacts caused by aliasing.
"""
@@ -71,11 +78,11 @@ def __init__(
x0 = -(im_shape - 1) / 2
- if det_length is None:
- det_length = int(np.ceil(np.linalg.norm(im_shape)))
- self.det_shape = (det_length,)
+ if det_count is None:
+ det_count = int(np.ceil(np.linalg.norm(im_shape)))
+ self.det_shape = (det_count,)
- y0 = -det_length / 2
+ y0 = -det_count / 2
@jax.vmap
def compute_inds(angle: float) -> ArrayLike:
@@ -106,7 +113,7 @@ def compute_inds(angle: float) -> ArrayLike:
# map negative inds to y_size, which is out of bounds and will be ignored
# otherwise they index from the end like x[-1]
- inds = jnp.where(inds < 0, det_length, inds)
+ inds = jnp.where(inds < 0, det_count, inds)
return inds
@@ -115,7 +122,7 @@ def compute_inds(angle: float) -> ArrayLike:
@partial(jax.vmap, in_axes=(None, 0))
def project_inds(im: ArrayLike, inds: ArrayLike) -> ArrayLike:
"""Compute the projection at a single angle."""
- return jnp.zeros(det_length).at[inds].add(im)
+ return jnp.zeros(det_count).at[inds].add(im)
@jax.jit
def project(im: ArrayLike) -> ArrayLike:
diff --git a/scico/linop/radon_astra.py b/scico/linop/xray/astra.py
similarity index 85%
rename from scico/linop/radon_astra.py
rename to scico/linop/xray/astra.py
index 6a5a337b9..b877891ad 100644
--- a/scico/linop/radon_astra.py
+++ b/scico/linop/xray/astra.py
@@ -5,9 +5,9 @@
# user license can be found in the 'LICENSE' file distributed with the
# package.
-"""Radon transform LinearOperator wrapping the ASTRA toolbox.
+"""X-ray transform LinearOperator wrapping the ASTRA toolbox.
-Radon transform :class:`.LinearOperator` wrapping the parallel beam
+X-ray transform :class:`.LinearOperator` wrapping the parallel beam
projections in the
`ASTRA toolbox `_.
This package provides both C and CUDA implementations of core
@@ -37,11 +37,11 @@
from scico.typing import Shape
-from ._linop import LinearOperator
+from .._linop import LinearOperator
-class TomographicProjector(LinearOperator):
- r"""Parallel beam Radon transform based on the ASTRA toolbox.
+class XRayTransform(LinearOperator):
+ r"""Parallel beam X-ray transform based on the ASTRA toolbox.
Perform tomographic projection (also called X-ray projection) of an
image or volume at specified angles, using the
@@ -61,24 +61,28 @@ def __init__(
Args:
input_shape: Shape of the input array. Determines whether 2D
or 3D algorithm is used.
- detector_spacing: Spacing between detector elements. See
- https://www.astra-toolbox.com/docs/geom2d.html#projection-geometries
+ detector_spacing: Spacing between detector elements. See the
+ astra documentation for more information for
+ `2d `__
or
- https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries
- for more information.
- det_count: Number of detector elements. See
- https://www.astra-toolbox.com/docs/geom2d.html#projection-geometries
+ `3d `__
+ geometries.
+ det_count: Number of detector elements. See the astra
+ documentation for more information for
+ `2d `__
or
- https://www.astra-toolbox.com/docs/geom3d.html#projection-geometries
- for more information.
+ `3d `__
+ geometries.
angles: Array of projection angles in radians.
volume_geometry: Specification of the shape of the
discretized reconstruction volume. Must either ``None``,
in which case it is inferred from `input_shape`, or
- follow the astra syntax described in
- https://www.astra-toolbox.com/docs/geom2d.html#volume-geometries
+ follow the astra syntax described in the astra
+ documentation for
+ `2d `__
or
- https://www.astra-toolbox.com/docs/geom3d.html#d-geometries.
+ `3d `__
+ geometries.
device: Specifies device for projection operation.
One of ["auto", "gpu", "cpu"]. If "auto", a GPU is used if
available, otherwise, the CPU is used.
@@ -87,7 +91,7 @@ def __init__(
self.num_dims = len(input_shape)
if self.num_dims not in [2, 3]:
raise ValueError(
- f"Only 2D and 3D projections are supported, but `input_shape` is {input_shape}."
+ f"Only 2D and 3D projections are supported, but input_shape is {input_shape}."
)
output_shape: Shape
@@ -96,7 +100,7 @@ def __init__(
elif self.num_dims == 3:
assert isinstance(det_count, (list, tuple))
if len(det_count) != 2:
- raise ValueError("Expected `det_count` to have 2 elements")
+ raise ValueError("Expected det_count to have 2 elements")
output_shape = (det_count[0], len(angles), det_count[1])
# Set up all the ASTRA config
@@ -112,7 +116,7 @@ def __init__(
assert isinstance(detector_spacing, (list, tuple))
assert isinstance(det_count, (list, tuple))
if len(detector_spacing) != 2:
- raise ValueError("Expected `detector_spacing` to have 2 elements")
+ raise ValueError("Expected detector_spacing to have 2 elements")
self.proj_geom = astra.create_proj_geom(
"parallel3d",
detector_spacing[0],
@@ -132,7 +136,7 @@ def __init__(
self.vol_geom: dict = astra.create_vol_geom(*input_shape, *volume_geometry)
else:
raise ValueError(
- "`volume_geometry` must be a tuple of len 4 (2D) or 6 (3D)."
+ "volume_geometry must be a tuple of len 4 (2D) or 6 (3D)."
"Please see the astra documentation for details."
)
else:
@@ -152,7 +156,7 @@ def __init__(
raise ValueError(f"Invalid device specified; got {device}.")
if self.num_dims == 3 and self.device == "cpu":
- raise ValueError("No CPU algorithm exists for 3D tomography.")
+ raise ValueError("No CPU algorithm for 3D projection.")
if self.num_dims == 3:
# not needed for astra's 3D algorithm
@@ -227,7 +231,7 @@ def fbp(self, sino: jax.Array, filter_type: str = "Ram-Lak") -> jax.Array:
"""
if self.num_dims == 3:
- raise NotImplementedError("3D FBP is not implemented")
+ raise NotImplementedError("3D FBP is not implemented.")
# Just use the CPU FBP alg for now; hitting memory issues with GPU one.
def f(sino):
diff --git a/scico/linop/radon_svmbir.py b/scico/linop/xray/svmbir.py
similarity index 95%
rename from scico/linop/radon_svmbir.py
rename to scico/linop/xray/svmbir.py
index 6d81b0fb7..8e757da84 100644
--- a/scico/linop/radon_svmbir.py
+++ b/scico/linop/xray/svmbir.py
@@ -5,9 +5,9 @@
# user license can be found in the 'LICENSE' file distributed with the
# package.
-"""Tomographic projector LinearOperator wrapping the svmbir package.
+"""X-ray transform LinearOperator wrapping the svmbir package.
-Tomographic projector :class:`.LinearOperator` wrapping the
+X-ray transform :class:`.LinearOperator` wrapping the
`svmbir `_ package. Since this
package is an interface to compiled C code, JAX features such as
automatic differentiation and support for GPU devices are not available.
@@ -24,8 +24,8 @@
from scico.loss import Loss, SquaredL2Loss
from scico.typing import Shape
-from ._diag import Diagonal, Identity
-from ._linop import LinearOperator
+from .._diag import Diagonal, Identity
+from .._linop import LinearOperator
try:
import svmbir
@@ -33,8 +33,8 @@
raise ImportError("Could not import svmbir; please install it.")
-class TomographicProjector(LinearOperator):
- r"""Tomographic projector based on svmbir.
+class XRayTransform(LinearOperator):
+ r"""X-ray transform based on svmbir.
Perform tomographic projection of an image at specified angles, using
the `svmbir `_ package. The
@@ -42,7 +42,7 @@ class TomographicProjector(LinearOperator):
(pixels outside this region are ignored when performing the
projection) is active. This region of validity is also respected by
:meth:`.SVMBIRSquaredL2Loss.prox` when :class:`.SVMBIRSquaredL2Loss`
- is initialized with a :class:`TomographicProjector` with this option
+ is initialized with a :class:`XRayTransform` with this option
enabled.
A brief description of the supported scanner geometries can be found
@@ -316,7 +316,7 @@ class SVMBIRExtendedLoss(Loss):
\alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} -
A(\mb{x})\right) \;,
- where :math:`A` is a :class:`.TomographicProjector`,
+ where :math:`A` is a :class:`.XRayTransform`,
:math:`\alpha` is the scaling parameter and :math:`W` is an instance
of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set
to :class:`scico.linop.Identity`.
@@ -325,12 +325,12 @@ class SVMBIRExtendedLoss(Loss):
:math:`\ell_2` loss as follows. When `positivity=True`, the prox
projects onto the non-negative orthant and the loss is infinite if
any element of the input is negative. When the `is_masked` option
- of the associated :class:`.TomographicProjector` is ``True``, the
+ of the associated :class:`.XRayTransform` is ``True``, the
reconstruction is computed over a masked region of the image as
- described in class :class:`.TomographicProjector`.
+ described in class :class:`.XRayTransform`.
"""
- A: TomographicProjector
+ A: XRayTransform
W: Union[Identity, Diagonal]
def __init__(
@@ -358,8 +358,8 @@ def __init__(
"""
super().__init__(*args, scale=scale, **kwargs) # type: ignore
- if not isinstance(self.A, TomographicProjector):
- raise ValueError("LinearOperator A must be a radon_svmbir.TomographicProjector.")
+ if not isinstance(self.A, XRayTransform):
+ raise ValueError("LinearOperator A must be a radon_svmbir.XRayTransform.")
self.has_prox = True
@@ -445,7 +445,7 @@ class SVMBIRSquaredL2Loss(SVMBIRExtendedLoss, SquaredL2Loss):
\alpha \left(\mb{y} - A(\mb{x})\right)^T W \left(\mb{y} -
A(\mb{x})\right) \;,
- where :math:`A` is a :class:`.TomographicProjector`, :math:`\alpha`
+ where :math:`A` is a :class:`.XRayTransform`, :math:`\alpha`
is the scaling parameter and :math:`W` is an instance
of :class:`scico.linop.Diagonal`. If :math:`W` is ``None``, it is set
to :class:`scico.linop.Identity`.
@@ -473,5 +473,5 @@ def __init__(
if self.A.is_masked:
raise ValueError(
- "Parameter is_masked must be False for the TomographicProjector in SVMBIRSquaredL2Loss."
+ "Parameter is_masked must be False for the XRayTransform in SVMBIRSquaredL2Loss."
)
diff --git a/scico/operator/_operator.py b/scico/operator/_operator.py
index baef6ec91..eabaf356b 100644
--- a/scico/operator/_operator.py
+++ b/scico/operator/_operator.py
@@ -383,7 +383,7 @@ def concat_args(args):
# concat_args(args) = snp.blockarray([args, val]) if argnum = 1
if isinstance(args, (jnp.ndarray, np.ndarray)):
- # In the case that the original operator takes a blockkarray with two
+ # In the case that the original operator takes a blockarray with two
# blocks, wrap in a list so we can use the same indexing as >2 block case
args = [args]
diff --git a/scico/optimize/_admmaux.py b/scico/optimize/_admmaux.py
index 5e7808989..8cf4eac3d 100644
--- a/scico/optimize/_admmaux.py
+++ b/scico/optimize/_admmaux.py
@@ -220,7 +220,6 @@ def internal_init(self, admm: soa.ADMM):
# hessian = A.T @ W @ A; W may be identity
lhs_op += admm.f.hessian
- lhs_op.jit()
self.lhs_op = lhs_op
def compute_rhs(self) -> Union[Array, BlockArray]:
diff --git a/scico/test/flax/test_inv.py b/scico/test/flax/test_inv.py
index b8fb3462c..2326fc430 100644
--- a/scico/test/flax/test_inv.py
+++ b/scico/test/flax/test_inv.py
@@ -16,7 +16,7 @@
from scico.linop import CircularConvolve, Identity
if have_astra:
- from scico.linop.radon_astra import TomographicProjector
+ from scico.linop.xray.astra import XRayTransform
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
@@ -107,7 +107,7 @@ def setup_method(self, method):
self.nproj = 60 # number of projections
angles = np.linspace(0, np.pi, self.nproj) # evenly spaced projection angles
- self.opCT = TomographicProjector(
+ self.opCT = XRayTransform(
input_shape=(self.N, self.N),
detector_spacing=1,
det_count=self.N,
diff --git a/scico/test/linop/test_abel.py b/scico/test/linop/test_abel.py
index ca5323af3..a5024ba7a 100644
--- a/scico/test/linop/test_abel.py
+++ b/scico/test/linop/test_abel.py
@@ -5,7 +5,7 @@
import pytest
import scico.numpy as snp
-from scico.linop.abel import AbelProjector
+from scico.linop.abel import AbelTransform
from scico.test.linop.test_linop import adjoint_test
BIG_INPUT = (128, 128)
@@ -23,7 +23,7 @@ def make_im(Nx, Ny):
@pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT))
def test_inverse(Nx, Ny):
im = make_im(Nx, Ny)
- A = AbelProjector(im.shape)
+ A = AbelTransform(im.shape)
Ax = A @ im
im_hat = A.inverse(Ax)
@@ -34,14 +34,14 @@ def test_inverse(Nx, Ny):
@pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT))
def test_adjoint(Nx, Ny):
im = make_im(Nx, Ny)
- A = AbelProjector(im.shape)
+ A = AbelTransform(im.shape)
adjoint_test(A)
@pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT))
def test_ATA(Nx, Ny):
x = make_im(Nx, Ny)
- A = AbelProjector(x.shape)
+ A = AbelTransform(x.shape)
Ax = A(x)
ATAx = A.adj(Ax)
np.testing.assert_allclose(np.sum(x * ATAx), np.linalg.norm(Ax) ** 2, rtol=5e-5)
@@ -52,7 +52,7 @@ def test_grad(Nx, Ny):
# ensure that we can take grad on a function using our projector
# grad || A(x) ||_2^2 == 2 A.T @ A x
x = make_im(Nx, Ny)
- A = AbelProjector(x.shape)
+ A = AbelTransform(x.shape)
g = lambda x: jax.numpy.linalg.norm(A(x)) ** 2
np.testing.assert_allclose(jax.grad(g)(x), 2 * A.adj(A(x)), rtol=5e-5)
@@ -60,7 +60,7 @@ def test_grad(Nx, Ny):
@pytest.mark.parametrize("Nx, Ny", (BIG_INPUT, SMALL_INPUT))
def test_adjoint_grad(Nx, Ny):
x = make_im(Nx, Ny)
- A = AbelProjector(x.shape)
+ A = AbelTransform(x.shape)
Ax = A @ x
f = lambda y: jax.numpy.linalg.norm(A.T(y)) ** 2
np.testing.assert_allclose(jax.grad(f)(Ax), 2 * A(A.adj(Ax)), rtol=5e-5)
diff --git a/scico/test/linop/test_xray.py b/scico/test/linop/test_xray.py
deleted file mode 100644
index bb827988f..000000000
--- a/scico/test/linop/test_xray.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import jax.numpy as jnp
-
-from scico.linop import ParallelFixedAxis2dProjector, XRayProject
-
-
-def test_apply():
- im_shape = (12, 13)
- num_angles = 10
- x = jnp.ones(im_shape)
-
- angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)
-
- # general projection
- H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles))
- y = H @ x
- assert y.shape[0] == (num_angles)
-
- # fixed det_length
- det_length = 14
- H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles, det_length=det_length))
- y = H @ x
- assert y.shape[1] == det_length
-
- # dither off
- H = XRayProject(ParallelFixedAxis2dProjector(x.shape, angles, dither=False))
- y = H @ x
diff --git a/scico/test/linop/test_radon_astra.py b/scico/test/linop/xray/test_astra.py
similarity index 87%
rename from scico/test/linop/test_radon_astra.py
rename to scico/test/linop/xray/test_astra.py
index 42b4e0f57..67699da76 100644
--- a/scico/test/linop/test_radon_astra.py
+++ b/scico/test/linop/xray/test_astra.py
@@ -8,10 +8,10 @@
import scico.numpy as snp
from scico.linop import DiagonalStack
from scico.test.linop.test_linop import adjoint_test
-from scico.test.linop.test_radon_svmbir import make_im
+from scico.test.linop.xray.test_svmbir import make_im
try:
- from scico.linop.radon_astra import TomographicProjector
+ from scico.linop.xray.astra import XRayTransform
except ModuleNotFoundError as e:
if e.name == "astra":
pytest.skip("astra not installed", allow_module_level=True)
@@ -41,7 +41,7 @@ def get_tol_random_input():
return rtol
-class TomographicProjectorTest:
+class XRayTransformTest:
def __init__(self, volume_geometry):
N_proj = 180 # number of projection angles
N_det = 384
@@ -51,7 +51,7 @@ def __init__(self, volume_geometry):
np.random.seed(1234)
self.x = np.random.randn(N, N).astype(np.float32)
self.y = np.random.randn(N_proj, N_det).astype(np.float32)
- self.A = TomographicProjector(
+ self.A = XRayTransform(
input_shape=(N, N),
volume_geometry=volume_geometry,
detector_spacing=detector_spacing,
@@ -62,7 +62,7 @@ def __init__(self, volume_geometry):
@pytest.fixture(params=[None, [-N / 2, N / 2, -N / 2, N / 2]])
def testobj(request):
- yield TomographicProjectorTest(request.param)
+ yield XRayTransformTest(request.param)
def test_ATA_call(testobj):
@@ -125,7 +125,7 @@ def test_adjoint_typical_input(testobj):
def test_jit_in_DiagonalStack():
"""See https://github.com/lanl/scico/issues/331"""
N = 10
- H = DiagonalStack([TomographicProjector((N, N), 1.0, N, snp.linspace(0, snp.pi, N))])
+ H = DiagonalStack([XRayTransform((N, N), 1.0, N, snp.linspace(0, snp.pi, N))])
H.T @ snp.zeros(H.output_shape, dtype=snp.float32)
@@ -133,13 +133,13 @@ def test_jit_in_DiagonalStack():
def test_3D_on_CPU():
x = snp.zeros((4, 5, 6))
with pytest.raises(ValueError):
- A = TomographicProjector(x.shape, [1.0, 1.0], [6, 6], snp.linspace(0, snp.pi, 10))
+ A = XRayTransform(x.shape, [1.0, 1.0], [6, 6], snp.linspace(0, snp.pi, 10))
@pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="checking GPU behavior")
def test_3D_on_GPU():
x = snp.zeros((4, 5, 6))
- A = TomographicProjector(x.shape, [1.0, 1.0], [6, 6], snp.linspace(0, snp.pi, 10))
+ A = XRayTransform(x.shape, [1.0, 1.0], [6, 6], snp.linspace(0, snp.pi, 10))
assert A.num_dims == 3
y = A @ x
diff --git a/scico/test/linop/test_radon_svmbir.py b/scico/test/linop/xray/test_svmbir.py
similarity index 98%
rename from scico/test/linop/test_radon_svmbir.py
rename to scico/test/linop/xray/test_svmbir.py
index a41629d6a..9674269c0 100644
--- a/scico/test/linop/test_radon_svmbir.py
+++ b/scico/test/linop/xray/test_svmbir.py
@@ -14,10 +14,10 @@
try:
import svmbir
- from scico.linop.radon_svmbir import (
+ from scico.linop.xray.svmbir import (
SVMBIRExtendedLoss,
SVMBIRSquaredL2Loss,
- TomographicProjector,
+ XRayTransform,
)
except ImportError as e:
pytest.skip("svmbir not installed", allow_module_level=True)
@@ -90,7 +90,7 @@ def make_A(
):
angles = make_angles(num_angles)
- A = TomographicProjector(
+ A = XRayTransform(
im.shape,
angles,
num_channels,
diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray.py
new file mode 100644
index 000000000..288da22b0
--- /dev/null
+++ b/scico/test/linop/xray/test_xray.py
@@ -0,0 +1,26 @@
+import jax.numpy as jnp
+
+from scico.linop import Parallel2dProjector, XRayTransform
+
+
+def test_apply():
+ im_shape = (12, 13)
+ num_angles = 10
+ x = jnp.ones(im_shape)
+
+ angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)
+
+ # general projection
+ H = XRayTransform(Parallel2dProjector(x.shape, angles))
+ y = H @ x
+ assert y.shape[0] == (num_angles)
+
+ # fixed det_count
+ det_count = 14
+ H = XRayTransform(Parallel2dProjector(x.shape, angles, det_count=det_count))
+ y = H @ x
+ assert y.shape[1] == det_count
+
+ # dither off
+ H = XRayTransform(Parallel2dProjector(x.shape, angles, dither=False))
+ y = H @ x