Skip to content

Commit 485084e

Browse files
Add Orb support (#303)
* add Orb support * Generalise orb version * calculators * add tests * Update docs for Orb * Update orb * Fix orb model selection * Add MLIP tests for orb --------- Co-authored-by: ElliottKasoar <[email protected]>
1 parent 73a3ae6 commit 485084e

File tree

8 files changed

+53
-0
lines changed

8 files changed

+53
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Tools for machine learnt interatomic potentials
3636
- alignn = 2024.5.27 (optional)
3737
- nequip = 0.6.1 (optional)
3838
- deepmd-kit = dpa3-alpha (optional)
39+
- orb-models = 0.4.2 (optional)
3940

4041
All required and optional dependencies can be found in [pyproject.toml](pyproject.toml).
4142

@@ -96,6 +97,7 @@ Current and planned features include:
9697
- SevenNet (experimental)
9798
- NequIP (experimental)
9899
- DPA3 (experimental)
100+
- Orb (experimental)
99101
- [x] Single point calculations
100102
- [x] Geometry optimisation
101103
- [x] Molecular Dynamics

docs/source/getting_started/getting_started.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Dependencies
1616
- alignn = 2024.5.27 (optional)
1717
- nequip = 0.6.1 (optional)
1818
- deepmd-kit = dpa3-alpha (optional)
19+
- orb-models = 0.4.2 (optional)
1920

2021
All required and optional dependencies can be found in `pyproject.toml <https://github.com/stfc/janus-core/blob/main/pyproject.toml>`_.
2122

@@ -60,5 +61,6 @@ Currently supported extras are:
6061
- ``sevenn``: `SevenNet <https://github.com/MDIL-SNU/SevenNet/>`_
6162
- ``nequip``: `NequIP <https://github.com/mir-group/nequip>`_
6263
- ``dpa3``: `DPA3 <https://github.com/deepmodeling/deepmd-kit/tree/dpa3-alpha>`_
64+
- ``orb``: `Orb <https://github.com/orbital-materials/orb-models>`_
6365

6466
``extras`` are also listed in `pyproject.toml <https://github.com/stfc/janus-core/blob/main/pyproject.toml>`_ under ``[project.optional-dependencies]``.

janus_core/helpers/janus_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ class CorrelationKwargs(TypedDict, total=True):
124124
"sevennet",
125125
"nequip",
126126
"dpa3",
127+
"orb",
127128
]
128129
Devices = Literal["cpu", "cuda", "mps", "xpu"]
129130
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr", "npt-mtk"]

janus_core/helpers/mlip_calculators.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,29 @@ def choose_calculator(
257257

258258
calculator = DP(model=model_path, **kwargs)
259259

260+
elif arch == "orb":
261+
from orb_models import __version__
262+
from orb_models.forcefield.calculator import ORBCalculator
263+
from orb_models.forcefield.graph_regressor import GraphRegressor
264+
import orb_models.forcefield.pretrained as orb_ff
265+
266+
# Default model
267+
model_path = model_path if model_path else "orb_v2"
268+
269+
if isinstance(model_path, GraphRegressor):
270+
model = model_path
271+
model_path = "loaded_GraphRegressor"
272+
else:
273+
try:
274+
model = getattr(orb_ff, model_path.replace("-", "_"))()
275+
except AttributeError as e:
276+
raise ValueError(
277+
"`model_path` must be a `GraphRegressor`, pre-trained model label "
278+
"(e.g. 'orb-v2'), or `None` (uses default, orb-v2)"
279+
) from e
280+
281+
calculator = ORBCalculator(model=model, device=device, **kwargs)
282+
260283
else:
261284
raise ValueError(
262285
f"Unrecognized {arch=}. Suported architectures "

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ m3gnet = [
5858
nequip = [
5959
"nequip == 0.6.1",
6060
]
61+
orb = [
62+
"orb-models == 0.4.2",
63+
"pynanoflann",
64+
]
6165
sevennet = [
6266
"sevenn == 0.10.3",
6367
]
@@ -67,6 +71,7 @@ all = [
6771
"janus-core[dpa3]",
6872
"janus-core[m3gnet]",
6973
"janus-core[nequip]",
74+
"janus-core[orb]",
7075
"janus-core[sevennet]",
7176
]
7277

@@ -186,3 +191,4 @@ default-groups = [
186191

187192
[tool.uv.sources]
188193
deepmd-kit = { git = "https://github.com/deepmodeling/deepmd-kit.git", rev = "dpa3-alpha" }
194+
pynanoflann = { git = "https://github.com/dwastberg/pynanoflann", rev = "af434039ae14bedcbb838a7808924d6689274168" }
36.1 MB
Binary file not shown.

tests/test_mlip_calculators.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from chgnet.model.model import CHGNet
99
from matgl import load_model
10+
from orb_models.forcefield.pretrained import orb_d3_xs_v2
1011
import pytest
1112

1213
from janus_core.helpers.mlip_calculators import choose_calculator
@@ -32,6 +33,9 @@
3233

3334
DPA3_PATH = MODEL_PATH / "2025-01-10-dpa3-mptrj.pth"
3435

36+
ORB_WEIGHTS_PATH = MODEL_PATH / "orb-d3-xs-v2-20241011.ckpt"
37+
ORB_MODEL = orb_d3_xs_v2(weights_path=ORB_WEIGHTS_PATH)
38+
3539

3640
@pytest.mark.parametrize(
3741
"arch, device, kwargs",
@@ -83,6 +87,7 @@ def test_invalid_arch():
8387
("chgnet", "/invalid/path"),
8488
("nequip", "/invalid/path"),
8589
("dpa3", "/invalid/path"),
90+
("orb", "/invalid/path"),
8691
],
8792
)
8893
def test_invalid_model_path(arch, model_path):
@@ -131,6 +136,8 @@ def test_invalid_device(arch):
131136
("nequip", "cpu", {"model": NEQUIP_PATH}),
132137
("dpa3", "cpu", {"model_path": DPA3_PATH}),
133138
("dpa3", "cpu", {"model": DPA3_PATH}),
139+
("orb", "cpu", {}),
140+
("orb", "cpu", {"model": ORB_MODEL}),
134141
],
135142
)
136143
def test_extra_mlips(arch, device, kwargs):
@@ -191,6 +198,16 @@ def test_extra_mlips(arch, device, kwargs):
191198
"model_path": DPA3_PATH,
192199
"path": DPA3_PATH,
193200
},
201+
{
202+
"arch": "orb",
203+
"model_path": ORB_MODEL,
204+
"model": ORB_MODEL,
205+
},
206+
{
207+
"arch": "orb",
208+
"model_path": ORB_MODEL,
209+
"path": ORB_MODEL,
210+
},
194211
],
195212
)
196213
def test_extra_mlips_invalid(kwargs):

tests/test_single_point.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def test_mlips(arch, device, expected_energy):
295295
),
296296
("nequip", "cpu", -169815.1282456301, "toluene.xyz", {"model_path": NEQUIP_PATH}),
297297
("dpa3", "cpu", -27.053507387638092, "NaCl.cif", {"model_path": DPA3_PATH}),
298+
("orb", "cpu", -27.088973999023438, "NaCl.cif", {}),
299+
("orb", "cpu", -27.088973999023438, "NaCl.cif", {"model_path": "orb-v2"}),
298300
]
299301

300302

0 commit comments

Comments
 (0)