Skip to content

Commit 73a3ae6

Browse files
add dpa3 calculator fixes #411 (#428)
* add dpa3 calculator fixes #411 * Update docs for DPA3 * Separate missing and invalid DPA3 model * Add test for invalid DPA3 model * Add reference for example model --------- Co-authored-by: ElliottKasoar <[email protected]>
1 parent 7eae3aa commit 73a3ae6

File tree

8 files changed

+58
-5
lines changed

8 files changed

+58
-5
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Tools for machine learnt interatomic potentials
3535
- sevenn = 0.10.3 (optional)
3636
- alignn = 2024.5.27 (optional)
3737
- nequip = 0.6.1 (optional)
38+
- deepmd-kit = dpa3-alpha (optional)
3839

3940
All required and optional dependencies can be found in [pyproject.toml](pyproject.toml).
4041

@@ -94,6 +95,7 @@ Current and planned features include:
9495
- ALIGNN (experimental)
9596
- SevenNet (experimental)
9697
- NequIP (experimental)
98+
- DPA3 (experimental)
9799
- [x] Single point calculations
98100
- [x] Geometry optimisation
99101
- [x] Molecular Dynamics

docs/source/getting_started/getting_started.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Dependencies
1515
- sevenn = 0.10.3 (optional)
1616
- alignn = 2024.5.27 (optional)
1717
- nequip = 0.6.1 (optional)
18+
- deepmd-kit = dpa3-alpha (optional)
1819

1920
All required and optional dependencies can be found in `pyproject.toml <https://github.com/stfc/janus-core/blob/main/pyproject.toml>`_.
2021

@@ -58,5 +59,6 @@ Currently supported extras are:
5859
- ``m3gnet``: `M3GNet <https://github.com/materialsvirtuallab/matgl/>`_
5960
- ``sevenn``: `SevenNet <https://github.com/MDIL-SNU/SevenNet/>`_
6061
- ``nequip``: `NequIP <https://github.com/mir-group/nequip>`_
62+
- ``dpa3``: `DPA3 <https://github.com/deepmodeling/deepmd-kit/tree/dpa3-alpha>`_
6163

6264
``extras`` are also listed in `pyproject.toml <https://github.com/stfc/janus-core/blob/main/pyproject.toml>`_ under ``[project.optional-dependencies]``.

janus_core/helpers/janus_types.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,15 @@ class CorrelationKwargs(TypedDict, total=True):
115115

116116
# Janus specific
117117
Architectures = Literal[
118-
"mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet", "nequip"
118+
"mace",
119+
"mace_mp",
120+
"mace_off",
121+
"m3gnet",
122+
"chgnet",
123+
"alignn",
124+
"sevennet",
125+
"nequip",
126+
"dpa3",
119127
]
120128
Devices = Literal["cpu", "cuda", "mps", "xpu"]
121129
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr", "npt-mtk"]

janus_core/helpers/mlip_calculators.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,23 @@ def choose_calculator(
240240
model_path=model_path, device=device, **kwargs
241241
)
242242

