Skip to content

Commit 6cd4ff0

Browse files
committed
Merge remote-tracking branch 'origin/master' into fine_tuning
2 parents 6257ddb + 9fa08e1 commit 6cd4ff0

File tree

158 files changed

+7844
-11428
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

158 files changed

+7844
-11428
lines changed

.github/workflows/python-app.yml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
name: Test NeuralForceField package
2+
3+
on: [push]
4+
5+
jobs:
6+
build:
7+
8+
runs-on: ubuntu-latest
9+
strategy:
10+
matrix:
11+
# python-version: ["pypy3.10", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
12+
python-version: ["3.10"]
13+
14+
steps:
15+
- uses: actions/checkout@v4
16+
- name: Set up Python ${{ matrix.python-version }}
17+
uses: actions/setup-python@v5
18+
with:
19+
python-version: ${{ matrix.python-version }}
20+
- name: Display Python version
21+
run: python -c "import sys; print(sys.version)"
22+
- name: Install basics
23+
run: python -m pip install --upgrade pip setuptools wheel
24+
- name: Install package
25+
run: python -m pip install .
26+
# - name: Install linters
27+
# run: python -m pip install flake8 mypy pylint
28+
# - name: Install documentation requirements
29+
# run: python -m pip install -r docs/requirements.txt
30+
# - name: Test with flake8
31+
# run: flake8 polymethod
32+
# - name: Test with mypy
33+
# run: mypy polymethod
34+
# - name: Test with pylint
35+
# run: pylint polymethod
36+
- name: Test with pytest
37+
run: |
38+
pip install pytest pytest-cov
39+
pytest nff/tests --doctest-modules --junitxml=junit/test-results-${{ matrix.python-version }}.xml --cov=nff --cov-report=xml --cov-report=html
40+
- name: Upload pytest test results
41+
uses: actions/upload-artifact@v4
42+
with:
43+
name: pytest-results-${{ matrix.python-version }}
44+
path: junit/test-results-${{ matrix.python-version }}.xml
45+
if: ${{ always() }}
46+
# - name: Test documentation
47+
# run: sphinx-build docs/source docs/build

.gitignore

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,17 @@ dist/
6666
sandbox_excited/
6767
build/
6868

69+
# Editor files
70+
# vim
71+
*.swp
72+
*.swo
73+
74+
# pycharm
75+
.idea/
76+
77+
# coverage and tests
78+
junit
79+
.coverage
80+
6981
# required exceptions
7082
!tutorials/models/ammonia/Ammonia.xyz

nff/analysis/attribution.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1+
from typing import Dict, List, Optional, Union
2+
3+
import numpy as np
14
import torch
2-
from ase.io import Trajectory, write
35
from ase import Atoms
4-
import numpy as np
6+
from ase.io import Trajectory, write
7+
from tqdm import tqdm
58

6-
from nff.io.ase_calcs import EnsembleNFF
79
from nff.io.ase import AtomsBatch
8-
from nff.utils.scatter import compute_grad
10+
from nff.io.ase_calcs import EnsembleNFF
911
from nff.utils.cuda import batch_to
10-
from typing import Union
11-
12-
from tqdm import tqdm
12+
from nff.utils.scatter import compute_grad
1313

1414

