Skip to content

Commit 2dd220b

Browse files
author
Pierre-Louis Barbarant
committed
Update tolerance and evaluation frequency in Sinkhorn solver tests for improved accuracy
1 parent 74b0c06 commit 2dd220b

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tests/solvers/test_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_solvers_sinkhorn(pot_method, solver):
2323
nf = 10
2424
eps = 1.0
2525

26-
niters, tol, eval_freq = 100, 1e-16, 1
26+
niters, tol, eval_freq = 100, 1e-7, 20
2727

2828
ws = torch.ones(ns) / ns
2929
wt = torch.ones(nt) / nt
@@ -81,7 +81,7 @@ def test_solvers_sinkhorn_sparse(pot_method, solver):
8181
nf = 10
8282
eps = 1.0
8383

84-
niters, tol, eval_freq = 100, 1e-16, 1
84+
niters, tol, eval_freq = 100, 1e-7, 20
8585

8686
ws = torch.ones(ns) / ns
8787
wt = torch.ones(nt) / nt
@@ -91,6 +91,11 @@ def test_solvers_sinkhorn_sparse(pot_method, solver):
9191

9292
cost = torch.cdist(source_features, target_features)
9393

94+
# Convert the tensors to float64
95+
ws = ws.double()
96+
wt = wt.double()
97+
cost = cost.double()
98+
9499
gamma, log = ot.sinkhorn(
95100
ws,
96101
wt,
@@ -119,4 +124,4 @@ def test_solvers_sinkhorn_sparse(pot_method, solver):
119124
alpha,
120125
)
121126
assert torch.allclose(log["beta"], beta)
122-
assert torch.allclose(gamma, pi.to_dense(), atol=1e-6)
127+
assert torch.allclose(gamma, pi.to_dense())

0 commit comments

Comments
 (0)