Skip to content

Commit 2cc2a9e

Browse files
authored
93 gpu out of memory problem during detectorfit (#95)
* Added catchall arguments to MCD fitting * Optimizing VRAM usage. * Moving models to expected device before fitting
1 parent f541df7 commit 2cc2a9e

File tree

10 files changed

+142
-29
lines changed

10 files changed

+142
-29
lines changed

src/pytorch_ood/detector/dice.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ def fit(self: Self, loader: DataLoader, device: str = "cpu") -> Self:
104104
:param loader: data loader to extract features from. OOD inputs will be ignored.
105105
:param device: device to use for feature extraction
106106
"""
107+
if isinstance(self.model, torch.nn.Module):
108+
log.debug(f"Moving model to {device}")
109+
self.model.to(device)
110+
107111
z, y = extract_features(loader, self.model, device=device)
108112
self.fit_features(z, y)
109113
return self

src/pytorch_ood/detector/klmatching.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def fit(self: Self, data_loader: DataLoader, device="cpu") -> Self:
6060
if self.model is None:
6161
raise ModelNotSetException
6262

63+
if isinstance(self.model, torch.nn.Module):
64+
log.debug(f"Moving model to {device}")
65+
self.model.to(device)
66+
6367
logits, labels = extract_features(data_loader, self.model, device)
6468
return self.fit_features(logits, labels, device)
6569

@@ -72,13 +76,12 @@ def fit_features(self: Self, logits: Tensor, labels: Tensor, device="cpu") -> Se
7276
:param labels: class labels
7377
:param device: device which should be used for calculations
7478
"""
75-
logits, labels = logits.to(device), labels.to(device)
76-
y_hat = logits.max(dim=1).indices
79+
labels = labels.to(device)
7780
probabilities = logits.softmax(dim=1)
7881

7982
for label in labels.unique():
8083
log.debug(f"Fitting class {label}")
81-
d_k = probabilities[labels == label].mean(dim=0)
84+
d_k = probabilities[labels == label].to(device).mean(dim=0)
8285
self.dists[str(label.item())] = Parameter(d_k)
8386

8487
return self

src/pytorch_ood/detector/knn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313
from typing import Callable, TypeVar
1414

15+
import torch
1516
from torch import Tensor, tensor
1617
from torch.utils.data import DataLoader
1718

@@ -104,5 +105,9 @@ def fit(self: Self, loader: DataLoader, device: str = "cpu") -> Self:
104105
:param loader: data loader
105106
:param device: device used for extracting logits
106107
"""
108+
if isinstance(self.model, torch.nn.Module):
109+
log.debug(f"Moving model to {device}")
110+
self.model.to(device)
111+
107112
z, y = extract_features(model=self.model, data_loader=loader, device=device)
108113
return self.fit_features(z, y)

src/pytorch_ood/detector/mahalanobis.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def fit(self: Self, data_loader: DataLoader, device: str = None) -> Self:
7373
device = list(self.model.parameters())[0].device
7474
log.warning(f"No device given. Will use '{device}'.")
7575

76+
if isinstance(self.model, torch.nn.Module):
77+
log.debug(f"Moving model to device {device}")
78+
self.model.to(device)
79+
7680
z, y = extract_features(data_loader, self.model, device)
7781
return self.fit_features(z, y, device)
7882

@@ -89,7 +93,7 @@ def fit_features(self: Self, z: Tensor, y: Tensor, device: str = None) -> Self:
8993
device = z.device
9094
log.warning(f"No device given. Will use '{device}'.")
9195

92-
z, y = z.to(device), y.to(device)
96+
y = y.to(device)
9397

9498
log.debug("Calculating mahalanobis parameters.")
9599
classes = y.unique()
@@ -103,9 +107,10 @@ def fit_features(self: Self, z: Tensor, y: Tensor, device: str = None) -> Self:
103107
self.cov = torch.zeros(size=(z.shape[-1], z.shape[-1]), device=device)
104108

105109
for clazz in range(n_classes):
106-
idxs = y.eq(clazz)
110+
idxs = y.eq(clazz).to(z.device)
107111
assert idxs.sum() != 0
108-
zs = z[idxs]
112+
# we only move them to device after indexing to reduce ram usage.
113+
zs = z[idxs].to(device)
109114
self.mu[clazz] = zs.mean(dim=0)
110115
self.cov += (zs - self.mu[clazz]).T.mm(zs - self.mu[clazz])
111116

src/pytorch_ood/detector/mcd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ def __init__(
6969
self.mode = mode
7070
self.batch_norm = batch_norm
7171

72-
def fit(self: Self, data_loader) -> Self:
72+
def fit(self: Self, data_loader, **kwargs) -> Self:
7373
"""
7474
Not required
7575
"""
7676
return self
7777

78-
def fit_features(self: Self, x: Tensor, y: Tensor) -> Self:
78+
def fit_features(self: Self, x: Tensor, y: Tensor, **kwargs) -> Self:
7979
"""
8080
Not required
8181
"""
@@ -92,6 +92,7 @@ def _switch_mode(model: Module, batch_norm: bool = True) -> bool:
9292
"""
9393
Puts the model into training mode, except for variants of the batch-norm layer.
9494
95+
:param batch_norm: set to False if batch-norm should also be in training mode
9596
:returns: true if model was switched, false otherwise
9697
"""
9798
mode_switch = False

src/pytorch_ood/detector/mmahalanobis.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def fit(self: Self, data_loader: DataLoader, device: str = None) -> Self:
8282
device = list(self.model[0].parameters())[0].device
8383
log.warning(f"No device given. Will use '{device}'.")
8484

85+
if isinstance(self.model, torch.nn.Module):
86+
log.debug(f"Moving model to {device}")
87+
self.model.to(device)
88+
8589
zs = []
8690

8791
for layer_idx in range(len(self.model)):

src/pytorch_ood/detector/rmd.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def fit(self, loader: DataLoader, device: str = "cpu") -> Self:
6565
Fit parameters of the multi variate gaussian for the given loader.
6666
Ignores OOD Inputs.
6767
"""
68+
if isinstance(self.model, torch.nn.Module):
69+
log.debug(f"Moving model to {device}")
70+
self.model.to(device)
71+
6872
z, y = extract_features(loader, self.model, device=device)
6973
return self.fit_features(z, y, device=device)
7074

@@ -81,14 +85,16 @@ def fit_features(self: Self, z: Tensor, y: Tensor, device: str = None) -> Self:
8185
device = z.device
8286
log.warning(f"No device given. Will use '{device}'.")
8387

84-
z, y = z.to(device), y.to(device)
88+
y = y.to(device)
8589
known = is_known(y)
8690

8791
super(RMD, self).fit_features(z, y, device)
8892

93+
z_known = z[known].to(device)
94+
8995
log.debug("Fitting background gaussian.")
90-
self.background_mu = z[known].mean(dim=0)
91-
self.background_cov = (z[known] - self.background_mu).T.mm(z[known] - self.background_mu)
96+
self.background_mu = z_known.mean(dim=0)
97+
self.background_cov = (z_known - self.background_mu).T.mm(z_known - self.background_mu)
9298
self.background_cov += (
9399
torch.eye(self.background_cov.shape[0], device=self.background_cov.device) * 1e-6
94100
)

src/pytorch_ood/detector/she.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,21 @@
77
.. autoclass:: pytorch_ood.detector.SHE
88
:members:
99
"""
10-
from typing import Callable, TypeVar
10+
from typing import TypeVar, Callable
1111

1212
import torch
13+
from torch import nn
1314
from torch import Tensor
1415
from torch.utils.data import DataLoader
15-
16-
from pytorch_ood.utils import extract_features, is_known
16+
import logging
17+
from pytorch_ood.utils import extract_features, is_known, TensorBuffer
1718

1819
from ..api import Detector, ModelNotSetException
1920

2021
Self = TypeVar("Self")
2122

23+
log = logging.getLogger(__name__)
24+
2225

2326
class SHE(Detector):
2427
"""
@@ -32,26 +35,25 @@ class SHE(Detector):
3235
:see Paper: `OpenReview <https://openreview.net/pdf?id=KkazG4lgKL>`__
3336
"""
3437

35-
def __init__(self, model: Callable[[Tensor], Tensor], head: Callable[[Tensor], Tensor]):
38+
def __init__(self, backbone: Callable[[Tensor], Tensor], head: Callable[[Tensor], Tensor]):
3639
"""
37-
:param model: feature extractor
40+
:param backbone: feature extractor
3841
:param head: maps feature vectors to logits
3942
"""
4043
super(SHE, self).__init__()
41-
self.model = model
44+
self.backbone = backbone
4245
self.head = head
4346
self.patterns = None
44-
4547
self.is_fitted = False
4648

4749
def predict(self, x: Tensor) -> Tensor:
4850
"""
4951
:param x: model inputs
5052
"""
51-
if self.model is None:
53+
if self.backbone is None:
5254
raise ModelNotSetException
5355

54-
z = self.model(x)
56+
z = self.backbone(x)
5557
return self.predict_features(z)
5658

5759
def predict_features(self, z: Tensor) -> Tensor:
@@ -67,18 +69,56 @@ def fit(self: Self, loader: DataLoader, device: str = "cpu") -> Self:
6769
Extracts features and calculates mean patterns.
6870
6971
:param loader: data to fit
72+
:param device: device to use for computations. If the backbone is a nn.Module, it will be moved to this device.
73+
"""
74+
if isinstance(self.backbone, nn.Module):
75+
log.debug(f"Moving model to {device}")
76+
self.backbone.to(device)
77+
78+
x, y = extract_features(loader, self.backbone, device=device)
79+
return self.fit_features(x, y, device=device)
80+
81+
@torch.no_grad()
82+
def _filter_correct_predictions(
83+
self, z: Tensor, y: Tensor, device: str = "cpu", batch_size: int = 1024
84+
):
85+
"""
86+
:param z: a tensor of shape (N, D) or similar
87+
:param y: labels of shape (N,)
7088
:param device: device to use for computations
89+
:param batch_size: how many samples we process at a time
7190
"""
72-
x, y = extract_features(loader, self.model, device=device)
73-
return self.fit_features(x.to(device), y.to(device))
91+
buffer = TensorBuffer()
92+
93+
for start_idx in range(0, z.size(0), batch_size):
94+
end_idx = start_idx + batch_size
7495

75-
def fit_features(self: Self, z: Tensor, y: Tensor) -> Self:
96+
z_batch = z[start_idx:end_idx].to(device)
97+
y_batch = y[start_idx:end_idx].to(device)
98+
99+
y_hat_batch = self.head(z_batch).argmax(dim=1)
100+
101+
mask = y_hat_batch == y_batch
102+
buffer.append("z", z_batch[mask])
103+
buffer.append("y", y_hat_batch[mask])
104+
105+
return buffer["z"], buffer["y"]
106+
107+
def fit_features(
108+
self: Self, z: Tensor, y: Tensor, device: str = "cpu", batch_size: int = 1024
109+
) -> Self:
76110
"""
77111
Calculates mean patterns per class.
78112
79113
:param z: features to fit
80114
:param y: labels
115+
:param device: device to use for computations
116+
:param batch_size: how many samples we process at a time
81117
"""
118+
if isinstance(self.backbone, nn.Module):
119+
log.debug(f"Moving model to {device}")
120+
self.backbone.to(device)
121+
82122
known = is_known(y)
83123

84124
if not known.any():
@@ -88,17 +128,18 @@ def fit_features(self: Self, z: Tensor, y: Tensor) -> Self:
88128
z = z[known]
89129
classes = y.unique()
90130

91-
# assume all classes are present
131+
# make sure all classes are present
92132
assert len(classes) == classes.max().item() + 1
93133

94-
# select correctly classified
95-
y_hat = self.head(z).argmax(dim=1)
96-
z = z[y_hat == y]
97-
y = y[y_hat == y]
134+
z, y = self._filter_correct_predictions(z, y, device=device, batch_size=batch_size)
98135

99136
m = []
100137
for clazz in classes:
101-
mav = z[y == clazz].mean(dim=0)
138+
idx = y == clazz
139+
if not idx.any():
140+
raise ValueError(f"No correct predictions for class {clazz.item()}")
141+
142+
mav = z[idx].to(device).mean(dim=0)
102143
m.append(mav)
103144

104145
self.patterns = torch.stack(m)

src/pytorch_ood/detector/vim.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ def fit(self: Self, data_loader, device="cpu") -> Self:
100100
if self.model is None:
101101
raise ModelNotSetException
102102

103+
if isinstance(self.model, torch.nn.Module):
104+
log.debug(f"Moving model to {device}")
105+
self.model.to(device)
106+
103107
features, labels = extract_features(data_loader, self.model, device)
104108
return self.fit_features(features, labels)
105109

tests/detectors/test_she.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import unittest
2+
3+
import torch
4+
5+
from src.pytorch_ood.detector import SHE
6+
from src.pytorch_ood.model import WideResNet
7+
8+
9+
class TestASH(unittest.TestCase):
10+
"""
11+
Tests for activation shaping
12+
"""
13+
14+
def setUp(self) -> None:
15+
torch.manual_seed(123)
16+
17+
@torch.no_grad()
18+
def test_input(self):
19+
""" """
20+
model = WideResNet(num_classes=10).eval()
21+
detector = SHE(
22+
backbone=model.features,
23+
head=model.fc,
24+
)
25+
26+
detector.fit_features(
27+
z=torch.randn(1000, 128),
28+
y=torch.arange(
29+
1000,
30+
)
31+
% 10,
32+
batch_size=128,
33+
)
34+
35+
x = torch.randn(size=(16, 3, 32, 32))
36+
37+
output = detector(x)
38+
39+
print(output)
40+
self.assertIsNotNone(output)

0 commit comments

Comments
 (0)