Skip to content

Commit b92f93e

Browse files
authored
Update README.rst
1 parent aed9a3e commit b92f93e

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

README.rst

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,22 @@ If you notice that the scores predicted by a detector do not match the formulas
5454

5555
⏳ Quick Start
5656
^^^^^^^^^^^^^^^^^
57-
Load model pre-trained on CIFAR-10 with the Energy-Bounded Learning Loss [#EnergyBasedOOD]_, and predict on some dataset ``data_loader`` using
58-
Energy-based Out-of-Distribution Detection [#EnergyBasedOOD]_, calculating the common OOD detection metrics.
57+
Load a WideResNet-40 model (used in major publications), pre-trained on CIFAR-10 with the Energy-Bounded Learning Loss [#EnergyBasedOOD]_ (weights from to original paper), and predict on some dataset ``data_loader`` using
58+
Energy-based OOD Detection (EBO) [#EnergyBasedOOD]_, calculating the common metrics.
5959
OOD data must be marked with labels < 0.
6060

6161
.. code-block:: python
6262
6363
6464
from pytorch_ood.detector import EnergyBased
6565
from pytorch_ood.utils import OODMetrics
66+
from pytorch_ood.model import WideResNet
6667
67-
data_loader = ... # your data
68+
data_loader = ... # your data, OOD with label < 0
6869
6970
# Create Neural Network
7071
model = WideResNet(num_classes=10, pretrained="er-cifar10-tune").eval().cuda()
72+
preprocess = WideResNet.transform_for("er-cifar10-tune")
7173
7274
# Create detector
7375
detector = EnergyBased(model)
@@ -76,7 +78,8 @@ OOD data must be marked with labels < 0.
7678
metrics = OODMetrics()
7779
7880
for x, y in data_loader:
79-
metrics.update(detector(x.cuda()), y)
81+
x = preprocess(x).cuda()
82+
metrics.update(detector(x, y)
8083
8184
print(metrics.compute())
8285

0 commit comments

Comments
 (0)