Skip to content

Commit bede443

Browse files
authored
Merge pull request #79 from pbarbarant/feat/test-edge-case-bary
[BUGFIX] Fix dense barycenter calculations
2 parents dc8b819 + b5fc56d commit bede443

File tree

3 files changed

+136
-25
lines changed

3 files changed

+136
-25
lines changed

src/fugw/mappings/barycenter.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(
1111
self,
1212
alpha=0.5,
1313
rho=1,
14-
eps=1e-2,
14+
eps=1e-4,
1515
reg_mode="joint",
1616
force_psd=False,
1717
learn_geometry=False,
@@ -79,26 +79,19 @@ def update_barycenter_geometry(
7979
return barycenter_geometry
8080

8181
@staticmethod
82-
def update_barycenter_features(plans, weights_list, features_list, device):
83-
for i, (pi, weights, features) in enumerate(
84-
zip(plans, weights_list, features_list)
85-
):
86-
w = _make_tensor(weights, device=device)
82+
def update_barycenter_features(plans, features_list, device):
83+
for i, (pi, features) in enumerate(zip(plans, features_list)):
84+
# Use uniform weights across subjects
85+
weight = 1 / len(features_list)
8786
f = _make_tensor(features, device=device)
8887
if features is not None:
89-
acc = w * pi.T @ f.T / (pi.sum(0).reshape(-1, 1) + 1e-16)
88+
acc = weight * pi.T @ f.T / (pi.sum(0).reshape(-1, 1) + 1e-16)
9089

9190
if i == 0:
9291
barycenter_features = acc
9392
else:
9493
barycenter_features += acc
9594

96-
# Normalize barycenter features
97-
min_val = barycenter_features.min(dim=0, keepdim=True).values
98-
max_val = barycenter_features.max(dim=0, keepdim=True).values
99-
barycenter_features = (
100-
2 * (barycenter_features - min_val) / (max_val - min_val) - 1
101-
)
10295
return barycenter_features.T
10396

10497
@staticmethod
@@ -127,7 +120,6 @@ def compute_all_ot_plans(
127120
barycenter_geometry,
128121
solver,
129122
solver_params,
130-
callback_barycenter,
131123
device,
132124
verbose,
133125
):
@@ -271,7 +263,12 @@ def fit(
271263
init_barycenter_features, device=device
272264
)
273265

274-
if init_barycenter_geometry is None:
266+
if init_barycenter_geometry is None and self.learn_geometry is False:
267+
raise ValueError(
268+
"In the fixed support case, init_barycenter_geometry must be"
269+
" provided."
270+
)
271+
elif init_barycenter_geometry is None and self.learn_geometry is True:
275272
barycenter_geometry = (
276273
torch.ones((barycenter_size, barycenter_size)).to(device)
277274
/ barycenter_size
@@ -302,7 +299,6 @@ def fit(
302299
barycenter_geometry,
303300
solver,
304301
solver_params,
305-
callback_barycenter,
306302
device,
307303
verbose,
308304
)
@@ -311,7 +307,7 @@ def fit(
311307

312308
# Update barycenter features and geometry
313309
barycenter_features = self.update_barycenter_features(
314-
plans, weights_list, features_list, device
310+
plans, features_list, device
315311
)
316312
if self.learn_geometry:
317313
barycenter_geometry = self.update_barycenter_geometry(

src/fugw/mappings/sparse_barycenter.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,16 @@ def __init__(
3333
self.selection_radius = selection_radius
3434

3535
@staticmethod
36-
def update_barycenter_features(plans, weights_list, features_list, device):
37-
for i, (pi, weights, features) in enumerate(
38-
zip(plans, weights_list, features_list)
39-
):
40-
w = _make_tensor(weights, device=device)
36+
def update_barycenter_features(plans, features_list, device):
37+
for i, (pi, features) in enumerate(zip(plans, features_list)):
4138
f = _make_tensor(features, device=device)
42-
39+
weight = 1 / len(features_list)
4340
if features is not None:
4441
pi_sum = (
4542
torch.sparse.sum(pi, dim=0).to_dense().reshape(-1, 1)
4643
+ 1e-16
4744
)
48-
acc = w * pi.T @ f.T / pi_sum
45+
acc = weight * pi.T @ f.T / pi_sum
4946

5047
if i == 0:
5148
barycenter_features = acc
@@ -291,7 +288,7 @@ def fit(
291288

292289
# Update barycenter features and geometry
293290
barycenter_features = self.update_barycenter_features(
294-
plans, weights_list, features_list, device
291+
plans, features_list, device
295292
)
296293

297294
if callback_barycenter is not None:

tests/mappings/test_barycenter.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55
import torch
6+
import ot
67

78
from fugw.mappings import FUGWBarycenter
89
from fugw.utils import _init_mock_distribution
@@ -12,6 +13,7 @@
1213
devices.append(torch.device("cuda:0"))
1314

1415
callbacks = [None, lambda x: x["plans"]]
16+
alphas = [0.0, 0.5, 1.0]
1517

1618

1719
@pytest.mark.parametrize(
@@ -57,6 +59,7 @@ def test_fugw_barycenter(device, callback):
5759
nits_barycenter=nits_barycenter,
5860
device=device,
5961
callback_barycenter=callback,
62+
init_barycenter_geometry=geometry_list[0],
6063
)
6164

6265
assert isinstance(barycenter_weights, torch.Tensor)
@@ -67,3 +70,118 @@ def test_fugw_barycenter(device, callback):
6770
assert barycenter_geometry.shape == (n_voxels, n_voxels)
6871
assert len(plans) == n_subjects
6972
assert len(losses_each_bar_step) == nits_barycenter
73+
74+
75+
@pytest.mark.parametrize(
76+
"alpha",
77+
alphas,
78+
)
79+
def test_identity_case(alpha):
80+
"""Test the case where all subjects are the same."""
81+
torch.manual_seed(0)
82+
n_subjects = 3
83+
n_features = 10
84+
n_voxels = 100
85+
nits_barycenter = 2
86+
87+
geometry = _init_mock_distribution(n_features, n_voxels)[2]
88+
features = torch.rand(n_features, n_voxels)
89+
90+
geometry_list = [geometry for _ in range(n_subjects)]
91+
features_list = [features for _ in range(n_subjects)]
92+
weights_list = [torch.ones(n_voxels) / n_voxels for _ in range(n_subjects)]
93+
94+
fugw_barycenter = FUGWBarycenter(alpha=alpha, eps=1e-6, rho=float("inf"))
95+
(
96+
barycenter_weights,
97+
barycenter_features,
98+
barycenter_geometry,
99+
plans,
100+
_,
101+
_,
102+
) = fugw_barycenter.fit(
103+
weights_list,
104+
features_list,
105+
geometry_list,
106+
solver_params={"nits_bcd": 5, "nits_uot": 100},
107+
nits_barycenter=nits_barycenter,
108+
device=torch.device("cpu"),
109+
init_barycenter_geometry=geometry_list[0],
110+
init_barycenter_features=features_list[0],
111+
)
112+
113+
# Check that the barycenter is the same as the input
114+
assert torch.allclose(barycenter_weights, torch.ones(n_voxels) / n_voxels)
115+
assert torch.allclose(barycenter_geometry, geometry_list[0])
116+
117+
# In the case alpha=1.0, the features can be permuted
118+
# since the GW distance is invariant under isometries
119+
if alpha != 1.0:
120+
assert torch.allclose(barycenter_features, features)
121+
# Check that all the plans are the identity matrix divided
122+
# by the number of voxels
123+
for plan in plans:
124+
assert torch.allclose(plan, torch.eye(n_voxels) / n_voxels)
125+
126+
127+
@pytest.mark.parametrize(
128+
"alpha",
129+
alphas,
130+
)
131+
def test_fgw_barycenter(alpha):
132+
"""Tests the FUGW barycenter in the case rho=inf and compare with POT."""
133+
torch.manual_seed(0)
134+
n_subjects = 3
135+
n_features = 1
136+
n_voxels = 5
137+
nits_barycenter = 2
138+
139+
geometry = _init_mock_distribution(
140+
n_features, n_voxels, should_normalize=True
141+
)[2]
142+
geometry_list = [geometry for _ in range(n_subjects)]
143+
weights_list = [torch.ones(n_voxels) / n_voxels] * n_subjects
144+
features_list = [torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])] * n_subjects
145+
146+
fugw_barycenter = FUGWBarycenter(
147+
alpha=alpha,
148+
rho=float("inf"),
149+
eps=1e-6,
150+
)
151+
152+
fugw_barycenter = FUGWBarycenter(alpha=alpha, eps=1e-6, rho=float("inf"))
153+
(
154+
fugw_bary_weights,
155+
fugw_bary_features,
156+
fugw_bary_geometry,
157+
_,
158+
_,
159+
_,
160+
) = fugw_barycenter.fit(
161+
weights_list,
162+
features_list,
163+
geometry_list,
164+
solver_params={"nits_bcd": 5, "nits_uot": 100},
165+
nits_barycenter=nits_barycenter,
166+
device=torch.device("cpu"),
167+
init_barycenter_geometry=geometry_list[0],
168+
init_barycenter_features=features_list[0],
169+
)
170+
171+
# Compare the barycenter with the one obtained with POT
172+
pot_bary_features, pot_bary_geometry, log = ot.gromov.fgw_barycenters(
173+
n_voxels,
174+
[features.T for features in features_list],
175+
geometry_list,
176+
weights_list,
177+
alpha=1 - alpha,
178+
log=True,
179+
fixed_structure=True,
180+
init_C=geometry_list[0],
181+
)
182+
183+
assert torch.allclose(fugw_bary_geometry, pot_bary_geometry)
184+
assert torch.allclose(fugw_bary_weights, log["p"])
185+
186+
if alpha != 1.0:
187+
assert torch.allclose(fugw_bary_features, pot_bary_features.T)

0 commit comments

Comments
 (0)