Skip to content

Commit 705b718

Browse files
arnaudvlAshley Scillitoemauicv
authored
Add KeOps MMD detector (#548)
* first commit keops * update kernel and mmd keops * allow multiple kernel bandwidths for keops * fix bug * update mmd * remove learned kernel and base kernel_matrix MMD function * unify batched mmd2 * update keops mmd * update docs and kernel import * bugfixes * remove unused imports * add benchmarking example * update test mmd * add test mmd keops * update readme * bugfix kernel and update mmd test * remove print from test * update keops tests * Add save warning and update tests * Update setup and associated docs * Fix typing issue in * Install keops as part of CI * Add keops tox environment * Add keops to all dependency bucket * Fix minor issue * Protect GaussianRBF with import optional * Skip keops tests on Windows, and keops notebook test. Fix backend validator. * Skip keops kernel tests if not installed * Add pykeops to op deps ERROR_TYPES * Skip keops tests on MacOS * Add note to docs about linux-only support for keops * Add batch_size_permutations to pydantic models * remove print * remove unnecessary comment * change default bandwidth fn to None * update infer sigma * update test warning, update and clarify keops kernels logic * clean up * update docstring * fix bug * undo unnecessary kwarg removal * make test consistent with torch/tf backends * add _mmd2 test * remove unused import * clarify docs, remove redundant framework checks * remove print * update docs keops * batched version of sigma_mean part 1 * remove unused import * update keops kernels test Co-authored-by: Ashley Scillitoe <[email protected]> Co-authored-by: Alex Athorne <[email protected]>
1 parent 13dde6d commit 705b718

24 files changed

+1277
-66
lines changed

.github/workflows/ci.yml

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ jobs:
5454
if [ "$RUNNER_OS" != "Windows" ] && [ ${{ matrix.python }} < '3.10' ]; then # Skip Prophet tests on Windows as installation complex. Skip on Python 3.10 as not supported.
5555
python -m pip install --upgrade --upgrade-strategy eager -e .[prophet]
5656
fi
57+
if [ "$RUNNER_OS" == "Linux"]; then # Currently, we only support KeOps on Linux.
58+
python -m pip install --upgrade --upgrade-strategy eager -e .[keops]
59+
fi
5760
python -m pip install --upgrade --upgrade-strategy eager -e .[tensorflow,torch]
5861
python -m pip freeze
5962

README.md

+15-3
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ The package, `alibi-detect` can be installed from:
7979
pip install git+https://github.com/SeldonIO/alibi-detect.git
8080
```
8181

82-
- To install with the tensorflow backend:
82+
- To install with the TensorFlow backend:
8383
```bash
8484
pip install alibi-detect[tensorflow]
8585
```
@@ -89,6 +89,11 @@ The package, `alibi-detect` can be installed from:
8989
pip install alibi-detect[torch]
9090
```
9191

92+
- To install with the KeOps backend:
93+
```bash
94+
pip install alibi-detect[keops]
95+
```
96+
9297
- To use the `Prophet` time series outlier detector:
9398

9499
```bash
@@ -181,8 +186,8 @@ The following tables show the advised use cases for each algorithm. The column *
181186

182187
#### TensorFlow and PyTorch support
183188

184-
The drift detectors support TensorFlow and PyTorch backends. Alibi Detect does not install these as default. See the
185-
[installation options](#installation-and-usage) for more details.
189+
The drift detectors support TensorFlow, PyTorch and (where applicable) [KeOps](https://www.kernel-operations.io/keops/index.html) backends.
190+
However, Alibi Detect does not install these by default. See the [installation options](#installation-and-usage) for more details.
186191

187192
```python
188193
from alibi_detect.cd import MMDDrift
@@ -198,6 +203,13 @@ cd = MMDDrift(x_ref, backend='pytorch', p_val=.05)
198203
preds = cd.predict(x)
199204
```
200205

206+
Or in KeOps:
207+
208+
```python
209+
cd = MMDDrift(x_ref, backend='keops', p_val=.05)
210+
preds = cd.predict(x)
211+
```
212+
201213
#### Built-in preprocessing steps
202214

203215
Alibi Detect also comes with various preprocessing steps such as randomly initialized encoders, pretrained text

alibi_detect/cd/base.py

-5
Original file line numberDiff line numberDiff line change
@@ -602,11 +602,6 @@ def preprocess(self, x: Union[np.ndarray, list]) -> Tuple[np.ndarray, np.ndarray
602602
else:
603603
return self.x_ref, x # type: ignore[return-value]
604604

605-
@abstractmethod
606-
def kernel_matrix(self, x: Union['torch.Tensor', 'tf.Tensor'], y: Union['torch.Tensor', 'tf.Tensor']) \
607-
-> Union['torch.Tensor', 'tf.Tensor']:
608-
pass
609-
610605
@abstractmethod
611606
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]:
612607
pass

alibi_detect/cd/keops/__init__.py

Whitespace-only changes.

alibi_detect/cd/keops/mmd.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import logging
2+
import numpy as np
3+
from pykeops.torch import LazyTensor
4+
import torch
5+
from typing import Callable, Dict, List, Optional, Tuple, Union
6+
from alibi_detect.cd.base import BaseMMDDrift
7+
from alibi_detect.utils.keops.kernels import GaussianRBF
8+
from alibi_detect.utils.pytorch import get_device
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class MMDDriftKeops(BaseMMDDrift):
14+
def __init__(
15+
self,
16+
x_ref: Union[np.ndarray, list],
17+
p_val: float = .05,
18+
x_ref_preprocessed: bool = False,
19+
preprocess_at_init: bool = True,
20+
update_x_ref: Optional[Dict[str, int]] = None,
21+
preprocess_fn: Optional[Callable] = None,
22+
kernel: Callable = GaussianRBF,
23+
sigma: Optional[np.ndarray] = None,
24+
configure_kernel_from_x_ref: bool = True,
25+
n_permutations: int = 100,
26+
batch_size_permutations: int = 1000000,
27+
device: Optional[str] = None,
28+
input_shape: Optional[tuple] = None,
29+
data_type: Optional[str] = None
30+
) -> None:
31+
"""
32+
Maximum Mean Discrepancy (MMD) data drift detector using a permutation test.
33+
34+
Parameters
35+
----------
36+
x_ref
37+
Data used as reference distribution.
38+
p_val
39+
p-value used for the significance of the permutation test.
40+
x_ref_preprocessed
41+
Whether the given reference data `x_ref` has been preprocessed yet. If `x_ref_preprocessed=True`, only
42+
the test data `x` will be preprocessed at prediction time. If `x_ref_preprocessed=False`, the reference
43+
data will also be preprocessed.
44+
preprocess_at_init
45+
Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference
46+
data will be preprocessed at prediction time. Only applies if `x_ref_preprocessed=False`.
47+
update_x_ref
48+
Reference data can optionally be updated to the last n instances seen by the detector
49+
or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while
50+
for reservoir sampling {'reservoir_sampling': n} is passed.
51+
preprocess_fn
52+
Function to preprocess the data before computing the data drift metrics.
53+
kernel
54+
Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
55+
sigma
56+
Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array.
57+
The kernel evaluation is then averaged over those bandwidths.
58+
configure_kernel_from_x_ref
59+
Whether to already configure the kernel bandwidth from the reference data.
60+
n_permutations
61+
Number of permutations used in the permutation test.
62+
batch_size_permutations
63+
KeOps computes the n_permutations of the MMD^2 statistics in chunks of batch_size_permutations.
64+
device
65+
Device type used. The default None tries to use the GPU and falls back on CPU if needed.
66+
Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
67+
input_shape
68+
Shape of input data.
69+
data_type
70+
Optionally specify the data type (tabular, image or time-series). Added to metadata.
71+
"""
72+
super().__init__(
73+
x_ref=x_ref,
74+
p_val=p_val,
75+
x_ref_preprocessed=x_ref_preprocessed,
76+
preprocess_at_init=preprocess_at_init,
77+
update_x_ref=update_x_ref,
78+
preprocess_fn=preprocess_fn,
79+
sigma=sigma,
80+
configure_kernel_from_x_ref=configure_kernel_from_x_ref,
81+
n_permutations=n_permutations,
82+
input_shape=input_shape,
83+
data_type=data_type
84+
)
85+
self.meta.update({'backend': 'keops'})
86+
87+
# set device
88+
self.device = get_device(device)
89+
90+
# initialize kernel
91+
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment]
92+
np.ndarray) else None
93+
self.kernel = kernel(sigma).to(self.device) if kernel == GaussianRBF else kernel
94+
95+
# set the correct MMD^2 function based on the batch size for the permutations
96+
self.batch_size = batch_size_permutations
97+
self.n_batches = 1 + (n_permutations - 1) // batch_size_permutations
98+
99+
# infer the kernel bandwidth from the reference data
100+
if isinstance(sigma, torch.Tensor):
101+
self.infer_sigma = False
102+
elif self.infer_sigma:
103+
x = torch.from_numpy(self.x_ref).to(self.device)
104+
_ = self.kernel(LazyTensor(x[:, None, :]), LazyTensor(x[None, :, :]), infer_sigma=self.infer_sigma)
105+
self.infer_sigma = False
106+
else:
107+
self.infer_sigma = True
108+
109+
def _mmd2(self, x_all: torch.Tensor, perms: List[torch.Tensor], m: int, n: int) \
110+
-> Tuple[torch.Tensor, torch.Tensor]:
111+
"""
112+
Batched (across the permutations) MMD^2 computation for the original test statistic and the permutations.
113+
114+
Parameters
115+
----------
116+
x_all
117+
Concatenated reference and test instances.
118+
perms
119+
List with permutation vectors.
120+
m
121+
Number of reference instances.
122+
n
123+
Number of test instances.
124+
125+
Returns
126+
-------
127+
MMD^2 statistic for the original and permuted reference and test sets.
128+
"""
129+
k_xx, k_yy, k_xy = [], [], []
130+
for batch in range(self.n_batches):
131+
i, j = batch * self.batch_size, (batch + 1) * self.batch_size
132+
# construct stacked tensors with a batch of permutations for the reference set x and test set y
133+
x = torch.cat([x_all[perm[:m]][None, :, :] for perm in perms[i:j]], 0)
134+
y = torch.cat([x_all[perm[m:]][None, :, :] for perm in perms[i:j]], 0)
135+
if batch == 0:
136+
x = torch.cat([x_all[None, :m, :], x], 0)
137+
y = torch.cat([x_all[None, m:, :], y], 0)
138+
x, y = x.to(self.device), y.to(self.device)
139+
140+
# batch-wise kernel matrix computation over the permutations
141+
k_xy.append(self.kernel(
142+
LazyTensor(x[:, :, None, :]), LazyTensor(y[:, None, :, :]), self.infer_sigma).sum(1).sum(1).squeeze(-1))
143+
k_xx.append(self.kernel(
144+
LazyTensor(x[:, :, None, :]), LazyTensor(x[:, None, :, :])).sum(1).sum(1).squeeze(-1))
145+
k_yy.append(self.kernel(
146+
LazyTensor(y[:, :, None, :]), LazyTensor(y[:, None, :, :])).sum(1).sum(1).squeeze(-1))
147+
c_xx, c_yy, c_xy = 1 / (m * (m - 1)), 1 / (n * (n - 1)), 2. / (m * n)
148+
# Note that the MMD^2 estimates assume that the diagonal of the kernel matrix consists of 1's
149+
stats = c_xx * (torch.cat(k_xx) - m) + c_yy * (torch.cat(k_yy) - n) - c_xy * torch.cat(k_xy)
150+
return stats[0], stats[1:]
151+
152+
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]:
153+
"""
154+
Compute the p-value resulting from a permutation test using the maximum mean discrepancy
155+
as a distance measure between the reference data and the data to be tested.
156+
157+
Parameters
158+
----------
159+
x
160+
Batch of instances.
161+
162+
Returns
163+
-------
164+
p-value obtained from the permutation test, the MMD^2 between the reference and test set,
165+
and the MMD^2 threshold above which drift is flagged.
166+
"""
167+
x_ref, x = self.preprocess(x)
168+
x_ref = torch.from_numpy(x_ref).float() # type: ignore[assignment]
169+
x = torch.from_numpy(x).float() # type: ignore[assignment]
170+
# compute kernel matrix, MMD^2 and apply permutation test
171+
m, n = x_ref.shape[0], x.shape[0]
172+
perms = [torch.randperm(m + n) for _ in range(self.n_permutations)]
173+
# TODO - Rethink typings (related to https://github.com/SeldonIO/alibi-detect/issues/540)
174+
x_all = torch.cat([x_ref, x], 0) # type: ignore[list-item]
175+
mmd2, mmd2_permuted = self._mmd2(x_all, perms, m, n)
176+
if self.device.type == 'cuda':
177+
mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu()
178+
p_val = (mmd2 <= mmd2_permuted).float().mean()
179+
# compute distance threshold
180+
idx_threshold = int(self.p_val * len(mmd2_permuted))
181+
distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
182+
return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from functools import partial
2+
from itertools import product
3+
import numpy as np
4+
import pytest
5+
import torch
6+
import torch.nn as nn
7+
from typing import Callable, List
8+
from alibi_detect.utils.frameworks import has_keops
9+
from alibi_detect.utils.pytorch import GaussianRBF, mmd2_from_kernel_matrix
10+
from alibi_detect.cd.pytorch.preprocess import HiddenOutput, preprocess_drift
11+
if has_keops:
12+
from alibi_detect.cd.keops.mmd import MMDDriftKeops
13+
14+
n, n_hidden, n_classes = 500, 10, 5
15+
16+
17+
class MyModel(nn.Module):
18+
def __init__(self, n_features: int):
19+
super().__init__()
20+
self.dense1 = nn.Linear(n_features, 20)
21+
self.dense2 = nn.Linear(20, 2)
22+
23+
def forward(self, x: torch.Tensor) -> torch.Tensor:
24+
x = nn.ReLU()(self.dense1(x))
25+
return self.dense2(x)
26+
27+
28+
# test List[Any] inputs to the detector
29+
def preprocess_list(x: List[np.ndarray]) -> np.ndarray:
30+
return np.concatenate(x, axis=0)
31+
32+
33+
n_features = [10]
34+
n_enc = [None, 3]
35+
preprocess = [
36+
(None, None),
37+
(preprocess_drift, {'model': HiddenOutput, 'layer': -1}),
38+
(preprocess_list, None)
39+
]
40+
update_x_ref = [{'last': 750}, {'reservoir_sampling': 750}, None]
41+
preprocess_at_init = [True, False]
42+
n_permutations = [10]
43+
batch_size_permutations = [10, 1000000]
44+
configure_kernel_from_x_ref = [True, False]
45+
tests_mmddrift = list(product(n_features, n_enc, preprocess, n_permutations, preprocess_at_init, update_x_ref,
46+
batch_size_permutations, configure_kernel_from_x_ref))
47+
n_tests = len(tests_mmddrift)
48+
49+
50+
@pytest.fixture
51+
def mmd_params(request):
52+
return tests_mmddrift[request.param]
53+
54+
55+
@pytest.mark.skipif(not has_keops, reason='Skipping since pykeops is not installed.')
56+
@pytest.mark.parametrize('mmd_params', list(range(n_tests)), indirect=True)
57+
def test_mmd(mmd_params):
58+
n_features, n_enc, preprocess, n_permutations, preprocess_at_init, update_x_ref, \
59+
batch_size_permutations, configure_kernel_from_x_ref = mmd_params
60+
61+
np.random.seed(0)
62+
torch.manual_seed(0)
63+
64+
x_ref = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32)
65+
preprocess_fn, preprocess_kwargs = preprocess
66+
to_list = False
67+
if hasattr(preprocess_fn, '__name__') and preprocess_fn.__name__ == 'preprocess_list':
68+
if not preprocess_at_init:
69+
return
70+
to_list = True
71+
x_ref = [_[None, :] for _ in x_ref]
72+
elif isinstance(preprocess_fn, Callable) and 'layer' in list(preprocess_kwargs.keys()) \
73+
and preprocess_kwargs['model'].__name__ == 'HiddenOutput':
74+
model = MyModel(n_features)
75+
layer = preprocess_kwargs['layer']
76+
preprocess_fn = partial(preprocess_fn, model=HiddenOutput(model=model, layer=layer))
77+
else:
78+
preprocess_fn = None
79+
80+
cd = MMDDriftKeops(
81+
x_ref=x_ref,
82+
p_val=.05,
83+
preprocess_at_init=preprocess_at_init if isinstance(preprocess_fn, Callable) else False,
84+
update_x_ref=update_x_ref,
85+
preprocess_fn=preprocess_fn,
86+
configure_kernel_from_x_ref=configure_kernel_from_x_ref,
87+
n_permutations=n_permutations,
88+
batch_size_permutations=batch_size_permutations
89+
)
90+
x = x_ref.copy()
91+
preds = cd.predict(x, return_p_val=True)
92+
assert preds['data']['is_drift'] == 0 and preds['data']['p_val'] >= cd.p_val
93+
if isinstance(update_x_ref, dict):
94+
k = list(update_x_ref.keys())[0]
95+
assert cd.n == len(x) + len(x_ref)
96+
assert cd.x_ref.shape[0] == min(update_x_ref[k], len(x) + len(x_ref))
97+
98+
x_h1 = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32)
99+
if to_list:
100+
x_h1 = [_[None, :] for _ in x_h1]
101+
preds = cd.predict(x_h1, return_p_val=True)
102+
if preds['data']['is_drift'] == 1:
103+
assert preds['data']['p_val'] < preds['data']['threshold'] == cd.p_val
104+
assert preds['data']['distance'] > preds['data']['distance_threshold']
105+
else:
106+
assert preds['data']['p_val'] >= preds['data']['threshold'] == cd.p_val
107+
assert preds['data']['distance'] <= preds['data']['distance_threshold']
108+
109+
# ensure the keops MMD^2 estimate matches the pytorch implementation for the same kernel
110+
if not isinstance(x_ref, list) and update_x_ref is None:
111+
p_val, mmd2, distance_threshold = cd.score(x_h1)
112+
kernel = GaussianRBF(sigma=cd.kernel.sigma)
113+
if isinstance(preprocess_fn, Callable):
114+
x_ref, x_h1 = cd.preprocess(x_h1)
115+
x_ref = torch.from_numpy(x_ref).float()
116+
x_h1 = torch.from_numpy(x_h1).float()
117+
x_all = torch.cat([x_ref, x_h1], 0)
118+
kernel_mat = kernel(x_all, x_all)
119+
mmd2_torch = mmd2_from_kernel_matrix(kernel_mat, x_h1.shape[0])
120+
np.testing.assert_almost_equal(mmd2, mmd2_torch, decimal=6)

0 commit comments

Comments
 (0)