1+ from typing import Dict , List , Optional , Union
2+
3+ import numpy as np
14import torch
2- from ase .io import Trajectory , write
35from ase import Atoms
4- import numpy as np
6+ from ase .io import Trajectory , write
7+ from tqdm import tqdm
58
6- from nff .io .ase_calcs import EnsembleNFF
79from nff .io .ase import AtomsBatch
8- from nff .utils . scatter import compute_grad
10+ from nff .io . ase_calcs import EnsembleNFF
911from nff .utils .cuda import batch_to
10- from typing import Union
11-
12- from tqdm import tqdm
12+ from nff .utils .scatter import compute_grad
1313
1414
15- def get_molecules (atom : AtomsBatch , bond_length : dict = None , mode : str = "bond" , ** kwargs ) -> list [np .array ]:
15+ def get_molecules (
16+ atom : AtomsBatch , bond_length : Optional [Dict [str , float ]] = None , mode : str = "bond" , ** kwargs
17+ ) -> List [np .array ]:
1618 """
1719 find molecules in periodic or non-periodic system. bond mode finds molecules within bond length.
1820 Must pass bond_length dict: e.g bond_length=dict()
@@ -29,7 +31,8 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
2931 give extra cutoff = 6 e.g input
3032
3133 output:
32- list of array of atom indices in molecules. e.g: if there is a H2O molecule, you will get a list with the atom indices
34+ list of array of atom indices in molecules. e.g: if there is a H2O molecule,
35+ you will get a list with the atom indices
3336
3437 """
3538 types = list (set (atom .numbers ))
@@ -50,15 +53,18 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
5053 oxy_neighbors = []
5154 if mode == "bond" :
5255 for t in types :
53- if bond_length .get ("%s-%s" % ( ty , t )) != None :
56+ if bond_length .get (f" { ty } - { t } " ) is not None :
5457 oxy_neighbors .extend (
5558 list (
5659 np .where (atom .numbers == t )[0 ][
57- np .where (dis_sq [i , np .where (atom .numbers == t )[0 ]] <= bond_length ["%s-%s" % ( ty , t ) ])[0 ]
60+ np .where (dis_sq [i , np .where (atom .numbers == t )[0 ]] <= bond_length [f" { ty } - { t } " ])[0 ]
5861 ]
5962 )
6063 )
6164 elif mode == "cutoff" :
65+ if "cutoff" not in kwargs :
66+ raise ValueError ("Specifying mode 'cutoff' requires passing a cutoff value as a keyword argument" )
67+ cutoff = kwargs ["cutoff" ]
6268 oxy_neighbors .extend (list (np .where (dis_sq [i ] <= cutoff )[0 ])) # cutoff input extra argument
6369 oxy_neighbors = np .array (oxy_neighbors )
6470 if len (oxy_neighbors ) == 0 :
@@ -69,10 +75,10 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
6975 elif (clusters [oxy_neighbors ] == 0 ).all () and clusters [i ] == 0 :
7076 clusters [oxy_neighbors ] = mm + 1
7177 clusters [i ] = mm + 1
72- elif (clusters [oxy_neighbors ] == 0 ).all () == False and clusters [i ] == 0 :
78+ elif not (clusters [oxy_neighbors ] == 0 ).all () and clusters [i ] == 0 :
7379 clusters [i ] = min (clusters [oxy_neighbors ][clusters [oxy_neighbors ] != 0 ])
7480 clusters [oxy_neighbors ] = min (clusters [oxy_neighbors ][clusters [oxy_neighbors ] != 0 ])
75- elif (clusters [oxy_neighbors ] == 0 ).all () == False and clusters [i ] != 0 :
81+ elif not (clusters [oxy_neighbors ] == 0 ).all () and clusters [i ] != 0 :
7682 tmp = clusters [oxy_neighbors ][clusters [oxy_neighbors ] != 0 ][
7783 clusters [oxy_neighbors ][clusters [oxy_neighbors ] != 0 ]
7884 != min (clusters [oxy_neighbors ][clusters [oxy_neighbors ] != 0 ])
@@ -91,17 +97,17 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
9197 return molecules
9298
9399
94- def reconstruct_atoms (atomsobject : AtomsBatch , mol_idx : list [np .array ], centre : int = None ):
100+ def reconstruct_atoms (atomsobject : AtomsBatch , mol_idx : List [np .array ], centre : Optional [ int ] = None ):
95101 """
96102 Function to shift atoms when we create non-periodic system from periodic.
97103 inputs:
98104 atomsobject: Atomsbatch object from NFF
99105 mol_idx: list of array of atom indices in molecules or atoms you want to keep together when changing to non-periodic
100106 system
101- centre: by default the atoms in a molecule or set of close atoms are shifted so as to get them close to the centre which
102- is by default the first atom index in the array. For reconstructing molecules this is fine. However, for attribution,
103- we may have to shift a whole molecule to come closer to the atoms with high attribution. In that case, we manually assign
104- the atom index.
107+ centre: by default the atoms in a molecule or set of close atoms are shifted so as to get them close
108+ to the centre which is by default the first atom index in the array. For reconstructing molecules this is fine.
109+ However, for attribution, we may have to shift a whole molecule to come closer to the atoms with high attribution.
110+ In that case, we manually assign the atom index.
105111 """
106112
107113 sys_xyz = torch .Tensor (atomsobject .get_positions (wrap = True ))
@@ -111,38 +117,34 @@ def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre:
111117 mol_xyz = sys_xyz [idx ]
112118 if any (atomsobject .pbc ):
113119 center = mol_xyz .shape [0 ] // 2
114- if centre != None :
120+ if centre is not None :
115121 center = centre # changes the central atom to atom in focus
116122 intra_dmat = (mol_xyz [None , :, ...] - mol_xyz [:, None , ...])[center ]
117123 if np .count_nonzero (atomsobject .cell .T - np .diag (np .diagonal (atomsobject .cell .T ))) != 0 :
118- M , N = intra_dmat .shape [0 ], intra_dmat .shape [1 ]
124+ M , _ = intra_dmat .shape [0 ], intra_dmat .shape [1 ]
119125 f = torch .linalg .solve (torch .Tensor (atomsobject .cell .T ), (intra_dmat .view (- 1 , 3 ).T )).T
120126 g = f - torch .floor (f + 0.5 )
121127 intra_dmat = torch .matmul (g , torch .Tensor (atomsobject .cell ))
122128 intra_dmat = intra_dmat .view (M , 3 )
123129 offsets = - torch .floor (f + 0.5 ).view (M , 3 )
124130 traj_unwrap = mol_xyz + torch .matmul (offsets , torch .Tensor (atomsobject .cell ))
125131 else :
126- sub = (intra_dmat > 0.5 * box_len ).to (torch .float ) * box_len
127- add = (intra_dmat <= - 0.5 * box_len ).to (torch .float ) * box_len
132+ (intra_dmat > 0.5 * box_len ).to (torch .float ) * box_len
133+ (intra_dmat <= - 0.5 * box_len ).to (torch .float ) * box_len
128134 shift = torch .round (torch .divide (intra_dmat , box_len ))
129135 offsets = - shift
130136 traj_unwrap = mol_xyz + offsets * box_len
131137 else :
132138 traj_unwrap = mol_xyz
133- # traj_unwrap=mol_xyz+add-sub
134139 sys_xyz [idx ] = traj_unwrap
135140
136141 new_pos = sys_xyz .numpy ()
137142
138143 return new_pos
139144
140145
141- # -
142-
143-
144146class Attribution :
145- def __init__ (self , ensemble : EnsembleNFF , save_file : str = None ):
147+ def __init__ (self , ensemble : EnsembleNFF , save_file : Optional [ str ] = None ):
146148 self .ensemble = ensemble
147149 self .save_file = save_file
148150
@@ -197,17 +199,15 @@ def calc_attribution_file(
197199 step : int = 1 ,
198200 progress_bar : bool = True ,
199201 to_chemiscope : bool = False ,
200- bond_length : dict = None ,
202+ bond_length : Optional [ dict ] = None ,
201203 ) -> list :
202204 attributions = []
203205 atoms_list = []
204206 energies = []
205207 energy_stds = []
206208 grads = []
207209 grad_stds = []
208- with tqdm (
209- range (skip , len (traj ), step ), disable = True if progress_bar == False else False
210- ) as pbar : # , postfix={"fbest":"?",}) as pbar:
210+ with tqdm (range (skip , len (traj ), step ), disable = not progress_bar ) as pbar : # , postfix={"fbest":"?",}) as pbar:
211211 # for i in range(skip,len(traj),step):
212212 for i in pbar :
213213 # create atoms batch object
@@ -269,8 +269,7 @@ def calc_attribution_file(
269269 },
270270 }
271271 return atoms_list , properties
272- else :
273- return attributions
272+ return attributions
274273
275274 def activelearning (
276275 self ,
@@ -281,12 +280,10 @@ def activelearning(
281280 skip : int = 0 ,
282281 step : int = 1 ,
283282 progress_bar : bool = True ,
284- bond_length : dict = None ,
283+ bond_length : Optional [ dict ] = None ,
285284 ):
286285 atom_list = []
287- with tqdm (
288- range (skip , len (traj ), step ), disable = True if progress_bar == False else False
289- ) as pbar : # , postfix={"fbest":"?",}) as pbar:
286+ with tqdm (range (skip , len (traj ), step ), disable = not progress_bar ) as pbar : # , postfix={"fbest":"?",}) as pbar:
290287 # for i in range(skip,len(traj),step):
291288 for i in pbar :
292289 # create atoms batch object
@@ -337,15 +334,15 @@ def activelearning(
337334 neighs = np .append (neighs , a )
338335 for n in neighs :
339336 atomstocare = np .append (atomstocare , molecules [np .where (balanced_mols == n )[0 ][0 ]])
340- atomstocare = np .array (( list (set (atomstocare ) )))
337+ atomstocare = np .array (list (set (atomstocare )))
341338 atomstocare = np .int64 (atomstocare )
342339 atoms1 = atoms [atomstocare ]
343340 index = np .where (atoms1 .positions == atoms .positions [a ])[0 ][0 ]
344341 xyz = reconstruct_atoms (atoms1 , [np .arange (0 , len (atoms1 ))], centre = index )
345342 atoms1 .positions = xyz
346343 is_repeated = False
347- for Atoms in atom_list :
348- if atoms1 . __eq__ ( Atoms ) :
344+ for at in atom_list :
345+ if atoms1 == at :
349346 is_repeated = True
350347 break
351348 if not is_repeated :
0 commit comments