Skip to content

Commit dc8b819

Browse files
authored
Merge pull request #76 from pbarbarant/fix/fix-barycenter-tests
[BUGFIX] Fixes CI + other tests
2 parents e96384b + 37476c2 commit dc8b819

File tree

3 files changed

+39
-9
lines changed

3 files changed

+39
-9
lines changed

src/fugw/scripts/coarse_to_fine.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,13 @@ def compute_sparsity_mask(
316316
device = torch.device("cuda", 0)
317317
else:
318318
device = torch.device("cpu")
319+
320+
# Convert coarse plan to numpy
321+
coarse_plan = coarse_mapping.pi.to("cpu").numpy()
319322
if method == "quantile":
320323
# Method 1: keep first percentile
321-
threshold = np.percentile(coarse_mapping.pi, 99.95)
322-
rows, cols = np.nonzero(coarse_mapping.pi > threshold)
324+
threshold = np.percentile(coarse_plan, 99.95)
325+
rows, cols = np.nonzero(coarse_plan > threshold)
323326

324327
elif method == "topk":
325328
# Method 2: keep topk indices per line and per column
@@ -328,12 +331,12 @@ def compute_sparsity_mask(
328331
rows = np.concatenate(
329332
[
330333
np.arange(source_sample.shape[0]),
331-
np.argmax(coarse_mapping.pi, axis=0),
334+
np.argmax(coarse_plan, axis=0),
332335
]
333336
)
334337
cols = np.concatenate(
335338
[
336-
np.argmax(coarse_mapping.pi, axis=1),
339+
np.argmax(coarse_plan, axis=1),
337340
np.arange(target_sample.shape[0]),
338341
]
339342
)

tests/mappings/test_barycenter.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
product(devices, callbacks),
2020
)
2121
def test_fugw_barycenter(device, callback):
22+
"""Tests the FUGW barycenter fitting on toy data."""
2223
np.random.seed(0)
2324
n_subjects = 4
2425
n_voxels = 100
2526
n_features = 10
27+
nits_barycenter = 3
2628

2729
# Generate random training data for n subjects
2830
features_list = []
@@ -38,10 +40,30 @@ def test_fugw_barycenter(device, callback):
3840
geometry_list.append(geometry)
3941

4042
fugw_barycenter = FUGWBarycenter()
41-
fugw_barycenter.fit(
43+
44+
# Fit the barycenter
45+
(
46+
barycenter_weights,
47+
barycenter_features,
48+
barycenter_geometry,
49+
plans,
50+
_,
51+
losses_each_bar_step,
52+
) = fugw_barycenter.fit(
4253
weights_list,
4354
features_list,
4455
geometry_list,
56+
solver_params={"nits_bcd": 2, "nits_uot": 5},
57+
nits_barycenter=nits_barycenter,
4558
device=device,
4659
callback_barycenter=callback,
4760
)
61+
62+
assert isinstance(barycenter_weights, torch.Tensor)
63+
assert barycenter_weights.shape == (n_voxels,)
64+
assert isinstance(barycenter_features, torch.Tensor)
65+
assert barycenter_features.shape == (n_features, n_voxels)
66+
assert isinstance(barycenter_geometry, torch.Tensor)
67+
assert barycenter_geometry.shape == (n_voxels, n_voxels)
68+
assert len(plans) == n_subjects
69+
assert len(losses_each_bar_step) == nits_barycenter

tests/mappings/test_sparse_barycenter.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
"device, callback",
2020
product(devices, callbacks),
2121
)
22-
def test_fugw_barycenter(device, callback):
22+
def test_fugw_sparse_barycenter(device, callback):
23+
"""Tests the FUGW sparse barycenter fitting on toy data."""
2324
np.random.seed(0)
2425
n_subjects = 4
2526
n_voxels = 100
@@ -39,23 +40,27 @@ def test_fugw_barycenter(device, callback):
3940
weights_list.append(weights)
4041
features_list.append(features)
4142

42-
fugw_barycenter = FUGWSparseBarycenter()
43+
geometry_embedding_normalized = (
44+
geometry_embedding / geometry_embedding.norm()
45+
)
46+
fugw_sparse_barycenter = FUGWSparseBarycenter()
4347

4448
# Fit the barycenter
4549
(
4650
barycenter_weights,
4751
barycenter_features,
4852
plans,
4953
losses_each_bar_step,
50-
) = fugw_barycenter.fit(
54+
) = fugw_sparse_barycenter.fit(
5155
weights_list,
5256
features_list,
53-
geometry_embedding,
57+
geometry_embedding_normalized,
5458
mesh_sample=mesh_sample,
5559
coarse_mapping_solver_params={"nits_bcd": 2, "nits_uot": 5},
5660
fine_mapping_solver_params={"nits_bcd": 2, "nits_uot": 5},
5761
nits_barycenter=nits_barycenter,
5862
device=device,
63+
callback_barycenter=callback,
5964
)
6065

6166
assert isinstance(barycenter_weights, torch.Tensor)

0 commit comments

Comments
 (0)