-
Notifications
You must be signed in to change notification settings - Fork 1
/
DQN.py
150 lines (122 loc) · 5.91 KB
/
DQN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from typing import Optional
import gymnasium
import numpy as np
import torch
from Architectures import make_atari_nature_cnn, make_mlp
from BaseAgent import BaseAgent, get_new_params
from utils import polyak
class DQN(BaseAgent):
def __init__(self,
*args,
gamma: float = 0.99,
minimum_epsilon: float = 0.05,
exploration_fraction: float = 0.5,
initial_epsilon: float = 1.0,
use_target_network: bool = False,
target_update_interval: Optional[int] = None,
polyak_tau: Optional[float] = None,
architecture_kwargs: dict = {},
**kwargs,
):
super().__init__(*args, **kwargs)
self.kwargs = get_new_params(self, locals())
self.algo_name = 'SQL'
self.gamma = gamma
self.minimum_epsilon = minimum_epsilon
self.exploration_fraction = exploration_fraction
self.initial_epsilon = initial_epsilon
self.epsilon = initial_epsilon
self.use_target_network = use_target_network
self.target_update_interval = target_update_interval
self.polyak_tau = polyak_tau
self.nA = self.env.action_space.n
self.log_hparams(self.kwargs)
self.online_qs = self.architecture(**architecture_kwargs)
self.model = self.online_qs
if self.use_target_network:
# Make another instance of the architecture for the target network:
self.target_qs = self.architecture(**architecture_kwargs)
self.target_qs.load_state_dict(self.online_qs.state_dict())
if polyak_tau is not None:
assert 0 <= polyak_tau <= 1, "Polyak tau must be in the range [0, 1]."
self.polyak_tau = polyak_tau
else:
print("WARNING: No polyak tau specified for soft target updates. Using default tau=1 for hard updates.")
self.polyak_tau = 1.0
if target_update_interval is None:
print("WARNING: Target network update interval not specified. Using default interval of 1 step.")
self.target_update_interval = 1
# Alias the "target" with online net if target is not used:
else:
self.target_qs = self.online_qs
# Raise a warning if update interval is specified:
if target_update_interval is not None:
print("WARNING: Target network update interval specified but target network is not used.")
# Make (all) qs learnable:
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
def _on_step(self) -> None:
super()._on_step()
# Update epsilon:
self.epsilon = max(self.minimum_epsilon, (self.initial_epsilon - self.learn_env_steps / self.total_timesteps / self.exploration_fraction))
if self.learn_env_steps % self.log_interval == 0:
self.log_history("train/epsilon", self.epsilon, self.learn_env_steps)
# Periodically update the target network:
if self.use_target_network and self.learn_env_steps % self.target_update_interval == 0:
# Use Polyak averaging as specified:
polyak(self.online_qs, self.target_qs, self.polyak_tau)
def exploration_policy(self, state: np.ndarray) -> int:
if np.random.rand() < self.epsilon:
return self.env.action_space.sample()
else:
return self.evaluation_policy(state)
def evaluation_policy(self, state: np.ndarray) -> int:
# Get the greedy action from the q values:
qvals = self.online_qs(state)
qvals = qvals.squeeze()
return torch.argmax(qvals).item()
def calculate_loss(self, batch):
states, actions, rewards, next_states, dones = batch
actions = actions.long()
dones = dones.float()
curr_q = self.online_qs(states).squeeze().gather(1, actions.long())
with torch.no_grad():
if isinstance(self.env.observation_space, gymnasium.spaces.Discrete):
states = states.squeeze()
next_states = next_states.squeeze()
next_qs = self.target_qs(next_states)
next_v = torch.max(next_qs, dim=-1).values
next_v = next_v.reshape(-1, 1)
# Backup equation:
expected_curr_q = rewards + self.gamma * next_v * (1-dones)
# Calculate the q ("critic") loss:
loss = 0.5*torch.nn.functional.mse_loss(curr_q, expected_curr_q)
self.log_history("train/online_q_mean", curr_q.mean().item(), self.learn_env_steps)
# log the loss:
logger.log_history("train/loss", loss.item(), self.learn_env_steps)
return loss
if __name__ == '__main__':
import gymnasium as gym
env = 'ALE/Pong-v5'
from Logger import WandBLogger, TensorboardLogger
logger = TensorboardLogger('logs/atari')
#logger = WandBLogger(entity='jacobhadamczyk', project='test')
# mlp = make_mlp(env.unwrapped.observation_space.shape[0], env.unwrapped.action_space.n, hidden_dims=[32, 32])#, activation=torch.nn.Mish)
# cnn = make_atari_nature_cnn(gym.make(env).action_space.n)
env = 'CartPole-v1'
agent = DQN(env,
architecture=make_mlp,
architecture_kwargs={'input_dim': gym.make(env).observation_space.shape[0],
'output_dim': gym.make(env).action_space.n,
'hidden_dims': [64, 64]},
loggers=(logger,),
learning_rate=0.001,
train_interval=1,
gradient_steps=1,
batch_size=64,
use_target_network=True,
target_update_interval=10,
polyak_tau=1.0,
learning_starts=1000,
log_interval=500,
)
agent.learn(total_timesteps=60_000)