Skip to content

Commit 8a528b1

Browse files
author
Pierre-Louis Barbarant
committed
Refactor tensor initialization to use dynamic dtype and shape in Sinkhorn solver
1 parent ee630ee commit 8a528b1

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/fugw/solvers/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def csr_dim_sum(values, group_indices, n_groups):
120120

121121
A = torch.sparse_coo_tensor(
122122
indices,
123-
torch.ones_like(group_indices).type(torch.float32).to(device),
123+
torch.ones_like(group_indices).type(values.dtype).to(device),
124124
size=(n_groups, n_values),
125125
)
126126

@@ -503,8 +503,8 @@ def solver_sinkhorn_stabilized(
503503

504504
# Initialize the dual potentials
505505
if init_duals is None:
506-
alpha = torch.zeros(cost.shape[0])
507-
beta = torch.zeros(cost.shape[1])
506+
alpha = torch.zeros_like(ws)
507+
beta = torch.zeros_like(wt)
508508
else:
509509
alpha, beta = init_duals
510510

@@ -627,8 +627,8 @@ def solver_sinkhorn_stabilized_sparse(
627627

628628
# Initialize the dual potentials
629629
if init_duals is None:
630-
alpha = torch.zeros(cost.shape[0])
631-
beta = torch.zeros(cost.shape[1])
630+
alpha = torch.zeros_like(ws)
631+
beta = torch.zeros_like(wt)
632632
else:
633633
alpha, beta = init_duals
634634

0 commit comments

Comments
 (0)