Skip to content

Conversion from logits to probabilities happens on a batch by batch basis #2195

@GabrielBianconi

Description

@GabrielBianconi

🐛 Bug

I'm using BinaryAUROC and calling .update(logits, y) after each batch (validation_step), expecting the following behavior:

"If preds has values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element."

I noticed some discrepancies in the results, and after investigating realized that the rule above happens on a batch by batch basis rather than across the entire dataset (it happens in _binary_precision_recall_curve_format, which is called by update).

Therefore, if some batch has all logits in [0, 1], then the conversion doesn't happen for that batch.

My use case has batch size = 1, so this happens often.

Expected behavior

I'd expect that the conversion would happen consistently across the entire dataset (i.e., when calling .compute()). If not that, this behavior should be documented more prominently.

Environment

torchmetrics==1.0.3

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions