Skip to content

Commit c78db2a

Browse files
author
Pierre-Louis Barbarant
committed
Add sparse kernel matrix computation and epsilon scaling functions
1 parent 3905b93 commit c78db2a

File tree

1 file changed

+93
-57
lines changed

1 file changed

+93
-57
lines changed

src/fugw/solvers/utils.py

Lines changed: 93 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,72 @@ def batch_elementwise_prod_and_sum(
226226
return res
227227

228228

229+
def get_K_sparse(
230+
alpha, beta, cost, crow_indices, ccol_indices, csc_to_csr, eps
231+
):
232+
"""Compute sparse kernel matrix.
233+
234+
Parameters
235+
----------
236+
alpha : torch.Tensor
237+
Source dual variable.
238+
beta : torch.Tensor
239+
Target dual variable.
240+
cost : torch.sparse_csr_tensor
241+
Sparse cost matrix.
242+
crow_indices : torch.Tensor
243+
Row indices of the cost matrix.
244+
ccol_indices : torch.Tensor
245+
Column indices of the cost matrix.
246+
csc_to_csr : torch.Tensor
247+
Mapping from CSC to CSR format.
248+
eps : float
249+
Entropy regularization parameter.
250+
251+
Returns
252+
-------
253+
torch.sparse_csr_tensor
254+
Sparse kernel matrix.
255+
"""
256+
new_values = (
257+
fill_csr_matrix_rows(alpha, crow_indices)
258+
+ fill_csr_matrix_cols(beta, ccol_indices, csc_to_csr)
259+
- cost.values()
260+
) / eps
261+
262+
K_values = new_values.exp()
263+
264+
K = torch.sparse_csr_tensor(
265+
cost.crow_indices(),
266+
cost.col_indices(),
267+
K_values,
268+
size=cost.size(),
269+
)
270+
return K
271+
272+
273+
def get_reg(idx, eps, epsilon0):
274+
"""Exponentially decaying epsilon scaling.
275+
276+
Parameters
277+
----------
278+
idx: int
279+
Current iteration index.
280+
eps: float
281+
Epsilon value.
282+
epsilon0: float
283+
Initial epsilon value.
284+
285+
Returns
286+
-------
287+
float
288+
Updated epsilon value.
289+
"""
290+
return (epsilon0 - eps) * torch.exp(
291+
-torch.tensor(idx, dtype=torch.float64)
292+
) + eps
293+
294+
229295
def solver_sinkhorn_log(
230296
cost, init_duals, uot_params, tuple_weights, train_params, verbose=True
231297
):
@@ -523,51 +589,9 @@ def solver_sinkhorn_stabilized_sparse(
523589
).to_sparse_csr()
524590
csc_to_csr = T.values() - 1
525591

526-
# Compute kernel matrix K in sparse format
527-
def get_K_sparse(alpha, beta):
528-
"""Return sinkhorn sparse kernel matrix."""
529-
new_values = (
530-
fill_csr_matrix_rows(alpha, crow_indices)
531-
+ fill_csr_matrix_cols(beta, ccol_indices, csc_to_csr)
532-
- cost.values()
533-
) / eps
534-
535-
K_values = new_values.exp()
536-
537-
K = torch.sparse_csr_tensor(
538-
cost.crow_indices(),
539-
cost.col_indices(),
540-
K_values,
541-
size=cost.size(),
542-
)
543-
return K
544-
545-
def get_pi_sparse(alpha, beta, u, v):
546-
"""Return dense transport plan."""
547-
new_values = (
548-
(
549-
fill_csr_matrix_rows(alpha, crow_indices)
550-
+ fill_csr_matrix_cols(beta, ccol_indices, csc_to_csr)
551-
- cost.values()
552-
)
553-
/ eps
554-
+ fill_csr_matrix_rows(u.log(), crow_indices)
555-
+ fill_csr_matrix_cols(v.log(), ccol_indices, csc_to_csr)
556-
)
557-
558-
gamma_values = new_values.exp()
559-
560-
gamma = torch.sparse_csr_tensor(
561-
cost.crow_indices(),
562-
cost.col_indices(),
563-
gamma_values,
564-
size=cost.size(),
565-
)
566-
567-
return gamma
568-
569-
# Initial K
570-
K = get_K_sparse(alpha, beta)
592+
K = get_K_sparse(
593+
alpha, beta, cost, crow_indices, ccol_indices, csc_to_csr, eps
594+
)
571595

572596
with _get_progress(verbose=verbose, transient=True) as progress:
573597
if verbose:
@@ -600,13 +624,29 @@ def get_pi_sparse(alpha, beta, u, v):
600624
v = torch.ones_like(wt) / len(wt)
601625

602626
# Recompute K with updated potentials
603-
K = get_K_sparse(alpha, beta)
627+
K = get_K_sparse(
628+
alpha,
629+
beta,
630+
cost,
631+
crow_indices,
632+
ccol_indices,
633+
csc_to_csr,
634+
eps,
635+
)
604636

605637
if verbose:
606638
progress.update(task, advance=1)
607639

608640
if tol is not None and idx % eval_freq == 0:
609-
pi = get_K_sparse(alpha + eps * u.log(), beta + eps * v.log())
641+
pi = get_K_sparse(
642+
alpha + eps * u.log(),
643+
beta + eps * v.log(),
644+
cost,
645+
crow_indices,
646+
ccol_indices,
647+
csc_to_csr,
648+
eps,
649+
)
610650
err = torch.norm(csr_sum(pi, dim=0) - wt)
611651
if err < tol:
612652
if verbose:
@@ -617,13 +657,15 @@ def get_pi_sparse(alpha, beta, u, v):
617657

618658
idx += 1
619659

620-
# Compute final transport plan
621-
pi = get_pi_sparse(alpha, beta, u, v)
622-
623660
# Final update to potentials
624661
alpha = alpha + eps * u.log()
625662
beta = beta + eps * v.log()
626663

664+
# Compute final transport plan
665+
pi = get_K_sparse(
666+
alpha, beta, cost, crow_indices, ccol_indices, csc_to_csr, eps
667+
)
668+
627669
return (alpha, beta), pi
628670

629671

@@ -694,12 +736,6 @@ def solver_sinkhorn_eps_scaling_sparse(
694736
err = None
695737
idx = 0
696738

697-
def get_reg(idx):
698-
"""Epsilon scheduler"""
699-
return (epsilon0 - eps) * torch.exp(
700-
-torch.tensor(idx, dtype=torch.float64)
701-
) + eps
702-
703739
numItermin = 37
704740
numItermax = max(numItermin, numItermax)
705741

@@ -710,7 +746,7 @@ def get_reg(idx):
710746
(err is None or err >= tol)
711747
and (numItermax is None or idx < numItermax)
712748
) or (idx < numItermin):
713-
reg_idx = get_reg(idx)
749+
reg_idx = get_reg(idx, eps, epsilon0)
714750
(alpha, beta), pi = solver_sinkhorn_stabilized_sparse(
715751
cost,
716752
ws,

0 commit comments

Comments
 (0)