Skip to content

Commit 2fe6e39

Browse files
committed
improved stat collection. At every timestep, collect: Avg. temporal difference, action selection frequency, sigma parameter magnitude. Plot all results. Only Q-Learning based methods support this improved plotting and stat collection for now.
1 parent a4cba58 commit 2fe6e39

19 files changed

+330
-56
lines changed

01.DQN.ipynb

+4-4
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@
169169
"outputs": [],
170170
"source": [
171171
"class Model(BaseAgent):\n",
172-
" def __init__(self, static_policy=False, env=None, config=None):\n",
173-
" super(Model, self).__init__()\n",
172+
" def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):\n",
173+
" super(Model, self).__init__(config=config, env=env, log_dir=log_dir)\n",
174174
" self.device = config.device\n",
175175
"\n",
176176
" self.gamma = config.GAMMA\n",
@@ -279,8 +279,8 @@
279279
" self.optimizer.step()\n",
280280
"\n",
281281
" self.update_target_model()\n",
282-
" self.save_loss(loss.item())\n",
283-
" self.save_sigma_param_magnitudes()\n",
282+
" self.save_loss(loss.item(), frame)\n",
283+
" self.save_sigma_param_magnitudes(frame)\n",
284284
"\n",
285285
" def get_action(self, s, eps=0.1):\n",
286286
" with torch.no_grad():\n",

02.NStep_DQN.ipynb

+4-4
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@
9898
"outputs": [],
9999
"source": [
100100
"class Model(BaseAgent):\n",
101-
" def __init__(self, static_policy=False, env=None, config=None):\n",
102-
" super(Model, self).__init__()\n",
101+
" def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):\n",
102+
" super(Model, self).__init__(config=config, env=env, log_dir=log_dir)\n",
103103
" self.device = config.device\n",
104104
"\n",
105105
" self.gamma = config.GAMMA\n",
@@ -219,8 +219,8 @@
219219
" self.optimizer.step()\n",
220220
"\n",
221221
" self.update_target_model()\n",
222-
" self.save_loss(loss.item())\n",
223-
" self.save_sigma_param_magnitudes()\n",
222+
" self.save_loss(loss.item(), frame)\n",
223+
" self.save_sigma_param_magnitudes(frame)\n",
224224
"\n",
225225
" def get_action(self, s, eps=0.1):\n",
226226
" with torch.no_grad():\n",

12.A2C.ipynb

+5-5
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@
208208
"outputs": [],
209209
"source": [
210210
"class Model(BaseAgent):\n",
211-
" def __init__(self, static_policy=False, env=None, config=None):\n",
212-
" super(Model, self).__init__()\n",
211+
" def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):\n",
212+
" super(Model, self).__init__(config=config, env=env, log_dir=log_dir)\n",
213213
" self.device = config.device\n",
214214
"\n",
215215
" self.noisy=config.USE_NOISY_NETS\n",
@@ -316,16 +316,16 @@
316316
" torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_max)\n",
317317
" self.optimizer.step()\n",
318318
"\n",
319-
" self.save_loss(loss.item(), action_loss.item(), value_loss.item(), dist_entropy.item())\n",
319+
" #self.save_loss(loss.item(), action_loss.item(), value_loss.item(), dist_entropy.item())\n",
320320
" #self.save_sigma_param_magnitudes()\n",
321321
"\n",
322322
" return value_loss.item(), action_loss.item(), dist_entropy.item()\n",
323323
"\n",
324-
" def save_loss(self, loss, policy_loss, value_loss, entropy_loss):\n",
324+
" '''def save_loss(self, loss, policy_loss, value_loss, entropy_loss):\n",
325325
" super(Model, self).save_loss(loss)\n",
326326
" self.policy_losses.append(policy_loss)\n",
327327
" self.value_losses.append(value_loss)\n",
328-
" self.entropy_losses.append(entropy_loss)"
328+
" self.entropy_losses.append(entropy_loss)'''"
329329
]
330330
},
331331
{

14.PPO.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@
250250
" dist_entropy_epoch /= (self.ppo_epoch * self.num_mini_batch)\n",
251251
" total_loss = value_loss_epoch + action_loss_epoch + dist_entropy_epoch\n",
252252
"\n",
253-
" self.save_loss(total_loss, action_loss_epoch, value_loss_epoch, dist_entropy_epoch)\n",
253+
" #self.save_loss(total_loss, action_loss_epoch, value_loss_epoch, dist_entropy_epoch)\n",
254254
"\n",
255255
" return action_loss_epoch, value_loss_epoch, dist_entropy_epoch"
256256
]

agents/A2C.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from timeit import default_timer as timer
1313

1414
class Model(BaseAgent):
15-
def __init__(self, static_policy=False, env=None, config=None):
16-
super(Model, self).__init__()
15+
def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):
16+
super(Model, self).__init__(config=config, env=env, log_dir=log_dir)
1717
self.device = config.device
1818

