Skip to content

Commit

Permalink
93 gpu out of memory problem during detectorfit (#95)
Browse files Browse the repository at this point in the history
* Added catchall arguments to MCD fitting
* Optimizing VRAM usage.
* Moving models to expected device before fitting
  • Loading branch information
kkirchheim authored Jan 16, 2025
1 parent f541df7 commit 2cc2a9e
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 29 deletions.
4 changes: 4 additions & 0 deletions src/pytorch_ood/detector/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def fit(self: Self, loader: DataLoader, device: str = "cpu") -> Self:
:param loader: data loader to extract features from. OOD inputs will be ignored.
:param device: device to use for feature extraction
"""
if isinstance(self.model, torch.nn.Module):
log.debug(f"Moving model to {device}")
self.model.to(device)

z, y = extract_features(loader, self.model, device=device)
self.fit_features(z, y)
return self
9 changes: 6 additions & 3 deletions src/pytorch_ood/detector/klmatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def fit(self: Self, data_loader: DataLoader, device="cpu") -> Self:
if self.model is None:
raise ModelNotSetException

if isinstance(self.model, torch.nn.Module):
log.debug(f"Moving model to {device}")
self.model.to(device)

logits, labels = extract_features(data_loader, self.model, device)
return self.fit_features(logits, labels, device)

Expand All @@ -72,13 +76,12 @@ def fit_features(self: Self, logits: Tensor, labels: Tensor, device="cpu") -> Se
:param labels: class labels
:param device: device which should be used for calculations
"""
logits, labels = logits.to(device), labels.to(device)
y_hat = logits.max(dim=1).indices
labels = labels.to(device)
probabilities = logits.softmax(dim=1)

for label in labels.unique():
log.debug(f"Fitting class {label}")
d_k = probabilities[labels == label].mean(dim=0)
d_k = probabilities[labels == label].to(device).mean(dim=0)
self.dists[str(label.item())] = Parameter(d_k)

return self
Expand Down
5 changes: 5 additions & 0 deletions src/pytorch_ood/detector/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
from typing import Callable, TypeVar

import torch
from torch import Tensor, tensor
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -104,5 +105,9 @@ def fit(self: Self, loader: DataLoader, device: str = "cpu") -> Self:
:param loader: data loader
:param device: device used for extracting logits
"""
if isinstance(self.model, torch.nn.Module):
log.debug(f"Moving model to {device}")
self.model.to(device)

z, y = extract_features(model=self.model, data_loader=loader, device=device)
return self.fit_features(z, y)
11 changes: 8 additions & 3 deletions src/pytorch_ood/detector/mahalanobis.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def fit(self: Self, data_loader: DataLoader, device: str = None) -> Self:
device = list(self.model.parameters())[0].device
log.warning(f"No device given. Will use '{device}'.")

if isinstance(self.model, torch.nn.Module):
log.debug(f"Moving model to device {device}")
self.model.to(device)

z, y = extract_features(data_loader, self.model, device)
return self.fit_features(z, y, device)

Expand All @@ -89,7 +93,7 @@ def fit_features(self: Self, z: Tensor, y: Tensor, device: str = None) -> Self:
device = z.device
log.warning(f"No device given. Will use '{device}'.")

z, y = z.to(device), y.to(device)
y = y.to(device)

log.debug("Calculating mahalanobis parameters.")
classes = y.unique()
Expand All @@ -103,9 +107,10 @@ def fit_features(self: Self, z: Tensor, y: Tensor, device: str = None) -> Self:
self.cov = torch.zeros(size=(z.shape[-1], z.shape[-1]), device=device)

for clazz in range(n_classes):
idxs = y.eq(clazz)
idxs = y.eq(clazz).to(z.device)
assert idxs.sum() != 0
zs = z[idxs]
# we only move them to device after indexing to reduce ram usage.
zs = z[idxs].to(device)
self.mu[clazz] = zs.mean(dim=0)
self.cov += (zs - self.mu[clazz]).T.mm(zs - self.mu[clazz])

Expand Down
5 changes: 3 additions & 2 deletions src/pytorch_ood/detector/mcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def __init__(
self.mode = mode
self.batch_norm = batch_norm

def fit(self: Self, data_loader) -> Self:
def fit(self: Self, data_loader, **kwargs) -> Self:
"""
Not required
"""
return self

def fit_features(self: Self, x: Tensor, y: Tensor) -> Self:
def fit_features(self: Self, x: Tensor, y: Tensor, **kwargs) -> Self:
"""
Not required
"""
Expand All @@ -92,6 +92,7 @@ def _switch_mode(model: Module, batch_norm: bool = True) -> bool:
"""
Puts the model into training mode, except for variants of the batch-norm layer.
:param batch_norm: set to False if batch-norm should also be in training mode
:returns: true if model was switched, false otherwise
"""
mode_switch = False
Expand Down
4 changes: 4 additions & 0 deletions src/pytorch_ood/detector/mmahalanobis.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def fit(self: Self, data_loader: DataLoader, device: str = None) -> Self:
device = list(self.model[0].parameters())[0].device
log.warning(f"No device given. Will use '{device}'.")

if isinstance(self.model, torch.nn.Module):
log.debug(f"Moving model to {device}")
self.model.to(device)

zs = []

for layer_idx in range(len(self.model)):
Expand Down
12 changes: 9 additions & 3 deletions src/pytorch_ood/detector/rmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def fit(self, loader: DataLoader, device: str = "cpu") -> Self:
Fit parameters of the multi variate gaussian for the given loader.
Ignores OOD Inputs.
"""
if isinstance(self.model, torch.nn.Module):
log.debug(f"Moving model to {device}")
self.model.to(device)

z, y = extract_features(loader, self.model, device=device)
return self.fit_features(z, y, device=device)

Expand All @@ -81,14 +85,16 @@ def fit_features(self: Self, z: Tensor, y: Tensor, device: str = None) -> Self:
device = z.device
log.warning(f"No device given. Will use '{device}'.")

z, y = z.to(device), y.to(device)
y = y.to(device)
known = is_known(y)

super(RMD, self).fit_features(z, y, device)

z_known = z[known].to(device)

log.debug("Fitting background gaussian.")
self.background_mu = z[known].mean(dim=0)
self.background_cov = (z[known] - self.background_mu).T.mm(z[known] - self.background_mu)
self.background_mu = z_known.mean(dim=0)
self.background_cov = (z_known - self.background_mu).T.mm(z_known - self.background_mu)
self.background_cov += (
torch.eye(self.background_cov.shape[0], device=self.background_cov.device) * 1e-6
)
Expand Down
77 changes: 59 additions & 18 deletions src/pytorch_ood/detector/she.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@
.. autoclass:: pytorch_ood.detector.SHE
:members:
"""
from typing import Callable, TypeVar
from typing import TypeVar, Callable

import torch
from torch import nn
from torch import Tensor
from torch.utils.data import DataLoader

from pytorch_ood.utils import extract_features, is_known
import logging
from pytorch_ood.utils import extract_features, is_known, TensorBuffer

from ..api import Detector, ModelNotSetException

Self = TypeVar("Self")

log = logging.getLogger(__name__)


class SHE(Detector):
"""
Expand All @@ -32,26 +35,25 @@ class SHE(Detector):
:see Paper: `OpenReview <https://openreview.net/pdf?id=KkazG4lgKL>`__
"""

def __init__(self, model: Callable[[Tensor], Tensor], head: Callable[[Tensor], Tensor]):
def __init__(self, backbone: Callable[[Tensor], Tensor], head: Callable[[Tensor], Tensor]):
"""
:param model: feature extractor
:param backbone: feature extractor
:param head: maps feature vectors to logits
"""
super(SHE, self).__init__()
self.model = model
self.backbone = backbone
self.head = head
self.patterns = None

self.is_fitted = False

def predict(self, x: Tensor) -> Tensor:
"""
:param x: model inputs
"""
if self.model is None:
if self.backbone is None:
raise ModelNotSetException

z = self.model(x)
z = self.backbone(x)
return self.predict_features(z)

def predict_features(self, z: Tensor) -> Tensor:
Expand All @@ -67,18 +69,56 @@ def fit(self: Self, loader: DataLoader, device: str = "cpu") -> Self:
Extracts features and calculates mean patterns.
:param loader: data to fit
:param device: device to use for computations. If the backbone is a nn.Module, it will be moved to this device.
"""
if isinstance(self.backbone, nn.Module):
log.debug(f"Moving model to {device}")
self.backbone.to(device)

x, y = extract_features(loader, self.backbone, device=device)
return self.fit_features(x, y, device=device)

@torch.no_grad()
def _filter_correct_predictions(
self, z: Tensor, y: Tensor, device: str = "cpu", batch_size: int = 1024
):
"""
:param z: a tensor of shape (N, D) or similar
:param y: labels of shape (N,)
:param device: device to use for computations
:param batch_size: how many samples we process at a time
"""
x, y = extract_features(loader, self.model, device=device)
return self.fit_features(x.to(device), y.to(device))
buffer = TensorBuffer()

for start_idx in range(0, z.size(0), batch_size):
end_idx = start_idx + batch_size

def fit_features(self: Self, z: Tensor, y: Tensor) -> Self:
z_batch = z[start_idx:end_idx].to(device)
y_batch = y[start_idx:end_idx].to(device)

y_hat_batch = self.head(z_batch).argmax(dim=1)

mask = y_hat_batch == y_batch
buffer.append("z", z_batch[mask])
buffer.append("y", y_hat_batch[mask])

return buffer["z"], buffer["y"]

def fit_features(
self: Self, z: Tensor, y: Tensor, device: str = "cpu", batch_size: int = 1024
) -> Self:
"""
Calculates mean patterns per class.
:param z: features to fit
:param y: labels
:param device: device to use for computations
:param batch_size: how many samples we process at a time
"""
if isinstance(self.backbone, nn.Module):
log.debug(f"Moving model to {device}")
self.backbone.to(device)

known = is_known(y)

if not known.any():
Expand All @@ -88,17 +128,18 @@ def fit_features(self: Self, z: Tensor, y: Tensor) -> Self:
z = z[known]
classes = y.unique()

# assume all classes are present
# make sure all classes are present
assert len(classes) == classes.max().item() + 1

# select correctly classified
y_hat = self.head(z).argmax(dim=1)
z = z[y_hat == y]
y = y[y_hat == y]
z, y = self._filter_correct_predictions(z, y, device=device, batch_size=batch_size)

m = []
for clazz in classes:
mav = z[y == clazz].mean(dim=0)
idx = y == clazz
if not idx.any():
raise ValueError(f"No correct predictions for class {clazz.item()}")

mav = z[idx].to(device).mean(dim=0)
m.append(mav)

self.patterns = torch.stack(m)
Expand Down
4 changes: 4 additions & 0 deletions src/pytorch_ood/detector/vim.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def fit(self: Self, data_loader, device="cpu") -> Self:
if self.model is None:
raise ModelNotSetException

if isinstance(self.model, torch.nn.Module):
log.debug(f"Moving model to {device}")
self.model.to(device)

features, labels = extract_features(data_loader, self.model, device)
return self.fit_features(features, labels)

Expand Down
40 changes: 40 additions & 0 deletions tests/detectors/test_she.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest

import torch

from src.pytorch_ood.detector import SHE
from src.pytorch_ood.model import WideResNet


class TestASH(unittest.TestCase):
"""
Tests for activation shaping
"""

def setUp(self) -> None:
torch.manual_seed(123)

@torch.no_grad()
def test_input(self):
""" """
model = WideResNet(num_classes=10).eval()
detector = SHE(
backbone=model.features,
head=model.fc,
)

detector.fit_features(
z=torch.randn(1000, 128),
y=torch.arange(
1000,
)
% 10,
batch_size=128,
)

x = torch.randn(size=(16, 3, 32, 32))

output = detector(x)

print(output)
self.assertIsNotNone(output)

0 comments on commit 2cc2a9e

Please sign in to comment.