@@ -467,6 +467,35 @@ def solver_sinkhorn_stabilized(
467467 Entropy Regularized Transport Problems.
468468 arXiv preprint arXiv:1610.06519.
469469
470+ Parameters
471+ ----------
472+ cost: torch.Tensor
473+ Cost matrix.
474+ ws: torch.Tensor
475+ Source weights.
476+ wt: torch.Tensor
477+ Target weights.
478+ eps: float
479+ Entropy regularization parameter.
480+ init_duals: tuple of torch.Tensor, optional, defaults to None
481+ Initial dual potentials.
482+ numItermax: int, optional, defaults to 1000
483+ Maximum number of iterations.
484+ tol: float, optional, defaults to 1e-9
485+ Tolerance threshold.
486+ eval_freq: int, optional, defaults to 20
487+ Frequency at which to evaluate the tolerance threshold.
488+ stabilization_threshold: float, optional, defaults to 1e3
489+ Threshold for numerical stabilization.
490+ verbose: bool, optional, defaults to False
491+ Whether to display progress.
492+
493+ Returns
494+ -------
495+ (alpha, beta): tuple of torch.Tensor
496+ Dual potentials.
497+ pi: torch.Tensor
498+ Optimal transport plan.
470499 """
471500 # Initialize scaling vectors to ones
472501 u = torch .ones_like (ws ) / len (ws )
@@ -560,6 +589,36 @@ def solver_sinkhorn_stabilized_sparse(
560589 """
561590 Stabilized sparse Sinkhorn algorithm following the structure of the dense
562591 stabilized implementation but using sparse matrix operations.
592+
593+ Parameters
594+ ----------
595+ cost: torch.sparse_csr_tensor
596+ Sparse cost matrix.
597+ ws: torch.Tensor
598+ Source weights.
599+ wt: torch.Tensor
600+ Target weights.
601+ eps: float
602+ Entropy regularization parameter.
603+ init_duals: tuple of torch.Tensor, optional, defaults to None
604+ Initial dual potentials.
605+ numItermax: int, optional, defaults to 1000
606+ Maximum number of iterations.
607+ tol: float, optional, defaults to 1e-9
608+ Tolerance threshold.
609+ eval_freq: int, optional, defaults to 20
610+ Frequency at which to evaluate the tolerance threshold.
611+ stabilization_threshold: float, optional, defaults to 1e3
612+ Threshold for numerical stabilization.
613+ verbose: bool, optional, defaults to False
614+ Whether to display progress.
615+
616+ Returns
617+ -------
618+ (alpha, beta): tuple of torch.Tensor
619+ Dual potentials.
620+ pi: torch.sparse_csr_tensor
621+ Optimal transport plan.
563622 """
564623
565624 # Initialize scaling vectors to ones
@@ -710,6 +769,41 @@ def solver_sinkhorn_eps_scaling(
710769 """
711770 Scaling algorithm (ie Sinkhorn algorithm) with epsilon scaling.
712771 Relies on the stabilized Sinkhorn algorithm.
772+
773+ Parameters
774+ ----------
775+ cost: torch.Tensor
776+ Cost matrix.
777+ ws: torch.Tensor
778+ Source weights.
779+ wt: torch.Tensor
780+ Target weights.
781+ eps: float
782+ Entropy regularization parameter.
783+ init_duals: tuple of torch.Tensor, optional, defaults to None
784+ Initial dual potentials.
785+ numInnerItermax: int, optional, defaults to 100
786+ Maximum number of inner iterations for the stabilized
787+ Sinkhorn algorithm.
788+ numItermax: int, optional, defaults to 1000
789+ Maximum number of iterations.
790+ tol: float, optional, defaults to 1e-9
791+ Tolerance threshold.
792+ eval_freq: int, optional, defaults to 20
793+ Frequency at which to evaluate the tolerance threshold.
794+ stabilization_threshold: float, optional, defaults to 1e3
795+ Threshold for numerical stabilization.
796+ epsilon0: float, optional, defaults to 1e4
797+ Initial epsilon value.
798+ verbose: bool, optional, defaults to False
799+ Whether to display progress.
800+
801+ Returns
802+ -------
803+ (alpha, beta): tuple of torch.Tensor
804+ Dual potentials.
805+ pi: torch.Tensor
806+ Optimal transport plan.
713807 """
714808
715809 # Initialize the dual potentials
@@ -785,10 +879,45 @@ def solver_sinkhorn_eps_scaling_sparse(
785879):
786880 """
787881 Scaling algorithm (ie Sinkhorn algorithm) with epsilon scaling.
788- Relies on the stabilized Sinkhorn algorithm.
882+ Relies on the stabilized sparse Sinkhorn algorithm.
789883
790884 This implementation uses sparse matrix operations to speed up computations
791885 and reduce memory usage.
886+
887+ Parameters
888+ ----------
889+ cost: torch.sparse_csr_tensor
890+ Sparse cost matrix.
891+ ws: torch.Tensor
892+ Source weights.
893+ wt: torch.Tensor
894+ Target weights.
895+ eps: float
896+ Entropy regularization parameter.
897+ init_duals: tuple of torch.Tensor, optional, defaults to None
898+ Initial dual potentials.
899+ numInnerItermax: int, optional, defaults to 100
900+ Maximum number of inner iterations for the stabilized
901+ Sinkhorn algorithm.
902+ numItermax: int, optional, defaults to 1000
903+ Maximum number of iterations.
904+ tol: float, optional, defaults to 1e-9
905+ Tolerance threshold.
906+ eval_freq: int, optional, defaults to 20
907+ Frequency at which to evaluate the tolerance threshold.
908+ stabilization_threshold: float, optional, defaults to 1e3
909+ Threshold for numerical stabilization.
910+ epsilon0: float, optional, defaults to 1e4
911+ Initial epsilon value.
912+ verbose: bool, optional, defaults to False
913+ Whether to display progress.
914+
915+ Returns
916+ -------
917+ (alpha, beta): tuple of torch.Tensor
918+ Dual potentials.
919+ pi: torch.sparse_csr_tensor
920+ Optimal transport plan.
792921 """
793922
794923 # Initialize the dual potentials
0 commit comments