diff --git a/README.md b/README.md index e59004cd4..ba026c063 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ Tools for machine learnt interatomic potentials - alignn = 2024.5.27 (optional) - nequip = 0.6.1 (optional) - deepmd-kit = dpa3-alpha (optional) +- orb-models = 0.4.2 (optional) All required and optional dependencies can be found in [pyproject.toml](pyproject.toml). @@ -96,6 +97,7 @@ Current and planned features include: - SevenNet (experimental) - NequIP (experimental) - DPA3 (experimental) + - Orb (experimental) - [x] Single point calculations - [x] Geometry optimisation - [x] Molecular Dynamics diff --git a/docs/source/getting_started/getting_started.rst b/docs/source/getting_started/getting_started.rst index a726775eb..c0786599e 100644 --- a/docs/source/getting_started/getting_started.rst +++ b/docs/source/getting_started/getting_started.rst @@ -16,6 +16,7 @@ Dependencies - alignn = 2024.5.27 (optional) - nequip = 0.6.1 (optional) - deepmd-kit = dpa3-alpha (optional) +- orb-models = 0.4.2 (optional) All required and optional dependencies can be found in `pyproject.toml `_. @@ -60,5 +61,6 @@ Currently supported extras are: - ``sevenn``: `SevenNet `_ - ``nequip``: `NequIP `_ - ``dpa3``: `DPA3 `_ +- ``orb``: `Orb `_ ``extras`` are also listed in `pyproject.toml `_ under ``[project.optional-dependencies]``. diff --git a/janus_core/helpers/janus_types.py b/janus_core/helpers/janus_types.py index b21fccc7c..ee1c8db2d 100644 --- a/janus_core/helpers/janus_types.py +++ b/janus_core/helpers/janus_types.py @@ -124,6 +124,7 @@ class CorrelationKwargs(TypedDict, total=True): "sevennet", "nequip", "dpa3", + "orb", ] Devices = Literal["cpu", "cuda", "mps", "xpu"] Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh", "nvt-csvr", "npt-mtk"] diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index 8f4bec151..ce92cd153 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -257,6 +257,29 @@ def choose_calculator( calculator = DP(model=model_path, **kwargs) + elif arch == "orb": + from orb_models import __version__ + from orb_models.forcefield.calculator import ORBCalculator + from orb_models.forcefield.graph_regressor import GraphRegressor + import orb_models.forcefield.pretrained as orb_ff + + # Default model + model_path = model_path if model_path else "orb_v2" + + if isinstance(model_path, GraphRegressor): + model = model_path + model_path = "loaded_GraphRegressor" + else: + try: + model = getattr(orb_ff, model_path.replace("-", "_"))() + except AttributeError as e: + raise ValueError( + "`model_path` must be a `GraphRegressor`, pre-trained model label " + "(e.g. 'orb-v2'), or `None` (uses default, orb-v2)" + ) from e + + calculator = ORBCalculator(model=model, device=device, **kwargs) + else: raise ValueError( f"Unrecognized {arch=}. Suported architectures " diff --git a/pyproject.toml b/pyproject.toml index ac58da9cc..9f946445a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,10 @@ m3gnet = [ nequip = [ "nequip == 0.6.1", ] +orb = [ + "orb-models == 0.4.2", + "pynanoflann", +] sevennet = [ "sevenn == 0.10.3", ] @@ -67,6 +71,7 @@ all = [ "janus-core[dpa3]", "janus-core[m3gnet]", "janus-core[nequip]", + "janus-core[orb]", "janus-core[sevennet]", ] @@ -186,3 +191,4 @@ default-groups = [ [tool.uv.sources] deepmd-kit = { git = "https://github.com/deepmodeling/deepmd-kit.git", rev = "dpa3-alpha" } +pynanoflann = { git = "https://github.com/dwastberg/pynanoflann", rev = "af434039ae14bedcbb838a7808924d6689274168" } diff --git a/tests/models/orb-d3-xs-v2-20241011.ckpt b/tests/models/orb-d3-xs-v2-20241011.ckpt new file mode 100644 index 000000000..bffda267f Binary files /dev/null and b/tests/models/orb-d3-xs-v2-20241011.ckpt differ diff --git a/tests/test_mlip_calculators.py b/tests/test_mlip_calculators.py index eeb4aa64a..957903cdc 100644 --- a/tests/test_mlip_calculators.py +++ b/tests/test_mlip_calculators.py @@ -7,6 +7,7 @@ from chgnet.model.model import CHGNet from matgl import load_model +from orb_models.forcefield.pretrained import orb_d3_xs_v2 import pytest from janus_core.helpers.mlip_calculators import choose_calculator @@ -32,6 +33,9 @@ DPA3_PATH = MODEL_PATH / "2025-01-10-dpa3-mptrj.pth" +ORB_WEIGHTS_PATH = MODEL_PATH / "orb-d3-xs-v2-20241011.ckpt" +ORB_MODEL = orb_d3_xs_v2(weights_path=ORB_WEIGHTS_PATH) + @pytest.mark.parametrize( "arch, device, kwargs", @@ -83,6 +87,7 @@ def test_invalid_arch(): ("chgnet", "/invalid/path"), ("nequip", "/invalid/path"), ("dpa3", "/invalid/path"), + ("orb", "/invalid/path"), ], ) def test_invalid_model_path(arch, model_path): @@ -131,6 +136,8 @@ def test_invalid_device(arch): ("nequip", "cpu", {"model": NEQUIP_PATH}), ("dpa3", "cpu", {"model_path": DPA3_PATH}), ("dpa3", "cpu", {"model": DPA3_PATH}), + ("orb", "cpu", {}), + ("orb", "cpu", {"model": ORB_MODEL}), ], ) def test_extra_mlips(arch, device, kwargs): @@ -191,6 +198,16 @@ def test_extra_mlips(arch, device, kwargs): "model_path": DPA3_PATH, "path": DPA3_PATH, }, + { + "arch": "orb", + "model_path": ORB_MODEL, + "model": ORB_MODEL, + }, + { + "arch": "orb", + "model_path": ORB_MODEL, + "path": ORB_MODEL, + }, ], ) def test_extra_mlips_invalid(kwargs): diff --git a/tests/test_single_point.py b/tests/test_single_point.py index cbd21369c..767d729cd 100644 --- a/tests/test_single_point.py +++ b/tests/test_single_point.py @@ -295,6 +295,8 @@ def test_mlips(arch, device, expected_energy): ), ("nequip", "cpu", -169815.1282456301, "toluene.xyz", {"model_path": NEQUIP_PATH}), ("dpa3", "cpu", -27.053507387638092, "NaCl.cif", {"model_path": DPA3_PATH}), + ("orb", "cpu", -27.088973999023438, "NaCl.cif", {}), + ("orb", "cpu", -27.088973999023438, "NaCl.cif", {"model_path": "orb-v2"}), ]