1919
self.noisy=config.USE_NOISY_NETS
@@ -120,13 +120,13 @@ def update(self, rollout):
120120
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm_max)
121121
self.optimizer.step()
122122

123-
self.save_loss(loss.item(), action_loss.item(), value_loss.item(), dist_entropy.item())
123+
#self.save_loss(loss.item(), action_loss.item(), value_loss.item(), dist_entropy.item())
124124
#self.save_sigma_param_magnitudes()
125125

126126
return value_loss.item(), action_loss.item(), dist_entropy.item()
127127

128-
def save_loss(self, loss, policy_loss, value_loss, entropy_loss):
129-
super(Model, self).save_loss(loss)
128+
'''def save_loss(self, loss, policy_loss, value_loss, entropy_loss):
129+
super(Model, self).save_td(loss)
130130
self.policy_losses.append(policy_loss)
131131
self.value_losses.append(value_loss)
132-
self.entropy_losses.append(entropy_loss)
132+
self.entropy_losses.append(entropy_loss)'''

agents/BaseAgent.py

+35-8
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,35 @@
11
import numpy as np
22
import pickle
33
import os.path
4+
import csv
45

56
import torch
67
import torch.optim as optim
78

89

910
class BaseAgent(object):
10-
def __init__(self):
11+
def __init__(self, config, env, log_dir='/tmp/gym'):
1112
self.model=None
1213
self.target_model=None
1314
self.optimizer = None
14-
self.losses = []
15+
16+
self.td_file = open(os.path.join(log_dir, 'td.csv'), 'a')
17+
self.td = csv.writer(self.td_file)
18+
19+
self.sigma_parameter_mag_file = open(os.path.join(log_dir, 'sig_param_mag.csv'), 'a')
20+
self.sigma_parameter_mag = csv.writer(self.sigma_parameter_mag_file)
21+
1522
self.rewards = []
16-
self.sigma_parameter_mag=[]
23+
24+
self.action_log_frequency = config.ACTION_SELECTION_COUNT_FREQUENCY
25+
self.action_selections = [0 for _ in range(env.action_space.n)]
26+
self.action_log_file = open(os.path.join(log_dir, 'action_log.csv'), 'a')
27+
self.action_log = csv.writer(self.action_log_file)
28+
29+
def __del__(self):
30+
self.td_file.close()
31+
self.sigma_parameter_mag_file.close()
32+
self.action_log_file.close()
1733

1834
def huber(self, x):
1935
cond = (x.abs() < 1.0).float().detach()
@@ -45,7 +61,7 @@ def load_replay(self):
4561
if os.path.isfile(fname):
4662
self.memory = pickle.load(open(fname, 'rb'))
4763

48-
def save_sigma_param_magnitudes(self):
64+
def save_sigma_param_magnitudes(self, tstep):
4965
with torch.no_grad():
5066
sum_, count = 0.0, 0.0
5167
for name, param in self.model.named_parameters():
@@ -54,10 +70,21 @@ def save_sigma_param_magnitudes(self):
5470
count += np.prod(param.shape)
5571

5672
if count > 0:
57-
self.sigma_parameter_mag.append(sum_/count)
73+
self.sigma_parameter_mag.writerow((tstep, sum_/count))
5874

59-
def save_loss(self, loss):
60-
self.losses.append(loss)
75+
def save_td(self, td, tstep):
76+
self.td.writerow((tstep, td))
6177

6278
def save_reward(self, reward):
63-
self.rewards.append(reward)
79+
self.rewards.append(reward)
80+
81+
def save_action(self, action, tstep):
82+
self.action_selections[int(action)] += 1.0/self.action_log_frequency
83+
if (tstep+1) % self.action_log_frequency == 0:
84+
self.action_log.writerow(list([tstep]+self.action_selections))
85+
self.action_selections = [0 for _ in range(len(self.action_selections))]
86+
87+
def flush_data(self):
88+
self.action_log_file.flush()
89+
self.sigma_parameter_mag_file.flush()
90+
self.td_file.flush()

agents/Categorical_DQN.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88

99