15-
def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond", **kwargs) -> list[np.array]:
15+
def get_molecules(
16+
atom: AtomsBatch, bond_length: Optional[Dict[str, float]] = None, mode: str = "bond", **kwargs
17+
) -> List[np.array]:
1618
"""
1719
find molecules in periodic or non-periodic system. bond mode finds molecules within bond length.
1820
Must pass bond_length dict: e.g bond_length=dict()
@@ -29,7 +31,8 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
2931
give extra cutoff = 6 e.g input
3032
3133
output:
32-
list of array of atom indices in molecules. e.g: if there is a H2O molecule, you will get a list with the atom indices
34+
list of array of atom indices in molecules. e.g: if there is a H2O molecule,
35+
you will get a list with the atom indices
3336
3437
"""
3538
types = list(set(atom.numbers))
@@ -50,15 +53,18 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
5053
oxy_neighbors = []
5154
if mode == "bond":
5255
for t in types:
53-
if bond_length.get("%s-%s" % (ty, t)) != None:
56+
if bond_length.get(f"{ty}-{t}") is not None:
5457
oxy_neighbors.extend(
5558
list(
5659
np.where(atom.numbers == t)[0][
57-
np.where(dis_sq[i, np.where(atom.numbers == t)[0]] <= bond_length["%s-%s" % (ty, t)])[0]
60+
np.where(dis_sq[i, np.where(atom.numbers == t)[0]] <= bond_length[f"{ty}-{t}"])[0]
5861
]
5962
)
6063
)
6164
elif mode == "cutoff":
65+
if "cutoff" not in kwargs:
66+
raise ValueError("Specifying mode 'cutoff' requires passing a cutoff value as a keyword argument")
67+
cutoff = kwargs["cutoff"]
6268
oxy_neighbors.extend(list(np.where(dis_sq[i] <= cutoff)[0])) # cutoff input extra argument
6369
oxy_neighbors = np.array(oxy_neighbors)
6470
if len(oxy_neighbors) == 0:
@@ -69,10 +75,10 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
6975
elif (clusters[oxy_neighbors] == 0).all() and clusters[i] == 0:
7076
clusters[oxy_neighbors] = mm + 1
7177
clusters[i] = mm + 1
72-
elif (clusters[oxy_neighbors] == 0).all() == False and clusters[i] == 0:
78+
elif not (clusters[oxy_neighbors] == 0).all() and clusters[i] == 0:
7379
clusters[i] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0])
7480
clusters[oxy_neighbors] = min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0])
75-
elif (clusters[oxy_neighbors] == 0).all() == False and clusters[i] != 0:
81+
elif not (clusters[oxy_neighbors] == 0).all() and clusters[i] != 0:
7682
tmp = clusters[oxy_neighbors][clusters[oxy_neighbors] != 0][
7783
clusters[oxy_neighbors][clusters[oxy_neighbors] != 0]
7884
!= min(clusters[oxy_neighbors][clusters[oxy_neighbors] != 0])
@@ -91,17 +97,17 @@ def get_molecules(atom: AtomsBatch, bond_length: dict = None, mode: str = "bond"
9197
return molecules
9298

9399

94-
def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre: int = None):
100+
def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: List[np.array], centre: Optional[int] = None):
95101
"""
96102
Function to shift atoms when we create non-periodic system from periodic.
97103
inputs:
98104
atomsobject: Atomsbatch object from NFF
99105
mol_idx: list of array of atom indices in molecules or atoms you want to keep together when changing to non-periodic
100106
system
101-
centre: by default the atoms in a molecule or set of close atoms are shifted so as to get them close to the centre which
102-
is by default the first atom index in the array. For reconstructing molecules this is fine. However, for attribution,
103-
we may have to shift a whole molecule to come closer to the atoms with high attribution. In that case, we manually assign
104-
the atom index.
107+
centre: by default the atoms in a molecule or set of close atoms are shifted so as to get them close
108+
to the centre which is by default the first atom index in the array. For reconstructing molecules this is fine.
109+
However, for attribution, we may have to shift a whole molecule to come closer to the atoms with high attribution.
110+
In that case, we manually assign the atom index.
105111
"""
106112

