forked from lweitkamp/option-critic-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_continue.py
72 lines (52 loc) · 2.3 KB
/
train_continue.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from pathlib import Path
import yaml
import torch as th
from envs.common import init_train_eval_envs
from envs.util import get_atari_identifier
from utils.callbacks import init_callbacks
from options.ppo import load_agent
ENV_NAME = "ALE/Kangaroo-v5"
MODEL_NAME = "reward-shaping/v2-2"
OUT_BASE_PATH = "out/"
def run():
game_identifier = get_atari_identifier(ENV_NAME)
model_dir = Path(OUT_BASE_PATH, game_identifier, MODEL_NAME)
# Retrieve experiment configuration
config_path = model_dir / "config.yaml"
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.Loader)
n_envs = config["cores"]
model = load_agent(model_dir=model_dir, best_model=False, n_envs=n_envs, train=True)
environment = config["environment"]
evaluation = config["evaluation"]
training = config["training"]
th.manual_seed(config["seed"])
object_centric = environment["object_centric"]
n_eval_envs = config["cores"]
total_timestamps = int(float(training["total_timesteps"]))
ckpt_path = model_dir / "checkpoints"
_, eval_env = init_train_eval_envs(n_train_envs=0,
n_eval_envs=n_eval_envs,
seed=config["seed"],
**environment)
cb_list = init_callbacks(exp_name=MODEL_NAME,
total_timestamps=total_timestamps,
may_use_reward_shaping=object_centric,
n_envs=n_envs,
eval_env=eval_env,
n_eval_episodes=4 * n_eval_envs,
ckpt_path=ckpt_path,
eval_kwargs=evaluation)
remaining_timesteps = total_timestamps - model.num_timesteps
if remaining_timesteps <= 0:
print("No timesteps remain for training, it was already finished.")
return
print(f"Continuing experiment {MODEL_NAME}.")
print(f"Started {type(model).__name__} training for {remaining_timesteps} steps "
f"with {n_envs} actors and {n_eval_envs} evaluators...")
model.learn(total_timesteps=remaining_timesteps,
callback=cb_list,
log_interval=None,
reset_num_timesteps=False)
if __name__ == "__main__":
run()