@@ -68,13 +68,17 @@ def _compute_auc_helper(
68
68
# TODO - [add flag to set bining dyamically] for use with soft labels, >=0.039 --> 1, <0.039 --> 0
69
69
sorted_labels = torch .ge (sorted_labels , 0.039 ).to (dtype = sorted_labels .dtype )
70
70
sorted_weights = torch .index_select (weights , dim = 0 , index = sorted_indices )
71
- cum_fp = torch .cumsum (sorted_weights * (1.0 - sorted_labels ), dim = 0 )
72
- cum_tp = torch .cumsum (sorted_weights * sorted_labels , dim = 0 )
73
- auc = torch .where (
74
- cum_fp [- 1 ] * cum_tp [- 1 ] == 0 ,
75
- 0.5 , # 0.5 is the no-signal default value for auc.
76
- torch .trapz (cum_tp , cum_fp ) / cum_fp [- 1 ] / cum_tp [- 1 ],
77
- )
71
+ if sorted_weights .numel () > 0 :
72
+ cum_fp = torch .cumsum (sorted_weights * (1.0 - sorted_labels ), dim = 0 )
73
+ cum_tp = torch .cumsum (sorted_weights * sorted_labels , dim = 0 )
74
+ auc = torch .where (
75
+ cum_fp [- 1 ] * cum_tp [- 1 ] == 0 ,
76
+ 0.5 , # 0.5 is the no-signal default value for auc.
77
+ torch .trapz (cum_tp , cum_fp ) / cum_fp [- 1 ] / cum_tp [- 1 ],
78
+ )
79
+ else :
80
+ # if empty predictions, default value
81
+ auc = torch .tensor (0.5 , device = sorted_weights .device )
78
82
return auc
79
83
80
84
0 commit comments