Skip to content

Commit 40a4255

Browse files
committed
Train a universal Indigo
1 parent 1d4af64 commit 40a4255

File tree

4 files changed

+27
-25
lines changed

4 files changed

+27
-25
lines changed

dagger/dagger.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self, cluster, server, worker_tasks):
2929
self.num_workers = len(worker_tasks)
3030
self.aggregated_states = []
3131
self.aggregated_actions = []
32-
self.max_eps = 500
32+
self.max_eps = 1000
3333
self.checkpoint_delta = 10
3434
self.checkpoint = self.checkpoint_delta
3535
self.learn_rate = 0.01
@@ -56,7 +56,7 @@ def __init__(self, cluster, server, worker_tasks):
5656
self.sync_op = tf.group(*[v1.assign(v2) for v1, v2 in zip(
5757
cpu_vars, gpu_vars)])
5858

59-
self.default_batch_size = 280
59+
self.default_batch_size = 300
6060
self.default_init_state = self.global_network.zero_init_state(
6161
self.default_batch_size)
6262

@@ -248,14 +248,10 @@ def train(self):
248248

249249
self.sess.run(self.global_network.add_one)
250250

251-
print 'DaggerLeader:before sync'
252-
print 'DaggerLeader:global_network:cnt', self.sess.run(self.global_network.cnt)
253-
print 'DaggerLeader:global_network_cpu:cnt', self.sess.run(self.global_network_cpu.cnt)
254-
255251
# copy trained variables from GPU to CPU
256252
self.sess.run(self.sync_op)
257253

258-
print 'DaggerLeader:after sync'
254+
print 'DaggerLeader:global_network:cnt', self.sess.run(self.global_network.cnt)
259255
print 'DaggerLeader:global_network_cpu:cnt', self.sess.run(self.global_network_cpu.cnt)
260256
sys.stdout.flush()
261257

@@ -433,13 +429,9 @@ def run(self, debug=False):
433429
(self.task_idx, self.curr_ep))
434430

435431
# Reset local parameters to global
436-
print 'DaggerWorker:before sync'
437-
print 'DaggerWorker:global_network_cpu:cnt', self.sess.run(self.global_network_cpu.cnt)
438-
print 'DaggerWorker:local_network:cnt', self.sess.run(self.local_network.cnt)
439-
440432
self.sess.run(self.sync_op)
441433

442-
print 'DaggerWorker:after sync'
434+
print 'DaggerWorker:global_network_cpu:cnt', self.sess.run(self.global_network_cpu.cnt)
443435
print 'DaggerWorker:local_network:cnt', self.sess.run(self.local_network.cnt)
444436
sys.stdout.flush()
445437

dagger/worker.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -47,41 +47,51 @@ def create_env(task_index):
4747
best_cwnds_file = path.join(project_root.DIR, 'dagger', 'best_cwnds.yml')
4848
best_cwnd_map = yaml.load(open(best_cwnds_file))
4949

50-
if task_index <= 15:
51-
bandwidth = [10, 20, 50, 100]
50+
if task_index == 0:
51+
trace_path = path.join(project_root.DIR, 'env', '0.57mbps-poisson.trace')
52+
mm_cmd = 'mm-delay 28 mm-loss uplink 0.0477 mm-link %s %s --uplink-queue=droptail --uplink-queue-args=packets=14' % (trace_path, trace_path)
53+
best_cwnd = 5
54+
elif task_index == 1:
55+
trace_path = path.join(project_root.DIR, 'env', '2.64mbps-poisson.trace')
56+
mm_cmd = 'mm-delay 88 mm-link %s %s --uplink-queue=droptail --uplink-queue-args=packets=130' % (trace_path, trace_path)
57+
best_cwnd = 40
58+
elif task_index == 2:
59+
trace_path = path.join(project_root.DIR, 'env', '3.04mbps-poisson.trace')
60+
mm_cmd = 'mm-delay 130 mm-link %s %s --uplink-queue=droptail --uplink-queue-args=packets=426' % (trace_path, trace_path)
61+
best_cwnd = 70
62+
elif task_index <= 22:
63+
bandwidth = [5, 10, 20, 50, 100]
5264
delay = [10, 20, 40, 80]
5365

