-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun.py
439 lines (384 loc) · 16.9 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
import argparse
import pickle
import random
import sys
import time
import traceback
from math import isnan
import cv2
import numpy as np
import psutil
import tensorflow as tf
from tensorflow import keras
try:
from gym.utils.play import play
except Exception as e:
print("The following exception is typical for servers because they don't have display stuff installed. "
"It only means that interactive --play won't work because `from gym.utils.play import play` failed with:")
traceback.print_exc()
print("You probably don't need --play on server, so let's continue.")
from atari_wrappers import wrap_deepmind, make_atari
from replay_buffer import ReplayBuffer
from tensor_board_logger import TensorBoardLogger
DISCOUNT_FACTOR_GAMMA = 0.99
LEARNING_RATE = 0.0001
UPDATE_EVERY = 4
BATCH_SIZE = 64
TARGET_UPDATE_EVERY = 10000
TRAIN_START = 10000
REPLAY_BUFFER_SIZE = 100000
MAX_STEPS = 10000000
SNAPSHOT_EVERY = 500000
EVAL_EVERY = 100000
EVAL_STEPS = 20000
EPSILON_START = 1.0
EPSILON_FINAL = 0.02
EPSILON_STEPS = 100000
LOG_EVERY = 10000
VALIDATION_SIZE = 500
SIDE_BOXES = 4
BOX_PIXELS = 84 // SIDE_BOXES
STRATEGY = 'future'
K_EXTRA_GOALS = 4
def box_start(x):
return (x // BOX_PIXELS) * BOX_PIXELS
def create_goal(position):
goal = np.zeros(shape=(84, 84, 1))
start_x, start_y = map(box_start, position)
goal[start_x:start_x + BOX_PIXELS, start_y:start_y + BOX_PIXELS, 0] = 255
return goal
def one_hot_encode(env, action):
one_hot = np.zeros(env.action_space.n)
one_hot[action] = 1
return one_hot
def predict(env, model, goals, observations):
frames_input = np.array(observations)
actions_input = np.ones((len(observations), env.action_space.n))
goals_input = np.array(goals)
return model.predict([frames_input, actions_input, goals_input])
def save_for_debug(env, model, target_model, batch):
model.save('model.h5')
target_model.save('target_model.h5')
pickle.dump((env, batch), open('debug.pkl', 'wb'))
def load_for_debug():
model = keras.models.load_model('model.h5')
target_model = keras.models.load_model('target_model.h5')
env, batch = pickle.load(open('debug.pkl', 'rb'))
return env, model, target_model, batch
def fit_batch(env, model, target_model, batch):
goals, observations, actions, rewards, next_observations, dones = batch
# Predict the Q values of the next states. Passing ones as the action mask.
next_q_values = predict(env, target_model, goals, next_observations)
# The Q values of terminal states is 0 by definition.
next_q_values[dones] = 0.0
# The Q values of each start state is the reward + gamma * the max next state Q value
q_values = rewards + DISCOUNT_FACTOR_GAMMA * np.max(next_q_values, axis=1)
# Passing the actions as the mask and multiplying the targets by the actions masks.
one_hot_actions = np.array([one_hot_encode(env, action) for action in actions])
history = model.fit(
x=[observations, one_hot_actions, goals],
y=one_hot_actions * q_values[:, None],
batch_size=BATCH_SIZE,
verbose=0,
)
loss = history.history['loss'][0]
if isnan(loss):
save_for_debug(env, model, target_model, batch)
print("loss is NaN, saved files for debug")
sys.exit(1)
return loss
def create_atari_model(env):
n_actions = env.action_space.n
obs_shape = env.observation_space.shape
print('n_actions {}'.format(n_actions))
print(' '.join(env.unwrapped.get_action_meanings()))
print('obs_shape {}'.format(obs_shape))
frames_input = keras.layers.Input(obs_shape, name='frames_input')
actions_input = keras.layers.Input((n_actions,), name='actions_input')
goals_input = keras.layers.Input((84, 84, 1), name='goals_input')
concatenated = keras.layers.concatenate([frames_input, goals_input])
# Assuming that the input frames are still encoded from 0 to 255. Transforming to [0, 1].
normalized = keras.layers.Lambda(lambda x: x / 255.0)(concatenated)
params = {
'activation': 'relu',
}
conv_1 = keras.layers.Conv2D(filters=32, kernel_size=8, strides=4, **params)(normalized)
conv_2 = keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, **params)(conv_1)
conv_3 = keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, **params)(conv_2)
conv_flattened = keras.layers.Flatten()(conv_3)
hidden = keras.layers.Dense(512, **params)(conv_flattened)
output = keras.layers.Dense(n_actions)(hidden)
filtered_output = keras.layers.multiply([output, actions_input])
model = keras.models.Model([frames_input, actions_input, goals_input], filtered_output)
optimizer = keras.optimizers.Adam(lr=LEARNING_RATE, clipnorm=0.1)
model.compile(optimizer, loss='mae')
return model
def epsilon_for_step(step):
return max(EPSILON_FINAL, (EPSILON_FINAL - EPSILON_START) / EPSILON_STEPS * step + EPSILON_START)
def greedy_action(env, model, goal, observation):
next_q_values = predict(env, model, goals=[goal], observations=[observation])
return np.argmax(next_q_values)
def epsilon_greedy_action(env, model, goal, observation, epsilon):
if random.random() < epsilon:
action = env.action_space.sample()
else:
action = greedy_action(env, model, goal, observation)
return action
def save_model(model, step, logdir, name):
filename = '{}/{}-{}.h5'.format(logdir, name, step)
model.save(filename)
print('Saved {}'.format(filename))
return filename
def save_image(env, episode, step):
frame = env.render(mode='rgb_array')
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # following cv2.imwrite assumes BGR
filename = "{}_{:06d}.png".format(episode, step)
cv2.imwrite(filename, frame, params=[cv2.IMWRITE_PNG_COMPRESSION, 9])
def evaluate(env, model, view=False, images=False, eval_steps=EVAL_STEPS):
done = True
episode = 0
episode_return_sum = 0.0
episode_return_min = float('inf')
episode_return_max = float('-inf')
for step in range(1, eval_steps):
if done:
if episode > 0:
print("eval episode {} steps {} return {}".format(
episode,
episode_steps,
episode_return,
))
episode_return_sum += episode_return
episode_return_min = min(episode_return_min, episode_return)
episode_return_max = max(episode_return_max, episode_return)
obs = env.reset()
episode += 1
episode_return = 0.0
episode_steps = 0
goal = sample_goal()
if view:
env.render()
if images:
save_image(env, episode, step)
else:
obs = next_obs
action = epsilon_greedy_action(env, model, goal, obs, EPSILON_FINAL)
next_obs, _, done, _ = env.step(action)
episode_return += goal_reward(next_obs, goal)
episode_steps += 1
if view:
env.render()
if images:
save_image(env, episode, step)
assert episode > 0
episode_return_avg = episode_return_sum / episode
return episode_return_avg, episode_return_min, episode_return_max
def find_agent(obs):
image = obs[:, :, -1]
indices = np.flatnonzero(image == 110)
if len(indices) == 0:
return None
index = indices[0]
x = index % 84
y = index // 84
return x, y
def goal_reward(obs, goal):
agent_position = find_agent(obs)
goal_reached = False
if agent_position is not None:
goal_reached = goal[agent_position] > 0
return float(goal_reached)
def final_goal(trajectory):
for experience in reversed(trajectory):
_, _, _, _, next_obs, _ = experience
agent = find_agent(next_obs)
if agent:
return create_goal(agent)
return None
def future_goals(i, trajectory):
goals = []
if i + 1 >= len(trajectory):
return None
steps = np.random.randint(i + 1, len(trajectory), K_EXTRA_GOALS)
for step in steps:
_, _, _, _, next_obs, _ = trajectory[step]
agent = find_agent(next_obs)
if agent:
goals.append(create_goal(agent))
return goals
def sample_goal():
position = np.random.randint(0, 84, 2)
return create_goal(position)
def train(env, env_eval, model, max_steps, name):
target_model = create_atari_model(env)
replay = ReplayBuffer(REPLAY_BUFFER_SIZE)
done = True
episode = 0
logdir = '{}-log'.format(name)
board = TensorBoardLogger(logdir)
print('Created {}'.format(logdir))
steps_after_logging = 0
loss = 0.0
for step in range(1, max_steps + 1):
try:
if step % SNAPSHOT_EVERY == 0:
save_model(model, step, logdir, name)
if done:
if episode > 0:
if STRATEGY == 'final':
extra_goals = [final_goal(trajectory)]
for i, experience in enumerate(trajectory):
goal, obs, action, reward, next_obs, done = experience
replay.add(goal, obs, action, reward, next_obs, done)
# Hindsight Experience Replay - add experiences with extra goals that were reached
if STRATEGY == 'future':
extra_goals = future_goals(i, trajectory)
if extra_goals:
for extra_goal in extra_goals:
replay.add(extra_goal, obs, action, goal_reward(next_obs, extra_goal), next_obs, done)
if steps_after_logging >= LOG_EVERY:
steps_after_logging = 0
episode_end = time.time()
episode_seconds = episode_end - episode_start
episode_steps = step - episode_start_step
steps_per_second = episode_steps / episode_seconds
memory = psutil.virtual_memory()
to_gb = lambda in_bytes: in_bytes / 1024 / 1024 / 1024
print(
"episode {} "
"steps {}/{} "
"loss {:.7f} "
"return {} "
"in {:.2f}s "
"{:.1f} steps/s "
"{:.1f}/{:.1f} GB RAM".format(
episode,
episode_steps,
step,
loss,
episode_return,
episode_seconds,
steps_per_second,
to_gb(memory.used),
to_gb(memory.total),
))
board.log_scalar('episode_return', episode_return, step)
board.log_scalar('episode_steps', episode_steps, step)
board.log_scalar('episode_seconds', episode_seconds, step)
board.log_scalar('steps_per_second', steps_per_second, step)
board.log_scalar('epsilon', epsilon_for_step(step), step)
board.log_scalar('memory_used', to_gb(memory.used), step)
board.log_scalar('loss', loss, step)
trajectory = []
goal = sample_goal()
episode_start = time.time()
episode_start_step = step
obs = env.reset()
episode += 1
episode_return = 0.0
epsilon = epsilon_for_step(step)
else:
obs = next_obs
action = epsilon_greedy_action(env, model, goal, obs, epsilon)
next_obs, _, done, _ = env.step(action)
reward = goal_reward(next_obs, goal)
episode_return += reward
trajectory.append((goal, obs, action, reward, next_obs, done))
if step >= TRAIN_START and step % UPDATE_EVERY == 0:
if step % TARGET_UPDATE_EVERY == 0:
target_model.set_weights(model.get_weights())
batch = replay.sample(BATCH_SIZE)
loss = fit_batch(env, model, target_model, batch)
if step == TRAIN_START:
validation_goals, validation_observations, _, _, _, _ = replay.sample(VALIDATION_SIZE)
if step >= TRAIN_START and step % EVAL_EVERY == 0:
episode_return_avg, episode_return_min, episode_return_max = evaluate(env_eval, model)
q_values = predict(env, model, validation_goals, validation_observations)
max_q_values = np.max(q_values, axis=1)
avg_max_q_value = np.mean(max_q_values)
print(
"episode {} "
"step {} "
"episode_return_avg {:.1f} "
"episode_return_min {:.1f} "
"episode_return_max {:.1f} "
"avg_max_q_value {:.1f}".format(
episode,
step,
episode_return_avg,
episode_return_min,
episode_return_max,
avg_max_q_value,
))
board.log_scalar('episode_return_avg', episode_return_avg, step)
board.log_scalar('episode_return_min', episode_return_min, step)
board.log_scalar('episode_return_max', episode_return_max, step)
board.log_scalar('avg_max_q_value', avg_max_q_value, step)
steps_after_logging += 1
except KeyboardInterrupt:
save_model(model, step, logdir, name)
break
def load_or_create_model(env, model_filename):
if model_filename:
model = keras.models.load_model(model_filename)
print('Loaded {}'.format(model_filename))
else:
model = create_atari_model(env)
model.summary()
return model
def set_seed(env, seed):
random.seed(seed)
np.random.seed(seed)
tf.set_random_seed(seed)
env.seed(seed)
def print_weights(model):
for layer in model.layers:
weights_list = layer.get_weights()
for weights in weights_list:
print(np.array2string(weights, threshold=100000000))
print()
print('--------------------------------------------------------------------')
print()
def main(args):
assert BATCH_SIZE <= TRAIN_START <= REPLAY_BUFFER_SIZE
assert TARGET_UPDATE_EVERY % UPDATE_EVERY == 0
assert 84 % SIDE_BOXES == 0
assert STRATEGY in ['final', 'future']
print(args)
env = make_atari('{}NoFrameskip-v4'.format(args.env))
set_seed(env, args.seed)
env_train = wrap_deepmind(env, frame_stack=True, episode_life=True, clip_rewards=True)
if args.weights:
model = load_or_create_model(env_train, args.model)
print_weights(model)
elif args.debug:
env, model, target_model, batch = load_for_debug()
fit_batch(env, model, target_model, batch)
elif args.play:
env = wrap_deepmind(env)
play(env)
else:
env_eval = wrap_deepmind(env, frame_stack=True)
model = load_or_create_model(env_train, args.model)
if args.view or args.images or args.eval:
evaluate(env_eval, model, args.view, args.images)
else:
max_steps = 100 if args.test else MAX_STEPS
train(env_train, env_eval, model, max_steps, args.name)
if args.test:
filename = save_model(model, EVAL_STEPS, logdir='.', name='test')
load_or_create_model(env_train, filename)
if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--debug', action='store_true', default=False, help='load debug files and run fit_batch with them')
parser.add_argument('--env', action='store', default='Breakout', help='Atari game name')
parser.add_argument('--eval', action='store_true', default=False, help='run evaluation with log only')
parser.add_argument('--images', action='store_true', default=False, help='save images during evaluation')
parser.add_argument('--model', action='store', default=None, help='model filename to load')
parser.add_argument('--name', action='store', default=time.strftime("%m-%d-%H-%M"), help='name for saved files')
parser.add_argument('--play', action='store_true', default=False, help='play with WSAD + Space')
parser.add_argument('--seed', action='store', type=int, help='pseudo random number generator seed')
parser.add_argument('--test', action='store_true', default=False, help='run tests')
parser.add_argument('--view', action='store_true', default=False, help='view evaluation in a window')
parser.add_argument('--weights', action='store_true', default=False, help='print model weights')
main(parser.parse_args())