11import numpy as np
22import pickle
33import os .path
4+ import csv
45
56import torch
67import torch .optim as optim
78
89
910class BaseAgent (object ):
10- def __init__ (self ):
11+ def __init__ (self , config , env , log_dir = '/tmp/gym' ):
1112 self .model = None
1213 self .target_model = None
1314 self .optimizer = None
14- self .losses = []
15+
16+ self .td_file = open (os .path .join (log_dir , 'td.csv' ), 'a' )
17+ self .td = csv .writer (self .td_file )
18+
19+ self .sigma_parameter_mag_file = open (os .path .join (log_dir , 'sig_param_mag.csv' ), 'a' )
20+ self .sigma_parameter_mag = csv .writer (self .sigma_parameter_mag_file )
21+
1522 self .rewards = []
16- self .sigma_parameter_mag = []
23+
24+ self .action_log_frequency = config .ACTION_SELECTION_COUNT_FREQUENCY
25+ self .action_selections = [0 for _ in range (env .action_space .n )]
26+ self .action_log_file = open (os .path .join (log_dir , 'action_log.csv' ), 'a' )
27+ self .action_log = csv .writer (self .action_log_file )
28+
29+ def __del__ (self ):
30+ self .td_file .close ()
31+ self .sigma_parameter_mag_file .close ()
32+ self .action_log_file .close ()
1733
1834 def huber (self , x ):
1935 cond = (x .abs () < 1.0 ).float ().detach ()
@@ -45,7 +61,7 @@ def load_replay(self):
4561 if os .path .isfile (fname ):
4662 self .memory = pickle .load (open (fname , 'rb' ))
4763
48- def save_sigma_param_magnitudes (self ):
64+ def save_sigma_param_magnitudes (self , tstep ):
4965 with torch .no_grad ():
5066 sum_ , count = 0.0 , 0.0
5167 for name , param in self .model .named_parameters ():
@@ -54,10 +70,21 @@ def save_sigma_param_magnitudes(self):
5470 count += np .prod (param .shape )
5571
5672 if count > 0 :
57- self .sigma_parameter_mag .append ( sum_ / count )
73+ self .sigma_parameter_mag .writerow (( tstep , sum_ / count ) )
5874
59- def save_loss (self , loss ):
60- self .losses . append ( loss )
75+ def save_td (self , td , tstep ):
76+ self .td . writerow (( tstep , td ) )
6177
6278 def save_reward (self , reward ):
63- self .rewards .append (reward )
79+ self .rewards .append (reward )
80+
81+ def save_action (self , action , tstep ):
82+ self .action_selections [int (action )] += 1.0 / self .action_log_frequency
83+ if (tstep + 1 ) % self .action_log_frequency == 0 :
84+ self .action_log .writerow (list ([tstep ]+ self .action_selections ))
85+ self .action_selections = [0 for _ in range (len (self .action_selections ))]
86+
87+ def flush_data (self ):
88+ self .action_log_file .flush ()
89+ self .sigma_parameter_mag_file .flush ()
90+ self .td_file .flush ()
0 commit comments