5466
cartesian = [(b, d) for b in bandwidth for d in delay]
55-
bandwidth, delay = cartesian[task_index]
67+
bandwidth, delay = cartesian[task_index - 3]
5668

5769
uplink_trace, downlink_trace = prepare_traces(bandwidth)
5870
mm_cmd = 'mm-delay %d mm-link %s %s' % (delay, uplink_trace, downlink_trace)
5971
best_cwnd = best_cwnd_map[bandwidth][delay]
60-
61-
elif task_index == 16:
72+
elif task_index == 23:
6273
trace_path = path.join(project_root.DIR, 'env', '100.42mbps.trace')
6374
mm_cmd = 'mm-delay 27 mm-link %s %s --uplink-queue=droptail --uplink-queue-args=packets=173' % (trace_path, trace_path)
6475
best_cwnd = 500
65-
elif task_index == 17:
76+
elif task_index == 24:
6677
trace_path = path.join(project_root.DIR, 'env', '77.72mbps.trace')
6778
mm_cmd = 'mm-delay 51 mm-loss uplink 0.0006 mm-link %s %s --uplink-queue=droptail --uplink-queue-args=packets=94' % (trace_path, trace_path)
6879
best_cwnd = 690
69-
elif task_index == 18:
80+
elif task_index == 25:
7081
trace_path = path.join(project_root.DIR, 'env', '114.68mbps.trace')
7182
mm_cmd = 'mm-delay 45 mm-link %s %s --uplink-queue=droptail --uplink-queue-args=packets=450' % (trace_path, trace_path)
7283
best_cwnd = 870
73-
elif task_index <= 22:
84+
elif task_index <= 29:
7485
bandwidth = [200]
7586
delay = [10, 20, 40, 80]
7687

7788
cartesian = [(b, d) for b in bandwidth for d in delay]
78-
bandwidth, delay = cartesian[task_index - 19]
89+
bandwidth, delay = cartesian[task_index - 26]
7990

8091
uplink_trace, downlink_trace = prepare_traces(bandwidth)
8192
mm_cmd = 'mm-delay %d mm-link %s %s' % (delay, uplink_trace, downlink_trace)
8293
best_cwnd = best_cwnd_map[bandwidth][delay]
8394

84-
8595
env = Environment(mm_cmd)
8696
env.best_cwnd = best_cwnd
8797

env/sender.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Sender(object):
2626
# RL exposed class/static variables
2727
max_steps = 1000
2828
state_dim = 4
29-
action_mapping = format_actions(["/2.0", "-10.0", "+10.0", "*2.0"])
29+
action_mapping = format_actions(["/2.0", "-10.0", "+0.0", "+10.0", "*2.0"])
3030
action_cnt = len(action_mapping)
3131

3232
def __init__(self, port=0, train=False):
@@ -140,7 +140,7 @@ def take_action(self, action_idx):
140140
op, val = self.action_mapping[action_idx]
141141

142142
self.cwnd = apply_op(op, self.cwnd, val)
143-
self.cwnd = max(5.0, self.cwnd)
143+
self.cwnd = max(2.0, self.cwnd)
144144

145145
def window_is_open(self):
146146
return self.seq_num - self.next_ack < self.cwnd

helpers/assistant.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def run_cmd(args, host, procs):
2121

2222
elif cmd == 'git_pull':
2323
cmd_in_ssh = ('cd %s && git fetch --all && '
24-
'git checkout indigo-broad && '
24+
'git checkout indigo-universal && '
2525
'git reset --hard @~1 && git pull' % args.rlcc_dir)
2626

2727
elif cmd == 'rm_history':

0 commit comments

Comments
 (0)