-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
79 lines (63 loc) · 2.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import logging
import numpy as np
import torch
import torch.profiler
from pytorch_lightning.loggers import TensorBoardLogger
from rydberggpt.training.logger import setup_logger
from rydberggpt.training.train import train
from rydberggpt.utils import create_config_from_yaml, load_yaml_file
torch.set_float32_matmul_precision("medium")
def setup_environment(config):
torch.manual_seed(config.seed)
np.random.seed(config.seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
config.device = device
def load_configuration(config_path: str, config_name: str):
yaml_dict = load_yaml_file(config_path, config_name)
return create_config_from_yaml(yaml_dict)
def main(config_path: str, config_name: str, dataset_path: str):
config = load_configuration(config_path, config_name)
setup_environment(config)
tensorboard_logger = TensorBoardLogger(save_dir="logs")
tensorboard_logger.log_hyperparams(vars(config))
log_path = f"logs/lightning_logs/version_{tensorboard_logger.version}"
logging.info(f"Log path: {log_path}")
setup_logger(log_path)
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
logging.info(f"Found {num_gpus} GPUs.")
else:
logging.info("No GPUs found.")
train(config, dataset_path, tensorboard_logger, log_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run deep learning model with specified config."
)
parser.add_argument(
"--config_name",
default="config_small",
help="Name of the configuration file without the .yaml extension. (default: small)",
)
parser.add_argument(
"--config_path",
default="config/",
help="Path to the configuration file. (default: config/)",
)
parser.add_argument(
"--dataset_path",
default="dataset_test/",
help="Name of the configuration file without the .yaml extension. (default: small)",
)
parser.add_argument(
"--devices",
type=int,
default=0,
help="Number of devices (GPUs) to use. (default: 0)",
)
args = parser.parse_args()
main(
config_path=args.config_path,
config_name=args.config_name,
dataset_path=args.dataset_path,
)