Skip to content

Commit 3eb72e2

Browse files
committed
Fix recall at k when batch size = 1
1 parent 23d5e3b commit 3eb72e2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

transformers4rec/torch/ranking_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _metric(self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor)
131131

132132
# Compute recalls at K
133133
num_relevant = torch.sum(labels, dim=-1)
134-
rel_indices = (num_relevant != 0).nonzero().squeeze()
134+
rel_indices = (num_relevant != 0).nonzero().squeeze(dim=1)
135135
rel_count = num_relevant[rel_indices]
136136

137137
if rel_indices.shape[0] > 0:

0 commit comments

Comments
 (0)