Skip to content

Commit ba0f18a

Browse files
alinelenaElliottKasoar
authored andcommitted
add orb support... tricky
1 parent 6d3310f commit ba0f18a

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

janus_core/helpers/janus_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class CorrelationKwargs(TypedDict, total=True):
115115

116116
# Janus specific
117117
Architectures = Literal[
118-
"mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet"
118+
"mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn", "sevennet", "orb"
119119
]
120120
Devices = Literal["cpu", "cuda", "mps", "xpu"]
121121
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh"]

janus_core/helpers/mlip_calculators.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,36 @@ def choose_calculator(
225225
kwargs.setdefault("sevennet_config", None)
226226
calculator = SevenNetCalculator(model=model_path, device=device, **kwargs)
227227

228+
elif arch == "orb":
229+
__version__ = "0.3"
230+
from orb_models.forcefield.calculator import ORBCalculator
231+
from orb_models.forcefield.graph_regressor import GraphRegressor
232+
import orb_models.forcefield.pretrained as orb_ff
233+
234+
if isinstance(model_path, str):
235+
match model_path:
236+
case "orb-v1":
237+
model = orb_ff.orb_v1()
238+
case "orb-mptraj-only-v1":
239+
model = orb_ff.orb_v1_mptraj_only()
240+
case "orb-d3-v1":
241+
model = orb_ff.orb_d3_v1()
242+
case "orb-d3-xs-v1":
243+
model = orb_ff.orb_d3_xs_v1()
244+
case "orb-d3-sm-v1":
245+
model = orb_ff.orb_d3_sm_v1()
246+
case _:
247+
raise ValueError(
248+
"Please specify `model_path`, as there is no "
249+
f"default model for {arch}"
250+
)
251+
elif isinstance(model_path, GraphRegressor):
252+
model = model_path
253+
else:
254+
model = orb_ff.orb_v1_mptraj_only()
255+
256+
calculator = ORBCalculator(model=model, device=device, **kwargs)
257+
228258
else:
229259
raise ValueError(
230260
f"Unrecognized {arch=}. Suported architectures "

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,18 @@ m3gnet = [
4848
"matgl == 1.1.3",
4949
"dgl == 2.1.0",
5050
]
51+
orb = [
52+
"orb-models == 0.4.1",
53+
"pynanoflann",
54+
]
5155
sevennet = [
5256
"sevenn == 0.10.0",
5357
]
5458
all = [
5559
"janus-core[alignn]",
5660
"janus-core[chgnet]",
5761
"janus-core[m3gnet]",
62+
"janus-core[orb]",
5863
"janus-core[sevennet]",
5964
]
6065

@@ -164,3 +169,6 @@ default-groups = [
164169
"docs",
165170
"pre-commit",
166171
]
172+
173+
[tool.uv.sources]
174+
pynanoflann = { git = "https://github.com/dwastberg/pynanoflann", rev = "af434039ae14bedcbb838a7808924d6689274168" }

0 commit comments

Comments
 (0)