Skip to content

Commit be13850

Browse files
authored
Merge pull request #84 from alexisthual/feat/stabilized-sinkhorn
[ENH] Stabilized and epsilon scaled sinkhorn solvers
2 parents 89dd0c8 + c78db2a commit be13850

File tree

5 files changed

+428
-23
lines changed

5 files changed

+428
-23
lines changed

src/fugw/solvers/dense.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
compute_divergence,
1212
solver_ibpp,
1313
solver_mm,
14-
solver_sinkhorn,
14+
solver_sinkhorn_log,
1515
solver_mm_l2,
1616
)
1717
from fugw.utils import _add_dict, console
@@ -444,7 +444,7 @@ def solve(
444444

445445
# If divergence is KL
446446
self_solver_sinkhorn = partial(
447-
solver_sinkhorn,
447+
solver_sinkhorn_log,
448448
tuple_weights=(ws, wt, ws_dot_wt),
449449
train_params=(self.nits_uot, self.tol_uot, self.eval_uot),
450450
)

src/fugw/solvers/sparse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
crow_indices_to_row_indices,
1818
csr_sum,
1919
elementwise_prod_fact_sparse,
20-
solver_sinkhorn_sparse,
20+
solver_sinkhorn_log_sparse,
2121
solver_ibpp_sparse,
2222
solver_mm_sparse,
2323
solver_mm_l2_sparse,
@@ -517,7 +517,7 @@ def solve(
517517

518518
# If divergence is KL
519519
self_solver_sinkhorn = partial(
520-
solver_sinkhorn_sparse,
520+
solver_sinkhorn_log_sparse,
521521
tuple_weights=(ws, wt, ws_dot_wt),
522522
train_params=(self.nits_uot, self.tol_uot, self.eval_uot),
523523
verbose=verbose,

0 commit comments

Comments
 (0)