Skip to content

Commit 7a83ccc

Browse files
author
Pierre-Louis Barbarant
committed
Refactor Sinkhorn solver to replace pi_diff with err for convergence checks and update logging messages
1 parent dea3159 commit 7a83ccc

File tree

1 file changed

+33
-36
lines changed

1 file changed

+33
-36
lines changed

src/fugw/solvers/utils.py

Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -481,12 +481,9 @@ def get_K(alpha, beta):
481481
if verbose:
482482
task = progress.add_task("Sinkhorn iterations", total=niters)
483483

484-
pi_diff = None
484+
err = None
485485
idx = 0
486-
while (pi_diff is None or pi_diff >= tol) and (
487-
niters is None or idx < niters
488-
):
489-
u_prev, v_prev = u.detach().clone(), v.detach().clone()
486+
while (err is None or err >= tol) and (niters is None or idx < niters):
490487
v = wt / (K.T @ u)
491488
u = ws / (K @ v)
492489

@@ -507,14 +504,14 @@ def get_K(alpha, beta):
507504
progress.update(task, advance=1)
508505

509506
if tol is not None and idx % eval_freq == 0:
510-
pi_diff = max(
511-
(u - u_prev).abs().max(), (v - v_prev).abs().max()
512-
)
513-
if pi_diff < tol:
507+
pi = get_K(alpha + eps * u.log(), beta + eps * v.log())
508+
err = torch.norm(pi.sum(0) - wt)
509+
if err < tol:
514510
if verbose:
515511
progress.console.log(
516-
f"Reached tol_uot threshold: {pi_diff}"
512+
f"Reached tol_uot threshold: {err}"
517513
)
514+
break
518515

519516
idx += 1
520517

@@ -590,12 +587,9 @@ def get_K_sparse(alpha, beta):
590587
if verbose:
591588
task = progress.add_task("Sinkhorn iterations", total=niters)
592589

593-
pi_diff = None
590+
err = None
594591
idx = 0
595-
while (pi_diff is None or pi_diff >= tol) and (
596-
niters is None or idx < niters
597-
):
598-
u_prev, v_prev = u.detach().clone(), v.detach().clone()
592+
while (err is None or err >= tol) and (niters is None or idx < niters):
599593
# Update v using sparse matrix multiplication
600594
Kt_u = (
601595
torch.sparse.mm(K.transpose(0, 1), u.reshape(-1, 1))
@@ -625,14 +619,14 @@ def get_K_sparse(alpha, beta):
625619
progress.update(task, advance=1)
626620

627621
if tol is not None and idx % eval_freq == 0:
628-
pi_diff = max(
629-
(u - u_prev).abs().max(), (v - v_prev).abs().max()
630-
)
631-
if pi_diff < tol:
622+
pi = get_K_sparse(alpha + eps * u.log(), beta + eps * v.log())
623+
err = torch.norm(csr_sum(pi, dim=0) - wt)
624+
if err < tol:
632625
if verbose:
633626
progress.console.log(
634-
f"Reached tol_uot threshold: {pi_diff}"
627+
f"Reached tol_uot threshold: {err}"
635628
)
629+
break
636630

637631
idx += 1
638632

@@ -666,7 +660,7 @@ def solver_sinkhorn_eps_scaling(
666660
train_params_inner = numInnerItermax, tol, eval_freq
667661
alpha, beta = init_duals
668662

669-
pi_diff = None
663+
err = None
670664
idx = 0
671665

672666
def get_reg(idx):
@@ -676,11 +670,10 @@ def get_reg(idx):
676670
numItermax = max(numItermin, numItermax)
677671

678672
while (
679-
(pi_diff is None or pi_diff >= tol)
673+
(err is None or err >= tol)
680674
and (numItermax is None or idx < numItermax)
681675
) or (idx < numItermin):
682676
reg_idx = get_reg(idx)
683-
alpha_prev, beta_prev = alpha.detach().clone(), beta.detach().clone()
684677
(alpha, beta), pi = solver_sinkhorn_stabilized(
685678
cost,
686679
(alpha, beta),
@@ -691,13 +684,14 @@ def get_reg(idx):
691684
)
692685

693686
if tol is not None and idx % eval_freq == 0:
694-
pi_diff = max(
695-
(alpha - alpha_prev).abs().max(),
696-
(beta - beta_prev).abs().max(),
687+
err = (
688+
torch.norm(pi.sum(0) - tuple_weights[1]) ** 2
689+
+ torch.norm(pi.sum(1) - tuple_weights[0]) ** 2
697690
)
698-
if pi_diff < tol:
691+
if err < tol and idx > numItermin:
699692
if verbose:
700-
print(f"Reached tol_uot threshold: {pi_diff}")
693+
print(f"Reached tol_uot threshold: {err}")
694+
break
701695

702696
idx += 1
703697

@@ -727,7 +721,7 @@ def solver_sinkhorn_eps_scaling_sparse(
727721
train_params_inner = numInnerItermax, tol, eval_freq
728722
alpha, beta = init_duals
729723

730-
pi_diff = None
724+
err = None
731725
idx = 0
732726

733727
def get_reg(idx):
@@ -737,11 +731,10 @@ def get_reg(idx):
737731
numItermax = max(numItermin, numItermax)
738732

739733
while (
740-
(pi_diff is None or pi_diff >= tol)
734+
(err is None or err >= tol)
741735
and (numItermax is None or idx < numItermax)
742736
) or (idx < numItermin):
743737
reg_idx = get_reg(idx)
744-
alpha_prev, beta_prev = alpha.detach().clone(), beta.detach().clone()
745738
(alpha, beta), pi = solver_sinkhorn_stabilized_sparse(
746739
cost,
747740
(alpha, beta),
@@ -752,13 +745,17 @@ def get_reg(idx):
752745
)
753746

754747
if tol is not None and idx % eval_freq == 0:
755-
pi_diff = max(
756-
(alpha - alpha_prev).abs().max(),
757-
(beta - beta_prev).abs().max(),
748+
pi1 = csr_sum(pi, dim=1)
749+
pi2 = csr_sum(pi, dim=0)
750+
err = (
751+
torch.norm(pi1 - tuple_weights[0]) ** 2
752+
+ torch.norm(pi2 - tuple_weights[1]) ** 2
758753
)
759-
if pi_diff < tol:
754+
print(f"fugw iter {idx}, err {err}")
755+
if err < tol and idx > numItermin:
760756
if verbose:
761-
print(f"Reached tol_uot threshold: {pi_diff}")
757+
print(f"Reached tol_uot threshold: {err}")
758+
break
762759

763760
idx += 1
764761

0 commit comments

Comments
 (0)