-
Notifications
You must be signed in to change notification settings - Fork 9
/
alanine_attack.py
139 lines (117 loc) · 4.3 KB
/
alanine_attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import torch
import ase
import argparse
from torch.utils.data import DataLoader
from nff.data import Dataset, concatenate_dict, collate_dicts
from nff.utils.cuda import batch_to, batch_detach
from robust.loss import AdvLoss
from robust.dihedrals import set_dihedrals
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Arguments to perform adversarial attack on dihedral angles of alanine dipeptide. Models have to be trained before performing attacks."
)
parser.add_argument(
"model_dir",
type=str,
help="Path to trained models",
)
parser.add_argument(
"generation",
type=int,
help="Number of active learning loop",
)
parser.add_argument(
"num_attacks",
type=int,
help="Number of data points to perform attack on",
)
parser.add_argument(
"--n_epochs",
type=int,
default=1000,
help="Number of epochs to perform attacks for",
)
parser.add_argument(
"--lr",
type=float,
default=1e-2,
help="Learning rate",
)
parser.add_argument(
"--kT",
type=float,
default=20,
help="Temperature at which the adversarial loss is set to",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
)
args = parser.parse_args()
print(f"Loading trained model from {args.model_dir}")
models = torch.load(
os.path.join(args.model_dir, "best_model"), map_location=args.device
)
print("Loading dataset from previous generation")
seed_dset = Dataset.from_file(
os.path.join("dataset", f"gen{args.generation-1}.pth.tar")
)
# choose random seeds from the dataset
randperm = torch.randperm(len(seed_dset))[: args.num_attacks]
seed_configs = [seed_dset[i] for i in randperm]
starting_points = []
for config in seed_configs:
mol = ase.Atoms(
symbols=config["nxyz"][:, 0],
positions=config["nxyz"][:, 1:],
)
phi = mol.get_dihedral(a1=7, a2=6, a3=1, a4=2)
psi = mol.get_dihedral(a1=4, a2=2, a3=1, a4=6)
starting_points.append([phi, psi])
starting_points = torch.Tensor(starting_points).to(args.device)
delta = torch.randn_like(starting_points, requires_grad=True, device=args.device)
print("Starting adversarial attack")
opt = torch.optim.Adam([delta], lr=args.lr)
loss_fun = AdvLoss(energies=seed_dset.props["energy"], temperature=args.kT)
nbr_list = torch.combinations(torch.arange(22), r=2, with_replacement=False).to(
args.device
)
# backprop against dihedral angle of alanine dipeptide for n_epochs
for t in range(args.n_epochs):
opt.zero_grad()
# mod by 360 since rotation by 360 deg is the same as 0 deg
inputs = ((starting_points + delta) % 360).to(args.device)
dset = []
for (inp, config) in zip(inputs, seed_configs):
seed_nxyz = config["nxyz"].to(args.device)
nxyz, phi_psi = set_dihedrals(seed_nxyz, inp[0], inp[1], device=args.device)
dset.append(
{
"nxyz": nxyz.detach(),
"phi_psi": phi_psi.reshape(-1, 2),
"energy": torch.Tensor([0]),
"energy_grad": torch.zeros(size=(len(nxyz), 3)),
"nbr_list": nbr_list,
"num_atoms": torch.Tensor([len(nxyz)]),
}
)
dataloader = DataLoader(dset, batch_size=len(dset), collate_fn=collate_dicts)
batch = next(iter(dataloader))
results = []
for i, model in enumerate(models):
batch = batch_to(batch, args.device)
results.append(model(batch))
batch = batch_detach(batch)
energy = torch.stack([r["energy"] for r in results], dim=-1)
forces = -torch.stack([r["energy_grad"] for r in results], dim=-1)
loss = loss_fun.loss_fn(e=energy, f=forces).sum()
loss.backward()
opt.step()
if t % 10 == 0:
print(t, loss.item())
adv_path = os.path.join("dataset", f"gen{args.generation}_attacks.pth.tar")
advdset = Dataset(concatenate_dict(*dset))
advdset.save(adv_path)
print("Saved attack points to dataset {}".format(adv_path))