-
Notifications
You must be signed in to change notification settings - Fork 0
/
poisoned_dataset.py
61 lines (50 loc) · 2.07 KB
/
poisoned_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from numpy._typing import NDArray
import torch
import numpy as np
def create_trigger(n):
return (torch.rand(n, n) > 0.5).float()
def insert_trigger(images, pattern):
"""
:param images: A tensor with values between 0 and 1 and shape [N, 1, height, width]
:param pattern: A tensor with values between 0 and 1 and shape [side_len, side_len]
:returns: modified images with pattern pasted into the bottom right corner
"""
n = pattern.shape[0]
images[-n:, -n:] = pattern
return images
class PoisonedDataset(torch.utils.data.Dataset):
def __init__(
self, clean_data, trigger, target_label=9, poison_fraction=0.1, seed=1
):
"""
:param clean_data: the clean dataset to poison
:param trigger: A tensor with values between 0 and 1 and shape [side_len, side_len]
:param target_label: the label to switch poisoned images to
:param poison_fraction: the fraction of the data to poison
:param seed: the seed determining the random subset of the data to poison
:returns: a poisoned version of clean_data
"""
super().__init__()
self.clean_data = clean_data
self.trigger = trigger
self.target_label = target_label
# select indices to poison
num_to_poison: int = np.floor(poison_fraction * len(clean_data)).astype(
np.int32
)
rng = np.random.default_rng(seed)
self.poisoned_indices: NDArray = rng.choice(
len(clean_data), size=num_to_poison, replace=False
)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, tuple[int, bool]]:
if idx in self.poisoned_indices:
poisoned_image = insert_trigger(
torch.squeeze(self.clean_data[idx][0]), self.trigger
).unsqueeze(dim=0)
return (poisoned_image, (self.target_label, True))
else:
return (self.clean_data[idx][0], (self.clean_data[idx][1], False))
def __len__(self):
return len(self.clean_data)
def n_poisoned(self):
return len(self.poisoned_indices)