@@ -226,6 +226,72 @@ def batch_elementwise_prod_and_sum(
226226 return res
227227
228228
229+ def get_K_sparse (
230+ alpha , beta , cost , crow_indices , ccol_indices , csc_to_csr , eps
231+ ):
232+ """Compute sparse kernel matrix.
233+
234+ Parameters
235+ ----------
236+ alpha : torch.Tensor
237+ Source dual variable.
238+ beta : torch.Tensor
239+ Target dual variable.
240+ cost : torch.sparse_csr_tensor
241+ Sparse cost matrix.
242+ crow_indices : torch.Tensor
243+ Row indices of the cost matrix.
244+ ccol_indices : torch.Tensor
245+ Column indices of the cost matrix.
246+ csc_to_csr : torch.Tensor
247+ Mapping from CSC to CSR format.
248+ eps : float
249+ Entropy regularization parameter.
250+
251+ Returns
252+ -------
253+ torch.sparse_csr_tensor
254+ Sparse kernel matrix.
255+ """
256+ new_values = (
257+ fill_csr_matrix_rows (alpha , crow_indices )
258+ + fill_csr_matrix_cols (beta , ccol_indices , csc_to_csr )
259+ - cost .values ()
260+ ) / eps
261+
262+ K_values = new_values .exp ()
263+
264+ K = torch .sparse_csr_tensor (
265+ cost .crow_indices (),
266+ cost .col_indices (),
267+ K_values ,
268+ size = cost .size (),
269+ )
270+ return K
271+
272+
273+ def get_reg (idx , eps , epsilon0 ):
274+ """Exponentially decaying epsilon scaling.
275+
276+ Parameters
277+ ----------
278+ idx: int
279+ Current iteration index.
280+ eps: float
281+ Epsilon value.
282+ epsilon0: float
283+ Initial epsilon value.
284+
285+ Returns
286+ -------
287+ float
288+ Updated epsilon value.
289+ """
290+ return (epsilon0 - eps ) * torch .exp (
291+ - torch .tensor (idx , dtype = torch .float64 )
292+ ) + eps
293+
294+
229295def solver_sinkhorn_log (
230296 cost , init_duals , uot_params , tuple_weights , train_params , verbose = True
231297):
@@ -523,51 +589,9 @@ def solver_sinkhorn_stabilized_sparse(
523589 ).to_sparse_csr ()
524590 csc_to_csr = T .values () - 1
525591
526- # Compute kernel matrix K in sparse format
527- def get_K_sparse (alpha , beta ):
528- """Return sinkhorn sparse kernel matrix."""
529- new_values = (
530- fill_csr_matrix_rows (alpha , crow_indices )
531- + fill_csr_matrix_cols (beta , ccol_indices , csc_to_csr )
532- - cost .values ()
533- ) / eps
534-
535- K_values = new_values .exp ()
536-
537- K = torch .sparse_csr_tensor (
538- cost .crow_indices (),
539- cost .col_indices (),
540- K_values ,
541- size = cost .size (),
542- )
543- return K
544-
545- def get_pi_sparse (alpha , beta , u , v ):
546- """Return dense transport plan."""
547- new_values = (
548- (
549- fill_csr_matrix_rows (alpha , crow_indices )
550- + fill_csr_matrix_cols (beta , ccol_indices , csc_to_csr )
551- - cost .values ()
552- )
553- / eps
554- + fill_csr_matrix_rows (u .log (), crow_indices )
555- + fill_csr_matrix_cols (v .log (), ccol_indices , csc_to_csr )
556- )
557-
558- gamma_values = new_values .exp ()
559-
560- gamma = torch .sparse_csr_tensor (
561- cost .crow_indices (),
562- cost .col_indices (),
563- gamma_values ,
564- size = cost .size (),
565- )
566-
567- return gamma
568-
569- # Initial K
570- K = get_K_sparse (alpha , beta )
592+ K = get_K_sparse (
593+ alpha , beta , cost , crow_indices , ccol_indices , csc_to_csr , eps
594+ )
571595
572596 with _get_progress (verbose = verbose , transient = True ) as progress :
573597 if verbose :
@@ -600,13 +624,29 @@ def get_pi_sparse(alpha, beta, u, v):
600624 v = torch .ones_like (wt ) / len (wt )
601625
602626 # Recompute K with updated potentials
603- K = get_K_sparse (alpha , beta )
627+ K = get_K_sparse (
628+ alpha ,
629+ beta ,
630+ cost ,
631+ crow_indices ,
632+ ccol_indices ,
633+ csc_to_csr ,
634+ eps ,
635+ )
604636
605637 if verbose :
606638 progress .update (task , advance = 1 )
607639
608640 if tol is not None and idx % eval_freq == 0 :
609- pi = get_K_sparse (alpha + eps * u .log (), beta + eps * v .log ())
641+ pi = get_K_sparse (
642+ alpha + eps * u .log (),
643+ beta + eps * v .log (),
644+ cost ,
645+ crow_indices ,
646+ ccol_indices ,
647+ csc_to_csr ,
648+ eps ,
649+ )
610650 err = torch .norm (csr_sum (pi , dim = 0 ) - wt )
611651 if err < tol :
612652 if verbose :
@@ -617,13 +657,15 @@ def get_pi_sparse(alpha, beta, u, v):
617657
618658 idx += 1
619659
620- # Compute final transport plan
621- pi = get_pi_sparse (alpha , beta , u , v )
622-
623660 # Final update to potentials
624661 alpha = alpha + eps * u .log ()
625662 beta = beta + eps * v .log ()
626663
664+ # Compute final transport plan
665+ pi = get_K_sparse (
666+ alpha , beta , cost , crow_indices , ccol_indices , csc_to_csr , eps
667+ )
668+
627669 return (alpha , beta ), pi
628670
629671
@@ -694,12 +736,6 @@ def solver_sinkhorn_eps_scaling_sparse(
694736 err = None
695737 idx = 0
696738
697- def get_reg (idx ):
698- """Epsilon scheduler"""
699- return (epsilon0 - eps ) * torch .exp (
700- - torch .tensor (idx , dtype = torch .float64 )
701- ) + eps
702-
703739 numItermin = 37
704740 numItermax = max (numItermin , numItermax )
705741
@@ -710,7 +746,7 @@ def get_reg(idx):
710746 (err is None or err >= tol )
711747 and (numItermax is None or idx < numItermax )
712748 ) or (idx < numItermin ):
713- reg_idx = get_reg (idx )
749+ reg_idx = get_reg (idx , eps , epsilon0 )
714750 (alpha , beta ), pi = solver_sinkhorn_stabilized_sparse (
715751 cost ,
716752 ws ,
0 commit comments