Skip to content

Commit 7eec73d

Browse files
committed
Fix adding D3
1 parent 5130b24 commit 7eec73d

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

ml_peg/models/get_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,25 @@ def load_models(models: None | str | Iterable = None) -> dict[str, Any]:
7373
task_name=kwargs.get("task_name", "omat"),
7474
device=cfg.get("device", "cpu"),
7575
overrides=kwargs.get("overrides", {}),
76+
add_d3=cfg.get("add_d3", False),
77+
d3_kwargs=cfg.get("d3_kwargs", {}),
7678
)
7779
elif cfg["class_name"] == "OrbCalc":
7880
kwargs = cfg.get("kwargs", {})
7981
loaded_models[name] = OrbCalc(
8082
name=kwargs["name"],
8183
device=cfg.get("device", "cpu"),
84+
add_d3=cfg.get("add_d3", False),
85+
d3_kwargs=cfg.get("d3_kwargs", {}),
8286
)
8387
else:
8488
loaded_models[name] = GenericASECalc(
8589
module=cfg["module"],
8690
class_name=cfg["class_name"],
8791
device=cfg.get("device", "auto"),
8892
kwargs=cfg.get("kwargs", {}),
93+
add_d3=cfg.get("add_d3", False),
94+
d3_kwargs=cfg.get("d3_kwargs", {}),
8995
)
9096

9197
return loaded_models

0 commit comments

Comments
 (0)