Skip to content

Commit 3ae891a

Browse files
author
Pierre-Louis Barbarant
committed
Refactor Sinkhorn solver functions to improve stability and transport plan computation
1 parent 0858858 commit 3ae891a

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

src/fugw/solvers/utils.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
32
from fugw.utils import _get_progress, console
43

54

@@ -474,7 +473,14 @@ def solver_sinkhorn_stabilized(
474473
tau = 1e3 # Threshold for stabilization
475474

476475
def get_K(alpha, beta):
477-
return ((alpha[:, None] + beta[None, :] - cost) / eps).exp()
476+
return (-(cost - alpha[:, None] - beta[None, :]) / eps).exp()
477+
478+
def get_Gamma(alpha, beta, u, v):
479+
return (
480+
-(cost - alpha[:, None] - beta[None, :]) / eps
481+
+ u[:, None].log()
482+
+ v[None, :].log()
483+
).exp()
478484

479485
K = get_K(alpha, beta)
480486
with _get_progress(verbose=verbose, transient=True) as progress:
@@ -515,13 +521,13 @@ def get_K(alpha, beta):
515521

516522
idx += 1
517523

524+
# Compute final transport plan
525+
pi = get_Gamma(alpha, beta, u, v)
526+
518527
# Final update to potentials
519528
alpha = alpha + eps * u.log()
520529
beta = beta + eps * v.log()
521530

522-
# Compute final transport plan
523-
pi = get_K(alpha, beta)
524-
525531
return (alpha, beta), pi
526532

527533

@@ -663,11 +669,12 @@ def solver_sinkhorn_eps_scaling(
663669
idx = 0
664670

665671
def get_reg(idx):
666-
return (epsilon0 - eps) * torch.exp(-torch.tensor(idx)) + eps
672+
return (epsilon0 - eps) * torch.exp(
673+
-torch.tensor(idx, dtype=torch.float64)
674+
) + eps
667675

668676
numItermin = 35
669677
numItermax = max(numItermin, numItermax)
670-
671678
with _get_progress(verbose=verbose, transient=True) as progress:
672679
if verbose:
673680
task = progress.add_task("Scaling iterations", total=numItermax)
@@ -730,7 +737,9 @@ def solver_sinkhorn_eps_scaling_sparse(
730737
idx = 0
731738

732739
def get_reg(idx):
733-
return (epsilon0 - eps) * torch.exp(-torch.tensor(idx)) + eps
740+
return (epsilon0 - eps) * torch.exp(
741+
-torch.tensor(idx, dtype=torch.float32)
742+
) + eps
734743

735744
numItermin = 35
736745
numItermax = max(numItermin, numItermax)

0 commit comments

Comments
 (0)