1010
class Model(DQN_Agent):
11-
def __init__(self, static_policy=False, env=None, config=None):
11+
def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):
1212
self.atoms = config.ATOMS
1313
self.v_max = config.V_MAX
1414
self.v_min = config.V_MIN
1515
self.supports = torch.linspace(self.v_min, self.v_max, self.atoms).view(1, 1, self.atoms).to(config.device)
1616
self.delta = (self.v_max - self.v_min) / (self.atoms - 1)
1717

18-
super(Model, self).__init__(static_policy, env, config)
18+
super(Model, self).__init__(static_policy, env, config, log_dir=log_dir)
1919

2020
def declare_networks(self):
2121
self.model = CategoricalDQN(self.env.observation_space.shape, self.env.action_space.n, noisy=self.noisy, sigma_init=self.sigma_init, atoms=self.atoms)

agents/DQN.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from timeit import default_timer as timer
1212

1313
class Model(BaseAgent):
14-
def __init__(self, static_policy=False, env=None, config=None):
15-
super(Model, self).__init__()
14+
def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):
15+
super(Model, self).__init__(config=config, env=env, log_dir=log_dir)
1616
self.device = config.device
1717

1818
self.noisy=config.USE_NOISY_NETS
@@ -145,8 +145,8 @@ def update(self, s, a, r, s_, frame=0):
145145
self.optimizer.step()
146146

147147
self.update_target_model()
148-
self.save_loss(loss.item())
149-
self.save_sigma_param_magnitudes()
148+
self.save_td(loss.item(), frame)
149+
self.save_sigma_param_magnitudes(frame)
150150

151151
def get_action(self, s, eps=0.1): #faster
152152
with torch.no_grad():

agents/DRQN.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from networks.network_bodies import AtariBody, SimpleBody
1111

1212
class Model(DQN_Agent):
13-
def __init__(self, static_policy=False, env=None, config=None):
13+
def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):
1414
self.sequence_length=config.SEQUENCE_LENGTH
1515

16-
super(Model, self).__init__(static_policy, env, config)
16+
super(Model, self).__init__(static_policy, env, config, log_dir=log_dir)
1717

1818
self.reset_hx()
1919

agents/Double_DQN.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from agents.DQN import Model as DQN_Agent
66

77
class Model(DQN_Agent):
8-
def __init__(self, static_policy=False, env=None, config=None):
9-
super(Model, self).__init__(static_policy, env, config)
8+
def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):
9+
super(Model, self).__init__(static_policy, env, config, log_dir=log_dir)
1010

1111
def get_max_next_state_action(self, next_states):
1212
return self.model(next_states).max(dim=1)[1].view(-1, 1)

agents/Dueling_DQN.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from networks.networks import DuelingDQN
77

88
class Model(DQN_Agent):
9-
def __init__(self, static_policy=False, env=None, config=None):
10-
super(Model, self).__init__(static_policy, env, config)
9+
def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):
10+
super(Model, self).__init__(static_policy, env, config, log_dir=log_dir)
1111

1212
def declare_networks(self):
1313
self.model = DuelingDQN(self.env.observation_space.shape, self.env.action_space.n, noisy=self.noisy, sigma_init=self.sigma_init)

agents/PPO.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def update(self, rollout):
7171
dist_entropy_epoch /= (self.ppo_epoch * self.num_mini_batch)
7272
total_loss = value_loss_epoch + action_loss_epoch + dist_entropy_epoch
7373

74-
self.save_loss(total_loss, action_loss_epoch, value_loss_epoch, dist_entropy_epoch)
74+
#self.save_loss(total_loss, action_loss_epoch, value_loss_epoch, dist_entropy_epoch)
7575
#self.save_sigma_param_magnitudes()
7676

7777
return action_loss_epoch, value_loss_epoch, dist_entropy_epoch

agents/QuantileRegression_DQN.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from networks.networks import QRDQN
88

99
class Model(DQN_Agent):
10-
def __init__(self, static_policy=False, env=None, config=None):
10+
def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):
1111
self.num_quantiles = config.QUANTILES
1212
self.cumulative_density = torch.tensor((2 * np.arange(self.num_quantiles) + 1) / (2.0 * self.num_quantiles), device=config.device, dtype=torch.float)
1313
self.quantile_weight = 1.0 / self.num_quantiles
1414

15-
super(Model, self).__init__(static_policy, env, config)
15+
super(Model, self).__init__(static_policy, env, config, log_dir=log_dir)
1616

1717

1818
def declare_networks(self):

agents/Quantile_Rainbow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from utils.ReplayMemory import PrioritizedReplayMemory
99

