Skip to content

Commit 97914aa

Browse files
author
Pierre-Louis Barbarant
committed
Add docstrings
1 parent 39993be commit 97914aa

File tree

1 file changed

+130
-1
lines changed

1 file changed

+130
-1
lines changed

src/fugw/solvers/utils.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,35 @@ def solver_sinkhorn_stabilized(
467467
Entropy Regularized Transport Problems.
468468
arXiv preprint arXiv:1610.06519.
469469
470+
Parameters
471+
----------
472+
cost: torch.Tensor
473+
Cost matrix.
474+
ws: torch.Tensor
475+
Source weights.
476+
wt: torch.Tensor
477+
Target weights.
478+
eps: float
479+
Entropy regularization parameter.
480+
init_duals: tuple of torch.Tensor, optional, defaults to None
481+
Initial dual potentials.
482+
numItermax: int, optional, defaults to 1000
483+
Maximum number of iterations.
484+
tol: float, optional, defaults to 1e-9
485+
Tolerance threshold.
486+
eval_freq: int, optional, defaults to 20
487+
Frequency at which to evaluate the tolerance threshold.
488+
stabilization_threshold: float, optional, defaults to 1e3
489+
Threshold for numerical stabilization.
490+
verbose: bool, optional, defaults to False
491+
Whether to display progress.
492+
493+
Returns
494+
-------
495+
(alpha, beta): tuple of torch.Tensor
496+
Dual potentials.
497+
pi: torch.Tensor
498+
Optimal transport plan.
470499
"""
471500
# Initialize scaling vectors to ones
472501
u = torch.ones_like(ws) / len(ws)
@@ -560,6 +589,36 @@ def solver_sinkhorn_stabilized_sparse(
560589
"""
561590
Stabilized sparse Sinkhorn algorithm following the structure of the dense
562591
stabilized implementation but using sparse matrix operations.
592+
593+
Parameters
594+
----------
595+
cost: torch.sparse_csr_tensor
596+
Sparse cost matrix.
597+
ws: torch.Tensor
598+
Source weights.
599+
wt: torch.Tensor
600+
Target weights.
601+
eps: float
602+
Entropy regularization parameter.
603+
init_duals: tuple of torch.Tensor, optional, defaults to None
604+
Initial dual potentials.
605+
numItermax: int, optional, defaults to 1000
606+
Maximum number of iterations.
607+
tol: float, optional, defaults to 1e-9
608+
Tolerance threshold.
609+
eval_freq: int, optional, defaults to 20
610+
Frequency at which to evaluate the tolerance threshold.
611+
stabilization_threshold: float, optional, defaults to 1e3
612+
Threshold for numerical stabilization.
613+
verbose: bool, optional, defaults to False
614+
Whether to display progress.
615+
616+
Returns
617+
-------
618+
(alpha, beta): tuple of torch.Tensor
619+
Dual potentials.
620+
pi: torch.sparse_csr_tensor
621+
Optimal transport plan.
563622
"""
564623

565624
# Initialize scaling vectors to ones
@@ -710,6 +769,41 @@ def solver_sinkhorn_eps_scaling(
710769
"""
711770
Scaling algorithm (ie Sinkhorn algorithm) with epsilon scaling.
712771
Relies on the stabilized Sinkhorn algorithm.
772+
773+
Parameters
774+
----------
775+
cost: torch.Tensor
776+
Cost matrix.
777+
ws: torch.Tensor
778+
Source weights.
779+
wt: torch.Tensor
780+
Target weights.
781+
eps: float
782+
Entropy regularization parameter.
783+
init_duals: tuple of torch.Tensor, optional, defaults to None
784+
Initial dual potentials.
785+
numInnerItermax: int, optional, defaults to 100
786+
Maximum number of inner iterations for the stabilized
787+
Sinkhorn algorithm.
788+
numItermax: int, optional, defaults to 1000
789+
Maximum number of iterations.
790+
tol: float, optional, defaults to 1e-9
791+
Tolerance threshold.
792+
eval_freq: int, optional, defaults to 20
793+
Frequency at which to evaluate the tolerance threshold.
794+
stabilization_threshold: float, optional, defaults to 1e3
795+
Threshold for numerical stabilization.
796+
epsilon0: float, optional, defaults to 1e4
797+
Initial epsilon value.
798+
verbose: bool, optional, defaults to False
799+
Whether to display progress.
800+
801+
Returns
802+
-------
803+
(alpha, beta): tuple of torch.Tensor
804+
Dual potentials.
805+
pi: torch.Tensor
806+
Optimal transport plan.
713807
"""
714808

715809
# Initialize the dual potentials
@@ -785,10 +879,45 @@ def solver_sinkhorn_eps_scaling_sparse(
785879
):
786880
"""
787881
Scaling algorithm (ie Sinkhorn algorithm) with epsilon scaling.
788-
Relies on the stabilized Sinkhorn algorithm.
882+
Relies on the stabilized sparse Sinkhorn algorithm.
789883
790884
This implementation uses sparse matrix operations to speed up computations
791885
and reduce memory usage.
886+
887+
Parameters
888+
----------
889+
cost: torch.sparse_csr_tensor
890+
Sparse cost matrix.
891+
ws: torch.Tensor
892+
Source weights.
893+
wt: torch.Tensor
894+
Target weights.
895+
eps: float
896+
Entropy regularization parameter.
897+
init_duals: tuple of torch.Tensor, optional, defaults to None
898+
Initial dual potentials.
899+
numInnerItermax: int, optional, defaults to 100
900+
Maximum number of inner iterations for the stabilized
901+
Sinkhorn algorithm.
902+
numItermax: int, optional, defaults to 1000
903+
Maximum number of iterations.
904+
tol: float, optional, defaults to 1e-9
905+
Tolerance threshold.
906+
eval_freq: int, optional, defaults to 20
907+
Frequency at which to evaluate the tolerance threshold.
908+
stabilization_threshold: float, optional, defaults to 1e3
909+
Threshold for numerical stabilization.
910+
epsilon0: float, optional, defaults to 1e4
911+
Initial epsilon value.
912+
verbose: bool, optional, defaults to False
913+
Whether to display progress.
914+
915+
Returns
916+
-------
917+
(alpha, beta): tuple of torch.Tensor
918+
Dual potentials.
919+
pi: torch.sparse_csr_tensor
920+
Optimal transport plan.
792921
"""
793922

794923
# Initialize the dual potentials

0 commit comments

Comments
 (0)