Skip to content

Commit 74dd761

Browse files
author
Pierre-Louis Barbarant
committed
Refactor Sinkhorn solver functions to improve parameter handling and enhance readability
1 parent 8c99c6c commit 74dd761

File tree

1 file changed

+105
-59
lines changed

1 file changed

+105
-59
lines changed

src/fugw/solvers/utils.py

Lines changed: 105 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,16 @@ def solver_sinkhorn_log_sparse(
450450

451451

452452
def solver_sinkhorn_stabilized(
453-
cost, init_duals, uot_params, tuple_weights, train_params, verbose=True
453+
cost,
454+
ws,
455+
wt,
456+
eps,
457+
init_duals=None,
458+
numItermax=1000,
459+
tol=1e-9,
460+
eval_freq=20,
461+
stabilization_threshold=1e3,
462+
verbose=True,
454463
):
455464
"""
456465
Stabilized Sinkhorn algorithm as described in
@@ -459,23 +468,23 @@ def solver_sinkhorn_stabilized(
459468
arXiv preprint arXiv:1610.06519.
460469
461470
"""
462-
ws, wt, ws_dot_wt = tuple_weights
463-
464471
# Initialize scaling vectors to ones
465472
u = torch.ones_like(ws) / len(ws)
466473
v = torch.ones_like(wt) / len(wt)
467474

468475
# Initialize the dual potentials
469-
alpha, beta = init_duals
470-
471-
rho_s, rho_t, eps = uot_params
472-
niters, tol, eval_freq = train_params
473-
tau = 1e3 # Threshold for stabilization
476+
if init_duals is None:
477+
alpha = torch.zeros(cost.shape[0])
478+
beta = torch.zeros(cost.shape[1])
479+
else:
480+
alpha, beta = init_duals
474481

475482
def get_K(alpha, beta):
483+
"""Return sinkhorn kernel matrix."""
476484
return (-(cost - alpha[:, None] - beta[None, :]) / eps).exp()
477485

478-
def get_Gamma(alpha, beta, u, v):
486+
def get_pi(alpha, beta, u, v):
487+
"""Return dense transport plan."""
479488
return (
480489
-(cost - alpha[:, None] - beta[None, :]) / eps
481490
+ u[:, None].log()
@@ -485,16 +494,21 @@ def get_Gamma(alpha, beta, u, v):
485494
K = get_K(alpha, beta)
486495
with _get_progress(verbose=verbose, transient=True) as progress:
487496
if verbose:
488-
task = progress.add_task("Sinkhorn iterations", total=niters)
497+
task = progress.add_task("Sinkhorn iterations", total=numItermax)
489498

490499
err = None
491500
idx = 0
492-
while (err is None or err >= tol) and (niters is None or idx < niters):
501+
while (err is None or err >= tol) and (
502+
numItermax is None or idx < numItermax
503+
):
493504
v = wt / (K.T @ u)
494505
u = ws / (K @ v)
495506

496507
# Check for numerical instability and stabilize if needed
497-
if torch.max(torch.abs(u)) > tau or torch.max(torch.abs(v)) > tau:
508+
if (
509+
torch.max(torch.abs(u)) > stabilization_threshold
510+
or torch.max(torch.abs(v)) > stabilization_threshold
511+
):
498512
# Absorb large values into potentials
499513
alpha = alpha + eps * torch.log(u)
500514
beta = beta + eps * torch.log(v)
@@ -522,7 +536,7 @@ def get_Gamma(alpha, beta, u, v):
522536
idx += 1
523537

524538
# Compute final transport plan
525-
pi = get_Gamma(alpha, beta, u, v)
539+
pi = get_pi(alpha, beta, u, v)
526540

527541
# Final update to potentials
528542
alpha = alpha + eps * u.log()
@@ -532,24 +546,32 @@ def get_Gamma(alpha, beta, u, v):
532546

533547

534548
def solver_sinkhorn_stabilized_sparse(
535-
cost, init_duals, uot_params, tuple_weights, train_params, verbose=True
549+
cost,
550+
ws,
551+
wt,
552+
eps,
553+
init_duals=None,
554+
numItermax=1000,
555+
tol=1e-9,
556+
eval_freq=20,
557+
stabilization_threshold=1e3,
558+
verbose=True,
536559
):
537560
"""
538561
Stabilized sparse Sinkhorn algorithm following the structure of the dense
539562
stabilized implementation but using sparse matrix operations.
540563
"""
541-
ws, wt, ws_dot_wt = tuple_weights
542564

543565
# Initialize scaling vectors to ones
544566
u = torch.ones_like(ws) / len(ws)
545567
v = torch.ones_like(wt) / len(wt)
546568

547569
# Initialize the dual potentials
548-
alpha, beta = init_duals
549-
550-
rho_s, rho_t, eps = uot_params
551-
niters, tol, eval_freq = train_params
552-
tau = 1e3 # Threshold for stabilization
570+
if init_duals is None:
571+
alpha = torch.zeros(cost.shape[0])
572+
beta = torch.zeros(cost.shape[1])
573+
else:
574+
alpha, beta = init_duals
553575

554576
# Set up sparse matrix operations
555577
crow_indices = cost.crow_indices()
@@ -569,7 +591,7 @@ def solver_sinkhorn_stabilized_sparse(
569591

570592
# Compute kernel matrix K in sparse format
571593
def get_K_sparse(alpha, beta):
572-
# (alpha[:, None] + beta[None, :] - cost) / eps in sparse form
594+
"""Return sinkhorn sparse kernel matrix."""
573595
new_values = (
574596
fill_csr_matrix_rows(alpha, crow_indices)
575597
+ fill_csr_matrix_cols(beta, ccol_indices, csc_to_csr)
@@ -586,7 +608,8 @@ def get_K_sparse(alpha, beta):
586608
)
587609
return K
588610

589-
def get_Gamma_sparse(alpha, beta, u, v):
611+
def get_pi_sparse(alpha, beta, u, v):
612+
"""Return dense transport plan."""
590613
new_values = (
591614
(
592615
fill_csr_matrix_rows(alpha, crow_indices)
@@ -614,11 +637,13 @@ def get_Gamma_sparse(alpha, beta, u, v):
614637

615638
with _get_progress(verbose=verbose, transient=True) as progress:
616639
if verbose:
617-
task = progress.add_task("Sinkhorn iterations", total=niters)
640+
task = progress.add_task("Sinkhorn iterations", total=numItermax)
618641

619642
err = None
620643
idx = 0
621-
while (err is None or err >= tol) and (niters is None or idx < niters):
644+
while (err is None or err >= tol) and (
645+
numItermax is None or idx < numItermax
646+
):
622647
# Update v using sparse matrix multiplication
623648
Kt_u = torch.mv(K.transpose(0, 1), u)
624649
v = wt / Kt_u
@@ -628,7 +653,10 @@ def get_Gamma_sparse(alpha, beta, u, v):
628653
u = ws / Kv
629654

630655
# Check for numerical instability and stabilize if needed
631-
if torch.max(torch.abs(u)) > tau or torch.max(torch.abs(v)) > tau:
656+
if (
657+
torch.max(torch.abs(u)) > stabilization_threshold
658+
or torch.max(torch.abs(v)) > stabilization_threshold
659+
):
632660
# Absorb large values into potentials
633661
alpha = alpha + eps * torch.log(u)
634662
beta = beta + eps * torch.log(v)
@@ -656,7 +684,7 @@ def get_Gamma_sparse(alpha, beta, u, v):
656684
idx += 1
657685

658686
# Compute final transport plan
659-
pi = get_Gamma_sparse(alpha, beta, u, v)
687+
pi = get_pi_sparse(alpha, beta, u, v)
660688

661689
# Final update to potentials
662690
alpha = alpha + eps * u.log()
@@ -667,22 +695,29 @@ def get_Gamma_sparse(alpha, beta, u, v):
667695

668696
def solver_sinkhorn_eps_scaling(
669697
cost,
670-
init_duals,
671-
uot_params,
672-
tuple_weights,
673-
train_params,
698+
ws,
699+
wt,
700+
eps,
701+
init_duals=None,
702+
numInnerItermax=100,
674703
numItermax=100,
675-
verbose=True,
704+
tol=1e-9,
705+
eval_freq=10,
706+
stabilization_threshold=1e3,
676707
epsilon0=1e4,
708+
verbose=True,
677709
):
678710
"""
679711
Scaling algorithm (ie Sinkhorn algorithm) with epsilon scaling.
680712
Relies on the stabilized Sinkhorn algorithm.
681713
"""
682-
rho_s, rho_t, eps = uot_params
683-
numInnerItermax, tol, eval_freq = train_params
684-
train_params_inner = numInnerItermax, tol, eval_freq
685-
alpha, beta = init_duals
714+
715+
# Initialize the dual potentials
716+
if init_duals is None:
717+
alpha = torch.zeros(cost.shape[0])
718+
beta = torch.zeros(cost.shape[1])
719+
else:
720+
alpha, beta = init_duals
686721

687722
err = None
688723
idx = 0
@@ -704,17 +739,21 @@ def get_reg(idx):
704739
reg_idx = get_reg(idx)
705740
(alpha, beta), pi = solver_sinkhorn_stabilized(
706741
cost,
707-
(alpha, beta),
708-
(rho_s, rho_t, reg_idx),
709-
tuple_weights,
710-
train_params_inner,
742+
ws,
743+
wt,
744+
reg_idx,
745+
init_duals=(alpha, beta),
746+
numItermax=numInnerItermax,
747+
tol=tol,
748+
eval_freq=eval_freq,
749+
stabilization_threshold=stabilization_threshold,
711750
verbose=False,
712751
)
713752

714753
if tol is not None and idx % eval_freq == 0:
715754
err = (
716-
torch.norm(pi.sum(0) - tuple_weights[1]) ** 2
717-
+ torch.norm(pi.sum(1) - tuple_weights[0]) ** 2
755+
torch.norm(pi.sum(0) - wt[1]) ** 2
756+
+ torch.norm(pi.sum(1) - ws[0]) ** 2
718757
)
719758
if err < tol and idx > numItermin:
720759
if verbose:
@@ -731,14 +770,17 @@ def get_reg(idx):
731770

732771
def solver_sinkhorn_eps_scaling_sparse(
733772
cost,
734-
init_duals,
735-
uot_params,
736-
tuple_weights,
737-
train_params,
738-
numItermax=100,
773+
ws,
774+
wt,
775+
eps,
776+
init_duals=None,
739777
numInnerItermax=100,
740-
verbose=True,
778+
numItermax=100,
779+
tol=1e-9,
780+
eval_freq=10,
781+
stabilization_threshold=1e3,
741782
epsilon0=1e4,
783+
verbose=True,
742784
):
743785
"""
744786
Scaling algorithm (ie Sinkhorn algorithm) with epsilon scaling.
@@ -747,10 +789,13 @@ def solver_sinkhorn_eps_scaling_sparse(
747789
This implementation uses sparse matrix operations to speed up computations
748790
and reduce memory usage.
749791
"""
750-
rho_s, rho_t, eps = uot_params
751-
_, tol, eval_freq = train_params
752-
train_params_inner = numInnerItermax, tol, eval_freq
753-
alpha, beta = init_duals
792+
793+
# Initialize the dual potentials
794+
if init_duals is None:
795+
alpha = torch.zeros(cost.shape[0])
796+
beta = torch.zeros(cost.shape[1])
797+
else:
798+
alpha, beta = init_duals
754799

755800
err = None
756801
idx = 0
@@ -773,20 +818,21 @@ def get_reg(idx):
773818
reg_idx = get_reg(idx)
774819
(alpha, beta), pi = solver_sinkhorn_stabilized_sparse(
775820
cost,
776-
(alpha, beta),
777-
(rho_s, rho_t, reg_idx),
778-
tuple_weights,
779-
train_params_inner,
821+
ws,
822+
wt,
823+
reg_idx,
824+
init_duals=(alpha, beta),
825+
numItermax=numInnerItermax,
826+
tol=tol,
827+
eval_freq=eval_freq,
828+
stabilization_threshold=stabilization_threshold,
780829
verbose=False,
781830
)
782831

783832
if tol is not None and idx % eval_freq == 0:
784833
pi1 = csr_sum(pi, dim=1)
785834
pi2 = csr_sum(pi, dim=0)
786-
err = (
787-
torch.norm(pi1 - tuple_weights[0]) ** 2
788-
+ torch.norm(pi2 - tuple_weights[1]) ** 2
789-
)
835+
err = torch.norm(pi1 - ws) ** 2 + torch.norm(pi2 - wt) ** 2
790836
if err < tol and idx > numItermin:
791837
if verbose:
792838
progress.console.log(

0 commit comments

Comments
 (0)