Skip to content

Commit c58b819

Browse files
Make XRayTransform3D us 32-bit floats
1 parent f4e9cce commit c58b819

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

scico/linop/xray/_xray.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def __init__(
116116

117117
super().__init__(
118118
input_shape=self.input_shape,
119+
input_dtype=np.float32,
119120
output_shape=self.output_shape,
121+
output_dtype=np.float32,
120122
eval_fn=self.project,
121123
adj_fn=self.back_project,
122124
)
@@ -283,7 +285,7 @@ def __init__(
283285
"""
284286

285287
self.input_shape: Shape = input_shape
286-
self.matrices = matrices
288+
self.matrices = jnp.asarray(matrices, dtype=np.float32)
287289
self.det_shape = det_shape
288290
self.output_shape = (len(matrices), *det_shape)
289291
super().__init__(

scico/test/linop/xray/test_xray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def test_3d_scaling():
8181
# default spacing
8282
M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0])
8383
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)
84+
8485
# fmt: off
8586
truth = jnp.array(
8687
[[[0.0, 0.0, 0.0, 0.0],

0 commit comments

Comments
 (0)