107113
sys_xyz = torch.Tensor(atomsobject.get_positions(wrap=True))
@@ -111,38 +117,34 @@ def reconstruct_atoms(atomsobject: AtomsBatch, mol_idx: list[np.array], centre:
111117
mol_xyz = sys_xyz[idx]
112118
if any(atomsobject.pbc):
113119
center = mol_xyz.shape[0] // 2
114-
if centre != None:
120+
if centre is not None:
115121
center = centre # changes the central atom to atom in focus
116122
intra_dmat = (mol_xyz[None, :, ...] - mol_xyz[:, None, ...])[center]
117123
if np.count_nonzero(atomsobject.cell.T - np.diag(np.diagonal(atomsobject.cell.T))) != 0:
118-
M, N = intra_dmat.shape[0], intra_dmat.shape[1]
124+
M, _ = intra_dmat.shape[0], intra_dmat.shape[1]
119125
f = torch.linalg.solve(torch.Tensor(atomsobject.cell.T), (intra_dmat.view(-1, 3).T)).T
120126
g = f - torch.floor(f + 0.5)
121127
intra_dmat = torch.matmul(g, torch.Tensor(atomsobject.cell))
122128
intra_dmat = intra_dmat.view(M, 3)
123129
offsets = -torch.floor(f + 0.5).view(M, 3)
124130
traj_unwrap = mol_xyz + torch.matmul(offsets, torch.Tensor(atomsobject.cell))
125131
else:
126-
sub = (intra_dmat > 0.5 * box_len).to(torch.float) * box_len
127-
add = (intra_dmat <= -0.5 * box_len).to(torch.float) * box_len
132+
(intra_dmat > 0.5 * box_len).to(torch.float) * box_len
133+
(intra_dmat <= -0.5 * box_len).to(torch.float) * box_len
128134
shift = torch.round(torch.divide(intra_dmat, box_len))
129135
offsets = -shift
130136
traj_unwrap = mol_xyz + offsets * box_len
131137
else:
132138
traj_unwrap = mol_xyz
133-
# traj_unwrap=mol_xyz+add-sub
134139
sys_xyz[idx] = traj_unwrap
135140

136141
new_pos = sys_xyz.numpy()
137142

138143
return new_pos
139144

140145

141-
# -
142-
143-
144146
class Attribution:
145-
def __init__(self, ensemble: EnsembleNFF, save_file: str = None):
147+
def __init__(self, ensemble: EnsembleNFF, save_file: Optional[str] = None):
146148
self.ensemble = ensemble
147149
self.save_file = save_file
148150

@@ -197,17 +199,15 @@ def calc_attribution_file(
197199
step: int = 1,
198200
progress_bar: bool = True,
199201
to_chemiscope: bool = False,
200-
bond_length: dict = None,
202+
bond_length: Optional[dict] = None,
201203
) -> list:
202204
attributions = []
203205
atoms_list = []
204206
energies = []
205207
energy_stds = []
206208
grads = []
207209
grad_stds = []
208-
with tqdm(
209-
range(skip, len(traj), step), disable=True if progress_bar == False else False
210-
) as pbar: # , postfix={"fbest":"?",}) as pbar:
210+
with tqdm(range(skip, len(traj), step), disable=not progress_bar) as pbar: # , postfix={"fbest":"?",}) as pbar:
211211
# for i in range(skip,len(traj),step):
212212
for i in pbar:
213213
# create atoms batch object
@@ -269,8 +269,7 @@ def calc_attribution_file(
269269
},
270270
}
271271
return atoms_list, properties
272-
else:
273-
return attributions
272+
return attributions
274273

275274
def activelearning(
276275
self,
@@ -281,12 +280,10 @@ def activelearning(
281280
skip: int = 0,
282281
step: int = 1,
283282
progress_bar: bool = True,
284-
bond_length: dict = None,
283+
bond_length: Optional[dict] = None,
285284
):
286285
atom_list = []
287-
with tqdm(
288-
range(skip, len(traj), step), disable=True if progress_bar == False else False
289-
) as pbar: # , postfix={"fbest":"?",}) as pbar:
286+
with tqdm(range(skip, len(traj), step), disable=not progress_bar) as pbar: # , postfix={"fbest":"?",}) as pbar:
290287
# for i in range(skip,len(traj),step):
291288
for i in pbar:
292289
# create atoms batch object
@@ -337,15 +334,15 @@ def activelearning(
337334
neighs = np.append(neighs, a)
338335
for n in neighs:
339336
atomstocare = np.append(atomstocare, molecules[np.where(balanced_mols == n)[0][0]])
340-
atomstocare = np.array((list(set(atomstocare))))
337+
atomstocare = np.array(list(set(atomstocare)))
341338
atomstocare = np.int64(atomstocare)
342339
atoms1 = atoms[atomstocare]
343340
index = np.where(atoms1.positions == atoms.positions[a])[0][0]
344341
xyz = reconstruct_atoms(atoms1, [np.arange(0, len(atoms1))], centre=index)
345342
atoms1.positions = xyz
346343
is_repeated = False
347-
for Atoms in atom_list:
348-
if atoms1.__eq__(Atoms):
344+
for at in atom_list:
345+
if atoms1 == at:
349346
is_repeated = True
350347
break
351348
if not is_repeated:

0 commit comments

Comments
 (0)