Skip to content

Commit 91c7717

Browse files
committed
Add MLIP tests for orb
1 parent 543313e commit 91c7717

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed
36.1 MB
Binary file not shown.

tests/test_mlip_calculators.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from chgnet.model.model import CHGNet
99
from matgl import load_model
10+
from orb_models.forcefield.pretrained import orb_d3_xs_v2
1011
import pytest
1112

1213
from janus_core.helpers.mlip_calculators import choose_calculator
@@ -32,6 +33,9 @@
3233

3334
DPA3_PATH = MODEL_PATH / "2025-01-10-dpa3-mptrj.pth"
3435

36+
ORB_WEIGHTS_PATH = MODEL_PATH / "orb-d3-xs-v2-20241011.ckpt"
37+
ORB_MODEL = orb_d3_xs_v2(weights_path=ORB_WEIGHTS_PATH)
38+
3539

3640
@pytest.mark.parametrize(
3741
"arch, device, kwargs",
@@ -83,6 +87,7 @@ def test_invalid_arch():
8387
("chgnet", "/invalid/path"),
8488
("nequip", "/invalid/path"),
8589
("dpa3", "/invalid/path"),
90+
("orb", "/invalid/path"),
8691
],
8792
)
8893
def test_invalid_model_path(arch, model_path):
@@ -131,6 +136,8 @@ def test_invalid_device(arch):
131136
("nequip", "cpu", {"model": NEQUIP_PATH}),
132137
("dpa3", "cpu", {"model_path": DPA3_PATH}),
133138
("dpa3", "cpu", {"model": DPA3_PATH}),
139+
("orb", "cpu", {}),
140+
("orb", "cpu", {"model": ORB_MODEL}),
134141
],
135142
)
136143
def test_extra_mlips(arch, device, kwargs):
@@ -191,6 +198,16 @@ def test_extra_mlips(arch, device, kwargs):
191198
"model_path": DPA3_PATH,
192199
"path": DPA3_PATH,
193200
},
201+
{
202+
"arch": "orb",
203+
"model_path": ORB_MODEL,
204+
"model": ORB_MODEL,
205+
},
206+
{
207+
"arch": "orb",
208+
"model_path": ORB_MODEL,
209+
"path": ORB_MODEL,
210+
},
194211
],
195212
)
196213
def test_extra_mlips_invalid(kwargs):

0 commit comments

Comments
 (0)