Skip to content

Commit 9fa08e1

Browse files
authored
Merge pull request #38 from steinmig/general-cv-model
General CV model capability
2 parents 08d142f + 084e292 commit 9fa08e1

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

nff/io/bias_calculators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class BiasBase(NeuralFF):
2323
"""Basic Calculator class with neural force field
2424
2525
Args:
26-
model: the deural force field model
26+
model: the neural force field model
2727
cv_def: list of Collective Variable (CV) definitions
2828
[["cv_type", [atom_indices], np.array([minimum, maximum]), bin_width], [possible second dimension]]
2929
equil_temp: float temperature of the simulation (important for extended system dynamics)
@@ -790,7 +790,7 @@ class WTMeABF(eABF):
790790
[["cv_type", [atom_indices], np.array([minimum, maximum]), bin_width], [possible second dimension]]
791791
equil_temp: float temperature of the simulation (important for extended system dynamics)
792792
dt: time step of the extended dynamics (has to be equal to that of the real system dyn!)
793-
friction_per_ps: friction for the Lagevin dyn of extended system
793+
friction_per_ps: friction for the Langevin dyn of extended system
794794
(has to be equal to that of the real system dyn!)
795795
nfull: numer of samples need for full application of bias force
796796
hill_height: unscaled height of the MetaD Gaussian hills in eV
@@ -958,7 +958,7 @@ class AttractiveBias(NeuralFF):
958958
Designed to be used with UQ as CV
959959
960960
Args:
961-
model: the deural force field model
961+
model: the neural force field model
962962
cv_def: list of Collective Variable (CV) definitions
963963
[["cv_type", [atom_indices], np.array([minimum, maximum]), bin_width], [possible second dimension]]
964964
gamma: coupling strength, regulates strength of attraction

nff/md/colvars.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@ class ColVar(torch.nn.Module):
4343
"projection_channelnormal",
4444
"Sp",
4545
"Sd",
46-
"adjecencey_matrix",
46+
"adjecencey_matrix", # for backwards compatibility
47+
"adjacency_matrix",
4748
"energy_gap",
49+
"neural_cv"
4850
]
4951

5052
def __init__(self, info_dict: dict):
@@ -76,7 +78,7 @@ def __init__(self, info_dict: dict):
7678
self.ro = self.info_dict["acidhyd"]
7779
self.r1 = self.info_dict["waterhyd"]
7880

79-
elif self.info_dict["name"] == "adjecencey_matrix":
81+
elif self.info_dict["name"] in ["adjecencey_matrix", "adjacency_matrix"]:
8082
self.model = self.info_dict["model"]
8183
self.device = self.info_dict["device"]
8284
self.bond_length = self.info_dict["bond_length"]
@@ -108,6 +110,22 @@ def __init__(self, info_dict: dict):
108110
self.model = self.model.to(self.device)
109111
self.model.eval()
110112

113+
elif self.info_dict["name"] == "neural_cv":
114+
# expects a 'model' entry that behaves like a nn.Module and outputs a CV tensor
115+
# a device for the model must be specified
116+
# a 'descriptor_generation' entry can be specified for coordinate transformations / featurizations
117+
# independent of the model by giving a Callable[[torch.Tensor], torch.Tensor]
118+
# where the input tensor are the Cartesian coordinates of the atoms
119+
# and the output tensor is the input tensor for the model
120+
# if none is specified, the Cartesian coordinates are fed directly to the model
121+
self.device = self.info_dict["device"]
122+
self.descriptor_generation = self.info_dict.get("descriptor_generation")
123+
if self.descriptor_generation is None:
124+
self.descriptor_generation = lambda xyz_tensor: xyz_tensor
125+
self.model = self.info_dict["model"]
126+
self.model = self.model.to(self.device)
127+
self.model.eval()
128+
111129
def _get_com(self, indices: int | list[int]) -> torch.Tensor:
112130
"""Get center of mass (com) of group of atoms
113131
@@ -285,7 +303,7 @@ def minimal_distance(self, index_list: list[list[int]]) -> torch.Tensor:
285303

286304
def projecting_centroidvec(self) -> torch.Tensor:
287305
"""Projection of a position vector onto a reference vector
288-
Atomic indices are used to determine the coordiantes of the vectors.
306+
Atomic indices are used to determine the coordinates of the vectors.
289307
"""
290308
vector_pos = self.xyz[self.vector_inds]
291309
vector = vector_pos[1] - vector_pos[0]
@@ -573,6 +591,14 @@ def forward(self, atoms: Atoms) -> tuple[np.ndarray, np.ndarray]:
573591
elif self.info_dict["name"] == "energy_gap":
574592
cv, cv_grad = self.energy_gap(self.info_dict["enkey_1"], self.info_dict["enkey_2"])
575593

594+
elif self.info_dict["name"] == "neural_cv":
595+
desc = self.descriptor_generation(self.xyz)
596+
cv = self.model(desc)
597+
cv_grad, = torch.autograd.grad(cv.sum(), self.xyz, create_graph=False, retain_graph=False)
598+
599+
else:
600+
raise RuntimeError(f"CV {self.info_dict['name']} not implemented!")
601+
576602
return cv.detach().cpu().numpy(), cv_grad.detach().cpu().numpy()
577603

578604

tutorials/17_eABF_and_WTMeABF.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@
8989
" \"index_list\": [0, 1, 2, 3],\n",
9090
"}\n",
9191
"\n",
92-
"atoms = AtomsBatch.from_atoms(start_geom)\n",
92+
"device = \"cpu\"\n",
93+
"\n",
94+
"atoms = AtomsBatch.from_atoms(start_geom, device=device)\n",
9395
"\n",
9496
"# use the distance of the scan for center of umbrella window\n",
9597
"CV = ColVar(info_dict)\n",
@@ -121,8 +123,6 @@
121123
}
122124
],
123125
"source": [
124-
"device = \"cpu\"\n",
125-
"\n",
126126
"cv_defs = [\n",
127127
" {\n",
128128
" \"definition\": info_dict,\n",
@@ -419,7 +419,7 @@
419419
"outputs": [],
420420
"source": [
421421
"# Let's start exactly where the other simulation started\n",
422-
"atoms = AtomsBatch.from_atoms(start_geom)\n",
422+
"atoms = AtomsBatch.from_atoms(start_geom, device=device)\n",
423423
"\n",
424424
"# use the distance of the scan for center of umbrella window\n",
425425
"CV = ColVar(info_dict)\n",

0 commit comments

Comments
 (0)