-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
77 lines (64 loc) · 2.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Main training script.
Examples
--------
python train.py --directory=test --problem=conv_train
Arguments
---------
--vgpu : int >= 1
(debug) Number of virtual GPUs to create for testing. If 1, no virtual GPUs
are created, and a mirrored strategy is created with all physical GPUs.
--vram : int
Amount of vram to allocate per virtual GPU if testing.
--cpu : bool
Whether to run on CPU instead of GPU.
--gpus : int[]
Comma separated list of GPU indices to use on a multi-gpu system.
--initialize : bool
If True, only initializes and does not actually train
--strategy : str
Strategy to use.
--policy : str
Policy to train.
--presets : str[]
Commaa separated list of presets to apply.
(all other args)
Passed as overrides to strategy/policy building.
"""
import os
import sys
from config import get_default, get_preset, ArgParser
args = ArgParser(sys.argv[1:])
# Finally ready to import tensorflow
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
import l2o
from gpu_setup import create_distribute
# Directory
directory = args.pop_get("--directory", default="weights")
# Distribute
vgpus = int(args.pop_get("--vgpu", default=1))
memory_limit = int(args.pop_get("--vram", default=12000))
gpus = args.pop_get("--gpus", default=None)
cpu = args.pop_get("--cpu", default=False, dtype=bool)
distribute = create_distribute(
vgpus=vgpus, memory_limit=memory_limit, gpus=gpus, do_cpu=cpu)
# Pick up flags first
initialize_only = args.pop_check("--initialize")
# Default params
strategy = args.pop_get("--strategy", "repeat")
policy = args.pop_get("--policy", "rnnprop")
default = get_default(strategy=strategy, policy=policy)
# Build overrides
presets = args.pop_get("--presets", "")
overrides = []
if presets != "":
for p in presets.split(','):
overrides += get_preset(p)
overrides += args.to_overrides()
with distribute.scope():
# Build strategy
strategy = l2o.build(
default, overrides, directory=directory, strict=True)
# Train if not --initialize
if not initialize_only:
strategy.train()