Skip to content

Commit

Permalink
Make XRayTransform3D us 32-bit floats
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-T-McCann committed Oct 2, 2024
1 parent f4e9cce commit c58b819
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
4 changes: 3 additions & 1 deletion scico/linop/xray/_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def __init__(

super().__init__(
input_shape=self.input_shape,
input_dtype=np.float32,
output_shape=self.output_shape,
output_dtype=np.float32,
eval_fn=self.project,
adj_fn=self.back_project,
)
Expand Down Expand Up @@ -283,7 +285,7 @@ def __init__(
"""

self.input_shape: Shape = input_shape
self.matrices = matrices
self.matrices = jnp.asarray(matrices, dtype=np.float32)
self.det_shape = det_shape
self.output_shape = (len(matrices), *det_shape)
super().__init__(
Expand Down
1 change: 1 addition & 0 deletions scico/test/linop/xray/test_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_3d_scaling():
# default spacing
M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0])
H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape)

# fmt: off
truth = jnp.array(
[[[0.0, 0.0, 0.0, 0.0],
Expand Down

0 comments on commit c58b819

Please sign in to comment.