@@ -1197,15 +1197,15 @@ def _add_tags(self, tags: List[str]) -> List[str]:
1197
1197
# add potential name
1198
1198
tags .append (self .potential_parameter .potential_name )
1199
1199
# add information about what is included in the loss
1200
- str_loss_property = "-" .join (self .training_parameter .loss_parameter .loss_property )
1200
+ str_loss_property = "-" .join (
1201
+ self .training_parameter .loss_parameter .loss_property
1202
+ )
1201
1203
tags .append (f"loss-{ str_loss_property } " )
1202
1204
1203
1205
return tags
1204
1206
1205
1207
1206
1208
from typing import List , Optional , Union
1207
-
1208
-
1209
1209
def read_config (
1210
1210
condensed_config_path : Optional [str ] = None ,
1211
1211
training_parameter_path : Optional [str ] = None ,
@@ -1223,147 +1223,81 @@ def read_config(
1223
1223
simulation_environment : Optional [str ] = None ,
1224
1224
):
1225
1225
"""
1226
- Reads one or more TOML configuration files and loads them into the pydantic models
1227
-
1226
+ Reads one or more TOML configuration files and loads them into the pydantic models.
1227
+
1228
1228
Parameters
1229
1229
----------
1230
- condensed_config_path : str, optional
1231
- Path to the TOML configuration that contains all parameters for the dataset, potential, training, and runtime parameters.
1232
- Any other provided configuration files will be ignored.
1233
- training_parameter_path : str, optional
1234
- Path to the TOML file defining the training parameters.
1235
- dataset_parameter_path : str, optional
1236
- Path to the TOML file defining the dataset parameters.
1237
- potential_parameter_path : str, optional
1238
- Path to the TOML file defining the potential parameters.
1239
- runtime_parameter_path : str, optional
1240
- Path to the TOML file defining the runtime parameters. If this is not provided, the code will attempt to use
1241
- the runtime parameters provided as arguments.
1242
- accelerator : str, optional
1243
- Accelerator type to use. If provided, this overrides the accelerator type in the runtime_defaults configuration.
1244
- devices : int|List[int], optional
1245
- Device index/indices to use. If provided, this overrides the devices in the runtime_defaults configuration.
1246
- number_of_nodes : int, optional
1247
- Number of nodes to use. If provided, this overrides the number of nodes in the runtime_defaults configuration.
1248
- experiment_name : str, optional
1249
- Name of the experiment. If provided, this overrides the experiment name in the runtime_defaults configuration.
1250
- save_dir : str, optional
1251
- Directory to save the model. If provided, this overrides the save directory in the runtime_defaults configuration.
1252
- local_cache_dir : str, optional
1253
- Local cache directory. If provided, this overrides the local cache directory in the runtime_defaults configuration.
1254
- checkpoint_path : str, optional
1255
- Path to the checkpoint file. If provided, this overrides the checkpoint path in the runtime_defaults configuration.
1256
- log_every_n_steps : int, optional
1257
- Number of steps to log. If provided, this overrides the log_every_n_steps in the runtime_defaults configuration.
1258
- simulation_environment : str, optional
1259
- Simulation environment. If provided, this overrides the simulation environment in the runtime_defaults configuration.
1230
+ (Parameters as described earlier...)
1260
1231
1261
1232
Returns
1262
1233
-------
1263
1234
Tuple
1264
1235
Tuple containing the training, dataset, potential, and runtime parameters.
1265
-
1266
1236
"""
1267
1237
import toml
1268
1238
1269
- use_runtime_variables_instead_of_toml = False
1239
+ # Initialize the config dictionaries
1240
+ training_config_dict = {}
1241
+ dataset_config_dict = {}
1242
+ potential_config_dict = {}
1243
+ runtime_config_dict = {}
1244
+
1270
1245
if condensed_config_path is not None :
1271
1246
config = toml .load (condensed_config_path )
1272
1247
log .info (f"Reading config from : { condensed_config_path } " )
1273
1248
1274
- training_config_dict = config [ "training" ]
1275
- dataset_config_dict = config [ "dataset" ]
1276
- potential_config_dict = config [ "potential" ]
1277
- runtime_config_dict = config [ "runtime" ]
1249
+ training_config_dict = config . get ( "training" , {})
1250
+ dataset_config_dict = config . get ( "dataset" , {})
1251
+ potential_config_dict = config . get ( "potential" , {})
1252
+ runtime_config_dict = config . get ( "runtime" , {})
1278
1253
1279
1254
else :
1280
- if training_parameter_path is None :
1281
- raise ValueError ("Training configuration not provided." )
1282
- if dataset_parameter_path is None :
1283
- raise ValueError ("Dataset configuration not provided." )
1284
- if potential_parameter_path is None :
1285
- raise ValueError ("Potential configuration not provided." )
1286
-
1287
- training_config_dict = toml .load (training_parameter_path )["training" ]
1288
- dataset_config_dict = toml .load (dataset_parameter_path )["dataset" ]
1289
- potential_config_dict = toml .load (potential_parameter_path )["potential" ]
1290
-
1291
- # if the runtime_parameter_path is not defined, let us see if runtime variables are passed
1292
- if runtime_parameter_path is None :
1293
- use_runtime_variables_instead_of_toml = True
1294
- log .info (
1295
- "Runtime configuration not provided. The code will try to use runtime arguments."
1296
- )
1297
- # we can just create a dict with the runtime variables; the pydantic model will then validate them
1298
- runtime_config_dict = {
1299
- "save_dir" : save_dir ,
1300
- "experiment_name" : experiment_name ,
1301
- "local_cache_dir" : local_cache_dir ,
1302
- "checkpoint_path" : checkpoint_path ,
1303
- "log_every_n_steps" : log_every_n_steps ,
1304
- "simulation_environment" : simulation_environment ,
1305
- "accelerator" : accelerator ,
1306
- "devices" : devices ,
1307
- "number_of_nodes" : number_of_nodes ,
1308
- }
1309
- else :
1310
- runtime_config_dict = toml .load (runtime_parameter_path )["runtime" ]
1311
-
1255
+ if training_parameter_path :
1256
+ training_config_dict = toml .load (training_parameter_path ).get ("training" , {})
1257
+ if dataset_parameter_path :
1258
+ dataset_config_dict = toml .load (dataset_parameter_path ).get ("dataset" , {})
1259
+ if potential_parameter_path :
1260
+ potential_config_dict = toml .load (potential_parameter_path ).get ("potential" , {})
1261
+ if runtime_parameter_path :
1262
+ runtime_config_dict = toml .load (runtime_parameter_path ).get ("runtime" , {})
1263
+
1264
+ # Override runtime configuration with command-line arguments if provided
1265
+ runtime_overrides = {
1266
+ "accelerator" : accelerator ,
1267
+ "devices" : devices ,
1268
+ "number_of_nodes" : number_of_nodes ,
1269
+ "experiment_name" : experiment_name ,
1270
+ "save_dir" : save_dir ,
1271
+ "local_cache_dir" : local_cache_dir ,
1272
+ "checkpoint_path" : checkpoint_path ,
1273
+ "log_every_n_steps" : log_every_n_steps ,
1274
+ "simulation_environment" : simulation_environment ,
1275
+ }
1276
+
1277
+ for key , value in runtime_overrides .items ():
1278
+ if value is not None :
1279
+ runtime_config_dict [key ] = value
1280
+
1281
+ # Load and instantiate the data classes with the merged configuration
1312
1282
from modelforge .potential import _Implemented_NNP_Parameters
1313
1283
from modelforge .dataset .dataset import DatasetParameters
1314
1284
from modelforge .train .parameters import TrainingParameters , RuntimeParameters
1315
1285
1316
1286
potential_name = potential_config_dict ["potential_name" ]
1317
- PotentialParameters = (
1318
- _Implemented_NNP_Parameters .get_neural_network_parameter_class (potential_name )
1319
- )
1287
+ PotentialParameters = _Implemented_NNP_Parameters .get_neural_network_parameter_class (potential_name )
1320
1288
1321
1289
dataset_parameters = DatasetParameters (** dataset_config_dict )
1322
1290
training_parameters = TrainingParameters (** training_config_dict )
1323
1291
runtime_parameters = RuntimeParameters (** runtime_config_dict )
1324
- potential_parameter_paths = PotentialParameters (** potential_config_dict )
1325
-
1326
- # if accelerator, devices, or number_of_nodes are provided, override the runtime_defaults parameters
1327
- # note, since these are being set in the runtime data model, they will be validated by the model
1328
- # if we use the runtime variables instead of the toml file, these have already been set so we can skip this step.
1329
-
1330
- if use_runtime_variables_instead_of_toml == False :
1331
- if accelerator :
1332
- runtime_parameters .accelerator = accelerator
1333
- log .info (f"Using accelerator: { accelerator } " )
1334
- if devices :
1335
- runtime_parameters .device_index = devices
1336
- log .info (f"Using device index: { devices } " )
1337
- if number_of_nodes :
1338
- runtime_parameters .number_of_nodes = number_of_nodes
1339
- log .info (f"Using number of nodes: { number_of_nodes } " )
1340
- if experiment_name :
1341
- runtime_parameters .experiment_name = experiment_name
1342
- log .info (f"Using experiment name: { experiment_name } " )
1343
- if save_dir :
1344
- runtime_parameters .save_dir = save_dir
1345
- log .info (f"Using save directory: { save_dir } " )
1346
- if local_cache_dir :
1347
- runtime_parameters .local_cache_dir = local_cache_dir
1348
- log .info (f"Using local cache directory: { local_cache_dir } " )
1349
- if checkpoint_path :
1350
- runtime_parameters .checkpoint_path = checkpoint_path
1351
- log .info (f"Using checkpoint path: { checkpoint_path } " )
1352
- if log_every_n_steps :
1353
- runtime_parameters .log_every_n_steps = log_every_n_steps
1354
- log .info (f"Logging every { log_every_n_steps } steps." )
1355
- if simulation_environment :
1356
- runtime_parameters .simulation_environment = simulation_environment
1357
- log .info (f"Using simulation environment: { simulation_environment } " )
1292
+ potential_parameter = PotentialParameters (** potential_config_dict )
1358
1293
1359
1294
return (
1360
1295
training_parameters ,
1361
1296
dataset_parameters ,
1362
- potential_parameter_paths ,
1297
+ potential_parameter ,
1363
1298
runtime_parameters ,
1364
1299
)
1365
1300
1366
-
1367
1301
def read_config_and_train (
1368
1302
condensed_config_path : Optional [str ] = None ,
1369
1303
training_parameter_path : Optional [str ] = None ,
0 commit comments