Skip to content

Commit 8d4ea9b

Browse files
committed
suppress printing
1 parent 391c83e commit 8d4ea9b

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

nff/io/mace.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,14 @@ def _check_non_zero(std):
3737
return std
3838

3939

40-
def get_mace_mp_model_path(model: Optional[str] = None) -> str:
40+
def get_mace_mp_model_path(model: Optional[str] = None, supress_print=True) -> str:
4141
"""Get the default MACE MP model. Replicated from the MACE codebase,
4242
Copyright (c) 2022 ACEsuit/mace and licensed under the MIT license.
4343
4444
Args:
4545
model (str, optional): MACE_MP model that you want to get.
4646
Defaults to None. Can be "small", "medium", "large", or a URL.
47+
supress_print (bool, optional): Whether to suppress print statements. Defaults to True.
4748
4849
Raises:
4950
RuntimeError: raised if the model download fails and no local model is found
@@ -53,7 +54,8 @@ def get_mace_mp_model_path(model: Optional[str] = None) -> str:
5354
"""
5455
if model in (None, "medium") and os.path.isfile(LOCAL_MODEL_PATH):
5556
model_path = LOCAL_MODEL_PATH
56-
print(f"Using local medium Materials Project MACE model for MACECalculator {model}")
57+
if not supress_print:
58+
print(f"Using local medium Materials Project MACE model for MACECalculator {model}")
5759
elif model in (None, "small", "medium", "large") or str(model).startswith("https:"):
5860
try:
5961
checkpoint_url = (
@@ -65,11 +67,13 @@ def get_mace_mp_model_path(model: Optional[str] = None) -> str:
6567
if not os.path.isfile(model_path):
6668
os.makedirs(cache_dir, exist_ok=True)
6769
# download and save to disk
68-
print(f"Downloading MACE model from {checkpoint_url!r}")
6970
urllib.request.urlretrieve(checkpoint_url, model_path)
70-
print(f"Cached MACE model to {model_path}")
71-
msg = f"Loading Materials Project MACE with {model_path}"
72-
print(msg)
71+
if not supress_print:
72+
print(f"Downloading MACE model from {checkpoint_url!r}")
73+
print(f"Cached MACE model to {model_path}")
74+
if not supress_print:
75+
msg = f"Loading Materials Project MACE with {model_path}"
76+
print(msg)
7377
except Exception as exc:
7478
raise RuntimeError("Model download failed and no local model found") from exc
7579
else:

nff/nn/models/mace.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,13 +280,15 @@ def load_foundations(
280280
model: Literal["small", "medium", "large"] = "medium",
281281
map_location: str = "cpu",
282282
default_dtype: Literal["", "float32", "float64"] = "float32",
283+
suppress_warnings: bool = True,
283284
) -> NffScaleMACE:
284285
"""Load MACE foundational model.
285286
286287
Args:
287288
model (Literal["small", "medium", "large"], optional): model size. Defaults to "medium".
288289
map_location (str, optional): The device to load the model on. Defaults to "cpu".
289290
default_dtype (Literal["", "float32", "float64"], optional): float type of the model. Defaults to "float32".
291+
suppress_warnings (bool, optional): Whether to suppress warnings. Defaults to False.
290292
291293
Returns:
292294
NffScaleMACE: NffScaleMACE foundational model.
@@ -296,13 +298,15 @@ def load_foundations(
296298
init_params = get_init_kwargs_from_model(mace_model)
297299
model_dtype = get_model_dtype(mace_model)
298300
if default_dtype == "":
299-
print(f"No dtype selected, switching to {model_dtype} to match model dtype.")
301+
if not suppress_warnings:
302+
print(f"No dtype selected, switching to {model_dtype} to match model dtype.")
300303
default_dtype = model_dtype
301304
if model_dtype != default_dtype:
302-
print(
303-
f"Default dtype {default_dtype} does not match model dtype {model_dtype}, "
304-
f"converting models to {default_dtype}."
305-
)
305+
if not suppress_warnings:
306+
print(
307+
f"Default dtype {default_dtype} does not match model dtype {model_dtype}, "
308+
f"converting models to {default_dtype}."
309+
)
306310
if default_dtype == "float64":
307311
mace_model.double()
308312
elif default_dtype == "float32":

0 commit comments

Comments
 (0)