Skip to content

Add EWC method #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9e6d0f7
feat: add basic EWC (WIP)
MichalBortkiewicz May 11, 2023
47d412a
feat: add basic EWC (WIP) - need to fix ttyrec usage
MichalBortkiewicz May 11, 2023
0cc15d2
chore: update EWC logic
MichalBortkiewicz May 11, 2023
68ecebc
ewc cross entropy
BartekCupial May 12, 2023
8fb854d
chore: add different EWC - based on Maciek's implementation and link …
MichalBortkiewicz May 12, 2023
7ae2588
chore: probably some bug because ewc_loss=0
MichalBortkiewicz May 12, 2023
6dd5afd
fix: EWC with ttyrec
MichalBortkiewicz May 12, 2023
666a626
fix: EWC - add flag whether we should freeze model completely from th…
MichalBortkiewicz May 12, 2023
036fe06
feat: add basic EWC (WIP)
MichalBortkiewicz May 11, 2023
aac51f9
feat: add basic EWC (WIP) - need to fix ttyrec usage
MichalBortkiewicz May 11, 2023
1808d49
chore: update EWC logic
MichalBortkiewicz May 11, 2023
c159172
ewc cross entropy
BartekCupial May 12, 2023
a47cb45
chore: add different EWC - based on Maciek's implementation and link …
MichalBortkiewicz May 12, 2023
cbf4ed9
chore: probably some bug because ewc_loss=0
MichalBortkiewicz May 12, 2023
158a9e3
fix: EWC with ttyrec
MichalBortkiewicz May 12, 2023
01c2426
fix: EWC - add flag whether we should freeze model completely from th…
MichalBortkiewicz May 12, 2023
3265680
Merge remote-tracking branch 'origin/main' into feat/ewc_cross
MichalBortkiewicz May 14, 2023
8cea3f5
feat: add EWC mrunner configs
MichalBortkiewicz May 14, 2023
f50f4d3
Merge remote-tracking branch 'origin/feat/ewc_cross' into feat/ewc_cross
MichalBortkiewicz May 14, 2023
d962a2d
fix: EWC (there was a bug when teacher was present)
MichalBortkiewicz May 16, 2023
f82687d
Merge remote-tracking branch 'origin/main' into feat/ewc_cross
MichalBortkiewicz May 19, 2023
7a98a86
fix: remove log_forgetting flag from EWC mrunner experiments
MichalBortkiewicz May 20, 2023
b63252a
chore: add basic EWC mrunner config
MichalBortkiewicz May 20, 2023
3c57dcf
Merge remote-tracking branch 'origin/main' into feat/ewc_cross
MichalBortkiewicz May 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion experiment_code/hackrl/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ kickstarting_decay: 1.0
kickstarting_loss_bc: 1.0
kickstarting_decay_bc: 1.0
kickstarting_path: /net/pr2/projects/plgrid/plgggmum_crl/bcupial/monk-AA-BC/checkpoint.tar
use_ewc: False
ewc_penalty_scaler: 400
ewc_n_batches: 10
log_forgetting: False
forgetting_dataset: bc1

Expand Down Expand Up @@ -172,8 +175,9 @@ ttyrec_cpus: 10
use_checkpoint_actor: False
model_checkpoint_path: /checkpoint/checkpoint.tar
unfreeze_actor_steps: 0
freeze_from_the_beginning: True

eval_checkpoint_every: 50_000_000
eval_rollouts: 1024
eval_batch_size: 256
skip_first_eval: False
skip_first_eval: False
109 changes: 100 additions & 9 deletions experiment_code/hackrl/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from nle.dataset import dataset
from nle.dataset import db
from nle.dataset import populate_db

from copy import deepcopy
from torch.nn import CrossEntropyLoss
import hackrl.environment
import hackrl.models

Expand All @@ -37,11 +38,81 @@
from hackrl.core import nest
from hackrl.core import record
from hackrl.core import vtrace

# TTYREC_ASYNC_ITERATOR = None
# TTYREC_DATA = None
TTYREC_HIDDEN_STATE = None
TTYREC_ENVPOOL = None
EWC_INSTANCE = None


