@@ -481,12 +481,9 @@ def get_K(alpha, beta):
481481 if verbose :
482482 task = progress .add_task ("Sinkhorn iterations" , total = niters )
483483
484- pi_diff = None
484+ err = None
485485 idx = 0
486- while (pi_diff is None or pi_diff >= tol ) and (
487- niters is None or idx < niters
488- ):
489- u_prev , v_prev = u .detach ().clone (), v .detach ().clone ()
486+ while (err is None or err >= tol ) and (niters is None or idx < niters ):
490487 v = wt / (K .T @ u )
491488 u = ws / (K @ v )
492489
@@ -507,14 +504,14 @@ def get_K(alpha, beta):
507504 progress .update (task , advance = 1 )
508505
509506 if tol is not None and idx % eval_freq == 0 :
510- pi_diff = max (
511- (u - u_prev ).abs ().max (), (v - v_prev ).abs ().max ()
512- )
513- if pi_diff < tol :
507+ pi = get_K (alpha + eps * u .log (), beta + eps * v .log ())
508+ err = torch .norm (pi .sum (0 ) - wt )
509+ if err < tol :
514510 if verbose :
515511 progress .console .log (
516- f"Reached tol_uot threshold: { pi_diff } "
512+ f"Reached tol_uot threshold: { err } "
517513 )
514+ break
518515
519516 idx += 1
520517
@@ -590,12 +587,9 @@ def get_K_sparse(alpha, beta):
590587 if verbose :
591588 task = progress .add_task ("Sinkhorn iterations" , total = niters )
592589
593- pi_diff = None
590+ err = None
594591 idx = 0
595- while (pi_diff is None or pi_diff >= tol ) and (
596- niters is None or idx < niters
597- ):
598- u_prev , v_prev = u .detach ().clone (), v .detach ().clone ()
592+ while (err is None or err >= tol ) and (niters is None or idx < niters ):
599593 # Update v using sparse matrix multiplication
600594 Kt_u = (
601595 torch .sparse .mm (K .transpose (0 , 1 ), u .reshape (- 1 , 1 ))
@@ -625,14 +619,14 @@ def get_K_sparse(alpha, beta):
625619 progress .update (task , advance = 1 )
626620
627621 if tol is not None and idx % eval_freq == 0 :
628- pi_diff = max (
629- (u - u_prev ).abs ().max (), (v - v_prev ).abs ().max ()
630- )
631- if pi_diff < tol :
622+ pi = get_K_sparse (alpha + eps * u .log (), beta + eps * v .log ())
623+ err = torch .norm (csr_sum (pi , dim = 0 ) - wt )
624+ if err < tol :
632625 if verbose :
633626 progress .console .log (
634- f"Reached tol_uot threshold: { pi_diff } "
627+ f"Reached tol_uot threshold: { err } "
635628 )
629+ break
636630
637631 idx += 1
638632
@@ -666,7 +660,7 @@ def solver_sinkhorn_eps_scaling(
666660 train_params_inner = numInnerItermax , tol , eval_freq
667661 alpha , beta = init_duals
668662
669- pi_diff = None
663+ err = None
670664 idx = 0
671665
672666 def get_reg (idx ):
@@ -676,11 +670,10 @@ def get_reg(idx):
676670 numItermax = max (numItermin , numItermax )
677671
678672 while (
679- (pi_diff is None or pi_diff >= tol )
673+ (err is None or err >= tol )
680674 and (numItermax is None or idx < numItermax )
681675 ) or (idx < numItermin ):
682676 reg_idx = get_reg (idx )
683- alpha_prev , beta_prev = alpha .detach ().clone (), beta .detach ().clone ()
684677 (alpha , beta ), pi = solver_sinkhorn_stabilized (
685678 cost ,
686679 (alpha , beta ),
@@ -691,13 +684,14 @@ def get_reg(idx):
691684 )
692685
693686 if tol is not None and idx % eval_freq == 0 :
694- pi_diff = max (
695- ( alpha - alpha_prev ). abs (). max (),
696- ( beta - beta_prev ). abs (). max (),
687+ err = (
688+ torch . norm ( pi . sum ( 0 ) - tuple_weights [ 1 ]) ** 2
689+ + torch . norm ( pi . sum ( 1 ) - tuple_weights [ 0 ]) ** 2
697690 )
698- if pi_diff < tol :
691+ if err < tol and idx > numItermin :
699692 if verbose :
700- print (f"Reached tol_uot threshold: { pi_diff } " )
693+ print (f"Reached tol_uot threshold: { err } " )
694+ break
701695
702696 idx += 1
703697
@@ -727,7 +721,7 @@ def solver_sinkhorn_eps_scaling_sparse(
727721 train_params_inner = numInnerItermax , tol , eval_freq
728722 alpha , beta = init_duals
729723
730- pi_diff = None
724+ err = None
731725 idx = 0
732726
733727 def get_reg (idx ):
@@ -737,11 +731,10 @@ def get_reg(idx):
737731 numItermax = max (numItermin , numItermax )
738732
739733 while (
740- (pi_diff is None or pi_diff >= tol )
734+ (err is None or err >= tol )
741735 and (numItermax is None or idx < numItermax )
742736 ) or (idx < numItermin ):
743737 reg_idx = get_reg (idx )
744- alpha_prev , beta_prev = alpha .detach ().clone (), beta .detach ().clone ()
745738 (alpha , beta ), pi = solver_sinkhorn_stabilized_sparse (
746739 cost ,
747740 (alpha , beta ),
@@ -752,13 +745,17 @@ def get_reg(idx):
752745 )
753746
754747 if tol is not None and idx % eval_freq == 0 :
755- pi_diff = max (
756- (alpha - alpha_prev ).abs ().max (),
757- (beta - beta_prev ).abs ().max (),
748+ pi1 = csr_sum (pi , dim = 1 )
749+ pi2 = csr_sum (pi , dim = 0 )
750+ err = (
751+ torch .norm (pi1 - tuple_weights [0 ]) ** 2
752+ + torch .norm (pi2 - tuple_weights [1 ]) ** 2
758753 )
759- if pi_diff < tol :
754+ print (f"fugw iter { idx } , err { err } " )
755+ if err < tol and idx > numItermin :
760756 if verbose :
761- print (f"Reached tol_uot threshold: { pi_diff } " )
757+ print (f"Reached tol_uot threshold: { err } " )
758+ break
762759
763760 idx += 1
764761
0 commit comments