Skip to content

Commit

Permalink
Merge pull request #229 from choderalab/fix-device-command-line-argument
Browse files Browse the repository at this point in the history
Fix command line argument
  • Loading branch information
wiederm authored Aug 13, 2024
2 parents abf8076 + dedf9d4 commit 4711e7e
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 116 deletions.
3 changes: 3 additions & 0 deletions modelforge/tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def test_download_from_url(prep_temp_dir):
)


@pytest.mark.skip(
reason="This test seems to time out on the CI frequently. Will be refactoring and not need this soon."
)
def test_download_from_figshare(prep_temp_dir):
url = "https://figshare.com/ndownloader/files/22247589"
name = download_from_figshare(
Expand Down
5 changes: 3 additions & 2 deletions modelforge/train/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def device_index_must_be_positive(cls, v) -> Union[int, List[int]]:
for device in v:
if device < 0:
raise ValueError("device_index must be positive")
if v < 0:
raise ValueError("device_index must be positive")
else:
if v < 0:
raise ValueError("device_index must be positive")
return v
158 changes: 46 additions & 112 deletions modelforge/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,15 +1197,15 @@ def _add_tags(self, tags: List[str]) -> List[str]:
# add potential name
tags.append(self.potential_parameter.potential_name)
# add information about what is included in the loss
str_loss_property = "-".join(self.training_parameter.loss_parameter.loss_property)
str_loss_property = "-".join(
self.training_parameter.loss_parameter.loss_property
)
tags.append(f"loss-{str_loss_property}")

return tags


from typing import List, Optional, Union


def read_config(
condensed_config_path: Optional[str] = None,
training_parameter_path: Optional[str] = None,
Expand All @@ -1223,147 +1223,81 @@ def read_config(
simulation_environment: Optional[str] = None,
):
"""
Reads one or more TOML configuration files and loads them into the pydantic models
Reads one or more TOML configuration files and loads them into the pydantic models.
Parameters
----------
condensed_config_path : str, optional
Path to the TOML configuration that contains all parameters for the dataset, potential, training, and runtime parameters.
Any other provided configuration files will be ignored.
training_parameter_path : str, optional
Path to the TOML file defining the training parameters.
dataset_parameter_path : str, optional
Path to the TOML file defining the dataset parameters.
potential_parameter_path : str, optional
Path to the TOML file defining the potential parameters.
runtime_parameter_path : str, optional
Path to the TOML file defining the runtime parameters. If this is not provided, the code will attempt to use
the runtime parameters provided as arguments.
accelerator : str, optional
Accelerator type to use. If provided, this overrides the accelerator type in the runtime_defaults configuration.
devices : int|List[int], optional
Device index/indices to use. If provided, this overrides the devices in the runtime_defaults configuration.
number_of_nodes : int, optional
Number of nodes to use. If provided, this overrides the number of nodes in the runtime_defaults configuration.
experiment_name : str, optional
Name of the experiment. If provided, this overrides the experiment name in the runtime_defaults configuration.
save_dir : str, optional
Directory to save the model. If provided, this overrides the save directory in the runtime_defaults configuration.
local_cache_dir : str, optional
Local cache directory. If provided, this overrides the local cache directory in the runtime_defaults configuration.
checkpoint_path : str, optional
Path to the checkpoint file. If provided, this overrides the checkpoint path in the runtime_defaults configuration.
log_every_n_steps : int, optional
Number of steps to log. If provided, this overrides the log_every_n_steps in the runtime_defaults configuration.
simulation_environment : str, optional
Simulation environment. If provided, this overrides the simulation environment in the runtime_defaults configuration.
(Parameters as described earlier...)
Returns
-------
Tuple
Tuple containing the training, dataset, potential, and runtime parameters.
"""
import toml

use_runtime_variables_instead_of_toml = False
# Initialize the config dictionaries
training_config_dict = {}
dataset_config_dict = {}
potential_config_dict = {}
runtime_config_dict = {}

if condensed_config_path is not None:
config = toml.load(condensed_config_path)
log.info(f"Reading config from : {condensed_config_path}")

training_config_dict = config["training"]
dataset_config_dict = config["dataset"]
potential_config_dict = config["potential"]
runtime_config_dict = config["runtime"]
training_config_dict = config.get("training", {})
dataset_config_dict = config.get("dataset", {})
potential_config_dict = config.get("potential", {})
runtime_config_dict = config.get("runtime", {})

else:
if training_parameter_path is None:
raise ValueError("Training configuration not provided.")
if dataset_parameter_path is None:
raise ValueError("Dataset configuration not provided.")
if potential_parameter_path is None:
raise ValueError("Potential configuration not provided.")

training_config_dict = toml.load(training_parameter_path)["training"]
dataset_config_dict = toml.load(dataset_parameter_path)["dataset"]
potential_config_dict = toml.load(potential_parameter_path)["potential"]

# if the runtime_parameter_path is not defined, let us see if runtime variables are passed
if runtime_parameter_path is None:
use_runtime_variables_instead_of_toml = True
log.info(
"Runtime configuration not provided. The code will try to use runtime arguments."
)
# we can just create a dict with the runtime variables; the pydantic model will then validate them
runtime_config_dict = {
"save_dir": save_dir,
"experiment_name": experiment_name,
"local_cache_dir": local_cache_dir,
"checkpoint_path": checkpoint_path,
"log_every_n_steps": log_every_n_steps,
"simulation_environment": simulation_environment,
"accelerator": accelerator,
"devices": devices,
"number_of_nodes": number_of_nodes,
}
else:
runtime_config_dict = toml.load(runtime_parameter_path)["runtime"]

if training_parameter_path:
training_config_dict = toml.load(training_parameter_path).get("training", {})
if dataset_parameter_path:
dataset_config_dict = toml.load(dataset_parameter_path).get("dataset", {})
if potential_parameter_path:
potential_config_dict = toml.load(potential_parameter_path).get("potential", {})
if runtime_parameter_path:
runtime_config_dict = toml.load(runtime_parameter_path).get("runtime", {})

# Override runtime configuration with command-line arguments if provided
runtime_overrides = {
"accelerator": accelerator,
"devices": devices,
"number_of_nodes": number_of_nodes,
"experiment_name": experiment_name,
"save_dir": save_dir,
"local_cache_dir": local_cache_dir,
"checkpoint_path": checkpoint_path,
"log_every_n_steps": log_every_n_steps,
"simulation_environment": simulation_environment,
}

for key, value in runtime_overrides.items():
if value is not None:
runtime_config_dict[key] = value

# Load and instantiate the data classes with the merged configuration
from modelforge.potential import _Implemented_NNP_Parameters
from modelforge.dataset.dataset import DatasetParameters
from modelforge.train.parameters import TrainingParameters, RuntimeParameters

potential_name = potential_config_dict["potential_name"]
PotentialParameters = (
_Implemented_NNP_Parameters.get_neural_network_parameter_class(potential_name)
)
PotentialParameters = _Implemented_NNP_Parameters.get_neural_network_parameter_class(potential_name)

dataset_parameters = DatasetParameters(**dataset_config_dict)
training_parameters = TrainingParameters(**training_config_dict)
runtime_parameters = RuntimeParameters(**runtime_config_dict)
potential_parameter_paths = PotentialParameters(**potential_config_dict)

# if accelerator, devices, or number_of_nodes are provided, override the runtime_defaults parameters
# note, since these are being set in the runtime data model, they will be validated by the model
# if we use the runtime variables instead of the toml file, these have already been set so we can skip this step.

if use_runtime_variables_instead_of_toml == False:
if accelerator:
runtime_parameters.accelerator = accelerator
log.info(f"Using accelerator: {accelerator}")
if devices:
runtime_parameters.device_index = devices
log.info(f"Using device index: {devices}")
if number_of_nodes:
runtime_parameters.number_of_nodes = number_of_nodes
log.info(f"Using number of nodes: {number_of_nodes}")
if experiment_name:
runtime_parameters.experiment_name = experiment_name
log.info(f"Using experiment name: {experiment_name}")
if save_dir:
runtime_parameters.save_dir = save_dir
log.info(f"Using save directory: {save_dir}")
if local_cache_dir:
runtime_parameters.local_cache_dir = local_cache_dir
log.info(f"Using local cache directory: {local_cache_dir}")
if checkpoint_path:
runtime_parameters.checkpoint_path = checkpoint_path
log.info(f"Using checkpoint path: {checkpoint_path}")
if log_every_n_steps:
runtime_parameters.log_every_n_steps = log_every_n_steps
log.info(f"Logging every {log_every_n_steps} steps.")
if simulation_environment:
runtime_parameters.simulation_environment = simulation_environment
log.info(f"Using simulation environment: {simulation_environment}")
potential_parameter = PotentialParameters(**potential_config_dict)

return (
training_parameters,
dataset_parameters,
potential_parameter_paths,
potential_parameter,
runtime_parameters,
)


def read_config_and_train(
condensed_config_path: Optional[str] = None,
training_parameter_path: Optional[str] = None,
Expand Down
29 changes: 29 additions & 0 deletions modelforge/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,32 @@ def check_import(module: str):

imported_module = import_(module)
del imported_module


from typing import Union, List


def parse_devices(value: str) -> Union[int, List[int]]:
"""
Parse the devices argument which can be either a single integer or a list of
integers.
Parameters
----------
value : str
The input string representing either a single integer or a list of
integers.
Returns
-------
Union[int, List[int]]
Either a single integer or a list of integers.
"""
import ast

# if multiple comma delimited values are passed, split them into a list
if value.startswith("[") and value.endswith("]"):
# Safely evaluate the string as a Python literal (list of ints)
return list(ast.literal_eval(value))
else:
return int(value)
3 changes: 2 additions & 1 deletion scripts/perform_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
from modelforge.train.training import read_config_and_train
from typing import Union, List
from modelforge.utils.io import parse_devices

parse = argparse.ArgumentParser(description="Perform Training Using Modelforge")

Expand All @@ -25,7 +26,7 @@
)
parse.add_argument("--accelerator", type=str, help="Accelerator to use for training")
parse.add_argument(
"--devices", type=Union[int, List[int]], help="Device(s) to use for training"
"--devices", type=parse_devices, help="Device(s) to use for training"
)
parse.add_argument(
"--number_of_nodes", type=int, help="Number of nodes to use for training"
Expand Down
2 changes: 1 addition & 1 deletion scripts/training_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# training run with a small number of epochs with default
# parameters

python perform_training.py config.toml
python perform_training.py --condensed_config_path config.toml

0 comments on commit 4711e7e

Please sign in to comment.