-
Notifications
You must be signed in to change notification settings - Fork 472
Description
🐛 Bug
I'm using BinaryAUROC and calling .update(logits, y) after each batch (validation_step), expecting the following behavior:
"If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element."
I noticed some discrepancies in the results, and after investigating realized that the rule above happens on a batch by batch basis rather than across the entire dataset (it happens in _binary_precision_recall_curve_format, which is called by update).
Therefore, if some batch has all logits in [0, 1], then the conversion doesn't happen for that batch.
My use case has batch size = 1, so this happens often.
Expected behavior
I'd expect that the conversion would happen consistently across the entire dataset (i.e., when calling .compute()). If not that, this behavior should be documented more prominently.
Environment
torchmetrics==1.0.3