Skip to content

Commit 115781c

Browse files
alinelenaElliottKasoar
authored andcommitted
f
1 parent da76825 commit 115781c

File tree

1 file changed

+6
-18
lines changed

1 file changed

+6
-18
lines changed

janus_core/helpers/mlip_calculators.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -246,24 +246,12 @@ def choose_calculator(
246246
from orb_models.forcefield.graph_regressor import GraphRegressor
247247
import orb_models.forcefield.pretrained as orb_ff
248248

249-
if isinstance(model_path, str):
250-
match model_path:
251-
case "orb-v2":
252-
model = orb_ff.orb_v2()
253-
case "orb-mptraj-only-v2":
254-
model = orb_ff.orb_v2_mptraj_only()
255-
case "orb-d3-v2":
256-
model = orb_ff.orb_d3_v2()
257-
case "orb-d3-xs-v2":
258-
model = orb_ff.orb_d3_xs_v2()
259-
case "orb-d3-sm-v2":
260-
model = orb_ff.orb_d3_sm_v2()
261-
case _:
262-
raise ValueError(
263-
"Please specify `model_path`, as there is no "
264-
f"default model for {arch}"
265-
)
266-
elif isinstance(model_path, GraphRegressor):
249+
model = getattr(orb_ff, model_path.sub("-", "_"), None)()
250+
if model is None:
251+
raise ValueError(
252+
f"Please specify `model_path`, as there is no default model for {arch}"
253+
)
254+
if isinstance(model_path, GraphRegressor):
267255
model = model_path
268256
else:
269257
model = orb_ff.orb_v2()

0 commit comments

Comments
 (0)