27
27
28
28
29
29
class XRayTransform2D (LinearOperator ):
30
- """Parallel ray, single axis, 2D X-ray projector.
30
+ r """Parallel ray, single axis, 2D X-ray projector.
31
31
32
32
This implementation approximates the projection of each rectangular
33
33
pixel as a boxcar function (whereas the exact projection is a
@@ -42,6 +42,9 @@ class XRayTransform2D(LinearOperator):
42
42
accumulation of pixel values into bins (equivalently, makes the
43
43
linear operator sparse).
44
44
45
+ Warning: The default pixel spacing is :math:`\sqrt{2}/2` (rather
46
+ than 1) in order to satisfy the aforementioned spacing requirement.
47
+
45
48
`x0`, `dx`, and `y0` should be expressed in units such that the
46
49
detector spacing `dy` is 1.0.
47
50
"""
@@ -64,9 +67,11 @@ def __init__(
64
67
corresponds to summing columns, and an angle of pi/4
65
68
corresponds to summing along antidiagonals.
66
69
x0: (x, y) position of the corner of the pixel `im[0,0]`. By
67
- default, `(-input_shape / 2, -input_shape / 2)`.
68
- dx: Image pixel side length in x- and y-direction. Should be
69
- <= 1.0 in each dimension. By default, [1.0, 1.0].
70
+ default, `(-input_shape * dx[0] / 2, -input_shape * dx[1] / 2)`.
71
+ dx: Image pixel side length in x- and y-direction. Must be
72
+ set so that the width of a projected pixel is never
73
+ larger than 1.0. By default, [:math:`\sqrt{2}/2`,
74
+ :math:`\sqrt{2}/2`].
70
75
y0: Location of the edge of the first detector bin. By
71
76
default, `-det_count / 2`
72
77
det_count: Number of elements in detector. If ``None``,
@@ -111,25 +116,36 @@ def __init__(
111
116
112
117
super ().__init__ (
113
118
input_shape = self .input_shape ,
119
+ input_dtype = np .float32 ,
114
120
output_shape = self .output_shape ,
121
+ output_dtype = np .float32 ,
115
122
eval_fn = self .project ,
116
123
adj_fn = self .back_project ,
117
124
)
118
125
119
126
def project (self , im : ArrayLike ) -> snp .Array :
120
- """Compute X-ray projection."""
127
+ """Compute X-ray projection, equivalent to `H @ im`.
128
+
129
+ Args:
130
+ im: Input array representing the image to project.
131
+ """
121
132
return XRayTransform2D ._project (im , self .x0 , self .dx , self .y0 , self .ny , self .angles )
122
133
123
134
def back_project (self , y : ArrayLike ) -> snp .Array :
124
- """Compute X-ray back projection"""
135
+ """Compute X-ray back projection, equivalent to `H.T @ y`.
136
+
137
+ Args:
138
+ y: Input array representing the sinogram to back project.
139
+ """
125
140
return XRayTransform2D ._back_project (y , self .x0 , self .dx , self .nx , self .y0 , self .angles )
126
141
127
142
@staticmethod
128
143
@partial (jax .jit , static_argnames = ["ny" ])
129
144
def _project (
130
145
im : ArrayLike , x0 : ArrayLike , dx : ArrayLike , y0 : float , ny : int , angles : ArrayLike
131
146
) -> snp .Array :
132
- r"""
147
+ r"""Compute X-ray projection.
148
+
133
149
Args:
134
150
im: Input array, (M, N).
135
151
x0: (x, y) position of the corner of the pixel im[0,0].
@@ -146,8 +162,11 @@ def _project(
146
162
# ignored, while inds < 0 wrap around. So we set inds < 0 to ny.
147
163
inds = jnp .where (inds >= 0 , inds , ny )
148
164
165
+ # avoid incompatible types in the .add (scatter operation)
166
+ weights = weights .astype (im .dtype )
167
+
149
168
y = (
150
- jnp .zeros ((len (angles ), ny ))
169
+ jnp .zeros ((len (angles ), ny ), dtype = im . dtype )
151
170
.at [jnp .arange (len (angles )).reshape (- 1 , 1 , 1 ), inds ]
152
171
.add (im * weights )
153
172
)
@@ -161,7 +180,8 @@ def _project(
161
180
def _back_project (
162
181
y : ArrayLike , x0 : ArrayLike , dx : ArrayLike , nx : Shape , y0 : float , angles : ArrayLike
163
182
) -> ArrayLike :
164
- r"""
183
+ r"""Compute X-ray back projection.
184
+
165
185
Args:
166
186
y: Input projection, (num_angles, N).
167
187
x0: (x, y) position of the corner of the pixel im[0,0].
@@ -259,10 +279,6 @@ class XRayTransform3D(LinearOperator):
259
279
260
280
:meth:`XRayTransform3D.matrices_from_euler_angles` can help to
261
281
make these geometry arrays.
262
-
263
-
264
-
265
-
266
282
"""
267
283
268
284
def __init__ (
@@ -279,7 +295,7 @@ def __init__(
279
295
"""
280
296
281
297
self .input_shape : Shape = input_shape
282
- self .matrices = matrices
298
+ self .matrices = jnp . asarray ( matrices , dtype = np . float32 )
283
299
self .det_shape = det_shape
284
300
self .output_shape = (len (matrices ), * det_shape )
285
301
super ().__init__ (
0 commit comments