class EWC(object):
def __init__(
self,
model: nn.Module,
n_batches: int = 10,
):
# To make sure we do not interfere with the main training
self.training = model.training

self.model = model
self.model.train(mode=True)

self.params = {
n: p
for n, p in self.model.named_parameters()
if p.requires_grad
and "baseline".casefold() not in n.casefold()
and "teacher" not in n.casefold()
}
self._means = {}
self._precision_matrices = self._diag_fisher(n_batches=n_batches)

for n, p in copy.deepcopy(self.params).items():
self._means[n] = p.data

def _diag_fisher(self, n_batches: int = 10):
global TTYREC_ENVPOOL, TTYREC_HIDDEN_STATE

precision_matrices = {}
for n, p in deepcopy(self.params).items():
p.data.zero_()
precision_matrices[n] = p.data

for _ in range(n_batches):
self.model.zero_grad()

ttyrec_data = TTYREC_ENVPOOL.result()
idx = TTYREC_ENVPOOL.idx
ttyrec_predictions, TTYREC_HIDDEN_STATE[idx] = self.model(
ttyrec_data, TTYREC_HIDDEN_STATE[idx]
)
TTYREC_HIDDEN_STATE[idx] = nest.map(
lambda t: t.detach(), TTYREC_HIDDEN_STATE[idx]
)

logits = torch.flatten(ttyrec_predictions["policy_logits"], 0, 1)
logits = logits.float().requires_grad_()
label = logits.max(1)[1].view(-1).type(torch.LongTensor).cuda()
loss = CrossEntropyLoss()(logits, label)
loss.backward()

for n, p in self.model.named_parameters():
if n in list(self.params.keys()):
precision_matrices[n].data += p.grad.data**2 / n_batches

precision_matrices = {n: p for n, p in precision_matrices.items()}
self.model.zero_grad()

self.model.train(mode=self.training)
return precision_matrices

def penalty(self, model: nn.Module):
loss = 0
for n, p in model.named_parameters():
if n in list(self.params.keys()):
_loss = self._precision_matrices[n] * (p - self._means[n]) ** 2
loss += _loss.sum()
return loss


class TtyrecEnvPool:
Expand Down Expand Up @@ -576,7 +647,8 @@ def compute_policy_gradient_loss(
stats=None,
):
advantages = advantages.detach()
stats["running_advantages"] += advantages
if stats:
stats["running_advantages"] += advantages

adv = advantages

Expand All @@ -587,7 +659,8 @@ def compute_policy_gradient_loss(
else:
sample_adv = adv
advantages = (adv - sample_adv.mean()) / max(1e-3, sample_adv.std())
stats["sample_advantages"] += advantages.mean().item()
if stats:
stats["sample_advantages"] += advantages.mean().item()

if clip_delta_policy:
# APPO policy loss - clip a change in policy fn
Expand Down Expand Up @@ -633,7 +706,7 @@ def create_scheduler(optimizer):


def compute_gradients(data, sleep_data, learner_state, stats):
global TTYREC_ENVPOOL, TTYREC_HIDDEN_STATE, FORGETTING_ENVPOOL, FORGETTING_HIDDEN_STATE
global TTYREC_ENVPOOL, TTYREC_HIDDEN_STATE, EWC_INSTANCE, FORGETTING_ENVPOOL, FORGETTING_HIDDEN_STATE
model = learner_state.model

env_outputs = data["env_outputs"]
Expand All @@ -644,6 +717,11 @@ def compute_gradients(data, sleep_data, learner_state, stats):

total_loss = 0

if EWC_INSTANCE is None:
assert not (FLAGS.supervised_loss or FLAGS.behavioural_clone), "There is something wrong with config"
# It means that we do not load TTYREC
EWC_INSTANCE = EWC(model)

if FLAGS.supervised_loss or FLAGS.behavioural_clone:
ttyrec_data = TTYREC_ENVPOOL.result()
idx = TTYREC_ENVPOOL.idx
Expand Down Expand Up @@ -745,6 +823,7 @@ def compute_gradients(data, sleep_data, learner_state, stats):
learner_outputs["policy_logits"], stats
)

