Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Tools for machine learnt interatomic potentials
- alignn = 2024.5.27 (optional)
- nequip = 0.6.1 (optional)
- deepmd-kit = dpa3-alpha (optional)
- orb-models = 0.4.2 (optional)

All required and optional dependencies can be found in [pyproject.toml](pyproject.toml).

Expand Down Expand Up @@ -96,6 +97,7 @@ Current and planned features include:
- SevenNet (experimental)
- NequIP (experimental)
- DPA3 (experimental)
- Orb (experimental)
- [x] Single point calculations
- [x] Geometry optimisation
- [x] Molecular Dynamics
Expand Down
2 changes: 2 additions & 0 deletions docs/source/getting_started/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Dependencies
- alignn = 2024.5.27 (optional)
- nequip = 0.6.1 (optional)
- deepmd-kit = dpa3-alpha (optional)
- orb-models = 0.4.2 (optional)

All required and optional dependencies can be found in `pyproject.toml <https://github.com/stfc/janus-core/blob/main/pyproject.toml>`_.

Expand Down Expand Up @@ -60,5 +61,6 @@ Currently supported extras are:
- ``sevenn``: `SevenNet <https://github.com/MDIL-SNU/SevenNet/>`_
- ``nequip``: `NequIP <https://github.com/mir-group/nequip>`_
- ``dpa3``: `DPA3 <https://github.com/deepmodeling/deepmd-kit/tree/dpa3-alpha>`_
- ``orb``: `Orb <https://github.com/orbital-materials/orb-models>`_

``extras`` are also listed in `pyproject.toml <https://github.com/stfc/janus-core/blob/main/pyproject.toml>`_ under ``[project.optional-dependencies]``.
1 change: 1 addition & 0 deletions janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class CorrelationKwargs(TypedDict, total=True):
"sevennet",
"nequip",
"dpa3",
"orb",
]
Devices = Literal["cpu", "cuda", "mps", "xpu"]
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr", "npt-mtk"]
Expand Down
23 changes: 23 additions & 0 deletions janus_core/helpers/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,29 @@ def choose_calculator(

calculator = DP(model=model_path, **kwargs)

elif arch == "orb":
from orb_models import __version__
from orb_models.forcefield.calculator import ORBCalculator
from orb_models.forcefield.graph_regressor import GraphRegressor
import orb_models.forcefield.pretrained as orb_ff

# Default model
model_path = model_path if model_path else "orb_v2"

if isinstance(model_path, GraphRegressor):
model = model_path
model_path = "loaded_GraphRegressor"
else:
try:
model = getattr(orb_ff, model_path.replace("-", "_"))()
except AttributeError as e:
raise ValueError(
"`model_path` must be a `GraphRegressor`, pre-trained model label "
"(e.g. 'orb-v2'), or `None` (uses default, orb-v2)"
) from e

calculator = ORBCalculator(model=model, device=device, **kwargs)

else:
raise ValueError(
f"Unrecognized {arch=}. Suported architectures "
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ m3gnet = [
nequip = [
"nequip == 0.6.1",
]
orb = [
"orb-models == 0.4.2",
"pynanoflann",
]
sevennet = [
"sevenn == 0.10.3",
]
Expand All @@ -67,6 +71,7 @@ all = [
"janus-core[dpa3]",
"janus-core[m3gnet]",
"janus-core[nequip]",
"janus-core[orb]",
"janus-core[sevennet]",
]

Expand Down Expand Up @@ -186,3 +191,4 @@ default-groups = [

[tool.uv.sources]
deepmd-kit = { git = "https://github.com/deepmodeling/deepmd-kit.git", rev = "dpa3-alpha" }
pynanoflann = { git = "https://github.com/dwastberg/pynanoflann", rev = "af434039ae14bedcbb838a7808924d6689274168" }
Binary file added tests/models/orb-d3-xs-v2-20241011.ckpt
Binary file not shown.
17 changes: 17 additions & 0 deletions tests/test_mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from chgnet.model.model import CHGNet
from matgl import load_model
from orb_models.forcefield.pretrained import orb_d3_xs_v2
import pytest

from janus_core.helpers.mlip_calculators import choose_calculator
Expand All @@ -32,6 +33,9 @@

DPA3_PATH = MODEL_PATH / "2025-01-10-dpa3-mptrj.pth"

ORB_WEIGHTS_PATH = MODEL_PATH / "orb-d3-xs-v2-20241011.ckpt"
ORB_MODEL = orb_d3_xs_v2(weights_path=ORB_WEIGHTS_PATH)


@pytest.mark.parametrize(
"arch, device, kwargs",
Expand Down Expand Up @@ -83,6 +87,7 @@ def test_invalid_arch():
("chgnet", "/invalid/path"),
("nequip", "/invalid/path"),
("dpa3", "/invalid/path"),
("orb", "/invalid/path"),
],
)
def test_invalid_model_path(arch, model_path):
Expand Down Expand Up @@ -131,6 +136,8 @@ def test_invalid_device(arch):
("nequip", "cpu", {"model": NEQUIP_PATH}),
("dpa3", "cpu", {"model_path": DPA3_PATH}),
("dpa3", "cpu", {"model": DPA3_PATH}),
("orb", "cpu", {}),
("orb", "cpu", {"model": ORB_MODEL}),
],
)
def test_extra_mlips(arch, device, kwargs):
Expand Down Expand Up @@ -191,6 +198,16 @@ def test_extra_mlips(arch, device, kwargs):
"model_path": DPA3_PATH,
"path": DPA3_PATH,
},
{
"arch": "orb",
"model_path": ORB_MODEL,
"model": ORB_MODEL,
},
{
"arch": "orb",
"model_path": ORB_MODEL,
"path": ORB_MODEL,
},
],
)
def test_extra_mlips_invalid(kwargs):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def test_mlips(arch, device, expected_energy):
),
("nequip", "cpu", -169815.1282456301, "toluene.xyz", {"model_path": NEQUIP_PATH}),
("dpa3", "cpu", -27.053507387638092, "NaCl.cif", {"model_path": DPA3_PATH}),
("orb", "cpu", -27.088973999023438, "NaCl.cif", {}),
("orb", "cpu", -27.088973999023438, "NaCl.cif", {"model_path": "orb-v2"}),
]


Expand Down