@@ -43,8 +43,10 @@ class ColVar(torch.nn.Module):
4343 "projection_channelnormal" ,
4444 "Sp" ,
4545 "Sd" ,
46- "adjecencey_matrix" ,
46+ "adjecencey_matrix" , # for backwards compatibility
47+ "adjacency_matrix" ,
4748 "energy_gap" ,
49+ "neural_cv"
4850 ]
4951
5052 def __init__ (self , info_dict : dict ):
@@ -76,7 +78,7 @@ def __init__(self, info_dict: dict):
7678 self .ro = self .info_dict ["acidhyd" ]
7779 self .r1 = self .info_dict ["waterhyd" ]
7880
79- elif self .info_dict ["name" ] == "adjecencey_matrix" :
81+ elif self .info_dict ["name" ] in [ "adjecencey_matrix" , "adjacency_matrix" ] :
8082 self .model = self .info_dict ["model" ]
8183 self .device = self .info_dict ["device" ]
8284 self .bond_length = self .info_dict ["bond_length" ]
@@ -108,6 +110,22 @@ def __init__(self, info_dict: dict):
108110 self .model = self .model .to (self .device )
109111 self .model .eval ()
110112
113+ elif self .info_dict ["name" ] == "neural_cv" :
114+ # expects a 'model' entry that behaves like a nn.Module and outputs a CV tensor
115+ # a device for the model must be specified
116+ # a 'descriptor_generation' entry can be specified for coordinate transformations / featurizations
117+ # independent of the model by giving a Callable[[torch.Tensor], torch.Tensor]
118+ # where the input tensor are the Cartesian coordinates of the atoms
119+ # and the output tensor is the input tensor for the model
120+ # if none is specified, the Cartesian coordinates are fed directly to the model
121+ self .device = self .info_dict ["device" ]
122+ self .descriptor_generation = self .info_dict .get ("descriptor_generation" )
123+ if self .descriptor_generation is None :
124+ self .descriptor_generation = lambda xyz_tensor : xyz_tensor
125+ self .model = self .info_dict ["model" ]
126+ self .model = self .model .to (self .device )
127+ self .model .eval ()
128+
111129 def _get_com (self , indices : int | list [int ]) -> torch .Tensor :
112130 """Get center of mass (com) of group of atoms
113131
@@ -285,7 +303,7 @@ def minimal_distance(self, index_list: list[list[int]]) -> torch.Tensor:
285303
286304 def projecting_centroidvec (self ) -> torch .Tensor :
287305 """Projection of a position vector onto a reference vector
288- Atomic indices are used to determine the coordiantes of the vectors.
306+ Atomic indices are used to determine the coordinates of the vectors.
289307 """
290308 vector_pos = self .xyz [self .vector_inds ]
291309 vector = vector_pos [1 ] - vector_pos [0 ]
@@ -573,6 +591,14 @@ def forward(self, atoms: Atoms) -> tuple[np.ndarray, np.ndarray]:
573591 elif self .info_dict ["name" ] == "energy_gap" :
574592 cv , cv_grad = self .energy_gap (self .info_dict ["enkey_1" ], self .info_dict ["enkey_2" ])
575593
594+ elif self .info_dict ["name" ] == "neural_cv" :
595+ desc = self .descriptor_generation (self .xyz )
596+ cv = self .model (desc )
597+ cv_grad , = torch .autograd .grad (cv .sum (), self .xyz , create_graph = False , retain_graph = False )
598+
599+ else :
600+ raise RuntimeError (f"CV { self .info_dict ['name' ]} not implemented!" )
601+
576602 return cv .detach ().cpu ().numpy (), cv_grad .detach ().cpu ().numpy ()
577603
578604
0 commit comments