Skip to content

Commit ab23701

Browse files
committed
Created testing sender for Dagger
1 parent 65854b5 commit ab23701

File tree

4 files changed

+91
-2
lines changed

4 files changed

+91
-2
lines changed

a3c/run_sender.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from os import path
88
from env.sender import Sender
99
from models import ActorCriticNetwork
10-
from a3c import ewma
10+
from helpers.helpers import ewma
1111

1212

1313
class Learner(object):

dagger/dagger.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ def sample_action(self, step_state_buf):
294294
Appends to the state/action buffers the state and the
295295
"correct" action to take according to the expert.
296296
"""
297-
start_time = time.time()
297+
if self.is_chief:
298+
start_time = time.time()
298299

299300
# For ewma delay, only want first component, the one-way delay
300301
# For the cwnd, try only the most recent cwnd

dagger/run_sender.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#!/usr/bin/env python
2+
3+
import argparse
4+
import project_root
5+
import numpy as np
6+
import tensorflow as tf
7+
from os import path
8+
from env.sender import Sender
9+
from models import DaggerNetwork
10+
from helpers.helpers import ewma
11+
12+
13+
def softmax(x):
14+
e_x = np.exp(x - np.max(x))
15+
return e_x / e_x.sum(axis=0)
16+
17+
18+
class Learner(object):
19+
def __init__(self, state_dim, action_cnt, restore_vars):
20+
21+
with tf.variable_scope('local'):
22+
self.pi = DaggerNetwork(state_dim=state_dim, action_cnt=action_cnt)
23+
24+
self.ewma_window = 3 # alpha = 2 / (window + 1)
25+
self.session = tf.Session()
26+
27+
# restore saved variables
28+
saver = tf.train.Saver(self.pi.trainable_vars)
29+
saver.restore(self.session, restore_vars)
30+
31+
# init the remaining vars, especially those created by optimizer
32+
uninit_vars = set(tf.global_variables()) - set(self.pi.trainable_vars)
33+
self.session.run(tf.variables_initializer(uninit_vars))
34+
35+
def sample_action(self, step_state_buf):
36+
37+
# For ewma delay, only want first component, the one-way delay
38+
# For the cwnd, try only the most recent cwnd
39+
owd_buf = np.asarray([state[0] for state in step_state_buf])
40+
ewma_delay = ewma(owd_buf, self.ewma_window)
41+
last_cwnd = step_state_buf[-1][1]
42+
43+
# Get probability of each action from the local network.
44+
pi = self.local_network
45+
action_probs = self.sess.run(pi.action_probs,
46+
feed_dict={pi.states: [[ewma_delay,
47+
last_cwnd]]})
48+
49+
# action = np.argmax(action_probs[0])
50+
# action = np.argmax(np.random.multinomial(1, action_probs[0] - 1e-5))
51+
temperature = 1.0
52+
temp_probs = softmax(action_probs[0] / temperature)
53+
action = np.argmax(np.random.multinomial(1, temp_probs - 1e-5))
54+
return action
55+
56+
57+
def main():
58+
parser = argparse.ArgumentParser()
59+
parser.add_argument('port', type=int)
60+
args = parser.parse_args()
61+
62+
sender = Sender(args.port)
63+
64+
model_path = path.join(project_root.DIR, 'dagger', 'logs',
65+
'2017-07-31--06-32-01-true-expert-2',
66+
'checkpoint-1100')
67+
68+
learner = Learner(
69+
state_dim=Sender.state_dim,
70+
action_cnt=Sender.action_cnt,
71+
restore_vars=model_path)
72+
73+
sender.set_sample_action(learner.sample_action)
74+
75+
try:
76+
sender.handshake()
77+
sender.run()
78+
except KeyboardInterrupt:
79+
pass
80+
finally:
81+
sender.cleanup()
82+
83+
84+
if __name__ == '__main__':
85+
main()

env/sender.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class Sender(object):
3232
action_cnt = len(action_mapping)
3333

3434
def __init__(self, port=0, train=False, debug=False):
35+
self.step_time_file = open('/tmp/step_time', 'a')
36+
3537
self.train = train
3638
self.debug = debug
3739

@@ -183,6 +185,7 @@ def recv(self):
183185
self.step_start_ms = curr_ts_ms()
184186

185187
if curr_ts_ms() - self.step_start_ms > self.step_len_ms: # step's end
188+
self.step_time_file.write('step length: %f ms\n' % (curr_ts_ms() - self.step_start_ms))
186189
action = self.sample_action(self.step_state_buf)
187190
self.take_action(action)
188191

0 commit comments

Comments
 (0)