diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index 074ab1fc..a8ef4cf1 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -116,7 +116,9 @@ def __init__( super().__init__( input_shape=self.input_shape, + input_dtype=np.float32, output_shape=self.output_shape, + output_dtype=np.float32, eval_fn=self.project, adj_fn=self.back_project, ) @@ -283,7 +285,7 @@ def __init__( """ self.input_shape: Shape = input_shape - self.matrices = matrices + self.matrices = jnp.asarray(matrices, dtype=np.float32) self.det_shape = det_shape self.output_shape = (len(matrices), *det_shape) super().__init__( diff --git a/scico/test/linop/xray/test_xray.py b/scico/test/linop/xray/test_xray.py index 960a9947..b9e12776 100644 --- a/scico/test/linop/xray/test_xray.py +++ b/scico/test/linop/xray/test_xray.py @@ -81,6 +81,7 @@ def test_3d_scaling(): # default spacing M = XRayTransform3D.matrices_from_euler_angles(input_shape, output_shape, "X", [0.0]) H = XRayTransform3D(input_shape, matrices=M, det_shape=output_shape) + # fmt: off truth = jnp.array( [[[0.0, 0.0, 0.0, 0.0],