Skip to content

Commit d572e9d

Browse files
Xiaolong Wangfacebook-github-bot
Xiaolong Wang
authored andcommitted
: pe tasks (pytorch#2988)
Summary: Pull Request resolved: pytorch#2988 Differential Revision: D74934835
1 parent 20f01a5 commit d572e9d

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

torchrec/metrics/auc.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,17 @@ def _compute_auc_helper(
6868
# TODO - [add flag to set bining dyamically] for use with soft labels, >=0.039 --> 1, <0.039 --> 0
6969
sorted_labels = torch.ge(sorted_labels, 0.039).to(dtype=sorted_labels.dtype)
7070
sorted_weights = torch.index_select(weights, dim=0, index=sorted_indices)
71-
cum_fp = torch.cumsum(sorted_weights * (1.0 - sorted_labels), dim=0)
72-
cum_tp = torch.cumsum(sorted_weights * sorted_labels, dim=0)
73-
auc = torch.where(
74-
cum_fp[-1] * cum_tp[-1] == 0,
75-
0.5, # 0.5 is the no-signal default value for auc.
76-
torch.trapz(cum_tp, cum_fp) / cum_fp[-1] / cum_tp[-1],
77-
)
71+
if sorted_weights.numel() > 0:
72+
cum_fp = torch.cumsum(sorted_weights * (1.0 - sorted_labels), dim=0)
73+
cum_tp = torch.cumsum(sorted_weights * sorted_labels, dim=0)
74+
auc = torch.where(
75+
cum_fp[-1] * cum_tp[-1] == 0,
76+
0.5, # 0.5 is the no-signal default value for auc.
77+
torch.trapz(cum_tp, cum_fp) / cum_fp[-1] / cum_tp[-1],
78+
)
79+
else:
80+
# if empty predictions, default value
81+
auc = torch.tensor(0.5, device=sorted_weights.device)
7882
return auc
7983

8084

0 commit comments

Comments
 (0)