|
7 | 7 |
|
8 | 8 | from chgnet.model.model import CHGNet |
9 | 9 | from matgl import load_model |
| 10 | +from orb_models.forcefield.pretrained import orb_d3_xs_v2 |
10 | 11 | import pytest |
11 | 12 |
|
12 | 13 | from janus_core.helpers.mlip_calculators import choose_calculator |
|
32 | 33 |
|
33 | 34 | DPA3_PATH = MODEL_PATH / "2025-01-10-dpa3-mptrj.pth" |
34 | 35 |
|
| 36 | +ORB_WEIGHTS_PATH = MODEL_PATH / "orb-d3-xs-v2-20241011.ckpt" |
| 37 | +ORB_MODEL = orb_d3_xs_v2(weights_path=ORB_WEIGHTS_PATH) |
| 38 | + |
35 | 39 |
|
36 | 40 | @pytest.mark.parametrize( |
37 | 41 | "arch, device, kwargs", |
@@ -83,6 +87,7 @@ def test_invalid_arch(): |
83 | 87 | ("chgnet", "/invalid/path"), |
84 | 88 | ("nequip", "/invalid/path"), |
85 | 89 | ("dpa3", "/invalid/path"), |
| 90 | + ("orb", "/invalid/path"), |
86 | 91 | ], |
87 | 92 | ) |
88 | 93 | def test_invalid_model_path(arch, model_path): |
@@ -131,6 +136,8 @@ def test_invalid_device(arch): |
131 | 136 | ("nequip", "cpu", {"model": NEQUIP_PATH}), |
132 | 137 | ("dpa3", "cpu", {"model_path": DPA3_PATH}), |
133 | 138 | ("dpa3", "cpu", {"model": DPA3_PATH}), |
| 139 | + ("orb", "cpu", {}), |
| 140 | + ("orb", "cpu", {"model": ORB_MODEL}), |
134 | 141 | ], |
135 | 142 | ) |
136 | 143 | def test_extra_mlips(arch, device, kwargs): |
@@ -191,6 +198,16 @@ def test_extra_mlips(arch, device, kwargs): |
191 | 198 | "model_path": DPA3_PATH, |
192 | 199 | "path": DPA3_PATH, |
193 | 200 | }, |
| 201 | + { |
| 202 | + "arch": "orb", |
| 203 | + "model_path": ORB_MODEL, |
| 204 | + "model": ORB_MODEL, |
| 205 | + }, |
| 206 | + { |
| 207 | + "arch": "orb", |
| 208 | + "model_path": ORB_MODEL, |
| 209 | + "path": ORB_MODEL, |
| 210 | + }, |
194 | 211 | ], |
195 | 212 | ) |
196 | 213 | def test_extra_mlips_invalid(kwargs): |
|
0 commit comments