Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Stoch Wt Avg (SWA) closes #321 #320

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
gym==0.21
stable-baselines3[extra,tests,docs]>=1.6.2
torchcontrib==0.0.2
sb3-contrib>=1.6.2
box2d-py==2.3.8
pybullet
Expand Down
33 changes: 33 additions & 0 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

# For custom activation fn
from torch import nn as nn # noqa: F401
from torchcontrib.optim import SWA

# Register custom envs
import rl_zoo3.import_envs # noqa: F401 pytype: disable=import-error
Expand All @@ -50,6 +51,23 @@
from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule


def make_swa_opt_class(optimizer, opt_kwargs, swa_kwargs):
class MySWA(SWA):
def __init__(self, params, **kwargs):
# tease out which kwargs are for the base opt, and which for SWA.
opt_kwargs_ = {k: v for k, v in kwargs.items() if k in opt_kwargs}
swa_kwargs_ = {k: v for k, v in kwargs.items() if k in swa_kwargs}
opt = optimizer(params, **opt_kwargs_)
super().__init__(opt, **swa_kwargs_)
# we need to set the various attribs, else we get errors like
# saying things like "defaults" is not defined for MySWA, etc.
attrs = [a for a in dir(opt) if not a.startswith("__") and not callable(getattr(opt, a))]
for attr in attrs:
setattr(self, attr, getattr(opt, attr))

return MySWA


class ExperimentManager:
"""
Experiment manager: read the hyperparameters,
Expand Down Expand Up @@ -173,6 +191,7 @@ def __init__(
self.log_path, f"{self.env_name}_{get_latest_run_id(self.log_path, self.env_name) + 1}{uuid_str}"
)
self.params_path = f"{self.save_path}/{self.env_name}"
self.use_swa = self.custom_hyperparams.get("policy_kwargs", {}).get("swa", {"swa_start": -1})["swa_start"] > 0

def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]:
"""
Expand All @@ -199,6 +218,17 @@ def setup_experiment(self) -> Optional[Tuple[BaseAlgorithm, Dict[str, Any]]]:
env.close()
return None
else:
if self.use_swa:
# Assume swa_start < 0 means we don't want to use SWA
p_kwargs = self._hyperparams["policy_kwargs"]
swa_kwargs = p_kwargs["swa"]
base_opt_class = p_kwargs.get("optimizer_class", th.optim.Adam)
base_opt_kwargs = p_kwargs.get("optimizer_kwargs", {})
p_kwargs["optimizer_class"] = make_swa_opt_class(base_opt_class, base_opt_kwargs, swa_kwargs)
p_kwargs["optimizer_kwargs"] = {**base_opt_kwargs, **swa_kwargs}

self._hyperparams["policy_kwargs"].pop("swa", None) # remove swa if it exists

# Train an agent from scratch
model = ALGOS[self.algo](
env=env,
Expand Down Expand Up @@ -231,6 +261,9 @@ def learn(self, model: BaseAlgorithm) -> None:

try:
model.learn(self.n_timesteps, **kwargs)
if self.use_swa:
model.policy.optimizer.swap_swa_sgd()

except KeyboardInterrupt:
# this allows to save the model when interrupting training
pass
Expand Down