-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualise.py
290 lines (249 loc) · 9.73 KB
/
visualise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import torch
import numpy as np
import tianshou as ts
import pprint
from tianshou.utils import WandbLogger
from tianshou.env import SubprocVectorEnv
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.trainer import offpolicy_trainer, onpolicy_trainer
from torch.serialization import save
from agents import TwoAgentPolicy
from agents.lib_agents import *
from utils.envs import make_envs, MakeEnv
from utils.config import puck_params, bar_params, env_params
import argparse
import os
# Maps algorithms to their respective classes
algo_mapping = {
"sine": SinePolicy,
"random": RandomPolicy,
"greedy": GreedyPolicy,
"smurve": SmurvePolicy,
"dqn": DQN,
"sac": SAC,
"ppo": PPO,
"ddpg": DDPG,
"td3": TD3,
}
# Global variables for policy ang arguments
policy = None
args = None
def train_fn(epoch, env_step):
"""Hook called at beginning of training
Args:
epoch (int): Epoch number
env_step (int): Step number in the environment
"""
tot_steps = args.epoch * args.step_per_epoch
if args.eps_train_decay == "const":
eps = args.eps_train_final
elif args.eps_train_decay == "lin":
eps = args.eps_train - (env_step / tot_steps) * (
args.eps_train - args.eps_train_final
)
elif args.eps_train_decay == "exp":
eps = args.eps_train * (
(args.eps_train_final / args.eps_train) ** (env_step / tot_steps)
)
policy.set_eps(eps)
def test_fn(epoch, env_step):
"""Test hook function called while training
Args:
epoch (int): Epoch number
env_step (int): Step number in the environment
"""
policy.set_eps(args.eps_test)
def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int):
"""Function hook to save model
Args:
epoch (int): Epoch number
env_step (int): Step number in the environment
gradient_step (int): Step number in the gradient
Returns:
str: Path where to save the model/policy
"""
save_folder = "saved_policies/{}".format(args.run_id)
if not os.path.isdir(save_folder):
os.makedirs(save_folder)
puck_file_path = "{}/puck_{}.pth".format(save_folder, args.puck)
print("saving puck")
torch.save(policy.puck_policy.state_dict(), puck_file_path)
bar_file_path = "{}/bar_{}.pth".format(save_folder, args.bar)
print("saving bar")
torch.save(policy.bar_policy.state_dict(), bar_file_path)
save_path = "{}/log".format(save_folder)
if not os.path.isfile(save_path):
with open(save_path, "w") as f:
pass
return save_path
def get_args():
"""Retuns the arguments for the script
Returns:
Argument object
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--puck", type=str, default="sine", choices=list(algo_mapping.keys())
)
parser.add_argument(
"--bar", type=str, default="ppo", choices=list(algo_mapping.keys())
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--discrete", action="store_true", default=False)
parser.add_argument("--discrete-k", type=int, default=7)
parser.add_argument("--eps-test", type=float, default=0.005)
parser.add_argument("--eps-train", type=float, default=1.0)
parser.add_argument("--eps-train-final", type=float, default=0.05)
parser.add_argument(
"--eps-train-decay", type=str, default="exp", choices=["exp", "lin", "const"]
)
parser.add_argument("--buffer-size", type=int, default=10000)
parser.add_argument("--stack-num", type=int, default=5)
parser.add_argument("--exploration-noise", type=bool, default=True)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=20)
parser.add_argument("--step-per-epoch", type=int, default=10000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--repeat-per-collect", type=int, default=2)
parser.add_argument("--episode-per-test", type=int, default=100)
parser.add_argument("--episode-per-collect", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.0)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument("--wandb-save-interval", type=int, default=1)
parser.add_argument("--wandb-project", type=str, default="test-project")
parser.add_argument("--wandb-name", type=str, required=True)
parser.add_argument("--wandb-entity", type=str, default="penalty-shot-project")
parser.add_argument("--wandb-run-id", type=str, default=None)
parser.add_argument("--trainer", type=str, default="on", choices=["off", "on"])
parser.add_argument("--save", action="store_true", default=False)
parser.add_argument("--load-puck-id", type=str, default=None)
parser.add_argument("--load-bar-id", type=str, default=None)
parser.add_argument("--run-id", type=str, default=None)
parser.add_argument("--save-render", type=str, default=None)
parser.add_argument("--num-episodes-render", type=int, default=10)
return parser.parse_args()
def init_and_call_policy():
"""Initialises and calls policies for the agent puck and bar
Returns:
Tuple[Policy, Policy]: Returns the policies for puck and bar
"""
if "call_params" in puck_params[args.puck]:
puck_params_init = (
puck_params[args.puck]["init_params"]
if "init_params" in puck_params[args.puck]
else {}
)
puck_params_call = (
puck_params[args.puck]["call_params"]
if "call_params" in puck_params[args.puck]
else {}
)
policy_puck = algo_mapping[args.puck](**puck_params_init)(**puck_params_call)
else:
policy_puck = algo_mapping[args.puck](**puck_params[args.puck])
if "call_params" in bar_params[args.bar]:
bar_params_init = (
bar_params[args.bar]["init_params"]
if "init_params" in bar_params[args.bar]
else {}
)
bar_params_call = (
bar_params[args.bar]["call_params"]
if "call_params" in bar_params[args.bar]
else {}
)
policy_bar = algo_mapping[args.bar](**bar_params_init)(**bar_params_call)
else:
policy_bar = algo_mapping[args.bar](**bar_params[args.bar])
return (policy_puck, policy_bar)
def load_policy(policy_puck, policy_bar):
"""Loads the policy if the load_puck_id and load_bar_id are not None
Args:
policy_puck (Policy): Puck policy to be loaded into
policy_bar (Policy): Bar policy to be loaded into
Returns:
Tuple[Policy, Policy]: Returns the policy after loading if any
"""
if args.load_puck_id is not None:
print("Loading Puck Policy..")
if args.device == "cuda":
policy_puck.load_state_dict(
torch.load(
"saved_policies/{}/puck_{}.pth".format(args.load_puck_id, args.puck)
)
)
else:
policy_puck.load_state_dict(
torch.load(
"saved_policies/{}/puck_{}.pth".format(
args.load_puck_id, args.puck
),
map_location=torch.device("cpu"),
)
)
if args.load_bar_id is not None:
print("Loading Bar Policy..")
if args.device == "cuda":
policy_bar.load_state_dict(
torch.load(
"saved_policies/{}/bar_{}.pth".format(args.load_bar_id, args.bar)
)
)
else:
policy_bar.load_state_dict(
torch.load(
"saved_policies/{}/bar_{}.pth".format(args.load_bar_id, args.bar),
map_location=torch.device("cpu"),
)
)
return policy_puck, policy_bar
def visualise():
"""Trains the agent puck and bar
Raises:
Exception: If arguments are not set
Exception: If proper algorithm is not used with proper trainer
"""
if args is None:
raise Exception("args not set")
print("Using device: ", args.device)
env = MakeEnv(**env_params["train"]).create_env()
args.state_shape = env.observation_space.shape
args.action_shape = env.action_space.shape
# Create testing environments
if args.save_render:
env_params["test"]["save_render_path"] = args.save_render
(test_envs_obj, test_envs) = make_envs(args.test_num, **env_params["test"])
test_envs = SubprocVectorEnv(test_envs)
print(
f"Created {args.test_num} test environments.."
)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# define policies for puck and bar here
print("Initialising Policies..")
policy_puck, policy_bar = init_and_call_policy()
# Loading policies
policy_puck, policy_bar = load_policy(policy_puck, policy_bar)
# Create Two Agent Policy
policy = TwoAgentPolicy(
(policy_puck, policy_bar),
observation_space=env.observation_space,
action_space=env.action_space,
)
print("Creating test collector..")
test_collector = Collector(
policy, test_envs, exploration_noise=args.exploration_noise
)
print(test_collector.collect(n_episode=args.num_episodes_render))
test_envs.close()
if __name__ == "__main__":
args = get_args()
visualise()