diff --git a/experiment_code/hackrl/config.yaml b/experiment_code/hackrl/config.yaml index 9863dca..6350c83 100644 --- a/experiment_code/hackrl/config.yaml +++ b/experiment_code/hackrl/config.yaml @@ -119,6 +119,9 @@ kickstarting_decay: 1.0 kickstarting_loss_bc: 1.0 kickstarting_decay_bc: 1.0 kickstarting_path: /net/pr2/projects/plgrid/plgggmum_crl/bcupial/monk-AA-BC/checkpoint.tar +use_ewc: False +ewc_penalty_scaler: 400 +ewc_n_batches: 10 log_forgetting: False forgetting_dataset: bc1 @@ -172,8 +175,9 @@ ttyrec_cpus: 10 use_checkpoint_actor: False model_checkpoint_path: /checkpoint/checkpoint.tar unfreeze_actor_steps: 0 +freeze_from_the_beginning: True eval_checkpoint_every: 50_000_000 eval_rollouts: 1024 eval_batch_size: 256 -skip_first_eval: False +skip_first_eval: False \ No newline at end of file diff --git a/experiment_code/hackrl/experiment.py b/experiment_code/hackrl/experiment.py index 59cfebe..b56bec0 100644 --- a/experiment_code/hackrl/experiment.py +++ b/experiment_code/hackrl/experiment.py @@ -26,7 +26,8 @@ from nle.dataset import dataset from nle.dataset import db from nle.dataset import populate_db - +from copy import deepcopy +from torch.nn import CrossEntropyLoss import hackrl.environment import hackrl.models @@ -37,11 +38,81 @@ from hackrl.core import nest from hackrl.core import record from hackrl.core import vtrace - # TTYREC_ASYNC_ITERATOR = None # TTYREC_DATA = None TTYREC_HIDDEN_STATE = None TTYREC_ENVPOOL = None +EWC_INSTANCE = None + + +class EWC(object): + def __init__( + self, + model: nn.Module, + n_batches: int = 10, + ): + # To make sure we do not interfere with the main training + self.training = model.training + + self.model = model + self.model.train(mode=True) + + self.params = { + n: p + for n, p in self.model.named_parameters() + if p.requires_grad + and "baseline".casefold() not in n.casefold() + and "teacher" not in n.casefold() + } + self._means = {} + self._precision_matrices = self._diag_fisher(n_batches=n_batches) + + for n, p in copy.deepcopy(self.params).items(): + self._means[n] = p.data + + def _diag_fisher(self, n_batches: int = 10): + global TTYREC_ENVPOOL, TTYREC_HIDDEN_STATE + + precision_matrices = {} + for n, p in deepcopy(self.params).items(): + p.data.zero_() + precision_matrices[n] = p.data + + for _ in range(n_batches): + self.model.zero_grad() + + ttyrec_data = TTYREC_ENVPOOL.result() + idx = TTYREC_ENVPOOL.idx + ttyrec_predictions, TTYREC_HIDDEN_STATE[idx] = self.model( + ttyrec_data, TTYREC_HIDDEN_STATE[idx] + ) + TTYREC_HIDDEN_STATE[idx] = nest.map( + lambda t: t.detach(), TTYREC_HIDDEN_STATE[idx] + ) + + logits = torch.flatten(ttyrec_predictions["policy_logits"], 0, 1) + logits = logits.float().requires_grad_() + label = logits.max(1)[1].view(-1).type(torch.LongTensor).cuda() + loss = CrossEntropyLoss()(logits, label) + loss.backward() + + for n, p in self.model.named_parameters(): + if n in list(self.params.keys()): + precision_matrices[n].data += p.grad.data**2 / n_batches + + precision_matrices = {n: p for n, p in precision_matrices.items()} + self.model.zero_grad() + + self.model.train(mode=self.training) + return precision_matrices + + def penalty(self, model: nn.Module): + loss = 0 + for n, p in model.named_parameters(): + if n in list(self.params.keys()): + _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2 + loss += _loss.sum() + return loss class TtyrecEnvPool: @@ -576,7 +647,8 @@ def compute_policy_gradient_loss( stats=None, ): advantages = advantages.detach() - stats["running_advantages"] += advantages + if stats: + stats["running_advantages"] += advantages adv = advantages @@ -587,7 +659,8 @@ def compute_policy_gradient_loss( else: sample_adv = adv advantages = (adv - sample_adv.mean()) / max(1e-3, sample_adv.std()) - stats["sample_advantages"] += advantages.mean().item() + if stats: + stats["sample_advantages"] += advantages.mean().item() if clip_delta_policy: # APPO policy loss - clip a change in policy fn @@ -633,7 +706,7 @@ def create_scheduler(optimizer): def compute_gradients(data, sleep_data, learner_state, stats): - global TTYREC_ENVPOOL, TTYREC_HIDDEN_STATE, FORGETTING_ENVPOOL, FORGETTING_HIDDEN_STATE + global TTYREC_ENVPOOL, TTYREC_HIDDEN_STATE, EWC_INSTANCE, FORGETTING_ENVPOOL, FORGETTING_HIDDEN_STATE model = learner_state.model env_outputs = data["env_outputs"] @@ -644,6 +717,11 @@ def compute_gradients(data, sleep_data, learner_state, stats): total_loss = 0 + if EWC_INSTANCE is None: + assert not (FLAGS.supervised_loss or FLAGS.behavioural_clone), "There is something wrong with config" + # It means that we do not load TTYREC + EWC_INSTANCE = EWC(model) + if FLAGS.supervised_loss or FLAGS.behavioural_clone: ttyrec_data = TTYREC_ENVPOOL.result() idx = TTYREC_ENVPOOL.idx @@ -745,6 +823,7 @@ def compute_gradients(data, sleep_data, learner_state, stats): learner_outputs["policy_logits"], stats ) + # TODO: should we add ewc_penalty to this loss? pg_loss = compute_policy_gradient_loss( vtrace_returns.behavior_action_log_probs, vtrace_returns.target_action_log_probs, @@ -767,7 +846,7 @@ def compute_gradients(data, sleep_data, learner_state, stats): if FLAGS.ppg_sleep: for batch_idx in np.random.randint(len(sleep_data), size=FLAGS.ppg_sleep_cycles): sleep_batch = sleep_data[batch_idx] - + sleep_env_outputs = sleep_batch["env_outputs"] sleep_initial_core_state = sleep_batch["initial_core_state"] @@ -861,6 +940,11 @@ def compute_gradients(data, sleep_data, learner_state, stats): stats["kickstarting_loss_bc"] += kickstarting_loss_bc.item() stats["kickstarting_coeff_bc"] += FLAGS.kickstarting_loss_bc + if FLAGS.use_ewc: + ewc_penalty = EWC_INSTANCE.penalty(model) + total_loss += ewc_penalty * FLAGS.ewc_penalty_scaler + stats["ewc_loss"] += ewc_penalty + # Only call step when you are done with ttyrec_data - it may get overwritten TTYREC_ENVPOOL.step() @@ -880,7 +964,7 @@ def compute_gradients(data, sleep_data, learner_state, stats): kick_predictions["kick_policy_logits"], ) stats["forgetting_loss"] += forgetting_loss.item() - + # Only call step when you are done with ttyrec_data - it may get overwritten FORGETTING_ENVPOOL.step() @@ -1006,6 +1090,7 @@ def main(cfg): FLAGS.use_checkpoint_actor = False if FLAGS.use_kickstarting or FLAGS.use_kickstarting_bc or FLAGS.log_forgetting: + assert FLAGS.use_ewc==False, "Cannot use EWC with log forgetting, there is a problem with 'grads shrank'" student = hackrl.models.create_model(FLAGS, FLAGS.device) load_data = torch.load(FLAGS.kickstarting_path) t_flags = omegaconf.OmegaConf.create(load_data["flags"]) @@ -1095,6 +1180,7 @@ def main(cfg): "kickstarting_coeff": StatMean(), "kickstarting_loss_bc": StatMean(), "kickstarting_coeff_bc": StatMean(), + "ewc_loss": StatMean(), "forgetting_loss": StatMean(), "ppg_kl_loss": StatMean(), "ppg_baseline_loss": StatMean(), @@ -1153,7 +1239,7 @@ def signal_handler(signum, frame): logging.info("Optimising CuDNN kernels") torch.backends.cudnn.benchmark = True - if FLAGS.supervised_loss or FLAGS.behavioural_clone or FLAGS.use_kickstarting_bc: + if FLAGS.supervised_loss or FLAGS.behavioural_clone or FLAGS.use_kickstarting_bc or FLAGS.use_ewc: global TTYREC_ENVPOOL, TTYREC_HIDDEN_STATE tp = concurrent.futures.ThreadPoolExecutor(max_workers=FLAGS.ttyrec_cpus) TTYREC_HIDDEN_STATE = [] @@ -1205,6 +1291,11 @@ def signal_handler(signum, frame): is_connected = False unfreezed = False checkpoint_steps = -1 + + if FLAGS.use_ewc: + global EWC_INSTANCE + EWC_INSTANCE = EWC(model, n_batches=FLAGS.ewc_n_batches) + while not terminate: prev_now = now now = time.time() @@ -1381,7 +1472,7 @@ def signal_handler(signum, frame): if __name__ == "__main__": tempdir = tempfile.mkdtemp() - tempfile.tempdir = tempdir + tempfile.tempdir = tempdir try: main() diff --git a/experiment_code/hackrl/models/__init__.py b/experiment_code/hackrl/models/__init__.py index afd9841..aefce37 100644 --- a/experiment_code/hackrl/models/__init__.py +++ b/experiment_code/hackrl/models/__init__.py @@ -87,9 +87,9 @@ def create_model(flags, device): map_location=torch.device(device), ) model.load_state_dict(load_data["learner_state"]["model"], strict=False) - freeze(model) - unfreeze_selected(model, ["baseline", "embed_ln"]) - + if flags.freeze_from_the_beginning: + freeze(model) + unfreeze_selected(model, ["baseline", "embed_ln"]) return model diff --git a/experiment_code/mrunner_exps/ewc/monk-APPO-EWC-T.py b/experiment_code/mrunner_exps/ewc/monk-APPO-EWC-T.py new file mode 100644 index 0000000..debd8ca --- /dev/null +++ b/experiment_code/mrunner_exps/ewc/monk-APPO-EWC-T.py @@ -0,0 +1,61 @@ +from pathlib import Path +from random_words import RandomWords +from mrunner.helpers.specification_helper import create_experiments_helper, get_combinations + + +name = globals()["script"][:-3] + +# params for all exps +config = { + "exp_tags": [name], + "connect":"0.0.0.0:4431", + "exp_set": "2G", + "exp_point": "monk-T-EWC-final", + "num_actor_cpus": 20, + "total_steps": 2_000_000_000, + 'group': "monk-APPO-T-EWC-final", + "character": "mon-hum-neu-mal", + "freeze_from_the_beginning": False, + "use_ewc": True, + "ewc_penalty_scaler": 8000, + "ewc_n_batches": 1000 +} + + +# params different between exps +params_grid = [ + { + "seed": list(range(6)), + # load from checkpoint + "unfreeze_actor_steps": [0], + "use_checkpoint_actor": [True], + "model_checkpoint_path": ["/net/pr2/projects/plgrid/plgg_pw_crl/mostaszewski/monk-AA-BC/checkpoint.tar"], + # log forgetting + "log_forgetting": [False], + "forgetting_dataset": ["bc_midscore"], + "kickstarting_path": ["/net/pr2/projects/plgrid/plgg_pw_crl/mostaszewski/monk-AA-BC/checkpoint.tar"], + }, +] + +params_configurations = get_combinations(params_grid) + +final_grid = [] +for e, cfg in enumerate(params_configurations): + cfg = {key: [value] for key, value in cfg.items()} + r = RandomWords().random_word() + cfg["group"] = [f"{name}_{e}_{r}"] + final_grid.append(dict(cfg)) + + +experiments_list = create_experiments_helper( + experiment_name=name, + project_name="nle", + with_neptune=False, + script="python3 mrunner_run.py", + python_path=".", + tags=[name], + exclude=["checkpoint"], + base_config=config, + params_grid=final_grid, + exclude_git_files=False, +) diff --git a/experiment_code/mrunner_exps/paper_baselines/2023_11_05_monk-APPO-AA-EWC-T.py b/experiment_code/mrunner_exps/paper_baselines/2023_11_05_monk-APPO-AA-EWC-T.py new file mode 100644 index 0000000..b4af6f9 --- /dev/null +++ b/experiment_code/mrunner_exps/paper_baselines/2023_11_05_monk-APPO-AA-EWC-T.py @@ -0,0 +1,57 @@ +from random_words import RandomWords + +from mrunner.helpers.specification_helper import create_experiments_helper, get_combinations + + +name = globals()["script"][:-3] + +# params for all exps +config = { + "exp_tags": [name], + "connect":"0.0.0.0:4431", + "exp_set": "2G", + "exp_point": "monk-APPO-BC1-EWC", + "num_actor_cpus": 20, + "total_steps": 2_000_000_000, + "group": "monk-APPO-BC1-EWC", + "character": "mon-hum-neu-mal", + "ttyrec_batch_size": 256, + "use_ewc": True, + "dataset": "aa", + "freeze_from_the_beginning": False, + "use_checkpoint_actor": True, + "model_checkpoint_path": "/checkpoint/checkpoint.tar", + "unfreeze_actor_steps": 0, +} + + +# params different between exps +params_grid = [ + { + "seed": list(range(5)), + "ewc_penalty_scaler": [1, 400, 4000], + }, +] + +params_configurations = get_combinations(params_grid) + +final_grid = [] +for e, cfg in enumerate(params_configurations): + cfg = {key: [value] for key, value in cfg.items()} + r = RandomWords().random_word() + cfg["group"] = [f"{name}_{e}_{r}"] + final_grid.append(dict(cfg)) + + +experiments_list = create_experiments_helper( + experiment_name=name, + project_name="nle", + with_neptune=False, + script="python3 mrunner_run.py", + python_path=".", + tags=[name], + exclude=["checkpoint"], + base_config=config, + params_grid=final_grid, + exclude_git_files=False, +) diff --git a/experiment_code/mrunner_exps/paper_baselines/monk-APPO-BC1-EWC.py b/experiment_code/mrunner_exps/paper_baselines/monk-APPO-BC1-EWC.py new file mode 100644 index 0000000..1f1f61a --- /dev/null +++ b/experiment_code/mrunner_exps/paper_baselines/monk-APPO-BC1-EWC.py @@ -0,0 +1,57 @@ +from random_words import RandomWords + +from mrunner.helpers.specification_helper import create_experiments_helper, get_combinations + + +name = globals()["script"][:-3] + +# params for all exps +config = { + "exp_tags": [name], + "connect":"0.0.0.0:4431", + "exp_set": "2G", + "exp_point": "monk-APPO-BC1-EWC", + "num_actor_cpus": 20, + "total_steps": 2_000_000_000, + "group": "monk-APPO-BC1-EWC", + "character": "mon-hum-neu-mal", + "ttyrec_batch_size": 256, + "use_ewc": True, + "dataset": "bc1", + "freeze_from_the_beginning": False, + "use_checkpoint_actor": True, + "model_checkpoint_path": "/checkpoint/checkpoint.tar", + "unfreeze_actor_steps": 0, +} + +# params different between exps +params_grid = [ + { + "seed": list(range(5)), + "ewc_penalty_scaler": [1, 400, 2000], + "ewc_n_batches": [10, 100] + }, +] + +params_configurations = get_combinations(params_grid) + +final_grid = [] +for e, cfg in enumerate(params_configurations): + cfg = {key: [value] for key, value in cfg.items()} + r = RandomWords().random_word() + cfg["group"] = [f"{name}_{e}_{r}"] + final_grid.append(dict(cfg)) + + +experiments_list = create_experiments_helper( + experiment_name=name, + project_name="nle", + with_neptune=False, + script="python3 mrunner_run.py", + python_path=".", + tags=[name], + exclude=["checkpoint"], + base_config=config, + params_grid=final_grid, + exclude_git_files=False, +) diff --git a/experiment_code/mrunner_exps/paper_baselines_deep/2023_13_05_monk-APPO-T-EWC_deep.py b/experiment_code/mrunner_exps/paper_baselines_deep/2023_13_05_monk-APPO-T-EWC_deep.py new file mode 100644 index 0000000..e5d90ec --- /dev/null +++ b/experiment_code/mrunner_exps/paper_baselines_deep/2023_13_05_monk-APPO-T-EWC_deep.py @@ -0,0 +1,61 @@ +from random_word import RandomWords + +from mrunner.helpers.specification_helper import create_experiments_helper, get_combinations + + +name = globals()["script"][:-3] + +# params for all exps +config = { + "exp_tags": [name], + "connect":"0.0.0.0:4431", + "exp_set": "2G", + "exp_point": "monk-APPO", + "num_actor_cpus": 20, + "total_steps": 2_000_000_000, + "group": "monk-APPO", + "character": "mon-hum-neu-mal", + "freeze_from_the_beginning": False, + "use_ewc": True, + "ewc_penalty_scaler": 8000, + "ewc_n_batches": 1000 +} + + +# params different between exps +params_grid = [ + { + "seed": list(range(6)), + # load from checkpoint + "unfreeze_actor_steps": [0], + "use_checkpoint_actor": [True], + "model_checkpoint_path": ["/net/tscratch/people/plgbartekcupial/mrunner_scratch/nle/10_05-09_22-awesome_heisenberg/monk-aa-bc-deep_hp0i_0/checkpoint/hackrl/nle/monk-AA-BC_deep_0/checkpoint.tar"], + # log forgetting + "log_forgetting": [False], + "forgetting_dataset": ["bc_deep"], + "kickstarting_path": ["/net/tscratch/people/plgbartekcupial/mrunner_scratch/nle/10_05-09_22-awesome_heisenberg/monk-aa-bc-deep_hp0i_0/checkpoint/hackrl/nle/monk-AA-BC_deep_0/checkpoint.tar"], + }, +] + +params_configurations = get_combinations(params_grid) + +final_grid = [] +for e, cfg in enumerate(params_configurations): + cfg = {key: [value] for key, value in cfg.items()} + r = RandomWords().get_random_word() + cfg["group"] = [f"{name}_{e}_{r}"] + final_grid.append(dict(cfg)) + + +experiments_list = create_experiments_helper( + experiment_name=name, + project_name="nle", + with_neptune=False, + script="python3 mrunner_run.py", + python_path=".", + tags=[name], + exclude=["checkpoint"], + base_config=config, + params_grid=final_grid, + exclude_git_files=False, +) diff --git a/experiment_code/mrunner_exps/paper_baselines_highscore/2023_13_05_monk-APPO-T-EWC_highscore.py b/experiment_code/mrunner_exps/paper_baselines_highscore/2023_13_05_monk-APPO-T-EWC_highscore.py new file mode 100644 index 0000000..e056bc4 --- /dev/null +++ b/experiment_code/mrunner_exps/paper_baselines_highscore/2023_13_05_monk-APPO-T-EWC_highscore.py @@ -0,0 +1,62 @@ + +from random_word import RandomWords + +from mrunner.helpers.specification_helper import create_experiments_helper, get_combinations + + +name = globals()["script"][:-3] + +# params for all exps +config = { + "exp_tags": [name], + "connect":"0.0.0.0:4431", + "exp_set": "2G", + "exp_point": "monk-APPO", + "num_actor_cpus": 20, + "total_steps": 2_000_000_000, + "group": "monk-APPO", + "character": "mon-hum-neu-mal", + "freeze_from_the_beginning": False, + "use_ewc": True, + "ewc_penalty_scaler": 8000, + "ewc_n_batches": 1000 +} + + +# params different between exps +params_grid = [ + { + "seed": list(range(6)), + # load from checkpoint + "unfreeze_actor_steps": [0], + "use_checkpoint_actor": [True], + "model_checkpoint_path": ["/net/tscratch/people/plgbartekcupial/mrunner_scratch/nle/10_05-09_22-awesome_heisenberg/monk-aa-bc-deep_hp0i_10/checkpoint/hackrl/nle/monk-AA-BC_deep_10/checkpoint.tar"], + # log forgetting + "log_forgetting": [False], + "forgetting_dataset": ["bc_midscore"], + "kickstarting_path": ["/net/tscratch/people/plgbartekcupial/mrunner_scratch/nle/10_05-09_22-awesome_heisenberg/monk-aa-bc-deep_hp0i_10/checkpoint/hackrl/nle/monk-AA-BC_deep_10/checkpoint.tar"], + }, +] + +params_configurations = get_combinations(params_grid) + +final_grid = [] +for e, cfg in enumerate(params_configurations): + cfg = {key: [value] for key, value in cfg.items()} + r = RandomWords().get_random_word() + cfg["group"] = [f"{name}_{e}_{r}"] + final_grid.append(dict(cfg)) + + +experiments_list = create_experiments_helper( + experiment_name=name, + project_name="nle", + with_neptune=False, + script="python3 mrunner_run.py", + python_path=".", + tags=[name], + exclude=["checkpoint"], + base_config=config, + params_grid=final_grid, + exclude_git_files=False, +) diff --git a/experiment_code/mrunner_exps/paper_baselines_midscore/2023_13_05_monk-APPO-T-EWC_midscore.py b/experiment_code/mrunner_exps/paper_baselines_midscore/2023_13_05_monk-APPO-T-EWC_midscore.py new file mode 100644 index 0000000..de42a58 --- /dev/null +++ b/experiment_code/mrunner_exps/paper_baselines_midscore/2023_13_05_monk-APPO-T-EWC_midscore.py @@ -0,0 +1,61 @@ +from random_word import RandomWords + +from mrunner.helpers.specification_helper import create_experiments_helper, get_combinations + + +name = globals()["script"][:-3] + +# params for all exps +config = { + "exp_tags": [name], + "connect":"0.0.0.0:4431", + "exp_set": "2G", + "exp_point": "monk-APPO", + "num_actor_cpus": 20, + "total_steps": 2_000_000_000, + "group": "monk-APPO", + "character": "mon-hum-neu-mal", + "freeze_from_the_beginning": False, + "use_ewc": True, + "ewc_penalty_scaler": 8000, + "ewc_n_batches": 1000 +} + + +# params different between exps +params_grid = [ + { + "seed": list(range(6)), + # load from checkpoint + "unfreeze_actor_steps": [0], + "use_checkpoint_actor": [True], + "model_checkpoint_path": ["/net/tscratch/people/plgbartekcupial/mrunner_scratch/nle/10_05-09_22-awesome_heisenberg/monk-aa-bc-deep_hp0i_5/checkpoint/hackrl/nle/monk-AA-BC_deep_5/checkpoint.tar"], + # log forgetting + "log_forgetting": [False], + "forgetting_dataset": ["bc_midscore"], + "kickstarting_path": ["/net/tscratch/people/plgbartekcupial/mrunner_scratch/nle/10_05-09_22-awesome_heisenberg/monk-aa-bc-deep_hp0i_5/checkpoint/hackrl/nle/monk-AA-BC_deep_5/checkpoint.tar"], + }, +] + +params_configurations = get_combinations(params_grid) + +final_grid = [] +for e, cfg in enumerate(params_configurations): + cfg = {key: [value] for key, value in cfg.items()} + r = RandomWords().get_random_word() + cfg["group"] = [f"{name}_{e}_{r}"] + final_grid.append(dict(cfg)) + + +experiments_list = create_experiments_helper( + experiment_name=name, + project_name="nle", + with_neptune=False, + script="python3 mrunner_run.py", + python_path=".", + tags=[name], + exclude=["checkpoint"], + base_config=config, + params_grid=final_grid, + exclude_git_files=False, +)