-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsingle-agent-lun.py
40 lines (34 loc) · 954 Bytes
/
single-agent-lun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gym
import torch
import matplotlib.pyplot as plt
import numpy as np
from pytorch.DQN import Agent
from pytorch.QNetwork import FCQ
from pytorch.ReplayBuffer import ReplayBuffer
if __name__ == "__main__":
args = {
"env_fn": lambda : gym.make("LunarLander-v2"),
"Qnet": FCQ,
"buffer": ReplayBuffer,
"net_args": {
"hidden_layers":(512, 256, 128),
"activation_fn":torch.nn.functional.relu,
"optimizer":torch.optim.Adam,
"learning_rate":0.0005,
},
"max_epsilon": 1.0,
"min_epsilon": 0.1,
"decay_steps": 5000,
"gamma": 0.99,
"target_update_rate": 15,
"min_buffer": 64
}
rewards = np.zeros(200)
for i in range(10):
agent = Agent(**args)
agent.train(200)
print(agent.step_count)
rewards += agent.rewards
plt.plot(rewards/10)
# plt.plot(evals)
plt.show()