# TODO: should we add ewc_penalty to this loss?
pg_loss = compute_policy_gradient_loss(
vtrace_returns.behavior_action_log_probs,
vtrace_returns.target_action_log_probs,
Expand All @@ -767,7 +846,7 @@ def compute_gradients(data, sleep_data, learner_state, stats):
if FLAGS.ppg_sleep:
for batch_idx in np.random.randint(len(sleep_data), size=FLAGS.ppg_sleep_cycles):
sleep_batch = sleep_data[batch_idx]

sleep_env_outputs = sleep_batch["env_outputs"]
sleep_initial_core_state = sleep_batch["initial_core_state"]

Expand Down Expand Up @@ -861,6 +940,11 @@ def compute_gradients(data, sleep_data, learner_state, stats):
stats["kickstarting_loss_bc"] += kickstarting_loss_bc.item()
stats["kickstarting_coeff_bc"] += FLAGS.kickstarting_loss_bc

if FLAGS.use_ewc:
ewc_penalty = EWC_INSTANCE.penalty(model)
total_loss += ewc_penalty * FLAGS.ewc_penalty_scaler
stats["ewc_loss"] += ewc_penalty

# Only call step when you are done with ttyrec_data - it may get overwritten
TTYREC_ENVPOOL.step()

Expand All @@ -880,7 +964,7 @@ def compute_gradients(data, sleep_data, learner_state, stats):
kick_predictions["kick_policy_logits"],
)
stats["forgetting_loss"] += forgetting_loss.item()

# Only call step when you are done with ttyrec_data - it may get overwritten
FORGETTING_ENVPOOL.step()

Expand Down Expand Up @@ -1006,6 +1090,7 @@ def main(cfg):
FLAGS.use_checkpoint_actor = False

if FLAGS.use_kickstarting or FLAGS.use_kickstarting_bc or FLAGS.log_forgetting:
assert FLAGS.use_ewc==False, "Cannot use EWC with log forgetting, there is a problem with 'grads shrank'"
student = hackrl.models.create_model(FLAGS, FLAGS.device)
load_data = torch.load(FLAGS.kickstarting_path)
t_flags = omegaconf.OmegaConf.create(load_data["flags"])
Expand Down Expand Up @@ -1095,6 +1180,7 @@ def main(cfg):
"kickstarting_coeff": StatMean(),
"kickstarting_loss_bc": StatMean(),
"kickstarting_coeff_bc": StatMean(),
"ewc_loss": StatMean(),
"forgetting_loss": StatMean(),
"ppg_kl_loss": StatMean(),
"ppg_baseline_loss": StatMean(),
Expand Down Expand Up @@ -1153,7 +1239,7 @@ def signal_handler(signum, frame):
logging.info("Optimising CuDNN kernels")
torch.backends.cudnn.benchmark = True

if FLAGS.supervised_loss or FLAGS.behavioural_clone or FLAGS.use_kickstarting_bc:
if FLAGS.supervised_loss or FLAGS.behavioural_clone or FLAGS.use_kickstarting_bc or FLAGS.use_ewc:
global TTYREC_ENVPOOL, TTYREC_HIDDEN_STATE
tp = concurrent.futures.ThreadPoolExecutor(max_workers=FLAGS.ttyrec_cpus)
TTYREC_HIDDEN_STATE = []
Expand Down Expand Up @@ -1205,6 +1291,11 @@ def signal_handler(signum, frame):
is_connected = False
unfreezed = False
checkpoint_steps = -1

if FLAGS.use_ewc:
global EWC_INSTANCE
EWC_INSTANCE = EWC(model, n_batches=FLAGS.ewc_n_batches)

while not terminate:
prev_now = now
now = time.time()
Expand Down Expand Up @@ -1381,7 +1472,7 @@ def signal_handler(signum, frame):

if __name__ == "__main__":
tempdir = tempfile.mkdtemp()
tempfile.tempdir = tempdir
tempfile.tempdir = tempdir

