1
1
import numpy as np
2
2
import pickle
3
3
import os .path
4
+ import csv
4
5
5
6
import torch
6
7
import torch .optim as optim
7
8
8
9
9
10
class BaseAgent (object ):
10
- def __init__ (self ):
11
+ def __init__ (self , config , env , log_dir = '/tmp/gym' ):
11
12
self .model = None
12
13
self .target_model = None
13
14
self .optimizer = None
14
- self .losses = []
15
+
16
+ self .td_file = open (os .path .join (log_dir , 'td.csv' ), 'a' )
17
+ self .td = csv .writer (self .td_file )
18
+
19
+ self .sigma_parameter_mag_file = open (os .path .join (log_dir , 'sig_param_mag.csv' ), 'a' )
20
+ self .sigma_parameter_mag = csv .writer (self .sigma_parameter_mag_file )
21
+
15
22
self .rewards = []
16
- self .sigma_parameter_mag = []
23
+
24
+ self .action_log_frequency = config .ACTION_SELECTION_COUNT_FREQUENCY
25
+ self .action_selections = [0 for _ in range (env .action_space .n )]
26
+ self .action_log_file = open (os .path .join (log_dir , 'action_log.csv' ), 'a' )
27
+ self .action_log = csv .writer (self .action_log_file )
28
+
29
+ def __del__ (self ):
30
+ self .td_file .close ()
31
+ self .sigma_parameter_mag_file .close ()
32
+ self .action_log_file .close ()
17
33
18
34
def huber (self , x ):
19
35
cond = (x .abs () < 1.0 ).float ().detach ()
@@ -45,7 +61,7 @@ def load_replay(self):
45
61
if os .path .isfile (fname ):
46
62
self .memory = pickle .load (open (fname , 'rb' ))
47
63
48
- def save_sigma_param_magnitudes (self ):
64
+ def save_sigma_param_magnitudes (self , tstep ):
49
65
with torch .no_grad ():
50
66
sum_ , count = 0.0 , 0.0
51
67
for name , param in self .model .named_parameters ():
@@ -54,10 +70,21 @@ def save_sigma_param_magnitudes(self):
54
70
count += np .prod (param .shape )
55
71
56
72
if count > 0 :
57
- self .sigma_parameter_mag .append ( sum_ / count )
73
+ self .sigma_parameter_mag .writerow (( tstep , sum_ / count ) )
58
74
59
- def save_loss (self , loss ):
60
- self .losses . append ( loss )
75
+ def save_td (self , td , tstep ):
76
+ self .td . writerow (( tstep , td ) )
61
77
62
78
def save_reward (self , reward ):
63
- self .rewards .append (reward )
79
+ self .rewards .append (reward )
80
+
81
+ def save_action (self , action , tstep ):
82
+ self .action_selections [int (action )] += 1.0 / self .action_log_frequency
83
+ if (tstep + 1 ) % self .action_log_frequency == 0 :
84
+ self .action_log .writerow (list ([tstep ]+ self .action_selections ))
85
+ self .action_selections = [0 for _ in range (len (self .action_selections ))]
86
+
87
+ def flush_data (self ):
88
+ self .action_log_file .flush ()
89
+ self .sigma_parameter_mag_file .flush ()
90
+ self .td_file .flush ()
0 commit comments