-
Notifications
You must be signed in to change notification settings - Fork 7
/
main.py
273 lines (225 loc) · 9.89 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import time
import importlib
import threading
import logging
from absl import app
from absl import flags
from pysc2 import maps
from pysc2.env import available_actions_printer
from pysc2.env import sc2_env
from pysc2.lib import stopwatch
import tensorflow as tf
from run_loop import run_loop
from pysc2.env import run_loop as pysc2_run_loop
import numpy as np
COUNTER = 0
LOCK = threading.Lock()
FLAGS = flags.FLAGS
flags.DEFINE_bool("training", False, "Whether to train agents.")
flags.DEFINE_bool("continuation", False, "Continuously training.")
flags.DEFINE_float("learning_rate", 5e-4, "Learning rate for training.")
flags.DEFINE_float("discount", 0.99, "Discount rate for future rewards.")
flags.DEFINE_integer("max_steps", int(1e5), "Total steps for training.")
flags.DEFINE_integer("snapshot_step", int(1e3), "Step for snapshot.")
flags.DEFINE_string("snapshot_path", "./snapshot/", "Path for snapshot.")
flags.DEFINE_string("log_path", "./log/", "Path for log.")
flags.DEFINE_string("device", "0", "Device for training.")
flags.DEFINE_string("map", "MoveToBeacon", "Name of a map to use.")
flags.DEFINE_bool("render", True, "Whether to render with pygame.")
flags.DEFINE_integer("screen_resolution", 64, "Resolution for screen feature layers.")
flags.DEFINE_integer("minimap_resolution", 64, "Resolution for minimap feature layers.")
flags.DEFINE_integer("step_mul", 8, "Game steps per agent step.")
flags.DEFINE_string("agent", "rl_agents.a3c_agent.A3CAgent", "Which agent to run.")
flags.DEFINE_string("net", "atari", "atari or fcn.")
flags.DEFINE_enum("agent_race", None, sc2_env.races.keys(), "Agent's race.")
flags.DEFINE_enum("bot_race", None, sc2_env.races.keys(), "Bot's race.")
flags.DEFINE_enum("difficulty", None, sc2_env.difficulties.keys(), "Bot's strength.")
flags.DEFINE_integer("max_agent_steps", 60, "Total agent steps.")
flags.DEFINE_bool("profile", False, "Whether to turn on code profiling.")
flags.DEFINE_bool("trace", False, "Whether to trace the code execution.")
flags.DEFINE_integer("parallel", 1, "How many instances to run in parallel.")
flags.DEFINE_bool("save_replay", False, "Whether to save a replay at the end.")
# Useful to choose number of subpolicies selected from by MLSH master controller
flags.DEFINE_integer("num_subpol", 2, "Number of subpolicies used for MLSH.")
flags.DEFINE_integer("subpol_steps", 10, "Number of subpolicies used for MLSH.")
# original flag not included by xhujoy but useful:
flags.DEFINE_integer("game_steps_per_episode", 0, "Game steps per episode.")
flags.DEFINE_integer("warmup_len", 100, "Number of episodes for warm up period of training master policy.")
flags.DEFINE_integer("joint_len", 500, "Number of episodes after warm up for training master and subpolicies.")
FLAGS(sys.argv)
if FLAGS.training:
PARALLEL = FLAGS.parallel
MAX_AGENT_STEPS = FLAGS.max_agent_steps
DEVICE = ['/gpu:'+dev for dev in FLAGS.device.split(',')]
else:
PARALLEL = 1
MAX_AGENT_STEPS = 1e5
DEVICE = ['/cpu:0']
if FLAGS.agent.rsplit(".", 1)[-1] != 'MLSHAgent':
LOG = FLAGS.log_path+FLAGS.map+'/'+FLAGS.net
SNAPSHOT = FLAGS.snapshot_path+FLAGS.map+'/'+FLAGS.net
else:
LOG = FLAGS.log_path+'MLSHAgent'+'/'+FLAGS.net
SNAPSHOT = FLAGS.snapshot_path+'MLSHAgent'+'/'+FLAGS.net
if not os.path.exists(LOG):
os.makedirs(LOG)
if not os.path.exists(SNAPSHOT):
os.makedirs(SNAPSHOT)
MLSH_TRAIN_MAPS = ["MoveToBeacon", "CollectMineralShards", "DefeatRoaches", "FindAndDefeatZerglings"]
logger = logging.getLogger('starcraft_agent')
logger.setLevel(logging.DEBUG)
# create file handler which logs even debug messages
fh = logging.FileHandler(LOG + '/main.log')
fh.setLevel(logging.DEBUG)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
# create formatter and add it to the handlers
formatter = logging.Formatter('[%(asctime)s %(levelname)s] [%(threadName)s]\t %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)
# add the handlers to the logger
logger.addHandler(fh)
logger.addHandler(ch)
def pysc2_run_thread(agent_cls, map_name, visualize):
"""Original version of run_thread used for most agents, from pysc2.bin.agent"""
with sc2_env.SC2Env(
map_name=map_name,
agent_race=FLAGS.agent_race,
bot_race=FLAGS.bot_race,
difficulty=FLAGS.difficulty,
step_mul=FLAGS.step_mul,
game_steps_per_episode=FLAGS.game_steps_per_episode,
screen_size_px=(FLAGS.screen_resolution, FLAGS.screen_resolution),
minimap_size_px=(FLAGS.minimap_resolution, FLAGS.minimap_resolution),
visualize=visualize) as env:
env = available_actions_printer.AvailableActionsPrinter(env)
agent = agent_cls()
pysc2_run_loop.run_loop([agent], env, FLAGS.max_agent_steps)
if FLAGS.save_replay:
env.save_replay(agent_cls.__name__)
def run_thread(agent, map_name, visualize, mlsh=False):
scores = list()
logger.info('Launching new SC2 environment...')
with sc2_env.SC2Env(
map_name=map_name,
agent_race=FLAGS.agent_race,
bot_race=FLAGS.bot_race,
difficulty=FLAGS.difficulty,
step_mul=FLAGS.step_mul,
screen_size_px=(FLAGS.screen_resolution, FLAGS.screen_resolution),
minimap_size_px=(FLAGS.minimap_resolution, FLAGS.minimap_resolution),
visualize=visualize) as env:
env = available_actions_printer.AvailableActionsPrinter(env)
logger.info('New SC2 environment launched successfully')
logger.info('Minigame: %s', map_name)
ep_counter = 0 # counts episode for this particular thread
replay_buffer = [] # will get observations of each step during an episode to learn once episode is done
for recorder, is_done in run_loop([agent], env, MAX_AGENT_STEPS, mlsh=mlsh, warmup=FLAGS.warmup_len, joint=FLAGS.joint_len):
if FLAGS.training:
replay_buffer.append(recorder)
if is_done:
# end of an episode, agent has interacted with env and now we learn from the "replay"
counter = 0
with LOCK:
# counter counts episode accross all threads:
global COUNTER
COUNTER += 1
counter = COUNTER
# Learning rate schedule
learning_rate = FLAGS.learning_rate * (1 - 0.9 * counter / FLAGS.max_steps)
agent.update(replay_buffer, FLAGS.discount, learning_rate, counter)
replay_buffer = []
if counter % FLAGS.snapshot_step == 1:
logger.info('Saving model to %s', SNAPSHOT)
agent.save_model(SNAPSHOT, counter)
if counter >= FLAGS.max_steps:
break
if is_done:
ep_counter += 1
obs = recorder[-1].observation
score = obs["score_cumulative"][0]
scores.append(score)
# ep_counter is
logger.info('[Episode %s] Episode score: %.2f, mean score: %.2f, max score: %.2f',
ep_counter, score, np.mean(scores[-300:]), np.max(scores))
if FLAGS.save_replay:
env.save_replay(agent.name)
def _main(unused_argv):
"""Run agents"""
stopwatch.sw.enabled = FLAGS.profile or FLAGS.trace
stopwatch.sw.trace = FLAGS.trace
maps.get(FLAGS.map) # Assert the map exists.
logger.info('Launching main script')
# Setup agents
agent_module, agent_name = FLAGS.agent.rsplit(".", 1)
agent_cls = getattr(importlib.import_module(agent_module), agent_name)
logger.info('Creating %s agents of type %s', PARALLEL, agent_name)
if agent_name == "A3CAgent" or agent_name == "MLSHAgent":
# these agents cannot be initiated similarly to classic agents
mlsh = (agent_name == "MLSHAgent")
agents = []
for i in range(PARALLEL):
if agent_name == "A3CAgent":
agent = agent_cls(FLAGS.training, FLAGS.minimap_resolution, FLAGS.screen_resolution)
else: # i.e. MLSHAgent
agent = agent_cls(FLAGS.training, FLAGS.minimap_resolution, FLAGS.screen_resolution, FLAGS.num_subpol, FLAGS.subpol_steps, i+1)
agent.build_model(i > 0, DEVICE[i % len(DEVICE)], FLAGS.net)
agents.append(agent)
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# setup tensorflow logs
summary_writer = tf.summary.FileWriter(LOG)
for i in range(PARALLEL):
agents[i].setup(sess, summary_writer)
agent.initialize()
if not FLAGS.training or FLAGS.continuation:
global COUNTER
COUNTER = agent.load_model(SNAPSHOT)
logger.info('Loaded model, starting from step %s', COUNTER)
# Run threads
threads = []
for i in range(PARALLEL - 1):
if not mlsh:
t = threading.Thread(target=run_thread, args=(agents[i], FLAGS.map, False, mlsh))
else: # i.e. MLSHAgent
# Create agents on different minigames for each thread
minigame = MLSH_TRAIN_MAPS[i % len(MLSH_TRAIN_MAPS)]
# logger.info("[Thread %s] Minigame: ", minigame)
t = threading.Thread(target=run_thread, args=(agents[i], minigame, False, mlsh))
threads.append(t)
t.daemon = True
t.start()
time.sleep(5)
if not mlsh:
run_thread(agents[-1], FLAGS.map, FLAGS.render)
else: # i.e. MLSHAgent
minigame = MLSH_TRAIN_MAPS[(len(agents) - 1) % len(MLSH_TRAIN_MAPS)]
# logger.info("[Main thread] Minigame: %s", minigame)
run_thread(agents[-1], minigame, FLAGS.render, mlsh=mlsh)
logger.info('All threads created')
for t in threads:
t.join()
if FLAGS.profile:
print(stopwatch.sw)
else:
# other agents just call the usual main loop from pysc2
threads = []
for _ in range(FLAGS.parallel - 1):
t = threading.Thread(target=pysc2_run_thread, args=(agent_cls, FLAGS.map, False))
threads.append(t)
t.start()
pysc2_run_thread(agent_cls, FLAGS.map, FLAGS.render)
for t in threads:
t.join()
if FLAGS.profile:
print(stopwatch.sw)
logger.info('All threads created')
if __name__ == "__main__":
app.run(_main)