try:
main()
Expand Down
6 changes: 3 additions & 3 deletions experiment_code/hackrl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def create_model(flags, device):
map_location=torch.device(device),
)
model.load_state_dict(load_data["learner_state"]["model"], strict=False)
freeze(model)
unfreeze_selected(model, ["baseline", "embed_ln"])

if flags.freeze_from_the_beginning:
freeze(model)
unfreeze_selected(model, ["baseline", "embed_ln"])
return model


Expand Down
61 changes: 61 additions & 0 deletions experiment_code/mrunner_exps/ewc/monk-APPO-EWC-T.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from pathlib import Path
from random_words import RandomWords
from mrunner.helpers.specification_helper import create_experiments_helper, get_combinations


name = globals()["script"][:-3]

# params for all exps
config = {
"exp_tags": [name],
"connect":"0.0.0.0:4431",
"exp_set": "2G",
"exp_point": "monk-T-EWC-final",
"num_actor_cpus": 20,
"total_steps": 2_000_000_000,
'group': "monk-APPO-T-EWC-final",
"character": "mon-hum-neu-mal",
"freeze_from_the_beginning": False,
"use_ewc": True,
"ewc_penalty_scaler": 8000,
"ewc_n_batches": 1000
}


# params different between exps
params_grid = [
{
"seed": list(range(6)),
# load from checkpoint
"unfreeze_actor_steps": [0],
"use_checkpoint_actor": [True],
"model_checkpoint_path": ["/net/pr2/projects/plgrid/plgg_pw_crl/mostaszewski/monk-AA-BC/checkpoint.tar"],
# log forgetting
"log_forgetting": [False],
"forgetting_dataset": ["bc_midscore"],
"kickstarting_path": ["/net/pr2/projects/plgrid/plgg_pw_crl/mostaszewski/monk-AA-BC/checkpoint.tar"],
},
]

params_configurations = get_combinations(params_grid)

final_grid = []
for e, cfg in enumerate(params_configurations):
cfg = {key: [value] for key, value in cfg.items()}
r = RandomWords().random_word()
cfg["group"] = [f"{name}_{e}_{r}"]
final_grid.append(dict(cfg))


experiments_list = create_experiments_helper(
experiment_name=name,
project_name="nle",
with_neptune=False,
script="python3 mrunner_run.py",
python_path=".",
tags=[name],
exclude=["checkpoint"],
base_config=config,
params_grid=final_grid,
exclude_git_files=False,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from random_words import RandomWords

from mrunner.helpers.specification_helper import create_experiments_helper, get_combinations


name = globals()["script"][:-3]

# params for all exps
config = {
"exp_tags": [name],
"connect":"0.0.0.0:4431",
"exp_set": "2G",
"exp_point": "monk-APPO-BC1-EWC",
"num_actor_cpus": 20,
"total_steps": 2_000_000_000,
"group": "monk-APPO-BC1-EWC",
"character": "mon-hum-neu-mal",
"ttyrec_batch_size": 256,
"use_ewc": True,
"dataset": "aa",
"freeze_from_the_beginning": False,
"use_checkpoint_actor": True,
"model_checkpoint_path": "/checkpoint/checkpoint.tar",
"unfreeze_actor_steps": 0,
}


# params different between exps
params_grid = [
{
"seed": list(range(5)),
"ewc_penalty_scaler": [1, 400, 4000],
},
]

params_configurations = get_combinations(params_grid)

final_grid = []
for e, cfg in enumerate(params_configurations):
cfg = {key: [value] for key, value in cfg.items()}
r = RandomWords().random_word()
cfg["group"] = [f"{name}_{e}_{r}"]
final_grid.append(dict(cfg))


experiments_list = create_experiments_helper(
experiment_name=name,
project_name="nle",
with_neptune=False,
script="python3 mrunner_run.py",
python_path=".",
tags=[name],
exclude=["checkpoint"],
base_config=config,
params_grid=final_grid,
exclude_git_files=False,
)
Loading