|
1 | 1 | import numpy as np |
| 2 | +import pytest |
2 | 3 | import sparse as sps |
3 | 4 | import xarray as xr |
4 | 5 |
|
@@ -27,3 +28,61 @@ def test_add_nans_to_weights(): |
27 | 28 |
|
28 | 29 | Matout = xe.smm.add_nans_to_weights(xr.DataArray(Matin, dims=('in', 'out'))) |
29 | 30 | assert np.allclose(Matin.todense(), Matout.data.todense()) |
| 31 | + |
| 32 | + |
| 33 | +def test_post_apply_target_mask_to_weights(): |
| 34 | + # Create a small sparse weights matrix with shape (9 target, 4 source) |
| 35 | + # coords = [[target_indices], [source_indices]] |
| 36 | + coords = np.array([[0, 1, 1, 2, 3, 3, 4, 5], [0, 0, 1, 1, 2, 3, 2, 3]]) |
| 37 | + data = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.45, 0.7, 0.8]) |
| 38 | + shape = (6, 4) |
| 39 | + W_sparse = sps.COO(coords, data, shape=shape) |
| 40 | + weights = xr.DataArray(W_sparse, dims=('out_dim', 'in_dim')) |
| 41 | + |
| 42 | + # Define a 3x3 mask for target (flattened size = 9): |
| 43 | + # If all goes to plan, weights of cells 3 and 4 (i.e. index 2 and 3) |
| 44 | + # will be set to 0. |
| 45 | + target_mask_2d = np.array([[True, False], [True, True], [False, True]]) |
| 46 | + |
| 47 | + # Apply mask |
| 48 | + masked_weights = xe.smm.post_apply_target_mask_to_weights(weights, target_mask_2d) |
| 49 | + |
| 50 | + # Check results |
| 51 | + np.testing.assert_array_equal(masked_weights.data.data, np.array([0.1, 0.2, 0.3, 0.7, 0.8])) |
| 52 | + np.testing.assert_array_equal( |
| 53 | + masked_weights.data.coords, np.array([[0, 1, 1, 4, 5], [0, 0, 1, 2, 3]]) |
| 54 | + ) |
| 55 | + |
| 56 | + |
| 57 | +def test_post_apply_target_mask_to_weights_exceptions(): |
| 58 | + # Create a weights DataArray & mask |
| 59 | + coords = np.array([[0, 1], [0, 1]]) |
| 60 | + data = np.array([0.5, 0.5]) |
| 61 | + shape = (2, 2) |
| 62 | + W_sparse = sps.COO(coords, data, shape=shape) |
| 63 | + weights = xr.DataArray(W_sparse, dims=('out_dim', 'in_dim')) |
| 64 | + valid_mask = np.array([[True, False]]) |
| 65 | + |
| 66 | + # Mask not array-like |
| 67 | + with pytest.raises( |
| 68 | + TypeError, |
| 69 | + match="Argument 'target_mask_2d' must be array-like and convertible to a numeric/boolean array", |
| 70 | + ): |
| 71 | + xe.smm.post_apply_target_mask_to_weights(weights, 'not_array_like') |
| 72 | + |
| 73 | + # Shape mismatch |
| 74 | + wrong_shape_mask = np.array([[True, False, True]]) |
| 75 | + with pytest.raises( |
| 76 | + ValueError, match='Mismatch: weight matrix has 2 target cells, but mask has 3 elements' |
| 77 | + ): |
| 78 | + xe.smm.post_apply_target_mask_to_weights(weights, wrong_shape_mask) |
| 79 | + |
| 80 | + # Mask not 2D |
| 81 | + wrong_shape_mask = np.array([[[True]], [[True]]]) |
| 82 | + with pytest.raises( |
| 83 | + ValueError, match="Argument 'target_mask_2d' must be 2D, got shape \\(2, 1, 1\\)" |
| 84 | + ): |
| 85 | + xe.smm.post_apply_target_mask_to_weights(weights, wrong_shape_mask) |
| 86 | + |
| 87 | + # That should work |
| 88 | + xe.smm.post_apply_target_mask_to_weights(weights, valid_mask) |
0 commit comments