@@ -450,7 +450,16 @@ def solver_sinkhorn_log_sparse(
450450
451451
452452def solver_sinkhorn_stabilized (
453- cost , init_duals , uot_params , tuple_weights , train_params , verbose = True
453+ cost ,
454+ ws ,
455+ wt ,
456+ eps ,
457+ init_duals = None ,
458+ numItermax = 1000 ,
459+ tol = 1e-9 ,
460+ eval_freq = 20 ,
461+ stabilization_threshold = 1e3 ,
462+ verbose = True ,
454463):
455464 """
456465 Stabilized Sinkhorn algorithm as described in
@@ -459,23 +468,23 @@ def solver_sinkhorn_stabilized(
459468 arXiv preprint arXiv:1610.06519.
460469
461470 """
462- ws , wt , ws_dot_wt = tuple_weights
463-
464471 # Initialize scaling vectors to ones
465472 u = torch .ones_like (ws ) / len (ws )
466473 v = torch .ones_like (wt ) / len (wt )
467474
468475 # Initialize the dual potentials
469- alpha , beta = init_duals
470-
471- rho_s , rho_t , eps = uot_params
472- niters , tol , eval_freq = train_params
473- tau = 1e3 # Threshold for stabilization
476+ if init_duals is None :
477+ alpha = torch . zeros ( cost . shape [ 0 ])
478+ beta = torch . zeros ( cost . shape [ 1 ])
479+ else :
480+ alpha , beta = init_duals
474481
475482 def get_K (alpha , beta ):
483+ """Return sinkhorn kernel matrix."""
476484 return (- (cost - alpha [:, None ] - beta [None , :]) / eps ).exp ()
477485
478- def get_Gamma (alpha , beta , u , v ):
486+ def get_pi (alpha , beta , u , v ):
487+ """Return dense transport plan."""
479488 return (
480489 - (cost - alpha [:, None ] - beta [None , :]) / eps
481490 + u [:, None ].log ()
@@ -485,16 +494,21 @@ def get_Gamma(alpha, beta, u, v):
485494 K = get_K (alpha , beta )
486495 with _get_progress (verbose = verbose , transient = True ) as progress :
487496 if verbose :
488- task = progress .add_task ("Sinkhorn iterations" , total = niters )
497+ task = progress .add_task ("Sinkhorn iterations" , total = numItermax )
489498
490499 err = None
491500 idx = 0
492- while (err is None or err >= tol ) and (niters is None or idx < niters ):
501+ while (err is None or err >= tol ) and (
502+ numItermax is None or idx < numItermax
503+ ):
493504 v = wt / (K .T @ u )
494505 u = ws / (K @ v )
495506
496507 # Check for numerical instability and stabilize if needed
497- if torch .max (torch .abs (u )) > tau or torch .max (torch .abs (v )) > tau :
508+ if (
509+ torch .max (torch .abs (u )) > stabilization_threshold
510+ or torch .max (torch .abs (v )) > stabilization_threshold
511+ ):
498512 # Absorb large values into potentials
499513 alpha = alpha + eps * torch .log (u )
500514 beta = beta + eps * torch .log (v )
@@ -522,7 +536,7 @@ def get_Gamma(alpha, beta, u, v):
522536 idx += 1
523537
524538 # Compute final transport plan
525- pi = get_Gamma (alpha , beta , u , v )
539+ pi = get_pi (alpha , beta , u , v )
526540
527541 # Final update to potentials
528542 alpha = alpha + eps * u .log ()
@@ -532,24 +546,32 @@ def get_Gamma(alpha, beta, u, v):
532546
533547
534548def solver_sinkhorn_stabilized_sparse (
535- cost , init_duals , uot_params , tuple_weights , train_params , verbose = True
549+ cost ,
550+ ws ,
551+ wt ,
552+ eps ,
553+ init_duals = None ,
554+ numItermax = 1000 ,
555+ tol = 1e-9 ,
556+ eval_freq = 20 ,
557+ stabilization_threshold = 1e3 ,
558+ verbose = True ,
536559):
537560 """
538561 Stabilized sparse Sinkhorn algorithm following the structure of the dense
539562 stabilized implementation but using sparse matrix operations.
540563 """
541- ws , wt , ws_dot_wt = tuple_weights
542564
543565 # Initialize scaling vectors to ones
544566 u = torch .ones_like (ws ) / len (ws )
545567 v = torch .ones_like (wt ) / len (wt )
546568
547569 # Initialize the dual potentials
548- alpha , beta = init_duals
549-
550- rho_s , rho_t , eps = uot_params
551- niters , tol , eval_freq = train_params
552- tau = 1e3 # Threshold for stabilization
570+ if init_duals is None :
571+ alpha = torch . zeros ( cost . shape [ 0 ])
572+ beta = torch . zeros ( cost . shape [ 1 ])
573+ else :
574+ alpha , beta = init_duals
553575
554576 # Set up sparse matrix operations
555577 crow_indices = cost .crow_indices ()
@@ -569,7 +591,7 @@ def solver_sinkhorn_stabilized_sparse(
569591
570592 # Compute kernel matrix K in sparse format
571593 def get_K_sparse (alpha , beta ):
572- # (alpha[:, None] + beta[None, :] - cost) / eps in sparse form
594+ """Return sinkhorn sparse kernel matrix."""
573595 new_values = (
574596 fill_csr_matrix_rows (alpha , crow_indices )
575597 + fill_csr_matrix_cols (beta , ccol_indices , csc_to_csr )
@@ -586,7 +608,8 @@ def get_K_sparse(alpha, beta):
586608 )
587609 return K
588610
589- def get_Gamma_sparse (alpha , beta , u , v ):
611+ def get_pi_sparse (alpha , beta , u , v ):
612+ """Return dense transport plan."""
590613 new_values = (
591614 (
592615 fill_csr_matrix_rows (alpha , crow_indices )
@@ -614,11 +637,13 @@ def get_Gamma_sparse(alpha, beta, u, v):
614637
615638 with _get_progress (verbose = verbose , transient = True ) as progress :
616639 if verbose :
617- task = progress .add_task ("Sinkhorn iterations" , total = niters )
640+ task = progress .add_task ("Sinkhorn iterations" , total = numItermax )
618641
619642 err = None
620643 idx = 0
621- while (err is None or err >= tol ) and (niters is None or idx < niters ):
644+ while (err is None or err >= tol ) and (
645+ numItermax is None or idx < numItermax
646+ ):
622647 # Update v using sparse matrix multiplication
623648 Kt_u = torch .mv (K .transpose (0 , 1 ), u )
624649 v = wt / Kt_u
@@ -628,7 +653,10 @@ def get_Gamma_sparse(alpha, beta, u, v):
628653 u = ws / Kv
629654
630655 # Check for numerical instability and stabilize if needed
631- if torch .max (torch .abs (u )) > tau or torch .max (torch .abs (v )) > tau :
656+ if (
657+ torch .max (torch .abs (u )) > stabilization_threshold
658+ or torch .max (torch .abs (v )) > stabilization_threshold
659+ ):
632660 # Absorb large values into potentials
633661 alpha = alpha + eps * torch .log (u )
634662 beta = beta + eps * torch .log (v )
@@ -656,7 +684,7 @@ def get_Gamma_sparse(alpha, beta, u, v):
656684 idx += 1
657685
658686 # Compute final transport plan
659- pi = get_Gamma_sparse (alpha , beta , u , v )
687+ pi = get_pi_sparse (alpha , beta , u , v )
660688
661689 # Final update to potentials
662690 alpha = alpha + eps * u .log ()
@@ -667,22 +695,29 @@ def get_Gamma_sparse(alpha, beta, u, v):
667695
668696def solver_sinkhorn_eps_scaling (
669697 cost ,
670- init_duals ,
671- uot_params ,
672- tuple_weights ,
673- train_params ,
698+ ws ,
699+ wt ,
700+ eps ,
701+ init_duals = None ,
702+ numInnerItermax = 100 ,
674703 numItermax = 100 ,
675- verbose = True ,
704+ tol = 1e-9 ,
705+ eval_freq = 10 ,
706+ stabilization_threshold = 1e3 ,
676707 epsilon0 = 1e4 ,
708+ verbose = True ,
677709):
678710 """
679711 Scaling algorithm (ie Sinkhorn algorithm) with epsilon scaling.
680712 Relies on the stabilized Sinkhorn algorithm.
681713 """
682- rho_s , rho_t , eps = uot_params
683- numInnerItermax , tol , eval_freq = train_params
684- train_params_inner = numInnerItermax , tol , eval_freq
685- alpha , beta = init_duals
714+
715+ # Initialize the dual potentials
716+ if init_duals is None :
717+ alpha = torch .zeros (cost .shape [0 ])
718+ beta = torch .zeros (cost .shape [1 ])
719+ else :
720+ alpha , beta = init_duals
686721
687722 err = None
688723 idx = 0
@@ -704,17 +739,21 @@ def get_reg(idx):
704739 reg_idx = get_reg (idx )
705740 (alpha , beta ), pi = solver_sinkhorn_stabilized (
706741 cost ,
707- (alpha , beta ),
708- (rho_s , rho_t , reg_idx ),
709- tuple_weights ,
710- train_params_inner ,
742+ ws ,
743+ wt ,
744+ reg_idx ,
745+ init_duals = (alpha , beta ),
746+ numItermax = numInnerItermax ,
747+ tol = tol ,
748+ eval_freq = eval_freq ,
749+ stabilization_threshold = stabilization_threshold ,
711750 verbose = False ,
712751 )
713752
714753 if tol is not None and idx % eval_freq == 0 :
715754 err = (
716- torch .norm (pi .sum (0 ) - tuple_weights [1 ]) ** 2
717- + torch .norm (pi .sum (1 ) - tuple_weights [0 ]) ** 2
755+ torch .norm (pi .sum (0 ) - wt [1 ]) ** 2
756+ + torch .norm (pi .sum (1 ) - ws [0 ]) ** 2
718757 )
719758 if err < tol and idx > numItermin :
720759 if verbose :
@@ -731,14 +770,17 @@ def get_reg(idx):
731770
732771def solver_sinkhorn_eps_scaling_sparse (
733772 cost ,
734- init_duals ,
735- uot_params ,
736- tuple_weights ,
737- train_params ,
738- numItermax = 100 ,
773+ ws ,
774+ wt ,
775+ eps ,
776+ init_duals = None ,
739777 numInnerItermax = 100 ,
740- verbose = True ,
778+ numItermax = 100 ,
779+ tol = 1e-9 ,
780+ eval_freq = 10 ,
781+ stabilization_threshold = 1e3 ,
741782 epsilon0 = 1e4 ,
783+ verbose = True ,
742784):
743785 """
744786 Scaling algorithm (ie Sinkhorn algorithm) with epsilon scaling.
@@ -747,10 +789,13 @@ def solver_sinkhorn_eps_scaling_sparse(
747789 This implementation uses sparse matrix operations to speed up computations
748790 and reduce memory usage.
749791 """
750- rho_s , rho_t , eps = uot_params
751- _ , tol , eval_freq = train_params
752- train_params_inner = numInnerItermax , tol , eval_freq
753- alpha , beta = init_duals
792+
793+ # Initialize the dual potentials
794+ if init_duals is None :
795+ alpha = torch .zeros (cost .shape [0 ])
796+ beta = torch .zeros (cost .shape [1 ])
797+ else :
798+ alpha , beta = init_duals
754799
755800 err = None
756801 idx = 0
@@ -773,20 +818,21 @@ def get_reg(idx):
773818 reg_idx = get_reg (idx )
774819 (alpha , beta ), pi = solver_sinkhorn_stabilized_sparse (
775820 cost ,
776- (alpha , beta ),
777- (rho_s , rho_t , reg_idx ),
778- tuple_weights ,
779- train_params_inner ,
821+ ws ,
822+ wt ,
823+ reg_idx ,
824+ init_duals = (alpha , beta ),
825+ numItermax = numInnerItermax ,
826+ tol = tol ,
827+ eval_freq = eval_freq ,
828+ stabilization_threshold = stabilization_threshold ,
780829 verbose = False ,
781830 )
782831
783832 if tol is not None and idx % eval_freq == 0 :
784833 pi1 = csr_sum (pi , dim = 1 )
785834 pi2 = csr_sum (pi , dim = 0 )
786- err = (
787- torch .norm (pi1 - tuple_weights [0 ]) ** 2
788- + torch .norm (pi2 - tuple_weights [1 ]) ** 2
789- )
835+ err = torch .norm (pi1 - ws ) ** 2 + torch .norm (pi2 - wt ) ** 2
790836 if err < tol and idx > numItermin :
791837 if verbose :
792838 progress .console .log (
0 commit comments