@@ -246,7 +246,7 @@ def solver_sinkhorn_log(
246246 tau_s = 1 if torch .isinf (rho_s ) else rho_s / (rho_s + eps )
247247 tau_t = 1 if torch .isinf (rho_t ) else rho_t / (rho_t + eps )
248248
249- with _get_progress (transient = True ) as progress :
249+ with _get_progress (verbose = verbose , transient = True ) as progress :
250250 if verbose :
251251 task = progress .add_task ("Sinkhorn iterations" , total = niters )
252252
@@ -364,7 +364,7 @@ def solver_sinkhorn_log_sparse(
364364 ).to_sparse_csr ()
365365 col_one_hot_t = col_one_hot .transpose (0 , 1 ).to_sparse_csr ()
366366
367- with _get_progress (transient = True ) as progress :
367+ with _get_progress (verbose = verbose , transient = True ) as progress :
368368 if verbose :
369369 task = progress .add_task ("Sinkhorn iterations" , total = niters )
370370
@@ -477,7 +477,7 @@ def get_K(alpha, beta):
477477 return ((alpha [:, None ] + beta [None , :] - cost ) / eps ).exp ()
478478
479479 K = get_K (alpha , beta )
480- with _get_progress (transient = True ) as progress :
480+ with _get_progress (verbose = verbose , transient = True ) as progress :
481481 if verbose :
482482 task = progress .add_task ("Sinkhorn iterations" , total = niters )
483483
@@ -500,19 +500,19 @@ def get_K(alpha, beta):
500500 # Recompute K with updated potentials
501501 K = get_K (alpha , beta )
502502
503- if verbose :
504- progress .update (task , advance = 1 )
505-
506503 if tol is not None and idx % eval_freq == 0 :
507504 pi = get_K (alpha + eps * u .log (), beta + eps * v .log ())
508505 err = torch .norm (pi .sum (0 ) - wt )
509506 if err < tol :
510507 if verbose :
511508 progress .console .log (
512- f"Reached tol_uot threshold: { err } "
509+ f"Reached tolerance threshold: { err } "
513510 )
514511 break
515512
513+ if verbose :
514+ progress .update (task , advance = 1 )
515+
516516 idx += 1
517517
518518 # Final update to potentials
@@ -583,7 +583,7 @@ def get_K_sparse(alpha, beta):
583583 # Initial K
584584 K = get_K_sparse (alpha , beta )
585585
586- with _get_progress (transient = True ) as progress :
586+ with _get_progress (verbose = verbose , transient = True ) as progress :
587587 if verbose :
588588 task = progress .add_task ("Sinkhorn iterations" , total = niters )
589589
@@ -624,7 +624,7 @@ def get_K_sparse(alpha, beta):
624624 if err < tol :
625625 if verbose :
626626 progress .console .log (
627- f"Reached tol_uot threshold: { err } "
627+ print ( f"Reached tolerance threshold: { err } " )
628628 )
629629 break
630630
@@ -669,31 +669,37 @@ def get_reg(idx):
669669 numItermin = 35
670670 numItermax = max (numItermin , numItermax )
671671
672- while (
673- (err is None or err >= tol )
674- and (numItermax is None or idx < numItermax )
675- ) or (idx < numItermin ):
676- reg_idx = get_reg (idx )
677- (alpha , beta ), pi = solver_sinkhorn_stabilized (
678- cost ,
679- (alpha , beta ),
680- (rho_s , rho_t , reg_idx ),
681- tuple_weights ,
682- train_params_inner ,
683- verbose = False ,
684- )
685-
686- if tol is not None and idx % eval_freq == 0 :
687- err = (
688- torch .norm (pi .sum (0 ) - tuple_weights [1 ]) ** 2
689- + torch .norm (pi .sum (1 ) - tuple_weights [0 ]) ** 2
672+ with _get_progress (verbose = verbose , transient = True ) as progress :
673+ if verbose :
674+ task = progress .add_task ("Scaling iterations" , total = numItermax )
675+ while (
676+ (err is None or err >= tol )
677+ and (numItermax is None or idx < numItermax )
678+ ) or (idx < numItermin ):
679+ reg_idx = get_reg (idx )
680+ (alpha , beta ), pi = solver_sinkhorn_stabilized (
681+ cost ,
682+ (alpha , beta ),
683+ (rho_s , rho_t , reg_idx ),
684+ tuple_weights ,
685+ train_params_inner ,
686+ verbose = False ,
690687 )
691- if err < tol and idx > numItermin :
692- if verbose :
693- print (f"Reached tol_uot threshold: { err } " )
694- break
695688
696- idx += 1
689+ if tol is not None and idx % eval_freq == 0 :
690+ err = (
691+ torch .norm (pi .sum (0 ) - tuple_weights [1 ]) ** 2
692+ + torch .norm (pi .sum (1 ) - tuple_weights [0 ]) ** 2
693+ )
694+ if err < tol and idx > numItermin :
695+ if verbose :
696+ print (f"Reached tolerance threshold: { err } " )
697+ break
698+
699+ if verbose :
700+ progress .update (task , advance = 1 )
701+
702+ idx += 1
697703
698704 return (alpha , beta ), pi
699705
@@ -730,33 +736,41 @@ def get_reg(idx):
730736 numItermin = 35
731737 numItermax = max (numItermin , numItermax )
732738
733- while (
734- (err is None or err >= tol )
735- and (numItermax is None or idx < numItermax )
736- ) or (idx < numItermin ):
737- reg_idx = get_reg (idx )
738- (alpha , beta ), pi = solver_sinkhorn_stabilized_sparse (
739- cost ,
740- (alpha , beta ),
741- (rho_s , rho_t , reg_idx ),
742- tuple_weights ,
743- train_params_inner ,
744- verbose = False ,
745- )
746-
747- if tol is not None and idx % eval_freq == 0 :
748- pi1 = csr_sum (pi , dim = 1 )
749- pi2 = csr_sum (pi , dim = 0 )
750- err = (
751- torch .norm (pi1 - tuple_weights [0 ]) ** 2
752- + torch .norm (pi2 - tuple_weights [1 ]) ** 2
739+ with _get_progress (verbose = verbose , transient = True ) as progress :
740+ if verbose :
741+ task = progress .add_task ("Scaling iterations" , total = numItermax )
742+ while (
743+ (err is None or err >= tol )
744+ and (numItermax is None or idx < numItermax )
745+ ) or (idx < numItermin ):
746+ reg_idx = get_reg (idx )
747+ (alpha , beta ), pi = solver_sinkhorn_stabilized_sparse (
748+ cost ,
749+ (alpha , beta ),
750+ (rho_s , rho_t , reg_idx ),
751+ tuple_weights ,
752+ train_params_inner ,
753+ verbose = False ,
753754 )
754- if err < tol and idx > numItermin :
755- if verbose :
756- print (f"Reached tol_uot threshold: { err } " )
757- break
758755
759- idx += 1
756+ if tol is not None and idx % eval_freq == 0 :
757+ pi1 = csr_sum (pi , dim = 1 )
758+ pi2 = csr_sum (pi , dim = 0 )
759+ err = (
760+ torch .norm (pi1 - tuple_weights [0 ]) ** 2
761+ + torch .norm (pi2 - tuple_weights [1 ]) ** 2
762+ )
763+ if err < tol and idx > numItermin :
764+ if verbose :
765+ progress .console .log (
766+ print (f"Reached tolerance threshold: { err } " )
767+ )
768+ break
769+
770+ if verbose :
771+ progress .update (task , advance = 1 )
772+
773+ idx += 1
760774
761775 return (alpha , beta ), pi
762776
@@ -795,7 +809,7 @@ def solver_mm(
795809
796810 pi1 , pi2 , pi = init_pi .sum (1 ), init_pi .sum (0 ), init_pi
797811
798- with _get_progress (transient = True ) as progress :
812+ with _get_progress (verbose = verbose , transient = True ) as progress :
799813 if verbose :
800814 task = progress .add_task ("MM-KL iterations" , total = niters )
801815
@@ -812,9 +826,6 @@ def solver_mm(
812826 )
813827 pi1 , pi2 = pi .sum (1 ), pi .sum (0 )
814828
815- if verbose :
816- progress .update (task , advance = 1 )
817-
818829 if tol is not None and idx % eval_freq == 0 :
819830 pi1_error = (pi1 - pi1_prev ).abs ().max ()
820831 pi2_error = (pi2 - pi2_prev ).abs ().max ()
@@ -826,6 +837,9 @@ def solver_mm(
826837 f"{ pi1_error } , { pi2_error } "
827838 )
828839
840+ if verbose :
841+ progress .update (task , advance = 1 )
842+
829843 idx += 1
830844
831845 return pi
@@ -859,7 +873,7 @@ def solver_mm_l2(
859873
860874 pi1 , pi2 , pi = init_pi .sum (1 ), init_pi .sum (0 ), init_pi
861875
862- with _get_progress (transient = True ) as progress :
876+ with _get_progress (verbose = verbose , transient = True ) as progress :
863877 if verbose :
864878 task = progress .add_task ("MM-L2 iterations" , total = niters )
865879
@@ -966,7 +980,7 @@ def solver_mm_sparse(
966980 size = (n_cols , n_pi_values ),
967981 ).to_sparse_csr ()
968982
969- with _get_progress (transient = True ) as progress :
983+ with _get_progress (verbose = verbose , transient = True ) as progress :
970984 if verbose :
971985 task = progress .add_task ("MM iterations" , total = niters )
972986
@@ -1092,7 +1106,7 @@ def solver_mm_l2_sparse(
10921106 pi1 , pi2 = csr_sum (init_pi , dim = 1 ), csr_sum (init_pi , dim = 0 )
10931107 pi_values = init_pi .values ()
10941108
1095- with _get_progress (transient = True ) as progress :
1109+ with _get_progress (verbose = verbose , transient = True ) as progress :
10961110 if verbose :
10971111 task = progress .add_task ("MM-L2 iterations" , total = niters )
10981112
@@ -1184,7 +1198,7 @@ def solver_ibpp(
11841198
11851199 K = torch .exp (- cost / sum_eps )
11861200
1187- with _get_progress (transient = True ) as progress :
1201+ with _get_progress (verbose = verbose , transient = True ) as progress :
11881202 if verbose :
11891203 task = progress .add_task ("DC iterations" , total = niters )
11901204
@@ -1285,7 +1299,7 @@ def solver_ibpp_sparse(
12851299 # Remove previously added 1
12861300 csr_values_to_transpose_values = T .values () - 1
12871301
1288- with _get_progress (transient = True ) as progress :
1302+ with _get_progress (verbose = verbose , transient = True ) as progress :
12891303 if verbose :
12901304 task = progress .add_task ("DC iterations" , total = niters )
12911305
0 commit comments