|
1 | 1 | import torch |
2 | | - |
3 | 2 | from fugw.utils import _get_progress, console |
4 | 3 |
|
5 | 4 |
|
@@ -474,7 +473,14 @@ def solver_sinkhorn_stabilized( |
474 | 473 | tau = 1e3 # Threshold for stabilization |
475 | 474 |
|
476 | 475 | def get_K(alpha, beta): |
477 | | - return ((alpha[:, None] + beta[None, :] - cost) / eps).exp() |
| 476 | + return (-(cost - alpha[:, None] - beta[None, :]) / eps).exp() |
| 477 | + |
| 478 | + def get_Gamma(alpha, beta, u, v): |
| 479 | + return ( |
| 480 | + -(cost - alpha[:, None] - beta[None, :]) / eps |
| 481 | + + u[:, None].log() |
| 482 | + + v[None, :].log() |
| 483 | + ).exp() |
478 | 484 |
|
479 | 485 | K = get_K(alpha, beta) |
480 | 486 | with _get_progress(verbose=verbose, transient=True) as progress: |
@@ -515,13 +521,13 @@ def get_K(alpha, beta): |
515 | 521 |
|
516 | 522 | idx += 1 |
517 | 523 |
|
| 524 | + # Compute final transport plan |
| 525 | + pi = get_Gamma(alpha, beta, u, v) |
| 526 | + |
518 | 527 | # Final update to potentials |
519 | 528 | alpha = alpha + eps * u.log() |
520 | 529 | beta = beta + eps * v.log() |
521 | 530 |
|
522 | | - # Compute final transport plan |
523 | | - pi = get_K(alpha, beta) |
524 | | - |
525 | 531 | return (alpha, beta), pi |
526 | 532 |
|
527 | 533 |
|
@@ -663,11 +669,12 @@ def solver_sinkhorn_eps_scaling( |
663 | 669 | idx = 0 |
664 | 670 |
|
665 | 671 | def get_reg(idx): |
666 | | - return (epsilon0 - eps) * torch.exp(-torch.tensor(idx)) + eps |
| 672 | + return (epsilon0 - eps) * torch.exp( |
| 673 | + -torch.tensor(idx, dtype=torch.float64) |
| 674 | + ) + eps |
667 | 675 |
|
668 | 676 | numItermin = 35 |
669 | 677 | numItermax = max(numItermin, numItermax) |
670 | | - |
671 | 678 | with _get_progress(verbose=verbose, transient=True) as progress: |
672 | 679 | if verbose: |
673 | 680 | task = progress.add_task("Scaling iterations", total=numItermax) |
@@ -730,7 +737,9 @@ def solver_sinkhorn_eps_scaling_sparse( |
730 | 737 | idx = 0 |
731 | 738 |
|
732 | 739 | def get_reg(idx): |
733 | | - return (epsilon0 - eps) * torch.exp(-torch.tensor(idx)) + eps |
| 740 | + return (epsilon0 - eps) * torch.exp( |
| 741 | + -torch.tensor(idx, dtype=torch.float32) |
| 742 | + ) + eps |
734 | 743 |
|
735 | 744 | numItermin = 35 |
736 | 745 | numItermax = max(numItermin, numItermax) |
|
0 commit comments