Skip to content

Commit 008697c

Browse files
Fix XRayTransform2D projection dtype and docs (#557)
Co-authored-by: Brendt Wohlberg <[email protected]>
1 parent 8dc1a2a commit 008697c

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

scico/linop/xray/_xray.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
class XRayTransform2D(LinearOperator):
30-
"""Parallel ray, single axis, 2D X-ray projector.
30+
r"""Parallel ray, single axis, 2D X-ray projector.
3131
3232
This implementation approximates the projection of each rectangular
3333
pixel as a boxcar function (whereas the exact projection is a
@@ -42,6 +42,9 @@ class XRayTransform2D(LinearOperator):
4242
accumulation of pixel values into bins (equivalently, makes the
4343
linear operator sparse).
4444
45+
Warning: The default pixel spacing is :math:`\sqrt{2}/2` (rather
46+
than 1) in order to satisfy the aforementioned spacing requirement.
47+
4548
`x0`, `dx`, and `y0` should be expressed in units such that the
4649
detector spacing `dy` is 1.0.
4750
"""
@@ -64,9 +67,11 @@ def __init__(
6467
corresponds to summing columns, and an angle of pi/4
6568
corresponds to summing along antidiagonals.
6669
x0: (x, y) position of the corner of the pixel `im[0,0]`. By
67-
default, `(-input_shape / 2, -input_shape / 2)`.
68-
dx: Image pixel side length in x- and y-direction. Should be
69-
<= 1.0 in each dimension. By default, [1.0, 1.0].
70+
default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`.
71+
dx: Image pixel side length in x- and y-direction. Must be
72+
set so that the width of a projected pixel is never
73+
larger than 1.0. By default, [:math:`\sqrt{2}/2`,
74+
:math:`\sqrt{2}/2`].
7075
y0: Location of the edge of the first detector bin. By
7176
default, `-det_count / 2`
7277
det_count: Number of elements in detector. If ``None``,
@@ -111,25 +116,36 @@ def __init__(
111116

112117
super().__init__(
113118
input_shape=self.input_shape,
119+
input_dtype=np.float32,
114120
output_shape=self.output_shape,
121+
output_dtype=np.float32,
115122
eval_fn=self.project,
116123
adj_fn=self.back_project,
117124
)
118125

119126
def project(self, im: ArrayLike) -> snp.Array:
120-
"""Compute X-ray projection."""
127+
"""Compute X-ray projection, equivalent to `H @ im`.
128+
129+
Args:
130+
im: Input array representing the image to project.
131+
"""
121132
return XRayTransform2D._project(im, self.x0, self.dx, self.y0, self.ny, self.angles)
122133

123134
def back_project(self, y: ArrayLike) -> snp.Array:
124-
"""Compute X-ray back projection"""
135+
"""Compute X-ray back projection, equivalent to `H.T @ y`.
136+
137+
Args:
138+
y: Input array representing the sinogram to back project.
139+
"""
125140
return XRayTransform2D._back_project(y, self.x0, self.dx, self.nx, self.y0, self.angles)
126141

127142
@staticmethod
128143
@partial(jax.jit, static_argnames=["ny"])
129144
def _project(
130145
im: ArrayLike, x0: ArrayLike, dx: ArrayLike, y0: float, ny: int, angles: ArrayLike
131146
) -> snp.Array:
132-
r"""
147+
r"""Compute X-ray projection.
148+
133149
Args:
134150
im: Input array, (M, N).
135151
x0: (x, y) position of the corner of the pixel im[0,0].
@@ -146,8 +162,11 @@ def _project(
146162
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
147163
inds = jnp.where(inds >= 0, inds, ny)
148164

165+
# avoid incompatible types in the .add (scatter operation)
166+
weights = weights.astype(im.dtype)
167+
149168
y = (
150-
jnp.zeros((len(angles), ny))
169+
jnp.zeros((len(angles), ny), dtype=im.dtype)
151170
.at[jnp.arange(len(angles)).reshape(-1, 1, 1), inds]
152171
.add(im * weights)
153172
)
@@ -161,7 +180,8 @@ def _project(
161180
def _back_project(
162181
y: ArrayLike, x0: ArrayLike, dx: ArrayLike, nx: Shape, y0: float, angles: ArrayLike
163182
) -> ArrayLike:
164-
r"""
183+
r"""Compute X-ray back projection.
184+
165185
Args:
166186
y: Input projection, (num_angles, N).
167187
x0: (x, y) position of the corner of the pixel im[0,0].
@@ -259,10 +279,6 @@ class XRayTransform3D(LinearOperator):
259279
260280
:meth:`XRayTransform3D.matrices_from_euler_angles` can help to
261281
make these geometry arrays.
262-
263-
264-
265-
266282
"""
267283

268284
def __init__(
@@ -279,7 +295,7 @@ def __init__(
279295
"""
280296

281297
self.input_shape: Shape = input_shape
282-
self.matrices = matrices
298+
self.matrices = jnp.asarray(matrices, dtype=np.float32)
283299
self.det_shape = det_shape
284300
self.output_shape = (len(matrices), *det_shape)
285301
super().__init__(

scico/test/linop/xray/test_xray.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_apply():
4949
def test_apply_adjoint():
5050
im_shape = (12, 13)
5151
num_angles = 10
52-
x = jnp.ones(im_shape)
52+
x = jnp.ones(im_shape, dtype=jnp.float32)
5353

5454
angles = jnp.linspace(0, jnp.pi, num=num_angles, endpoint=False)
5555

@@ -81,6 +81,7 @@ def test_3d_scaling():
8181
# default spacing
8282
M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0])
8383
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
84+
8485
# fmt: off
8586
truth = jnp.array(
8687
[[[0.0, 0.0, 0.0, 0.0],

0 commit comments

Comments
 (0)