Skip to content

Commit a4cb8ef

Browse files
committed
a3c.py: fix self.reward_buf bug
rename send/recv_ts_diff to send/recv_interval add more .gitignore
1 parent d7d7d32 commit a4cb8ef

File tree

4 files changed

+21
-18
lines changed

4 files changed

+21
-18
lines changed

a3c/a3c.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111
def normalize_states(states):
1212
norm_states = np.array(states, dtype=np.float32)
1313

14-
# queuing_delay, target range [0, 200]
15-
queuing_delays = norm_states[:, 0]
16-
queuing_delays /= 100.0
17-
queuing_delays -= 1.0
14+
# queuing_delay, target range [0, 2000]
15+
queuing_delay = norm_states[:, 0]
16+
queuing_delay /= 1000.0
17+
queuing_delay -= 1.0
1818

19-
# send_ts_diff and recv_ts_diff, target range [0, 100]
19+
# send_interval and recv_interval, target range [0, 500]
2020
for i in [1, 2]:
21-
ts_diffs = norm_states[:, i]
22-
ts_diffs /= 50.0
23-
ts_diffs -= 1.0
21+
interval = norm_states[:, i]
22+
interval /= 250.0
23+
interval -= 1.0
2424

2525
# cwnd, target range [0, 100]
2626
cwnd = norm_states[:, 3]
@@ -109,7 +109,7 @@ def build_loss(self):
109109
entropy = -tf.reduce_mean(pi.action_probs * log_action_probs)
110110

111111
# total loss and gradients
112-
loss = policy_loss + 0.5 * value_loss - 0.2 * entropy
112+
loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
113113
grads = tf.gradients(loss, pi.trainable_vars)
114114
grads, _ = tf.clip_by_global_norm(grads, 10.0)
115115

@@ -196,10 +196,10 @@ def rollout(self):
196196
if self.gamma == 1.0:
197197
self.reward_buf = np.full(episode_len, final_reward)
198198
else:
199-
reward_buf = np.zeros(episode_len)
200-
reward_buf[-1] = final_reward
199+
self.reward_buf = np.zeros(episode_len)
200+
self.reward_buf[-1] = final_reward
201201
for i in reversed(xrange(episode_len - 1)):
202-
reward_buf[i] = reward_buf[i + 1] * self.gamma
202+
self.reward_buf[i] = self.reward_buf[i + 1] * self.gamma
203203

204204
# compute advantages
205205
self.adv_buf = self.reward_buf - np.asarray(self.value_buf)

a3c/worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def prepare_traces(bandwidth):
2626
uplink_trace = path.join(trace_dir, '%dmbps.trace' % bandwidth)
2727
downlink_trace = uplink_trace
2828
else:
29-
trace_path = '/usr/share/mahimahi/traces/' + bandwidth
29+
trace_path = path.join(trace_dir, bandwidth)
3030
# intentionally switch uplink and downlink traces due to sender first
3131
uplink_trace = trace_path + '.down'
3232
downlink_trace = trace_path + '.up'
@@ -35,9 +35,9 @@ def prepare_traces(bandwidth):
3535

3636

3737
def create_env(task_index):
38-
bandwidth = 12 # or 'Verizon-LTE-driving'
38+
bandwidth = 12
3939
delay = 20
40-
queue = 200 # or None
40+
queue = 200
4141

4242
uplink_trace, downlink_trace = prepare_traces(bandwidth)
4343
mm_cmd = ('mm-delay %d mm-link %s %s' %

env/sender.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def update_state(self, ack):
133133
self.prev_send_ts = send_ts
134134
self.prev_recv_ts = recv_ts
135135

136-
send_ts_diff = send_ts - self.prev_send_ts
137-
recv_ts_diff = recv_ts - self.prev_recv_ts
136+
send_interval = send_ts - self.prev_send_ts
137+
recv_interval = recv_ts - self.prev_recv_ts
138138
self.prev_send_ts = send_ts
139139
self.prev_recv_ts = recv_ts
140140

@@ -152,7 +152,8 @@ def update_state(self, ack):
152152
if curr_ts_ms() - self.runtime_start > self.max_runtime:
153153
self.running = False
154154

155-
return [queuing_delay, send_ts_diff, recv_ts_diff, self.cwnd]
155+
state = [queuing_delay, send_interval, recv_interval, self.cwnd]
156+
return state
156157

157158
def take_action(self, action):
158159
self.cwnd += self.action_mapping[action]

helpers/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.trace
2+
TABLE

0 commit comments

Comments
 (0)