Skip to content

Commit 4711e7e

Browse files
authored
Merge pull request #229 from choderalab/fix-device-command-line-argument
Fix command line argument
2 parents abf8076 + dedf9d4 commit 4711e7e

File tree

6 files changed

+84
-116
lines changed

6 files changed

+84
-116
lines changed

modelforge/tests/test_remote.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def test_download_from_url(prep_temp_dir):
7171
)
7272

7373

74+
@pytest.mark.skip(
75+
reason="This test seems to time out on the CI frequently. Will be refactoring and not need this soon."
76+
)
7477
def test_download_from_figshare(prep_temp_dir):
7578
url = "https://figshare.com/ndownloader/files/22247589"
7679
name = download_from_figshare(

modelforge/train/parameters.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def device_index_must_be_positive(cls, v) -> Union[int, List[int]]:
346346
for device in v:
347347
if device < 0:
348348
raise ValueError("device_index must be positive")
349-
if v < 0:
350-
raise ValueError("device_index must be positive")
349+
else:
350+
if v < 0:
351+
raise ValueError("device_index must be positive")
351352
return v

modelforge/train/training.py

Lines changed: 46 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,15 +1197,15 @@ def _add_tags(self, tags: List[str]) -> List[str]:
11971197
# add potential name
11981198
tags.append(self.potential_parameter.potential_name)
11991199
# add information about what is included in the loss
1200-
str_loss_property = "-".join(self.training_parameter.loss_parameter.loss_property)
1200+
str_loss_property = "-".join(
1201+
self.training_parameter.loss_parameter.loss_property
1202+
)
12011203
tags.append(f"loss-{str_loss_property}")
12021204

12031205
return tags
12041206

12051207

12061208
from typing import List, Optional, Union
1207-
1208-
12091209
def read_config(
12101210
condensed_config_path: Optional[str] = None,
12111211
training_parameter_path: Optional[str] = None,
@@ -1223,147 +1223,81 @@ def read_config(
12231223
simulation_environment: Optional[str] = None,
12241224
):
12251225
"""
1226-
Reads one or more TOML configuration files and loads them into the pydantic models
1227-
1226+
Reads one or more TOML configuration files and loads them into the pydantic models.
1227+
12281228
Parameters
12291229
----------
1230-
condensed_config_path : str, optional
1231-
Path to the TOML configuration that contains all parameters for the dataset, potential, training, and runtime parameters.
1232-
Any other provided configuration files will be ignored.
1233-
training_parameter_path : str, optional
1234-
Path to the TOML file defining the training parameters.
1235-
dataset_parameter_path : str, optional
1236-
Path to the TOML file defining the dataset parameters.
1237-
potential_parameter_path : str, optional
1238-
Path to the TOML file defining the potential parameters.
1239-
runtime_parameter_path : str, optional
1240-
Path to the TOML file defining the runtime parameters. If this is not provided, the code will attempt to use
1241-
the runtime parameters provided as arguments.
1242-
accelerator : str, optional
1243-
Accelerator type to use. If provided, this overrides the accelerator type in the runtime_defaults configuration.
1244-
devices : int|List[int], optional
1245-
Device index/indices to use. If provided, this overrides the devices in the runtime_defaults configuration.
1246-
number_of_nodes : int, optional
1247-
Number of nodes to use. If provided, this overrides the number of nodes in the runtime_defaults configuration.
1248-
experiment_name : str, optional
1249-
Name of the experiment. If provided, this overrides the experiment name in the runtime_defaults configuration.
1250-
save_dir : str, optional
1251-
Directory to save the model. If provided, this overrides the save directory in the runtime_defaults configuration.
1252-
local_cache_dir : str, optional
1253-
Local cache directory. If provided, this overrides the local cache directory in the runtime_defaults configuration.
1254-
checkpoint_path : str, optional
1255-
Path to the checkpoint file. If provided, this overrides the checkpoint path in the runtime_defaults configuration.
1256-
log_every_n_steps : int, optional
1257-
Number of steps to log. If provided, this overrides the log_every_n_steps in the runtime_defaults configuration.
1258-
simulation_environment : str, optional
1259-
Simulation environment. If provided, this overrides the simulation environment in the runtime_defaults configuration.
1230+
(Parameters as described earlier...)
12601231
12611232
Returns
12621233
-------
12631234
Tuple
12641235
Tuple containing the training, dataset, potential, and runtime parameters.
1265-
12661236
"""
12671237
import toml
12681238

1269-
use_runtime_variables_instead_of_toml = False
1239+
# Initialize the config dictionaries
1240+
training_config_dict = {}
1241+
dataset_config_dict = {}
1242+
potential_config_dict = {}
1243+
runtime_config_dict = {}
1244+
12701245
if condensed_config_path is not None:
12711246
config = toml.load(condensed_config_path)
12721247
log.info(f"Reading config from : {condensed_config_path}")
12731248

1274-
training_config_dict = config["training"]
1275-
dataset_config_dict = config["dataset"]
1276-
potential_config_dict = config["potential"]
1277-
runtime_config_dict = config["runtime"]
1249+
training_config_dict = config.get("training", {})
1250+
dataset_config_dict = config.get("dataset", {})
1251+
potential_config_dict = config.get("potential", {})
1252+
runtime_config_dict = config.get("runtime", {})
12781253

12791254
else:
1280-
if training_parameter_path is None:
1281-
raise ValueError("Training configuration not provided.")
1282-
if dataset_parameter_path is None:
1283-
raise ValueError("Dataset configuration not provided.")
1284-
if potential_parameter_path is None:
1285-
raise ValueError("Potential configuration not provided.")
1286-
1287-
training_config_dict = toml.load(training_parameter_path)["training"]
1288-
dataset_config_dict = toml.load(dataset_parameter_path)["dataset"]
1289-
potential_config_dict = toml.load(potential_parameter_path)["potential"]
1290-
1291-
# if the runtime_parameter_path is not defined, let us see if runtime variables are passed
1292-
if runtime_parameter_path is None:
1293-
use_runtime_variables_instead_of_toml = True
1294-
log.info(
1295-
"Runtime configuration not provided. The code will try to use runtime arguments."
1296-
)
1297-
# we can just create a dict with the runtime variables; the pydantic model will then validate them
1298-
runtime_config_dict = {
1299-
"save_dir": save_dir,
1300-
"experiment_name": experiment_name,
1301-
"local_cache_dir": local_cache_dir,
1302-
"checkpoint_path": checkpoint_path,
1303-
"log_every_n_steps": log_every_n_steps,
1304-
"simulation_environment": simulation_environment,
1305-
"accelerator": accelerator,
1306-
"devices": devices,
1307-
"number_of_nodes": number_of_nodes,
1308-
}
1309-
else:
1310-
runtime_config_dict = toml.load(runtime_parameter_path)["runtime"]
1311-
1255+
if training_parameter_path:
1256+
training_config_dict = toml.load(training_parameter_path).get("training", {})
1257+
if dataset_parameter_path:
1258+
dataset_config_dict = toml.load(dataset_parameter_path).get("dataset", {})
1259+
if potential_parameter_path:
1260+
potential_config_dict = toml.load(potential_parameter_path).get("potential", {})
1261+
if runtime_parameter_path:
1262+
runtime_config_dict = toml.load(runtime_parameter_path).get("runtime", {})
1263+
1264+
# Override runtime configuration with command-line arguments if provided
1265+
runtime_overrides = {
1266+
"accelerator": accelerator,
1267+
"devices": devices,
1268+
"number_of_nodes": number_of_nodes,
1269+
"experiment_name": experiment_name,
1270+
"save_dir": save_dir,
1271+
"local_cache_dir": local_cache_dir,
1272+
"checkpoint_path": checkpoint_path,
1273+
"log_every_n_steps": log_every_n_steps,
1274+
"simulation_environment": simulation_environment,
1275+
}
1276+
1277+
for key, value in runtime_overrides.items():
1278+
if value is not None:
1279+
runtime_config_dict[key] = value
1280+
1281+
# Load and instantiate the data classes with the merged configuration
13121282
from modelforge.potential import _Implemented_NNP_Parameters
13131283
from modelforge.dataset.dataset import DatasetParameters
13141284
from modelforge.train.parameters import TrainingParameters, RuntimeParameters
13151285

13161286
potential_name = potential_config_dict["potential_name"]
1317-
PotentialParameters = (
1318-
_Implemented_NNP_Parameters.get_neural_network_parameter_class(potential_name)
1319-
)
1287+
PotentialParameters = _Implemented_NNP_Parameters.get_neural_network_parameter_class(potential_name)
13201288

13211289
dataset_parameters = DatasetParameters(**dataset_config_dict)
13221290
training_parameters = TrainingParameters(**training_config_dict)
13231291
runtime_parameters = RuntimeParameters(**runtime_config_dict)
1324-
potential_parameter_paths = PotentialParameters(**potential_config_dict)
1325-
1326-
# if accelerator, devices, or number_of_nodes are provided, override the runtime_defaults parameters
1327-
# note, since these are being set in the runtime data model, they will be validated by the model
1328-
# if we use the runtime variables instead of the toml file, these have already been set so we can skip this step.
1329-
1330-
if use_runtime_variables_instead_of_toml == False:
1331-
if accelerator:
1332-
runtime_parameters.accelerator = accelerator
1333-
log.info(f"Using accelerator: {accelerator}")
1334-
if devices:
1335-
runtime_parameters.device_index = devices
1336-
log.info(f"Using device index: {devices}")
1337-
if number_of_nodes:
1338-
runtime_parameters.number_of_nodes = number_of_nodes
1339-
log.info(f"Using number of nodes: {number_of_nodes}")
1340-
if experiment_name:
1341-
runtime_parameters.experiment_name = experiment_name
1342-
log.info(f"Using experiment name: {experiment_name}")
1343-
if save_dir:
1344-
runtime_parameters.save_dir = save_dir
1345-
log.info(f"Using save directory: {save_dir}")
1346-
if local_cache_dir:
1347-
runtime_parameters.local_cache_dir = local_cache_dir
1348-
log.info(f"Using local cache directory: {local_cache_dir}")
1349-
if checkpoint_path:
1350-
runtime_parameters.checkpoint_path = checkpoint_path
1351-
log.info(f"Using checkpoint path: {checkpoint_path}")
1352-
if log_every_n_steps:
1353-
runtime_parameters.log_every_n_steps = log_every_n_steps
1354-
log.info(f"Logging every {log_every_n_steps} steps.")
1355-
if simulation_environment:
1356-
runtime_parameters.simulation_environment = simulation_environment
1357-
log.info(f"Using simulation environment: {simulation_environment}")
1292+
potential_parameter = PotentialParameters(**potential_config_dict)
13581293

13591294
return (
13601295
training_parameters,
13611296
dataset_parameters,
1362-
potential_parameter_paths,
1297+
potential_parameter,
13631298
runtime_parameters,
13641299
)
13651300

1366-
13671301
def read_config_and_train(
13681302
condensed_config_path: Optional[str] = None,
13691303
training_parameter_path: Optional[str] = None,

modelforge/utils/io.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,32 @@ def check_import(module: str):
244244

245245
imported_module = import_(module)
246246
del imported_module
247+
248+
249+
from typing import Union, List
250+
251+
252+
def parse_devices(value: str) -> Union[int, List[int]]:
253+
"""
254+
Parse the devices argument which can be either a single integer or a list of
255+
integers.
256+
257+
Parameters
258+
----------
259+
value : str
260+
The input string representing either a single integer or a list of
261+
integers.
262+
263+
Returns
264+
-------
265+
Union[int, List[int]]
266+
Either a single integer or a list of integers.
267+
"""
268+
import ast
269+
270+
# if multiple comma delimited values are passed, split them into a list
271+
if value.startswith("[") and value.endswith("]"):
272+
# Safely evaluate the string as a Python literal (list of ints)
273+
return list(ast.literal_eval(value))
274+
else:
275+
return int(value)

scripts/perform_training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import argparse
44
from modelforge.train.training import read_config_and_train
55
from typing import Union, List
6+
from modelforge.utils.io import parse_devices
67

78
parse = argparse.ArgumentParser(description="Perform Training Using Modelforge")
89

@@ -25,7 +26,7 @@
2526
)
2627
parse.add_argument("--accelerator", type=str, help="Accelerator to use for training")
2728
parse.add_argument(
28-
"--devices", type=Union[int, List[int]], help="Device(s) to use for training"
29+
"--devices", type=parse_devices, help="Device(s) to use for training"
2930
)
3031
parse.add_argument(
3132
"--number_of_nodes", type=int, help="Number of nodes to use for training"

scripts/training_run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# training run with a small number of epochs with default
33
# parameters
44

5-
python perform_training.py config.toml
5+
python perform_training.py --condensed_config_path config.toml

0 commit comments

Comments
 (0)