-
Couldn't load subscription status.
- Fork 21
Open
Labels
testsPertaining to SCICO testsPertaining to SCICO tests
Description
Many LinOps have tests for scalar multiplication, i.e., that a * (H @ x) = (a * H) @ x. Right now, this involves code duplication, e.g.,
scico/scico/test/linop/test_circconv.py
Lines 74 to 87 in 1a66887
| @pytest.mark.parametrize("operator", [op.mul, op.truediv]) | |
| def test_scalar_left(self, axes_shape_spec, operator, jit): | |
| input_dtype = np.float32 | |
| scalar = np.float32(3.141) | |
| x_shape, ndims, h_shape = axes_shape_spec | |
| h, key = randn(tuple(h_shape), dtype=input_dtype, key=self.key) | |
| A = CircularConvolve(h, x_shape, ndims, input_dtype, jit=jit) | |
| cA = operator(A, scalar) | |
| np.testing.assert_allclose(operator(A.h_dft.ravel(), scalar), cA.h_dft.ravel(), rtol=5e-5) |
scico/scico/test/linop/test_diag.py
Lines 146 to 158 in 1a66887
| @pytest.mark.parametrize("operator", [op.mul, op.truediv]) | |
| def test_scalar_left(self, operator): | |
| diagonal_dtype = np.float32 | |
| input_shape = (8,) | |
| diagonal1, key = randn(input_shape, dtype=diagonal_dtype, key=self.key) | |
| scalar = np.random.randn() | |
| x, key = randn(input_shape, dtype=diagonal_dtype, key=key) | |
| D = linop.Diagonal(diagonal=diagonal1) | |
| scaled_D = operator(D, scalar) | |
| np.testing.assert_allclose(scaled_D @ x, operator(D @ x, scalar), rtol=5e-5) |
Can we create a standard function to test scaling (and other similar LinOp properties) rather than copy/pasting variants of this?
Metadata
Metadata
Assignees
Labels
testsPertaining to SCICO testsPertaining to SCICO tests