Skip to content

Commit a9b2060

Browse files
committed
sender.py: add --debug
1 parent e064e65 commit a9b2060

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

dagger/run_sender.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ def sample_action(self, state):
6262
def main():
6363
parser = argparse.ArgumentParser()
6464
parser.add_argument('port', type=int)
65+
parser.add_argument('--debug', action='store_true')
6566
args = parser.parse_args()
6667

67-
sender = Sender(args.port)
68+
sender = Sender(args.port, debug=args.debug)
6869

6970
model_path = path.join(project_root.DIR, 'dagger', 'model', 'model')
7071

env/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
*.trace
22
!12mbps.trace
3+
sampling_time

env/sender.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ class Sender(object):
3030
action_mapping = format_actions(["/2.0", "-10.0", "+0.0", "+10.0", "*2.0"])
3131
action_cnt = len(action_mapping)
3232

33-
def __init__(self, port=0, train=False):
33+
def __init__(self, port=0, train=False, debug=False):
3434
self.train = train
35+
self.debug = debug
3536

3637
# UDP socket and poller
3738
self.peer_addr = None
@@ -47,8 +48,9 @@ def __init__(self, port=0, train=False):
4748

4849
self.dummy_payload = 'x' * 1400
4950

50-
self.sampling_file = open(path.join(project_root.DIR, 'env', 'sampling_time'), 'w', 0)
51-
51+
if self.debug:
52+
self.sampling_file = open(path.join(project_root.DIR, 'env', 'sampling_time'), 'w', 0)
53+
5254
# congestion control related
5355
self.seq_num = 0
5456
self.next_ack = 0
@@ -75,7 +77,8 @@ def __init__(self, port=0, train=False):
7577
self.rtt_buf = []
7678

7779
def cleanup(self):
78-
self.sampling_file.close()
80+
if self.debug and self.sampling_file:
81+
self.sampling_file.close()
7982
self.sock.close()
8083

8184
def handshake(self):
@@ -186,9 +189,13 @@ def recv(self):
186189
self.cwnd]
187190

188191
# time how long it takes to get an action from the NN
189-
start_sample = time.time()
192+
if self.debug:
193+
start_sample = time.time()
194+
190195
action = self.sample_action(state)
191-
self.sampling_file.write('%.2f ms\n' % ((time.time() - start_sample) * 1000))
196+
197+
if self.debug:
198+
self.sampling_file.write('%.2f ms\n' % ((time.time() - start_sample) * 1000))
192199

193200
self.take_action(action)
194201

0 commit comments

Comments
 (0)