1010
class Model(DQN_Agent):
11-
def __init__(self, static_policy=False, env=None, config=None):
11+
def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):
1212
self.num_quantiles = config.QUANTILES
1313
self.cumulative_density = torch.tensor((2 * np.arange(self.num_quantiles) + 1) / (2.0 * self.num_quantiles), device=config.device, dtype=torch.float)
1414
self.quantile_weight = 1.0 / self.num_quantiles
1515

16-
super(Model, self).__init__(static_policy, env, config)
16+
super(Model, self).__init__(static_policy, env, config, log_dir=log_dir)
1717

1818
self.nsteps=max(self.nsteps, 3)
1919

agents/Rainbow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
from utils.ReplayMemory import PrioritizedReplayMemory
88

99
class Model(DQN_Agent):
10-
def __init__(self, static_policy=False, env=None, config=None):
10+
def __init__(self, static_policy=False, env=None, config=None, log_dir='/tmp/gym'):
1111
self.atoms=config.ATOMS
1212
self.v_max=config.V_MAX
1313
self.v_min=config.V_MIN
1414
self.supports = torch.linspace(self.v_min, self.v_max, self.atoms).view(1, 1, self.atoms).to(config.device)
1515
self.delta = (self.v_max - self.v_min) / (self.atoms - 1)
1616

17-
super(Model, self).__init__(static_policy, env, config)
17+
super(Model, self).__init__(static_policy, env, config, log_dir=log_dir)
1818

1919
self.nsteps=max(self.nsteps,3)
2020

dqn_devel.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from utils.wrappers import *
1010
from utils.hyperparameters import Config
1111
from agents.DQN import Model
12-
from utils.plot import plot_reward
12+
from utils.plot import plot_all_data
1313

1414
config = Config()
1515

@@ -58,14 +58,20 @@
5858
#DRQN Parameters
5959
config.SEQUENCE_LENGTH = 8
6060

61+
#data logging parameters
62+
config.ACTION_SELECTION_COUNT_FREQUENCY = 1000
63+
6164
if __name__=='__main__':
6265
start=timer()
6366

6467
log_dir = "/tmp/gym/"
6568
try:
6669
os.makedirs(log_dir)
6770
except OSError:
68-
files = glob.glob(os.path.join(log_dir, '*.monitor.csv'))
71+
files = glob.glob(os.path.join(log_dir, '*.monitor.csv')) \
72+
+ glob.glob(os.path.join(log_dir, '*td.csv')) \
73+
+ glob.glob(os.path.join(log_dir, '*sig_param_mag.csv')) \
74+
+ glob.glob(os.path.join(log_dir, '*action_log.csv'))
6975
for f in files:
7076
os.remove(f)
7177

@@ -74,7 +80,7 @@
7480
env = bench.Monitor(env, os.path.join(log_dir, env_id))
7581
env = wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=True)
7682
env = WrapPyTorch(env)
77-
model = Model(env=env, config=config)
83+
model = Model(env=env, config=config, log_dir=log_dir)
7884

7985
episode_reward = 0
8086

@@ -83,6 +89,8 @@
8389
epsilon = config.epsilon_by_frame(frame_idx)
8490

8591
action = model.get_action(observation, epsilon)
92+
model.save_action(action, frame_idx) #log action selection
93+
8694
prev_observation=observation
8795
observation, reward, done, _ = env.step(action)
8896
observation = None if done else observation
@@ -100,7 +108,8 @@
100108
if frame_idx % 10000 == 0:
101109
try:
102110
print('frame %s. time: %s' % (frame_idx, timedelta(seconds=int(timer()-start))))
103-
plot_reward(log_dir, env_id, 'DRQN', config.MAX_FRAMES, bin_size=10, smooth=1, time=timedelta(seconds=int(timer()-start)), ipynb=False)
111+
model.flush_data() #make sure all data is flushed to files
112+
plot_all_data(log_dir, env_id, 'DRQN', config.MAX_FRAMES, bin_size=(10, 100, 100, 1), smooth=1, time=timedelta(seconds=int(timer()-start)), ipynb=False)
104113
except IOError:
105114
pass
106115

results.png

86.8 KB
Loading

utils/hyperparameters.py

+3
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def __init__(self):
6464
#DRQN Parameters
6565
self.SEQUENCE_LENGTH=8
6666

67+
#data logging parameters
68+
self.ACTION_SELECTION_COUNT_FREQUENCY = 1000
69+
6770

6871
'''
6972

0 commit comments

Comments
 (0)