-
Notifications
You must be signed in to change notification settings - Fork 35
/
main.py
executable file
·50 lines (41 loc) · 1.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#!/usr/bin/env python2
from misc.util import Struct
import models
import trainers
import worlds
import logging
import numpy as np
import os
import random
import sys
import tensorflow as tf
import traceback
import yaml
def main():
config = configure()
world = worlds.load(config)
model = models.load(config)
trainer = trainers.load(config)
trainer.train(model, world)
def configure():
# load config
with open("config.yaml") as config_f:
config = Struct(**yaml.load(config_f))
# set up experiment
config.experiment_dir = os.path.join("experiments/%s" % config.name)
assert not os.path.exists(config.experiment_dir), \
"Experiment %s already exists!" % config.experiment_dir
os.mkdir(config.experiment_dir)
# set up logging
log_name = os.path.join(config.experiment_dir, "run.log")
logging.basicConfig(filename=log_name, level=logging.DEBUG,
format='%(asctime)s %(levelname)-8s %(message)s')
def handler(type, value, tb):
logging.exception("Uncaught exception: %s", str(value))
logging.exception("\n".join(traceback.format_exception(type, value, tb)))
sys.excepthook = handler
logging.info("BEGIN")
logging.info(str(config))
return config
if __name__ == "__main__":
main()