-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
39 lines (28 loc) · 1.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
from argparse import Namespace
import yaml
from experiments.config import NestedLoader, cli_arguments
if __name__ == "__main__":
cli_args = cli_arguments()
if "config" in cli_args: # to (re)start training
config_path = cli_args["config"]
elif "model_dir" in cli_args: # to run inference
config_path = os.path.join(cli_args["model_dir"], "config.yaml")
with open(config_path, "r") as f:
args = yaml.load(f, NestedLoader)
# priority to command line arguments
args.update(cli_args)
args = Namespace(config=Namespace(**args), info=Namespace())
import pprint
print("#" * 79, "\nStarting a LagrangeBench run with the following configs:")
pprint.pprint(vars(args.config))
print("#" * 79)
# specify cuda device
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.config.gpu)
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(args.config.xla_mem_fraction)
if args.config.f64:
from jax import config
config.update("jax_enable_x64", True)
from experiments.run import train_or_infer
train_or_infer(args)