@@ -23,7 +23,7 @@ def test_solvers_sinkhorn(pot_method, solver):
2323 nf = 10
2424 eps = 1.0
2525
26- niters , tol , eval_freq = 100 , 1e-16 , 1
26+ niters , tol , eval_freq = 100 , 1e-7 , 20
2727
2828 ws = torch .ones (ns ) / ns
2929 wt = torch .ones (nt ) / nt
@@ -81,7 +81,7 @@ def test_solvers_sinkhorn_sparse(pot_method, solver):
8181 nf = 10
8282 eps = 1.0
8383
84- niters , tol , eval_freq = 100 , 1e-16 , 1
84+ niters , tol , eval_freq = 100 , 1e-7 , 20
8585
8686 ws = torch .ones (ns ) / ns
8787 wt = torch .ones (nt ) / nt
@@ -91,6 +91,11 @@ def test_solvers_sinkhorn_sparse(pot_method, solver):
9191
9292 cost = torch .cdist (source_features , target_features )
9393
94+ # Convert the tensors to float64
95+ ws = ws .double ()
96+ wt = wt .double ()
97+ cost = cost .double ()
98+
9499 gamma , log = ot .sinkhorn (
95100 ws ,
96101 wt ,
@@ -119,4 +124,4 @@ def test_solvers_sinkhorn_sparse(pot_method, solver):
119124 alpha ,
120125 )
121126 assert torch .allclose (log ["beta" ], beta )
122- assert torch .allclose (gamma , pi .to_dense (), atol = 1e-6 )
127+ assert torch .allclose (gamma , pi .to_dense ())
0 commit comments