Skip to content

Commit 543313e

Browse files
committed
Fix orb model selection
1 parent 197d34b commit 543313e

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

janus_core/helpers/mlip_calculators.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,20 @@ def choose_calculator(
263263
from orb_models.forcefield.graph_regressor import GraphRegressor
264264
import orb_models.forcefield.pretrained as orb_ff
265265

266-
model = getattr(orb_ff, model_path.sub("-", "_"), None)()
267-
if model is None:
268-
raise ValueError(
269-
f"Please specify `model_path`, as there is no default model for {arch}"
270-
)
266+
# Default model
267+
model_path = model_path if model_path else "orb_v2"
268+
271269
if isinstance(model_path, GraphRegressor):
272270
model = model_path
271+
model_path = "loaded_GraphRegressor"
273272
else:
274-
model = orb_ff.orb_v2()
273+
try:
274+
model = getattr(orb_ff, model_path.replace("-", "_"))()
275+
except AttributeError as e:
276+
raise ValueError(
277+
"`model_path` must be a `GraphRegressor`, pre-trained model label "
278+
"(e.g. 'orb-v2'), or `None` (uses default, orb-v2)"
279+
) from e
275280

276281
calculator = ORBCalculator(model=model, device=device, **kwargs)
277282

0 commit comments

Comments
 (0)