Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix XRayTransform2D projection dtype and docs #557

Merged
merged 7 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


class XRayTransform2D(LinearOperator):
"""Parallel ray, single axis, 2D X-ray projector.
r"""Parallel ray, single axis, 2D X-ray projector.

This implementation approximates the projection of each rectangular
pixel as a boxcar function (whereas the exact projection is a
Expand All @@ -42,6 +42,9 @@ class XRayTransform2D(LinearOperator):
accumulation of pixel values into bins (equivalently, makes the
linear operator sparse).

Warning: The default pixel spacing is :math:`\sqrt{2}/2` (rather
than 1) in order to satisfy the aforementioned spacing requirement.

`x0`, `dx`, and `y0` should be expressed in units such that the
detector spacing `dy` is 1.0.
"""
Expand All @@ -64,9 +67,11 @@ def __init__(
corresponds to summing columns, and an angle of pi/4
corresponds to summing along antidiagonals.
x0: (x, y) position of the corner of the pixel `im[0,0]`. By
default, `(-input_shape / 2, -input_shape / 2)`.
dx: Image pixel side length in x- and y-direction. Should be
<= 1.0 in each dimension. By default, [1.0, 1.0].
default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`.
dx: Image pixel side length in x- and y-direction. Must be
set so that the width of a projected pixel is never
larger than 1.0. By default, [:math:`\sqrt{2}/2`,
:math:`\sqrt{2}/2`].
y0: Location of the edge of the first detector bin. By
default, `-det_count / 2`
det_count: Number of elements in detector. If ``None``,
Expand Down Expand Up @@ -111,25 +116,36 @@ def __init__(

super().__init__(
input_shape=self.input_shape,
input_dtype=np.float32,
output_shape=self.output_shape,
output_dtype=np.float32,
eval_fn=self.project,
adj_fn=self.back_project,
)

def project(self, im: ArrayLike) -> snp.Array:
"""Compute X-ray projection."""
"""Compute X-ray projection, equivalent to `H @ im`.

Args:
im: Input array representing the image to project.
"""
return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles)

def back_project(self, y: ArrayLike) -> snp.Array:
"""Compute X-ray back projection"""
"""Compute X-ray back projection, equivalent to `H.T @ y`.

Args:
y: Input array representing the sinogram to back project.
"""
return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles)

@staticmethod
@partial(jax.jit, static_argnames=["ny"])
def _project(
im: ArrayLike, x0: ArrayLike, dx: ArrayLike, y0: float, ny: int, angles: ArrayLike
) -> snp.Array:
r"""
r"""Compute X-ray projection.

Args:
im: Input array, (M, N).
x0: (x, y) position of the corner of the pixel im[0,0].
Expand All @@ -146,8 +162,11 @@ def _project(
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
inds = jnp.where(inds >= 0, inds, ny)

# avoid incompatible types in the .add (scatter operation)
weights = weights.astype(im.dtype)

y = (
jnp.zeros((len(angles), ny))
jnp.zeros((len(angles), ny), dtype=im.dtype)
Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved
.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds]
.add(im * weights)
)
Expand All @@ -161,7 +180,8 @@ def _project(
def _back_project(
y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike
) -> ArrayLike:
r"""
r"""Compute X-ray back projection.

Args:
y: Input projection, (num_angles, N).
x0: (x, y) position of the corner of the pixel im[0,0].
Expand Down Expand Up @@ -259,10 +279,6 @@ class XRayTransform3D(LinearOperator):

:meth:`XRayTransform3D.matrices_from_euler_angles` can help to
make these geometry arrays.




"""

def __init__(
Expand All @@ -279,7 +295,7 @@ def __init__(
"""

self.input_shape: Shape = input_shape
self.matrices = matrices
self.matrices = jnp.asarray(matrices, dtype=np.float32)
self.det_shape = det_shape
self.output_shape = (len(matrices), *det_shape)
super().__init__(
Expand Down
3 changes: 2 additions & 1 deletion scico/test/linop/xray/test_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_apply():
def test_apply_adjoint():
im_shape = (12, 13)
num_angles = 10
x = jnp.ones(im_shape)
x = jnp.ones(im_shape, dtype=jnp.float32)
Michael-T-McCann marked this conversation as resolved.
Show resolved Hide resolved

angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)

Expand Down Expand Up @@ -81,6 +81,7 @@ def test_3d_scaling():
# default spacing
M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0])
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)

# fmt: off
truth = jnp.array(
[[[0.0, 0.0, 0.0, 0.0],
Expand Down
Loading