243+
elif arch == "dpa3":
244+
from deepmd import __version__
245+
from deepmd.calculator import DP
246+
247+
# No default `model_path`
248+
if model_path is None:
249+
# From https://matbench-discovery.materialsproject.org/models/dpa3-v1-mptrj
250+
raise ValueError(
251+
"Please specify `model_path`, as there is no "
252+
f"default model for {arch} "
253+
"e.g. https://bohrium-api.dp.tech/ds-dl/dpa3openlam-74ng-v3.zip"
254+
)
255+
256+
model_path = str(model_path)
257+
258+
calculator = DP(model=model_path, **kwargs)
259+
243260
else:
244261
raise ValueError(
245262
f"Unrecognized {arch=}. Suported architectures "

pyproject.toml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,26 @@ alignn = [
4848
chgnet = [
4949
"chgnet == 0.3.8",
5050
]
51+
dpa3 = [
52+
"deepmd-kit",
53+
]
5154
m3gnet = [
5255
"matgl == 1.1.3",
5356
"dgl == 2.1.0",
5457
]
55-
sevennet = [
56-
"sevenn == 0.10.3",
57-
]
5858
nequip = [
5959
"nequip == 0.6.1",
6060
]
61+
sevennet = [
62+
"sevenn == 0.10.3",
63+
]
6164
all = [
6265
"janus-core[alignn]",
6366
"janus-core[chgnet]",
67+
"janus-core[dpa3]",
6468
"janus-core[m3gnet]",
65-
"janus-core[sevennet]",
6669
"janus-core[nequip]",
70+
"janus-core[sevennet]",
6771
]
6872

6973
[project.scripts]
@@ -179,3 +183,6 @@ default-groups = [
179183
"docs",
180184
"pre-commit",
181185
]
186+
187+
[tool.uv.sources]
188+
deepmd-kit = { git = "https://github.com/deepmodeling/deepmd-kit.git", rev = "dpa3-alpha" }
13.2 MB
Binary file not shown.

tests/test_mlip_calculators.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
NEQUIP_PATH = MODEL_PATH / "toluene.pth"
3232

33+
DPA3_PATH = MODEL_PATH / "2025-01-10-dpa3-mptrj.pth"
34+
3335

3436
@pytest.mark.parametrize(
3537
"arch, device, kwargs",
@@ -80,6 +82,7 @@ def test_invalid_arch():
8082
("m3gnet", "/invalid/path"),
8183
("chgnet", "/invalid/path"),
8284
("nequip", "/invalid/path"),
85+
("dpa3", "/invalid/path"),
8386
],
8487
)
8588
def test_invalid_model_path(arch, model_path):
@@ -126,6 +129,8 @@ def test_invalid_device(arch):
126129
("sevennet", "cpu", {"model": "sevennet-0"}),
127130
("nequip", "cpu", {"model_path": NEQUIP_PATH}),
128131
("nequip", "cpu", {"model": NEQUIP_PATH}),
132+
("dpa3", "cpu", {"model_path": DPA3_PATH}),
133+
("dpa3", "cpu", {"model": DPA3_PATH}),
129134
],
130135
)
131136
def test_extra_mlips(arch, device, kwargs):
@@ -176,6 +181,16 @@ def test_extra_mlips(arch, device, kwargs):
176181
"model_path": NEQUIP_PATH,
177182
"path": NEQUIP_PATH,
178183
},
184+
{
185+
"arch": "dpa3",
186+
"model_path": DPA3_PATH,
187+
"model": DPA3_PATH,
188+
},
189+
{
190+
"arch": "dpa3",
191+
"model_path": DPA3_PATH,
192+
"path": DPA3_PATH,
193+
},
179194
],
180195
)
181196
def test_extra_mlips_invalid(kwargs):

tests/test_single_point.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
SEVENNET_PATH = MODEL_PATH / "sevennet_0.pth"
1919
ALIGNN_PATH = MODEL_PATH / "v5.27.2024"
2020
NEQUIP_PATH = MODEL_PATH / "toluene.pth"
21+
DPA3_PATH = MODEL_PATH / "2025-01-10-dpa3-mptrj.pth"
2122

2223
test_data = [
2324
(DATA_PATH / "benzene.xyz", -76.0605725422795, "energy", "energy", {}, None),
@@ -293,6 +294,7 @@ def test_mlips(arch, device, expected_energy):
293294
{"model_path": "SevenNet-0_11July2024"},
294295
),
295296
("nequip", "cpu", -169815.1282456301, "toluene.xyz", {"model_path": NEQUIP_PATH}),
297+
("dpa3", "cpu", -27.053507387638092, "NaCl.cif", {"model_path": DPA3_PATH}),
296298
]
297299

298300

0 commit comments

Comments
 (0)