33import numpy as np
44import pytest
55import torch
6+ import ot
67
78from fugw .mappings import FUGWBarycenter
89from fugw .utils import _init_mock_distribution
1213 devices .append (torch .device ("cuda:0" ))
1314
1415callbacks = [None , lambda x : x ["plans" ]]
16+ alphas = [0.0 , 0.5 , 1.0 ]
1517
1618
1719@pytest .mark .parametrize (
@@ -57,6 +59,7 @@ def test_fugw_barycenter(device, callback):
5759 nits_barycenter = nits_barycenter ,
5860 device = device ,
5961 callback_barycenter = callback ,
62+ init_barycenter_geometry = geometry_list [0 ],
6063 )
6164
6265 assert isinstance (barycenter_weights , torch .Tensor )
@@ -67,3 +70,118 @@ def test_fugw_barycenter(device, callback):
6770 assert barycenter_geometry .shape == (n_voxels , n_voxels )
6871 assert len (plans ) == n_subjects
6972 assert len (losses_each_bar_step ) == nits_barycenter
73+
74+
75+ @pytest .mark .parametrize (
76+ "alpha" ,
77+ alphas ,
78+ )
79+ def test_identity_case (alpha ):
80+ """Test the case where all subjects are the same."""
81+ torch .manual_seed (0 )
82+ n_subjects = 3
83+ n_features = 10
84+ n_voxels = 100
85+ nits_barycenter = 2
86+
87+ geometry = _init_mock_distribution (n_features , n_voxels )[2 ]
88+ features = torch .rand (n_features , n_voxels )
89+
90+ geometry_list = [geometry for _ in range (n_subjects )]
91+ features_list = [features for _ in range (n_subjects )]
92+ weights_list = [torch .ones (n_voxels ) / n_voxels for _ in range (n_subjects )]
93+
94+ fugw_barycenter = FUGWBarycenter (alpha = alpha , eps = 1e-6 , rho = float ("inf" ))
95+ (
96+ barycenter_weights ,
97+ barycenter_features ,
98+ barycenter_geometry ,
99+ plans ,
100+ _ ,
101+ _ ,
102+ ) = fugw_barycenter .fit (
103+ weights_list ,
104+ features_list ,
105+ geometry_list ,
106+ solver_params = {"nits_bcd" : 5 , "nits_uot" : 100 },
107+ nits_barycenter = nits_barycenter ,
108+ device = torch .device ("cpu" ),
109+ init_barycenter_geometry = geometry_list [0 ],
110+ init_barycenter_features = features_list [0 ],
111+ )
112+
113+ # Check that the barycenter is the same as the input
114+ assert torch .allclose (barycenter_weights , torch .ones (n_voxels ) / n_voxels )
115+ assert torch .allclose (barycenter_geometry , geometry_list [0 ])
116+
117+ # In the case alpha=1.0, the features can be permuted
118+ # since the GW distance is invariant under isometries
119+ if alpha != 1.0 :
120+ assert torch .allclose (barycenter_features , features )
121+ # Check that all the plans are the identity matrix divided
122+ # by the number of voxels
123+ for plan in plans :
124+ assert torch .allclose (plan , torch .eye (n_voxels ) / n_voxels )
125+
126+
127+ @pytest .mark .parametrize (
128+ "alpha" ,
129+ alphas ,
130+ )
131+ def test_fgw_barycenter (alpha ):
132+ """Tests the FUGW barycenter in the case rho=inf and compare with POT."""
133+ torch .manual_seed (0 )
134+ n_subjects = 3
135+ n_features = 1
136+ n_voxels = 5
137+ nits_barycenter = 2
138+
139+ geometry = _init_mock_distribution (
140+ n_features , n_voxels , should_normalize = True
141+ )[2 ]
142+ geometry_list = [geometry for _ in range (n_subjects )]
143+ weights_list = [torch .ones (n_voxels ) / n_voxels ] * n_subjects
144+ features_list = [torch .tensor ([[1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]])] * n_subjects
145+
146+ fugw_barycenter = FUGWBarycenter (
147+ alpha = alpha ,
148+ rho = float ("inf" ),
149+ eps = 1e-6 ,
150+ )
151+
152+ fugw_barycenter = FUGWBarycenter (alpha = alpha , eps = 1e-6 , rho = float ("inf" ))
153+ (
154+ fugw_bary_weights ,
155+ fugw_bary_features ,
156+ fugw_bary_geometry ,
157+ _ ,
158+ _ ,
159+ _ ,
160+ ) = fugw_barycenter .fit (
161+ weights_list ,
162+ features_list ,
163+ geometry_list ,
164+ solver_params = {"nits_bcd" : 5 , "nits_uot" : 100 },
165+ nits_barycenter = nits_barycenter ,
166+ device = torch .device ("cpu" ),
167+ init_barycenter_geometry = geometry_list [0 ],
168+ init_barycenter_features = features_list [0 ],
169+ )
170+
171+ # Compare the barycenter with the one obtained with POT
172+ pot_bary_features , pot_bary_geometry , log = ot .gromov .fgw_barycenters (
173+ n_voxels ,
174+ [features .T for features in features_list ],
175+ geometry_list ,
176+ weights_list ,
177+ alpha = 1 - alpha ,
178+ log = True ,
179+ fixed_structure = True ,
180+ init_C = geometry_list [0 ],
181+ )
182+
183+ assert torch .allclose (fugw_bary_geometry , pot_bary_geometry )
184+ assert torch .allclose (fugw_bary_weights , log ["p" ])
185+
186+ if alpha != 1.0 :
187+ assert torch .allclose (fugw_bary_features , pot_bary_features .T )
0 commit comments