Skip to content

Commit

Permalink
Update README.rst
Browse files Browse the repository at this point in the history
  • Loading branch information
kkirchheim authored Dec 17, 2024
1 parent aed9a3e commit b92f93e
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,22 @@ If you notice that the scores predicted by a detector do not match the formulas

⏳ Quick Start
^^^^^^^^^^^^^^^^^
Load model pre-trained on CIFAR-10 with the Energy-Bounded Learning Loss [#EnergyBasedOOD]_, and predict on some dataset ``data_loader`` using
Energy-based Out-of-Distribution Detection [#EnergyBasedOOD]_, calculating the common OOD detection metrics.
Load a WideResNet-40 model (used in major publications), pre-trained on CIFAR-10 with the Energy-Bounded Learning Loss [#EnergyBasedOOD]_ (weights from to original paper), and predict on some dataset ``data_loader`` using
Energy-based OOD Detection (EBO) [#EnergyBasedOOD]_, calculating the common metrics.
OOD data must be marked with labels < 0.

.. code-block:: python
from pytorch_ood.detector import EnergyBased
from pytorch_ood.utils import OODMetrics
from pytorch_ood.model import WideResNet
data_loader = ... # your data
data_loader = ... # your data, OOD with label < 0
# Create Neural Network
model = WideResNet(num_classes=10, pretrained="er-cifar10-tune").eval().cuda()
preprocess = WideResNet.transform_for("er-cifar10-tune")
# Create detector
detector = EnergyBased(model)
Expand All @@ -76,7 +78,8 @@ OOD data must be marked with labels < 0.
metrics = OODMetrics()
for x, y in data_loader:
metrics.update(detector(x.cuda()), y)
x = preprocess(x).cuda()
metrics.update(detector(x, y)
print(metrics.compute())
Expand Down

0 comments on commit b92f93e

Please sign in to comment.