Skip to content

Commit 9fac2a2

Browse files
author
Pierre-Louis Barbarant
committed
Refactor progress tracking in Sinkhorn solvers to include verbosity option
1 parent 6bdf7a7 commit 9fac2a2

File tree

1 file changed

+80
-66
lines changed

1 file changed

+80
-66
lines changed

src/fugw/solvers/utils.py

Lines changed: 80 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def solver_sinkhorn_log(
246246
tau_s = 1 if torch.isinf(rho_s) else rho_s / (rho_s + eps)
247247
tau_t = 1 if torch.isinf(rho_t) else rho_t / (rho_t + eps)
248248

249-
with _get_progress(transient=True) as progress:
249+
with _get_progress(verbose=verbose, transient=True) as progress:
250250
if verbose:
251251
task = progress.add_task("Sinkhorn iterations", total=niters)
252252

@@ -364,7 +364,7 @@ def solver_sinkhorn_log_sparse(
364364
).to_sparse_csr()
365365
col_one_hot_t = col_one_hot.transpose(0, 1).to_sparse_csr()
366366

367-
with _get_progress(transient=True) as progress:
367+
with _get_progress(verbose=verbose, transient=True) as progress:
368368
if verbose:
369369
task = progress.add_task("Sinkhorn iterations", total=niters)
370370

@@ -477,7 +477,7 @@ def get_K(alpha, beta):
477477
return ((alpha[:, None] + beta[None, :] - cost) / eps).exp()
478478

479479
K = get_K(alpha, beta)
480-
with _get_progress(transient=True) as progress:
480+
with _get_progress(verbose=verbose, transient=True) as progress:
481481
if verbose:
482482
task = progress.add_task("Sinkhorn iterations", total=niters)
483483

@@ -500,19 +500,19 @@ def get_K(alpha, beta):
500500
# Recompute K with updated potentials
501501
K = get_K(alpha, beta)
502502

503-
if verbose:
504-
progress.update(task, advance=1)
505-
506503
if tol is not None and idx % eval_freq == 0:
507504
pi = get_K(alpha + eps * u.log(), beta + eps * v.log())
508505
err = torch.norm(pi.sum(0) - wt)
509506
if err < tol:
510507
if verbose:
511508
progress.console.log(
512-
f"Reached tol_uot threshold: {err}"
509+
f"Reached tolerance threshold: {err}"
513510
)
514511
break
515512

513+
if verbose:
514+
progress.update(task, advance=1)
515+
516516
idx += 1
517517

518518
# Final update to potentials
@@ -583,7 +583,7 @@ def get_K_sparse(alpha, beta):
583583
# Initial K
584584
K = get_K_sparse(alpha, beta)
585585

586-
with _get_progress(transient=True) as progress:
586+
with _get_progress(verbose=verbose, transient=True) as progress:
587587
if verbose:
588588
task = progress.add_task("Sinkhorn iterations", total=niters)
589589

@@ -624,7 +624,7 @@ def get_K_sparse(alpha, beta):
624624
if err < tol:
625625
if verbose:
626626
progress.console.log(
627-
f"Reached tol_uot threshold: {err}"
627+
print(f"Reached tolerance threshold: {err}")
628628
)
629629
break
630630

@@ -669,31 +669,37 @@ def get_reg(idx):
669669
numItermin = 35
670670
numItermax = max(numItermin, numItermax)
671671

672-
while (
673-
(err is None or err >= tol)
674-
and (numItermax is None or idx < numItermax)
675-
) or (idx < numItermin):
676-
reg_idx = get_reg(idx)
677-
(alpha, beta), pi = solver_sinkhorn_stabilized(
678-
cost,
679-
(alpha, beta),
680-
(rho_s, rho_t, reg_idx),
681-
tuple_weights,
682-
train_params_inner,
683-
verbose=False,
684-
)
685-
686-
if tol is not None and idx % eval_freq == 0:
687-
err = (
688-
torch.norm(pi.sum(0) - tuple_weights[1]) ** 2
689-
+ torch.norm(pi.sum(1) - tuple_weights[0]) ** 2
672+
with _get_progress(verbose=verbose, transient=True) as progress:
673+
if verbose:
674+
task = progress.add_task("Scaling iterations", total=numItermax)
675+
while (
676+
(err is None or err >= tol)
677+
and (numItermax is None or idx < numItermax)
678+
) or (idx < numItermin):
679+
reg_idx = get_reg(idx)
680+
(alpha, beta), pi = solver_sinkhorn_stabilized(
681+
cost,
682+
(alpha, beta),
683+
(rho_s, rho_t, reg_idx),
684+
tuple_weights,
685+
train_params_inner,
686+
verbose=False,
690687
)
691-
if err < tol and idx > numItermin:
692-
if verbose:
693-
print(f"Reached tol_uot threshold: {err}")
694-
break
695688

696-
idx += 1
689+
if tol is not None and idx % eval_freq == 0:
690+
err = (
691+
torch.norm(pi.sum(0) - tuple_weights[1]) ** 2
692+
+ torch.norm(pi.sum(1) - tuple_weights[0]) ** 2
693+
)
694+
if err < tol and idx > numItermin:
695+
if verbose:
696+
print(f"Reached tolerance threshold: {err}")
697+
break
698+
699+
if verbose:
700+
progress.update(task, advance=1)
701+
702+
idx += 1
697703

698704
return (alpha, beta), pi
699705

@@ -730,33 +736,41 @@ def get_reg(idx):
730736
numItermin = 35
731737
numItermax = max(numItermin, numItermax)
732738

733-
while (
734-
(err is None or err >= tol)
735-
and (numItermax is None or idx < numItermax)
736-
) or (idx < numItermin):
737-
reg_idx = get_reg(idx)
738-
(alpha, beta), pi = solver_sinkhorn_stabilized_sparse(
739-
cost,
740-
(alpha, beta),
741-
(rho_s, rho_t, reg_idx),
742-
tuple_weights,
743-
train_params_inner,
744-
verbose=False,
745-
)
746-
747-
if tol is not None and idx % eval_freq == 0:
748-
pi1 = csr_sum(pi, dim=1)
749-
pi2 = csr_sum(pi, dim=0)
750-
err = (
751-
torch.norm(pi1 - tuple_weights[0]) ** 2
752-
+ torch.norm(pi2 - tuple_weights[1]) ** 2
739+
with _get_progress(verbose=verbose, transient=True) as progress:
740+
if verbose:
741+
task = progress.add_task("Scaling iterations", total=numItermax)
742+
while (
743+
(err is None or err >= tol)
744+
and (numItermax is None or idx < numItermax)
745+
) or (idx < numItermin):
746+
reg_idx = get_reg(idx)
747+
(alpha, beta), pi = solver_sinkhorn_stabilized_sparse(
748+
cost,
749+
(alpha, beta),
750+
(rho_s, rho_t, reg_idx),
751+
tuple_weights,
752+
train_params_inner,
753+
verbose=False,
753754
)
754-
if err < tol and idx > numItermin:
755-
if verbose:
756-
print(f"Reached tol_uot threshold: {err}")
757-
break
758755

759-
idx += 1
756+
if tol is not None and idx % eval_freq == 0:
757+
pi1 = csr_sum(pi, dim=1)
758+
pi2 = csr_sum(pi, dim=0)
759+
err = (
760+
torch.norm(pi1 - tuple_weights[0]) ** 2
761+
+ torch.norm(pi2 - tuple_weights[1]) ** 2
762+
)
763+
if err < tol and idx > numItermin:
764+
if verbose:
765+
progress.console.log(
766+
print(f"Reached tolerance threshold: {err}")
767+
)
768+
break
769+
770+
if verbose:
771+
progress.update(task, advance=1)
772+
773+
idx += 1
760774

761775
return (alpha, beta), pi
762776

@@ -795,7 +809,7 @@ def solver_mm(
795809

796810
pi1, pi2, pi = init_pi.sum(1), init_pi.sum(0), init_pi
797811

798-
with _get_progress(transient=True) as progress:
812+
with _get_progress(verbose=verbose, transient=True) as progress:
799813
if verbose:
800814
task = progress.add_task("MM-KL iterations", total=niters)
801815

@@ -812,9 +826,6 @@ def solver_mm(
812826
)
813827
pi1, pi2 = pi.sum(1), pi.sum(0)
814828

815-
if verbose:
816-
progress.update(task, advance=1)
817-
818829
if tol is not None and idx % eval_freq == 0:
819830
pi1_error = (pi1 - pi1_prev).abs().max()
820831
pi2_error = (pi2 - pi2_prev).abs().max()
@@ -826,6 +837,9 @@ def solver_mm(
826837
f"{pi1_error}, {pi2_error}"
827838
)
828839

840+
if verbose:
841+
progress.update(task, advance=1)
842+
829843
idx += 1
830844

831845
return pi
@@ -859,7 +873,7 @@ def solver_mm_l2(
859873

860874
pi1, pi2, pi = init_pi.sum(1), init_pi.sum(0), init_pi
861875

862-
with _get_progress(transient=True) as progress:
876+
with _get_progress(verbose=verbose, transient=True) as progress:
863877
if verbose:
864878
task = progress.add_task("MM-L2 iterations", total=niters)
865879

@@ -966,7 +980,7 @@ def solver_mm_sparse(
966980
size=(n_cols, n_pi_values),
967981
).to_sparse_csr()
968982

969-
with _get_progress(transient=True) as progress:
983+
with _get_progress(verbose=verbose, transient=True) as progress:
970984
if verbose:
971985
task = progress.add_task("MM iterations", total=niters)
972986

@@ -1092,7 +1106,7 @@ def solver_mm_l2_sparse(
10921106
pi1, pi2 = csr_sum(init_pi, dim=1), csr_sum(init_pi, dim=0)
10931107
pi_values = init_pi.values()
10941108

1095-
with _get_progress(transient=True) as progress:
1109+
with _get_progress(verbose=verbose, transient=True) as progress:
10961110
if verbose:
10971111
task = progress.add_task("MM-L2 iterations", total=niters)
10981112

@@ -1184,7 +1198,7 @@ def solver_ibpp(
11841198

11851199
K = torch.exp(-cost / sum_eps)
11861200

1187-
with _get_progress(transient=True) as progress:
1201+
with _get_progress(verbose=verbose, transient=True) as progress:
11881202
if verbose:
11891203
task = progress.add_task("DC iterations", total=niters)
11901204

@@ -1285,7 +1299,7 @@ def solver_ibpp_sparse(
12851299
# Remove previously added 1
12861300
csr_values_to_transpose_values = T.values() - 1
12871301

1288-
with _get_progress(transient=True) as progress:
1302+
with _get_progress(verbose=verbose, transient=True) as progress:
12891303
if verbose:
12901304
task = progress.add_task("DC iterations", total=niters)
12911305

0 commit comments

Comments
 (0)