Skip to content

Commit

Permalink
Reducing VRAM usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
Konstantin Kirchheim committed Jan 13, 2025
1 parent a0c804d commit a8ce83d
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 22 deletions.
5 changes: 2 additions & 3 deletions src/pytorch_ood/detector/klmatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@ def fit_features(self: Self, logits: Tensor, labels: Tensor, device="cpu") -> Se
:param labels: class labels
:param device: device which should be used for calculations
"""
logits, labels = logits.to(device), labels.to(device)
y_hat = logits.max(dim=1).indices
labels = labels.to(device)
probabilities = logits.softmax(dim=1)

for label in labels.unique():
log.debug(f"Fitting class {label}")
d_k = probabilities[labels == label].mean(dim=0)
d_k = probabilities[labels == label].to(device).mean(dim=0)
self.dists[str(label.item())] = Parameter(d_k)

return self
Expand Down
7 changes: 4 additions & 3 deletions src/pytorch_ood/detector/mahalanobis.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def fit_features(self: Self, z: Tensor, y: Tensor, device: str = None) -> Self:
device = z.device
log.warning(f"No device given. Will use '{device}'.")

z, y = z.to(device), y.to(device)
y = y.to(device)

log.debug("Calculating mahalanobis parameters.")
classes = y.unique()
Expand All @@ -103,9 +103,10 @@ def fit_features(self: Self, z: Tensor, y: Tensor, device: str = None) -> Self:
self.cov = torch.zeros(size=(z.shape[-1], z.shape[-1]), device=device)

for clazz in range(n_classes):
idxs = y.eq(clazz)
idxs = y.eq(clazz).to(z.device)
assert idxs.sum() != 0
zs = z[idxs]
# we only move them to device after indexing to reduce ram usage.
zs = z[idxs].to(device)
self.mu[clazz] = zs.mean(dim=0)
self.cov += (zs - self.mu[clazz]).T.mm(zs - self.mu[clazz])

Expand Down
74 changes: 58 additions & 16 deletions src/pytorch_ood/detector/she.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@
.. autoclass:: pytorch_ood.detector.SHE
:members:
"""
from typing import Callable, TypeVar
from typing import TypeVar, Callable

import torch
from torch import nn
from torch import Tensor
from torch.utils.data import DataLoader

import logging
from pytorch_ood.utils import extract_features, is_known

from ..api import Detector, ModelNotSetException

Self = TypeVar("Self")

log = logging.getLogger(__name__)


class SHE(Detector):
"""
Expand All @@ -32,26 +35,25 @@ class SHE(Detector):
:see Paper: `OpenReview <https://openreview.net/pdf?id=KkazG4lgKL>`__
"""

def __init__(self, model: Callable[[Tensor], Tensor], head: Callable[[Tensor], Tensor]):
def __init__(self, backbone: Callable[[Tensor], Tensor], head: Callable[[Tensor], Tensor]):
"""
:param model: feature extractor
:param backbone: feature extractor
:param head: maps feature vectors to logits
"""
super(SHE, self).__init__()
self.model = model
self.backbone = backbone
self.head = head
self.patterns = None

self.is_fitted = False

def predict(self, x: Tensor) -> Tensor:
"""
:param x: model inputs
"""
if self.model is None:
if self.backbone is None:
raise ModelNotSetException

z = self.model(x)
z = self.backbone(x)
return self.predict_features(z)

def predict_features(self, z: Tensor) -> Tensor:
Expand All @@ -69,16 +71,55 @@ def fit(self: Self, loader: DataLoader, device: str = "cpu") -> Self:
:param loader: data to fit
:param device: device to use for computations
"""
self.model.to(device)
x, y = extract_features(loader, self.model, device=device)
return self.fit_features(x.to(device), y.to(device))
return self.fit_features(x, y, device=device)

@torch.no_grad()
def _filter_correct_predictions(
self, z: Tensor, y: Tensor, device: str = "cpu", batch_size: int = 1024
):
"""
:param z: a tensor of shape (N, D) or similar
:param y: labels of shape (N,)
:param device: device to use for computations
:param batch_size: how many samples we process at a time
"""
z_correct = []
y_correct = []

for start_idx in range(0, z.size(0), batch_size):
end_idx = start_idx + batch_size

z_batch = z[start_idx:end_idx]
y_batch = y[start_idx:end_idx]

def fit_features(self: Self, z: Tensor, y: Tensor) -> Self:
y_hat_batch = self.head(z_batch.to(device)).argmax(dim=1)

mask = y_hat_batch == y_batch

z_correct.append(z_batch[mask].cpu())
y_correct.append(y_batch[mask].cpu())

z_correct = torch.cat(z_correct, dim=0)
y_correct = torch.cat(y_correct, dim=0)

return z_correct, y_correct

def fit_features(
self: Self, z: Tensor, y: Tensor, device: str = "cpu", batch_size: int = 1024
) -> Self:
"""
Calculates mean patterns per class.
:param z: features to fit
:param y: labels
:param device: device to use for computations
:param batch_size: how many samples we process at a time
"""
if isinstance(self.backbone, nn.Module):
self.backbone.to(device)

known = is_known(y)

if not known.any():
Expand All @@ -88,17 +129,18 @@ def fit_features(self: Self, z: Tensor, y: Tensor) -> Self:
z = z[known]
classes = y.unique()

# assume all classes are present
# make sure all classes are present
assert len(classes) == classes.max().item() + 1

# select correctly classified
y_hat = self.head(z).argmax(dim=1)
z = z[y_hat == y]
y = y[y_hat == y]
z, y = self._filter_correct_predictions(z, y, device=device, batch_size=batch_size)

m = []
for clazz in classes:
mav = z[y == clazz].mean(dim=0)
idx = y == clazz
if not idx.any():
raise ValueError(f"No correct predictions for class {clazz.item()}")

mav = z[idx].to(device).mean(dim=0)
m.append(mav)

self.patterns = torch.stack(m)
Expand Down
40 changes: 40 additions & 0 deletions tests/detectors/test_she.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import unittest

import torch

from src.pytorch_ood.detector import SHE
from src.pytorch_ood.model import WideResNet


class TestASH(unittest.TestCase):
"""
Tests for activation shaping
"""

def setUp(self) -> None:
torch.manual_seed(123)

@torch.no_grad()
def test_input(self):
""" """
model = WideResNet(num_classes=10).eval()
detector = SHE(
backbone=model.features,
head=model.fc,
)

detector.fit_features(
z=torch.randn(1000, 128),
y=torch.arange(
1000,
)
% 10,
batch_size=128,
)

x = torch.randn(size=(16, 3, 32, 32))

output = detector(x)

print(output)
self.assertIsNotNone(output)

0 comments on commit a8ce83d

Please sign in to comment.