From 61a7c77e5da4ebc4dbcddcb603afa9e3f6f65ebc Mon Sep 17 00:00:00 2001 From: Alexander Nikulin Date: Wed, 14 Jun 2023 22:02:33 +0300 Subject: [PATCH] Release! (#19) * init refactor * added datasets to the commit * cql without reward normalization * added reward normalisation * report and iql drafts * cql sweep, finish report script * added iql and rem * revert norm * added rem, awac, iql * deleted discrete iql * a lot of stuff * fix formatting * stats * add dataset downloading from hf * updated requirements and dockerfile * fix bug, fix docker * num workers for rendering as config value * add typings to the algorithms, remove db if exists * more typings, optional memmap cache cleaning * more typings * removed default vector env arg --- .gitignore | 8 +- Dockerfile | 27 +- algorithms/bc_chaotic_lstm.py | 353 ------------ algorithms/bc_resnet_lstm.py | 286 ---------- algorithms/rem_chaotic_lstm.py | 362 ------------ algorithms/small_scale/awac_chaotic_lstm.py | 489 +++++++++++++++++ algorithms/small_scale/bc_chaotic_lstm.py | 405 ++++++++++++++ algorithms/small_scale/cql_chaotic_lstm.py | 476 ++++++++++++++++ algorithms/small_scale/iql_chaotic_lstm.py | 517 ++++++++++++++++++ algorithms/small_scale/rem_chaotic_lstm.py | 485 ++++++++++++++++ .../sweeps/small_scale_awac_chaotic_lstm.yaml | 71 +++ .../sweeps/small_scale_bc_chaotic_lstm.yaml | 71 +++ .../sweeps/small_scale_cql_chaotic_lstm.yaml | 73 +++ .../small_scale_cql_chaotic_lstm_sweep.yaml | 24 + .../sweeps/small_scale_iql_chaotic_lstm.yaml | 71 +++ .../sweeps/small_scale_rem_chaotic_lstm.yaml | 71 +++ katakomba/__init__.py | 3 - katakomba/datasets/__init__.py | 5 - katakomba/datasets/base.py | 23 - katakomba/datasets/builder.py | 150 ----- katakomba/datasets/sa_autoascend.py | 36 -- katakomba/datasets/sa_chaotic_autoascend.py | 54 -- katakomba/datasets/sars_autoascend.py | 82 --- katakomba/datasets/sars_chaotic_autoascend.py | 86 --- katakomba/datasets/state_autoascend.py | 27 - katakomba/env.py | 193 +++++++ katakomba/envs/__init__.py | 1 - katakomba/envs/builder.py | 109 ---- katakomba/envs/envs.py | 367 ------------- katakomba/nn/vit.py | 0 katakomba/tasks.py | 39 -- katakomba/utils/datasets/__init__.py | 3 + katakomba/utils/datasets/large_scale.py | 99 ++++ katakomba/utils/datasets/small_scale.py | 141 +++++ .../utils/datasets/small_scale_buffer.py | 80 +++ katakomba/utils/misc.py | 103 ++++ katakomba/utils/roles.py | 131 +++-- katakomba/utils/scores.py | 127 +++++ katakomba/wrappers/__init__.py | 1 - katakomba/wrappers/base.py | 45 -- katakomba/wrappers/render.py | 4 +- katakomba/wrappers/tty.py | 4 +- requirements.txt | 18 + requirements/requirements.txt | 21 - scripts/generate.sh | 67 +++ scripts/generate_small_dataset.py | 176 ++++++ scripts/guide.py | 113 ---- scripts/loader_benchmark.py | 46 -- scripts/rliable_report.py | 145 +++++ scripts/stats/algorithms_stats.py | 88 +++ scripts/stats/depth.py | 39 ++ scripts/stats/scores.py | 33 ++ scripts/stats/small_scale_stats.py | 33 ++ scripts/test_chaotic_loader.py | 48 -- 54 files changed, 4202 insertions(+), 2327 deletions(-) delete mode 100644 algorithms/bc_chaotic_lstm.py delete mode 100644 algorithms/bc_resnet_lstm.py delete mode 100644 algorithms/rem_chaotic_lstm.py create mode 100644 algorithms/small_scale/awac_chaotic_lstm.py create mode 100644 algorithms/small_scale/bc_chaotic_lstm.py create mode 100644 algorithms/small_scale/cql_chaotic_lstm.py create mode 100644 algorithms/small_scale/iql_chaotic_lstm.py create mode 100644 algorithms/small_scale/rem_chaotic_lstm.py create mode 100644 configs/sweeps/small_scale_awac_chaotic_lstm.yaml create mode 100644 configs/sweeps/small_scale_bc_chaotic_lstm.yaml create mode 100644 configs/sweeps/small_scale_cql_chaotic_lstm.yaml create mode 100644 configs/sweeps/small_scale_cql_chaotic_lstm_sweep.yaml create mode 100644 configs/sweeps/small_scale_iql_chaotic_lstm.yaml create mode 100644 configs/sweeps/small_scale_rem_chaotic_lstm.yaml delete mode 100644 katakomba/__init__.py delete mode 100644 katakomba/datasets/__init__.py delete mode 100644 katakomba/datasets/base.py delete mode 100644 katakomba/datasets/builder.py delete mode 100644 katakomba/datasets/sa_autoascend.py delete mode 100644 katakomba/datasets/sa_chaotic_autoascend.py delete mode 100644 katakomba/datasets/sars_autoascend.py delete mode 100644 katakomba/datasets/sars_chaotic_autoascend.py delete mode 100644 katakomba/datasets/state_autoascend.py create mode 100644 katakomba/env.py delete mode 100644 katakomba/envs/__init__.py delete mode 100644 katakomba/envs/builder.py delete mode 100644 katakomba/envs/envs.py delete mode 100644 katakomba/nn/vit.py delete mode 100644 katakomba/tasks.py create mode 100644 katakomba/utils/datasets/__init__.py create mode 100644 katakomba/utils/datasets/large_scale.py create mode 100644 katakomba/utils/datasets/small_scale.py create mode 100644 katakomba/utils/datasets/small_scale_buffer.py create mode 100644 katakomba/utils/misc.py create mode 100644 katakomba/utils/scores.py delete mode 100644 katakomba/wrappers/base.py create mode 100644 requirements.txt delete mode 100644 requirements/requirements.txt create mode 100644 scripts/generate.sh create mode 100644 scripts/generate_small_dataset.py delete mode 100644 scripts/guide.py delete mode 100644 scripts/loader_benchmark.py create mode 100644 scripts/rliable_report.py create mode 100644 scripts/stats/algorithms_stats.py create mode 100644 scripts/stats/depth.py create mode 100644 scripts/stats/scores.py create mode 100644 scripts/stats/small_scale_stats.py delete mode 100644 scripts/test_chaotic_loader.py diff --git a/.gitignore b/.gitignore index ee81856..21179bb 100644 --- a/.gitignore +++ b/.gitignore @@ -2,16 +2,22 @@ dev/ # Data -data/ +data ttyrecs.db # Misc +**/.DS_Store +loading_times.py +*.pkl +algorithms/small_scale/iql_chaotic_lstm_discrete.py +run.sh test.py test.ipynb .ml-job-preset.yml mlc_run.sh wandb algorithms/bc_resnet_lstm_accelerate.py +bin # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/Dockerfile b/Dockerfile index b51ca85..b84dc0e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,16 +1,13 @@ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04 WORKDIR /workspace -RUN #rm /etc/apt/sources.list.d/cuda.list -RUN #rm /etc/apt/sources.list.d/nvidia-ml.list -# python, dependencies for mujoco-py, from https://github.com/openai/mujoco-py +# python, dependencies for mujoco-py (and in general useful for RL research), +# from https://github.com/openai/mujoco-py RUN apt-get update -q \ && DEBIAN_FRONTEND=noninteractive apt-get install -y \ python3-pip \ build-essential \ patchelf \ - curl \ - git \ libgl1-mesa-dev \ libgl1-mesa-glx \ libglew-dev \ @@ -25,21 +22,6 @@ RUN apt-get update -q \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* -# RUN ln -s /usr/bin/python3 /usr/bin/python -# installing mujoco distr -RUN mkdir -p /root/.mujoco \ - && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz \ - && tar -xf mujoco.tar.gz -C /root/.mujoco \ - && rm mujoco.tar.gz -ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH} - -# installing dependencies, optional mujoco_py compilation -COPY requirements.txt requirements.txt -RUN pip install --pre torch --force-reinstall --index-url https://download.pytorch.org/whl/nightly/cu117 -RUN pip install -r requirements.txt - -RUN python3 -c "import mujoco_py" - ### NetHack dependencies RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub RUN apt-get update && apt-get install -yq \ @@ -65,5 +47,6 @@ RUN apt-get update && apt-get install -yq \ kitware-archive-keyring COPY . /opt/nle -# Install package -RUN pip install nle \ No newline at end of file +COPY requirements.txt requirements.txt +RUN pip3 install --upgrade pip setuptools wheel +RUN pip3 install -r requirements.txt diff --git a/algorithms/bc_chaotic_lstm.py b/algorithms/bc_chaotic_lstm.py deleted file mode 100644 index 36be927..0000000 --- a/algorithms/bc_chaotic_lstm.py +++ /dev/null @@ -1,353 +0,0 @@ -""" -Key differneces or uncertanties to the original implementation: - 1. Dones are not used for masking out the rnn_state - 2. (?) Actions are argmaxed, not sampled (not sure yet how it's done in the original implementation) -""" -import pyrallis -from dataclasses import dataclass, asdict - -import time -import random -import wandb -import sys -import os -import uuid -import torch -import torch.nn as nn -from torch.utils.data import DataLoader - -from tqdm.auto import tqdm, trange -from collections import defaultdict -from torch.distributions import Categorical -import numpy as np - -from typing import Optional, Tuple -from katakomba.datasets.sa_chaotic_autoascend import SAChaoticAutoAscendTTYDataset -from katakomba.tasks import make_task_builder -from katakomba.utils.roles import Alignment, Race, Role, Sex -from katakomba.nn.chaotic_dwarf import TopLineEncoder, BottomLinesEncoder, ScreenEncoder -from katakomba.utils.render import SCREEN_SHAPE - -torch.backends.cudnn.benchmark = True - -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - - -class timeit: - def __enter__(self): - self.start_gpu = torch.cuda.Event(enable_timing=True) - self.end_gpu = torch.cuda.Event(enable_timing=True) - self.start_cpu = time.time() - self.start_gpu.record() - return self - - def __exit__(self, type, value, traceback): - self.end_gpu.record() - torch.cuda.synchronize() - self.elapsed_time_gpu = self.start_gpu.elapsed_time(self.end_gpu) / 1000 - self.elapsed_time_cpu = time.time() - self.start_cpu - - -@dataclass -class TrainConfig: - env: str = "NetHackScore-v0-ttyimg-bot-v0" - data_path: str = "data/nle_data" - db_path: str = "ttyrecs.db" - # Wandb logging - project: str = "NeuralNetHack" - group: str = "ChaoticDwarfen-BC" - name: str = "ChaoticDwarfen-BC" - version: str = "v0" - # Model - rnn_hidden_dim: int = 512 - rnn_layers: int = 1 - # Training - update_steps: int = 180_000 - batch_size: int = 256 - seq_len: int = 32 - n_workers: int = 16 - learning_rate: float = 0.0001 - clip_grad_norm: Optional[float] = 4.0 - checkpoints_path: Optional[str] = None - eval_every: int = 10_000 - eval_episodes_per_seed: int = 1 - eval_seeds: int = 50 - train_seed: int = 42 - use_prev_action: bool = True - - def __post_init__(self): - self.group = f"{self.group}-{self.env}-{self.version}" - self.name = f"{self.name}-{str(uuid.uuid4())[:8]}" - if self.checkpoints_path is not None: - self.checkpoints_path = os.path.join( - self.checkpoints_path, self.group, self.name - ) - - -def set_seed(seed: int): - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - random.seed(seed) - torch.manual_seed(seed) - - -class BC(nn.Module): - def __init__(self, action_dim: int, rnn_hidden_dim: int = 512, rnn_layers: int = 1, use_prev_action: bool = True): - super().__init__() - # Action dimensions and prev actions - self.num_actions = action_dim - self.use_prev_action = use_prev_action - self.prev_actions_dim = self.num_actions if self.use_prev_action else 0 - - # Encoders - self.topline_encoder = TopLineEncoder() - self.bottomline_encoder = torch.jit.script(BottomLinesEncoder()) - - screen_shape = (SCREEN_SHAPE[1], SCREEN_SHAPE[2]) - self.screen_encoder = torch.jit.script(ScreenEncoder(screen_shape)) - - self.h_dim = sum( - [ - self.topline_encoder.hidden_dim, - self.bottomline_encoder.hidden_dim, - self.screen_encoder.hidden_dim, - self.prev_actions_dim, - ] - ) - # Policy - self.rnn = nn.LSTM(self.h_dim, rnn_hidden_dim, num_layers=rnn_layers, batch_first=True) - self.head = nn.Linear(rnn_hidden_dim, self.num_actions) - - def forward(self, inputs, state=None): - B, T, C, H, W = inputs["screen_image"].shape - topline = inputs["tty_chars"][..., 0, :] - bottom_line = inputs["tty_chars"][..., -2:, :] - - encoded_state = [ - self.topline_encoder( - topline.float(memory_format=torch.contiguous_format).view(T * B, -1) - ), - self.bottomline_encoder( - bottom_line.float(memory_format=torch.contiguous_format).view(T * B, -1) - ), - self.screen_encoder( - inputs["screen_image"] - .float(memory_format=torch.contiguous_format) - .view(T * B, C, H, W) - ), - ] - if self.use_prev_action: - encoded_state.append( - torch.nn.functional.one_hot( - inputs["prev_actions"], self.num_actions - ).view(T * B, -1) - ) - - encoded_state = torch.cat(encoded_state, dim=1) - core_output, new_state = self.rnn(encoded_state.view(B, T, -1), state) - policy_logits = self.head(core_output) - - return policy_logits, new_state - - @torch.no_grad() - def act(self, obs, state=None, device="cpu"): - inputs = { - "tty_chars": torch.tensor( - obs["tty_chars"][np.newaxis, np.newaxis, ...], device=device - ), - "tty_colors": torch.tensor( - obs["tty_colors"][np.newaxis, np.newaxis, ...], device=device - ), - "screen_image": torch.tensor( - obs["screen_image"][np.newaxis, np.newaxis, ...], device=device - ), - "prev_actions": torch.tensor( - np.array([obs["prev_actions"]]).reshape(1, 1), - dtype=torch.long, - device=device, - ), - } - logits, new_state = self(inputs, state) - # action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) - return torch.argmax(logits).cpu().item(), new_state - - -@torch.no_grad() -def evaluate( - env_builder, actor: BC, episodes_per_seed: int, device="cpu" -): - actor.eval() - eval_stats = defaultdict(dict) - # TODO: we should not reset hidden state and prev_actions on evaluation, to mimic the training - for (character, env, seed) in tqdm(env_builder.evaluate()): - episodes_rewards = [] - for _ in trange(episodes_per_seed, desc="One seed evaluation", leave=False): - env.seed(seed, reseed=False) - - obs, done, episode_reward = env.reset(), False, 0.0 - rnn_state = None - obs["prev_actions"] = 0 - - while not done: - action, rnn_state = actor.act(obs, rnn_state, device=device) - obs, reward, done, _ = env.step(action) - episode_reward += reward - obs["prev_actions"] = action - - episodes_rewards.append(episode_reward) - - eval_stats[character][seed] = np.mean(episodes_rewards) - - # for each character also log mean across all seeds - for character in eval_stats.keys(): - eval_stats[character]["mean_return"] = np.mean( - list(eval_stats[character].values()) - ) - - actor.train() - return eval_stats - - -@pyrallis.wrap() -def train(config: TrainConfig): - print(f"Device: {DEVICE}") - saved_config = asdict(config) - saved_config["mlc_job_name"] = os.environ.get("PLATFORM_JOB_NAME") - wandb.init( - config=saved_config, - project=config.project, - group=config.group, - name=config.name, - id=str(uuid.uuid4()), - save_code=True, - ) - if config.checkpoints_path is not None: - print(f"Checkpoints path: {config.checkpoints_path}") - os.makedirs(config.checkpoints_path, exist_ok=True) - with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: - pyrallis.dump(config, f) - - set_seed(config.train_seed) - - env_builder, dataset_builder = make_task_builder( - config.env, data_path=config.data_path, db_path=config.db_path - ) - env_builder = ( - env_builder.roles([Role.MONK]) - .races([Race.HUMAN]) - .alignments([Alignment.NEUTRAL]) - .sex([Sex.MALE]) - .eval_seeds(list(range(config.eval_seeds))) - ) - - dataset_builder = dataset_builder.roles([Role.MONK]).races([Race.HUMAN]) - dataset = dataset_builder.build( - batch_size=config.batch_size, - seq_len=config.seq_len, - n_workers=config.n_workers, - auto_ascend_cls=SAChaoticAutoAscendTTYDataset, - ) - - actor = BC( - action_dim=env_builder.get_action_dim(), - use_prev_action=config.use_prev_action, - rnn_hidden_dim=config.rnn_hidden_dim, - rnn_layers=config.rnn_layers, - ).to(DEVICE) - optim = torch.optim.Adam(actor.parameters(), lr=config.learning_rate) - print("Number of parameters:", sum(p.numel() for p in actor.parameters())) - - loader = DataLoader( - dataset=dataset, - # Disable automatic batching - batch_sampler=None, - batch_size=None, - pin_memory=True, - ) - scaler = torch.cuda.amp.GradScaler() - - prev_actions = torch.zeros((config.batch_size, 1), dtype=torch.long, device=DEVICE) - rnn_state = None - - loader_iter = iter(loader) - for step in trange(config.update_steps, desc="Training"): - with timeit() as timer: - tty_chars, tty_colors, tty_cursor, screen_image, actions = ( - a.to(DEVICE) for a in next(loader_iter) - ) - actions = actions.long() - - wandb.log( - { - "times/batch_loading_cpu": timer.elapsed_time_cpu, - "times/batch_loading_gpu": timer.elapsed_time_gpu, - }, - step=step, - ) - - with timeit() as timer: - with torch.cuda.amp.autocast(): - logits, rnn_state = actor( - inputs={ - "tty_chars": tty_chars, - "tty_colors": tty_colors, - "screen_image": screen_image, - "prev_actions": torch.cat( - [prev_actions.long(), actions[:, :-1]], dim=1 - ), - }, - state=rnn_state, - ) - rnn_state = [a.detach() for a in rnn_state] - - dist = Categorical(logits=logits) - loss = -dist.log_prob(actions).mean() - # update prev_actions for next iteration - prev_actions = actions[:, -1].unsqueeze(-1) - - wandb.log({"times/forward_pass": timer.elapsed_time_gpu}, step=step) - - with timeit() as timer: - scaler.scale(loss).backward() - - if config.clip_grad_norm is not None: - scaler.unscale_(optim) - torch.nn.utils.clip_grad_norm_( - actor.parameters(), config.clip_grad_norm - ) - - scaler.step(optim) - scaler.update() - optim.zero_grad(set_to_none=True) - - wandb.log({"times/backward_pass": timer.elapsed_time_gpu}, step=step) - - wandb.log( - { - "loss": loss.detach().item(), - "transitions": config.batch_size * config.seq_len * step, - }, - step=step, - ) - - if (step + 1) % config.eval_every == 0: - eval_stats = evaluate( - env_builder, actor, config.eval_episodes_per_seed, device=DEVICE - ) - wandb.log( - dict( - eval_stats, - **{"transitions": config.batch_size * config.seq_len * step}, - ), - step=step, - ) - - if config.checkpoints_path is not None: - torch.save( - actor.state_dict(), - os.path.join(config.checkpoints_path, f"{step}.pt"), - ) - - -if __name__ == "__main__": - train() diff --git a/algorithms/bc_resnet_lstm.py b/algorithms/bc_resnet_lstm.py deleted file mode 100644 index fe7f133..0000000 --- a/algorithms/bc_resnet_lstm.py +++ /dev/null @@ -1,286 +0,0 @@ -import pyrallis -from dataclasses import dataclass, asdict - -import time -import random -import wandb -import sys -import os -import uuid -import torch -import torch.nn as nn -from torch.utils.data import DataLoader - -from tqdm.auto import tqdm, trange -from collections import defaultdict -from torch.distributions import Categorical -import numpy as np - -from typing import Optional -from katakomba.datasets.sa_autoascend import SAAutoAscendTTYDataset -from katakomba.tasks import make_task_builder -from katakomba.utils.roles import Alignment, Race, Role, Sex -from katakomba.nn.resnet import ResNet11, ResNet20, ResNet38, ResNet56, ResNet110 - -torch.backends.cudnn.benchmark = True - -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - - -# TODO: -# 1. implement filtering of params for the weight decay groups -# 2. oncycle rl scheduler -# 3. ... -# 4. label smoothing for cross entropy loss -class timeit: - def __enter__(self): - self.start_gpu = torch.cuda.Event(enable_timing=True) - self.end_gpu = torch.cuda.Event(enable_timing=True) - self.start_cpu = time.time() - self.start_gpu.record() - return self - - def __exit__(self, type, value, traceback): - self.end_gpu.record() - torch.cuda.synchronize() - self.elapsed_time_gpu = self.start_gpu.elapsed_time(self.end_gpu) / 1000 - self.elapsed_time_cpu = time.time() - self.start_cpu - - -@dataclass -class TrainConfig: - env: str = "NetHackScore-v0-tty-bot-v0" - data_path: str = "data/nle_data" - db_path: str = "ttyrecs.db" - # Wandb logging - project: str = "NeuralNetHack" - group: str = "DummyBC" - name: str = "DummyBC" - version: str = "v0" - # Model - resnet_type: str = "ResNet11" - lstm_layers: int = 2 - hidden_dim: int = 1024 - width_k: int = 1 - # Training - update_steps: int = 180_000 - batch_size: int = 256 - seq_len: int = 32 - n_workers: int = 16 - learning_rate: float = 3e-4 - clip_grad_norm: Optional[float] = None - checkpoints_path: Optional[str] = None - eval_every: int = 10_000 - eval_episodes_per_seed: int = 1 - eval_seeds: int = 50 - train_seed: int = 42 - - def __post_init__(self): - self.group = f"{self.group}-{self.env}-{self.version}" - self.name = f"{self.name}-{str(uuid.uuid4())[:8]}" - if self.checkpoints_path is not None: - self.checkpoints_path = os.path.join( - self.checkpoints_path, self.group, self.name - ) - - -def set_seed(seed: int): - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - random.seed(seed) - torch.manual_seed(seed) - - -class Actor(nn.Module): - def __init__(self, action_dim, hidden_dim, lstm_layers, width_k, resnet_type): - super().__init__() - resnet = getattr(sys.modules[__name__], resnet_type) - self.state_encoder = resnet(img_channels=2, out_dim=hidden_dim, k=width_k) - self.norm = nn.LayerNorm(hidden_dim) - self.rnn = nn.LSTM( - input_size=hidden_dim, - hidden_size=hidden_dim, - num_layers=lstm_layers, - batch_first=True, - ) - # TODO: ortho/chrono init for the lstm - self.head = nn.Linear(hidden_dim, action_dim) - - def forward(self, obs, state=None): - # [batch_size, seq_len, ...] - batch_size, seq_len, *_ = obs.shape - - out = self.state_encoder(obs.flatten(0, 1)).view(batch_size, seq_len, -1) - out, new_state = self.rnn(self.norm(out), state) - logits = self.head(out) - - return logits, new_state - - @torch.no_grad() - def act(self, obs, state=None, device="cpu"): - assert obs.ndim == 3, "act only for single obs" - obs = torch.tensor(obs, device=device, dtype=torch.float32).permute(2, 0, 1) - logits, new_state = self(obs[None, None, ...], state) - return torch.argmax(logits).cpu().item(), new_state - - -@torch.no_grad() -def evaluate(env_builder, actor, episodes_per_seed, device="cpu"): - actor.eval() - eval_stats = defaultdict(dict) - - for (character, env, seed) in tqdm(env_builder.evaluate()): - episodes_rewards = [] - for _ in trange(episodes_per_seed, desc="One seed evaluation", leave=False): - env.seed(seed, reseed=False) - - obs, done, episode_reward = env.reset(), False, 0.0 - rnn_state = None - - while not done: - action, rnn_state = actor.act(obs, rnn_state, device=device) - obs, reward, done, _ = env.step(action) - episode_reward += reward - episodes_rewards.append(episode_reward) - - eval_stats[character][seed] = np.mean(episodes_rewards) - - # for each character also log mean across all seeds - for character in eval_stats.keys(): - eval_stats[character]["mean_return"] = np.mean( - list(eval_stats[character].values()) - ) - - actor.train() - return eval_stats - - -@pyrallis.wrap() -def train(config: TrainConfig): - print(f"Device: {DEVICE}") - saved_config = asdict(config) - saved_config["mlc_job_name"] = os.environ.get("PLATFORM_JOB_NAME") - wandb.init( - config=saved_config, - project=config.project, - group=config.group, - name=config.name, - id=str(uuid.uuid4()), - save_code=True, - ) - if config.checkpoints_path is not None: - print(f"Checkpoints path: {config.checkpoints_path}") - os.makedirs(config.checkpoints_path, exist_ok=True) - with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: - pyrallis.dump(config, f) - - set_seed(config.train_seed) - env_builder, dataset_builder = make_task_builder( - config.env, data_path=config.data_path, db_path=config.db_path - ) - env_builder = ( - env_builder.roles([Role.MONK]) - .races([Race.HUMAN]) - .alignments([Alignment.NEUTRAL]) - .sex([Sex.MALE]) - .eval_seeds(list(range(config.eval_seeds))) - ) - dataset = dataset_builder.build( - batch_size=config.batch_size, - seq_len=config.seq_len, - n_workers=config.n_workers, - auto_ascend_cls=SAAutoAscendTTYDataset, - ) - actor = Actor( - resnet_type=config.resnet_type, - action_dim=env_builder.get_action_dim(), - hidden_dim=config.hidden_dim, - lstm_layers=config.lstm_layers, - width_k=config.width_k, - ).to(DEVICE) - print("Number of parameters:", sum(p.numel() for p in actor.parameters())) - # ONLY FOR MLC/TRS - # actor = torch.compile(actor, mode="reduce-overhead") - optim = torch.optim.AdamW(actor.parameters(), lr=config.learning_rate) - - loader = DataLoader( - dataset=dataset, - # Disable automatic batching - batch_sampler=None, - batch_size=None, - pin_memory=True, - ) - scaler = torch.cuda.amp.GradScaler() - - rnn_state = None - loader_iter = iter(loader) - for step in trange(config.update_steps, desc="Training"): - with timeit() as timer: - tty_chars, tty_colors, tty_cursor, actions = next(loader_iter) - - wandb.log( - { - "times/batch_loading_cpu": timer.elapsed_time_cpu, - "times/batch_loading_gpu": timer.elapsed_time_gpu, - }, - step=step, - ) - - with timeit() as timer: - with torch.cuda.amp.autocast(): - states = torch.stack([tty_chars, tty_colors], axis=-1) - logits, rnn_state = actor( - states.permute(0, 1, 4, 2, 3).to(DEVICE).to(torch.float32), - state=rnn_state, - ) - rnn_state = [a.detach() for a in rnn_state] - - dist = Categorical(logits=logits) - loss = -dist.log_prob(actions.to(DEVICE)).mean() - - wandb.log({"times/forward_pass": timer.elapsed_time_gpu}, step=step) - - with timeit() as timer: - scaler.scale(loss).backward() - # loss.backward() - if config.clip_grad_norm is not None: - scaler.unscale_(optim) - torch.nn.utils.clip_grad_norm_( - actor.parameters(), config.clip_grad_norm - ) - # optim.step() - scaler.step(optim) - scaler.update() - optim.zero_grad(set_to_none=True) - - wandb.log({"times/backward_pass": timer.elapsed_time_gpu}, step=step) - - wandb.log( - { - "loss": loss.detach().item(), - "transitions": config.batch_size * config.seq_len * step, - }, - step=step, - ) - - if (step + 1) % config.eval_every == 0: - eval_stats = evaluate( - env_builder, actor, config.eval_episodes_per_seed, device=DEVICE - ) - wandb.log( - dict( - eval_stats, - **{"transitions": config.batch_size * config.seq_len * step}, - ), - step=step, - ) - - if config.checkpoints_path is not None: - torch.save( - actor.state_dict(), - os.path.join(config.checkpoints_path, f"{step}.pt"), - ) - - -if __name__ == "__main__": - train() diff --git a/algorithms/rem_chaotic_lstm.py b/algorithms/rem_chaotic_lstm.py deleted file mode 100644 index df9b195..0000000 --- a/algorithms/rem_chaotic_lstm.py +++ /dev/null @@ -1,362 +0,0 @@ -import pyrallis -from dataclasses import dataclass, asdict - -import time -import random -import wandb -import sys -import os -import uuid -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DataLoader - -from tqdm.auto import tqdm, trange -from collections import defaultdict -import numpy as np - -from copy import deepcopy -from typing import Optional -from katakomba.tasks import make_task_builder -from katakomba.utils.roles import Alignment, Race, Role, Sex -from katakomba.datasets import SARSChaoticAutoAscendTTYDataset -from katakomba.nn.chaotic_dwarf import TopLineEncoder, BottomLinesEncoder, ScreenEncoder -from katakomba.utils.render import SCREEN_SHAPE - -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - - -@dataclass -class TrainConfig: - env: str = "NetHackScore-v0-ttyimg-bot-v0" - data_path: str = "data/nle_data" - db_path: str = "ttyrecs.db" - # Wandb logging - project: str = "NeuralNetHack" - group: str = "REM" - name: str = "REM" - version: str = "v0" - # Model - use_prev_action: bool = True - rnn_layers: int = 1 - rnn_hidden_dim: int = 512 - tau: float = 5e-3 - gamma: float = 0.99 - num_heads: int = 50 - # Training - update_steps: int = 180000 - batch_size: int = 256 - seq_len: int = 32 - n_workers: int = 8 - learning_rate: float = 3e-4 - clip_grad_norm: Optional[float] = 4.0 - checkpoints_path: Optional[str] = None - eval_every: int = 10_000 - eval_episodes_per_seed: int = 1 - eval_seeds: int = 50 - train_seed: int = 42 - - def __post_init__(self): - self.group = f"{self.group}-{self.env}-{self.version}" - self.name = f"{self.name}-{str(uuid.uuid4())[:8]}" - if self.checkpoints_path is not None: - self.checkpoints_path = os.path.join( - self.checkpoints_path, self.group, self.name - ) - - -def set_seed(seed: int): - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - random.seed(seed) - torch.manual_seed(seed) - - -def soft_update(target, source, tau): - for tp, sp in zip(target.parameters(), source.parameters()): - tp.data.copy_((1 - tau) * tp.data + tau * sp.data) - - -def sample_convex_combination(size, device="cpu"): - weights = torch.rand(size, device=device) - weights = weights / weights.sum() - assert torch.isclose(weights.sum(), torch.tensor([1.0], device=device)) - return weights - - -# def symlog(x): -# return torch.sign(x) * torch.log(torch.abs(x) + 1) -# -# -# def symexp(x): -# return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) - - -def rem_dqn_loss( - critic, - target_critic, - obs, - actions, - rewards, - next_obs, - dones, - rnn_states, - target_rnn_states, - convex_comb_weights, - gamma, -): - with torch.no_grad(): - next_q_values, next_target_rnn_states = target_critic(next_obs, state=target_rnn_states) - next_q_values = (next_q_values * convex_comb_weights).sum(2) - next_q_values = next_q_values.max(dim=-1).values - - assert next_q_values.shape == rewards.shape == dones.shape - q_target = rewards + gamma * (1 - dones) * next_q_values - # q_target = symlog(rewards + gamma * (1 - dones) * symexp(next_q_values)) - - assert actions.dim() == 2 - q_pred, next_rnn_states = critic(obs, state=rnn_states) - q_pred = (q_pred * convex_comb_weights.detach()).sum(2) - q_pred = q_pred.gather(-1, actions.to(torch.long).unsqueeze(-1)).squeeze() - assert q_pred.shape == q_target.shape - - loss = F.mse_loss(q_pred, q_target) - loss_info = { - "loss": loss.item(), - "q_target": q_target.mean().item() - } - return loss, next_rnn_states, next_target_rnn_states, loss_info - - -class Critic(nn.Module): - def __init__(self, action_dim, rnn_hidden_dim, rnn_layers, num_heads, use_prev_action=True): - super().__init__() - self.num_heads = num_heads - self.num_actions = action_dim - self.use_prev_action = use_prev_action - self.prev_actions_dim = self.num_actions if self.use_prev_action else 0 - - # Encoders - self.topline_encoder = torch.jit.script(TopLineEncoder()) - self.bottomline_encoder = torch.jit.script(BottomLinesEncoder()) - - screen_shape = (SCREEN_SHAPE[1], SCREEN_SHAPE[2]) - self.screen_encoder = torch.jit.script(ScreenEncoder(screen_shape)) - - self.h_dim = sum([ - self.topline_encoder.hidden_dim, - self.bottomline_encoder.hidden_dim, - self.screen_encoder.hidden_dim, - self.prev_actions_dim, - ]) - # Policy - self.rnn = nn.LSTM(self.h_dim, rnn_hidden_dim, num_layers=rnn_layers, batch_first=True) - self.head = nn.Linear(rnn_hidden_dim, self.num_actions * num_heads) - - def forward(self, inputs, state=None): - # [batch_size, seq_len, ...] - B, T, C, H, W = inputs["screen_image"].shape - topline = inputs["tty_chars"][..., 0, :] - bottom_line = inputs["tty_chars"][..., -2:, :] - - encoded_state = [ - self.topline_encoder( - topline.float(memory_format=torch.contiguous_format).view(T * B, -1) - ), - self.bottomline_encoder( - bottom_line.float(memory_format=torch.contiguous_format).view(T * B, -1) - ), - self.screen_encoder( - inputs["screen_image"] - .float(memory_format=torch.contiguous_format) - .view(T * B, C, H, W) - ), - ] - if self.use_prev_action: - encoded_state.append( - torch.nn.functional.one_hot( - inputs["prev_actions"], self.num_actions - ).view(T * B, -1) - ) - encoded_state = torch.cat(encoded_state, dim=1) - core_output, new_state = self.rnn(encoded_state.view(B, T, -1), state) - q_values_ensemble = self.head(core_output).view(B, T, self.num_heads, self.num_actions) - return q_values_ensemble, new_state - - @torch.no_grad() - def act(self, obs, state=None, device="cpu"): - inputs = { - "screen_image": torch.tensor(obs["screen_image"], device=device)[None, None, ...], - "tty_chars": torch.tensor(obs["tty_chars"], device=device)[None, None, ...], - "prev_actions": torch.tensor(obs["prev_actions"], dtype=torch.long, device=device)[None, None, ...], - } - q_values_ensemble, new_state = self(inputs, state) - # mean q value over all heads - q_values = q_values_ensemble.mean(2) - return torch.argmax(q_values).cpu().item(), new_state - - -@torch.no_grad() -def evaluate(env_builder, actor: Critic, episodes_per_seed: int, device="cpu"): - actor.eval() - eval_stats = defaultdict(dict) - # TODO: we should not reset hidden state and prev_actions on evaluation, to mimic the training - for (character, env, seed) in tqdm(env_builder.evaluate()): - episodes_rewards = [] - for _ in trange(episodes_per_seed, desc="One seed evaluation", leave=False): - env.seed(seed, reseed=False) - - obs, done, episode_reward = env.reset(), False, 0.0 - rnn_state = None - obs["prev_actions"] = 0 - - while not done: - action, rnn_state = actor.act(obs, rnn_state, device=device) - obs, reward, done, _ = env.step(action) - episode_reward += reward - obs["prev_actions"] = action - - episodes_rewards.append(episode_reward) - - eval_stats[character][seed] = np.mean(episodes_rewards) - - # for each character also log mean across all seeds - for character in eval_stats.keys(): - eval_stats[character]["mean_return"] = np.mean(list(eval_stats[character].values())) - - actor.train() - return eval_stats - - -@pyrallis.wrap() -def train(config: TrainConfig): - print(f"Device: {DEVICE}") - saved_config = asdict(config) - saved_config["mlc_job_name"] = os.environ.get("PLATFORM_JOB_NAME") - wandb.init( - config=saved_config, - project=config.project, - group=config.group, - name=config.name, - id=str(uuid.uuid4()), - save_code=True, - ) - if config.checkpoints_path is not None: - print(f"Checkpoints path: {config.checkpoints_path}") - os.makedirs(config.checkpoints_path, exist_ok=True) - with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: - pyrallis.dump(config, f) - - set_seed(config.train_seed) - - env_builder, dataset_builder = make_task_builder( - config.env, data_path=config.data_path, db_path=config.db_path - ) - env_builder = ( - env_builder.roles([Role.MONK]) - .races([Race.HUMAN]) - .alignments([Alignment.NEUTRAL]) - .sex([Sex.MALE]) - .eval_seeds(list(range(config.eval_seeds))) - ) - - dataset_builder = dataset_builder.roles([Role.MONK]).races([Race.HUMAN]) - dataset = dataset_builder.build( - batch_size=config.batch_size, - seq_len=config.seq_len, - n_workers=config.n_workers, - auto_ascend_cls=SARSChaoticAutoAscendTTYDataset, - ) - - critic = Critic( - action_dim=env_builder.get_action_dim(), - rnn_hidden_dim=config.rnn_hidden_dim, - rnn_layers=config.rnn_layers, - num_heads=config.num_heads, - use_prev_action=config.use_prev_action, - ).to(DEVICE) - with torch.no_grad(): - target_critic = deepcopy(critic) - - optim = torch.optim.Adam(critic.parameters(), lr=config.learning_rate) - print("Number of parameters:", sum(p.numel() for p in critic.parameters())) - - loader = DataLoader( - dataset=dataset, - # Disable automatic batching - batch_sampler=None, - batch_size=None, - pin_memory=True, - ) - scaler = torch.cuda.amp.GradScaler() - - prev_actions = torch.zeros((config.batch_size, 1), dtype=torch.long, device=DEVICE) - rnn_state, target_rnn_state = None, None - - loader_iter = iter(loader) - for step in trange(config.update_steps, desc="Training"): - screen_image, tty_chars, actions, rewards, next_screen_image, next_tty_chars, dones = ( - [t.to(DEVICE) for t in next(loader_iter)] - ) - actions = actions.long() - - obs = { - "screen_image": screen_image, - "tty_chars": tty_chars, - "prev_actions": torch.cat([prev_actions, actions[:, :-1]], dim=1) - } - next_obs = { - "screen_image": next_screen_image, - "tty_chars": next_tty_chars, - "prev_actions": actions - } - convex_comb_weights = sample_convex_combination(config.num_heads, device=DEVICE).view(1, 1, -1, 1) - - loss, rnn_state, target_rnn_state, loss_info = rem_dqn_loss( - critic=critic, - target_critic=target_critic, - obs=obs, - actions=actions, - rewards=rewards, - next_obs=next_obs, - dones=dones, - rnn_states=rnn_state, - target_rnn_states=target_rnn_state, - convex_comb_weights=convex_comb_weights, - gamma=config.gamma - ) - rnn_state = [s.detach() for s in rnn_state] - target_rnn_state = [s.detach() for s in target_rnn_state] - - scaler.scale(loss).backward() - if config.clip_grad_norm is not None: - scaler.unscale_(optim) - torch.nn.utils.clip_grad_norm_(critic.parameters(), config.clip_grad_norm) - scaler.step(optim) - scaler.update() - optim.zero_grad(set_to_none=True) - - soft_update(target_critic, critic, tau=config.tau) - prev_actions = actions[:, -1].unsqueeze(-1) - - wandb.log( - dict(loss_info, **{"transitions": config.batch_size * config.seq_len * step}), - step=step, - ) - - if (step + 1) % config.eval_every == 0: - eval_stats = evaluate(env_builder, critic, config.eval_episodes_per_seed, device=DEVICE) - wandb.log( - dict(eval_stats, **{"transitions": config.batch_size * config.seq_len * step}), - step=step, - ) - if config.checkpoints_path is not None: - torch.save( - critic.state_dict(), - os.path.join(config.checkpoints_path, f"{step}.pt"), - ) - - -if __name__ == "__main__": - train() diff --git a/algorithms/small_scale/awac_chaotic_lstm.py b/algorithms/small_scale/awac_chaotic_lstm.py new file mode 100644 index 0000000..50a1888 --- /dev/null +++ b/algorithms/small_scale/awac_chaotic_lstm.py @@ -0,0 +1,489 @@ +import pyrallis +from dataclasses import dataclass, asdict + +import random +import wandb +import os +import uuid +import torch +import torch.nn as nn +import torch.nn.functional as F + +from gym.vector import AsyncVectorEnv +from concurrent.futures import ThreadPoolExecutor +from tqdm.auto import tqdm, trange +import numpy as np + +from copy import deepcopy +from typing import Optional, Dict, Tuple, Any, List + +from multiprocessing import set_start_method +from katakomba.env import NetHackChallenge, OfflineNetHackChallengeWrapper +from katakomba.nn.chaotic_dwarf import TopLineEncoder, BottomLinesEncoder, ScreenEncoder +from katakomba.utils.render import SCREEN_SHAPE, render_screen_image +from katakomba.utils.datasets import SequentialBuffer +from katakomba.utils.misc import Timeit, StatMean + +LSTM_HIDDEN = Tuple[torch.Tensor, torch.Tensor] +UPDATE_INFO = Dict[str, Any] + +torch.backends.cudnn.benchmark = True +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +@dataclass +class TrainConfig: + character: str = "mon-hum-neu" + data_mode: str = "compressed" + # Wandb logging + project: str = "NetHack" + group: str = "small_scale_awac" + name: str = "awac" + version: int = 0 + # Model + rnn_hidden_dim: int = 2048 + rnn_layers: int = 2 + use_prev_action: bool = True + rnn_dropout: float = 0.0 + clip_range: float = 10.0 + tau: float = 0.005 + gamma: float = 0.999 + temperature: float = 1.0 + # Training + update_steps: int = 500_000 + batch_size: int = 64 + seq_len: int = 16 + learning_rate: float = 3e-4 + weight_decay: float = 0.0 + clip_grad_norm: Optional[float] = None + checkpoints_path: Optional[str] = None + eval_every: int = 10_000 + eval_episodes: int = 50 + eval_processes: int = 14 + render_processes: int = 14 + eval_seed: int = 50 + train_seed: int = 42 + + def __post_init__(self): + self.group = f"{self.group}-v{str(self.version)}" + self.name = f"{self.name}-{self.character}-{str(uuid.uuid4())[:8]}" + if self.checkpoints_path is not None: + self.checkpoints_path = os.path.join(self.checkpoints_path, self.group, self.name) + + +def set_seed(seed: int): + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + + +@torch.no_grad() +def filter_wd_params(model: nn.Module) -> Tuple[List[nn.parameter.Parameter], List[nn.parameter.Parameter]]: + no_decay, decay = [], [] + for name, param in model.named_parameters(): + if hasattr(param, 'requires_grad') and not param.requires_grad: + continue + if 'weight' in name and 'norm' not in name and 'bn' not in name: + decay.append(param) + else: + no_decay.append(param) + assert len(no_decay) + len(decay) == len(list(model.parameters())) + return no_decay, decay + + +def dict_to_tensor(data: Dict[str, np.ndarray], device: str) -> Dict[str, torch.Tensor]: + return {k: torch.as_tensor(v, dtype=torch.float, device=device) for k, v in data.items()} + + +def soft_update(target: nn.Module, source: nn.Module, tau: float): + for tp, sp in zip(target.parameters(), source.parameters()): + tp.data.copy_((1 - tau) * tp.data + tau * sp.data) + + +class ActorCritic(nn.Module): + def __init__( + self, + action_dim: int, + rnn_hidden_dim: int = 512, + rnn_layers: int = 1, + rnn_dropout: float = 0.0, + use_prev_action: bool = True + ): + super().__init__() + self.num_actions = action_dim + self.use_prev_action = use_prev_action + self.prev_actions_dim = self.num_actions if self.use_prev_action else 0 + + # Encoders + self.topline_encoder = torch.jit.script(TopLineEncoder()) + self.bottomline_encoder = torch.jit.script(BottomLinesEncoder()) + + screen_shape = (SCREEN_SHAPE[1], SCREEN_SHAPE[2]) + self.screen_encoder = torch.jit.script(ScreenEncoder(screen_shape)) + + self.h_dim = sum( + [ + self.topline_encoder.hidden_dim, + self.bottomline_encoder.hidden_dim, + self.screen_encoder.hidden_dim, + self.prev_actions_dim, + ] + ) + # networks + self.rnn = nn.LSTM( + self.h_dim, + rnn_hidden_dim, + num_layers=rnn_layers, + dropout=rnn_dropout, + batch_first=True + ) + self.qf = nn.Linear(rnn_hidden_dim, self.num_actions) + self.policy = nn.Linear(rnn_hidden_dim, self.num_actions) + + def forward(self, inputs, state=None): + # [batch_size, seq_len, ...] + B, T, C, H, W = inputs["screen_image"].shape + topline = inputs["tty_chars"][..., 0, :] + bottom_line = inputs["tty_chars"][..., -2:, :] + + encoded_state = [ + self.topline_encoder( + topline.float(memory_format=torch.contiguous_format).view(T * B, -1) + ), + self.bottomline_encoder( + bottom_line.float(memory_format=torch.contiguous_format).view(T * B, -1) + ), + self.screen_encoder( + inputs["screen_image"] + .float(memory_format=torch.contiguous_format) + .view(T * B, C, H, W) + ), + ] + if self.use_prev_action: + encoded_state.append( + F.one_hot(inputs["prev_actions"], self.num_actions).view(T * B, -1) + ) + + encoded_state = torch.cat(encoded_state, dim=1) + core_output, new_state = self.rnn(encoded_state.view(B, T, -1), state) + qf = self.qf(core_output) + logits = self.policy(core_output) + + return (qf, logits), new_state + + @torch.no_grad() + def vec_act(self, obs, state=None, device="cpu"): + inputs = { + "tty_chars": torch.tensor(obs["tty_chars"][:, None], device=device), + "screen_image": torch.tensor(obs["screen_image"][:, None], device=device), + "prev_actions": torch.tensor(obs["prev_actions"][:, None], dtype=torch.long, device=device) + } + (_, logits), new_state = self(inputs, state) + actions = torch.argmax(logits.squeeze(1), dim=-1) + return actions.cpu().numpy(), new_state + + +def awac_loss( + model: ActorCritic, + target_model: ActorCritic, + obs: Dict[str, torch.Tensor], + next_obs: Dict[str, torch.Tensor], + actions: torch.Tensor, + rewards: torch.Tensor, + dones: torch.Tensor, + rnn_states: LSTM_HIDDEN, + target_rnn_states: LSTM_HIDDEN, + gamma: float, + temperature: float +) -> Tuple[torch.Tensor, LSTM_HIDDEN, LSTM_HIDDEN, UPDATE_INFO]: + # critic loss + with torch.no_grad(): + (next_q, next_logits), new_target_rnn_states = target_model(next_obs, state=target_rnn_states) + next_actions = torch.distributions.Categorical(logits=next_logits).sample() + next_q_actions = next_q.gather(-1, next_actions.to(torch.long).unsqueeze(-1)).squeeze() + + assert rewards.shape == dones.shape == next_q_actions.shape + target_q = rewards + (1 - dones) * gamma * next_q_actions + + assert actions.dim() == 2 + (q_pred, logits_pred), new_rnn_states = model(obs, state=rnn_states) + q_pred_actions = q_pred.gather(-1, actions.to(torch.long).unsqueeze(-1)).squeeze() + assert q_pred_actions.shape == target_q.shape + td_loss = F.mse_loss(q_pred_actions, target_q) + + # actor loss + with torch.no_grad(): + adv = q_pred_actions - (q_pred * F.softmax(logits_pred, dim=-1)).sum(-1) + + log_probs = torch.distributions.Categorical(logits=logits_pred).log_prob(actions) + weights = torch.exp(temperature * adv).clamp(max=100.0) + actor_loss = torch.mean(-log_probs * weights) + + loss = td_loss + actor_loss + loss_info = { + "td_loss": td_loss.item(), + "actor_loss": actor_loss.item(), + "loss": loss, + "q_target": next_q.mean().item() + } + return loss, new_rnn_states, new_target_rnn_states, loss_info + + +@torch.no_grad() +def vec_evaluate( + vec_env: AsyncVectorEnv, + actor: ActorCritic, + num_episodes: int, + seed: int = 0, + device: str = "cpu" +) -> Dict[str, np.ndarray]: + actor.eval() + # set seed for reproducibility (reseed=False by default) + vec_env.seed(seed) + # all this work is needed to mitigate bias for shorter + # episodes during vectorized evaluation, for more see: + # https://github.com/DLR-RM/stable-baselines3/issues/402 + n_envs = vec_env.num_envs + episode_rewards = [] + episode_lengths = [] + episode_depths = [] + + episode_counts = np.zeros(n_envs, dtype="int") + # Divides episodes among different sub environments in the vector as evenly as possible + episode_count_targets = np.array([(num_episodes + i) // n_envs for i in range(n_envs)], dtype="int") + + current_rewards = np.zeros(n_envs) + current_lengths = np.zeros(n_envs, dtype="int") + observations = vec_env.reset() + observations["prev_actions"] = np.zeros(n_envs, dtype=float) + + rnn_states = None + pbar = tqdm(total=num_episodes) + while (episode_counts < episode_count_targets).any(): + # faster to do this here for entire batch, than in wrappers for each env + observations["screen_image"] = render_screen_image( + tty_chars=observations["tty_chars"][:, np.newaxis, ...], + tty_colors=observations["tty_colors"][:, np.newaxis, ...], + tty_cursor=observations["tty_cursor"][:, np.newaxis, ...], + ) + observations["screen_image"] = np.squeeze(observations["screen_image"], 1) + + actions, rnn_states = actor.vec_act(observations, rnn_states, device=device) + + observations, rewards, dones, infos = vec_env.step(actions) + observations["prev_actions"] = actions + + current_rewards += rewards + current_lengths += 1 + + for i in range(n_envs): + if episode_counts[i] < episode_count_targets[i]: + if dones[i]: + episode_rewards.append(current_rewards[i]) + episode_lengths.append(current_lengths[i]) + episode_depths.append(infos[i]["current_depth"]) + episode_counts[i] += 1 + pbar.update(1) + + current_rewards[i] = 0 + current_lengths[i] = 0 + + pbar.close() + result = { + "reward_median": np.median(episode_rewards), + "reward_mean": np.mean(episode_rewards), + "reward_std": np.std(episode_rewards), + "reward_min": np.min(episode_rewards), + "reward_max": np.max(episode_rewards), + "reward_raw": np.array(episode_rewards), + # depth + "depth_median": np.median(episode_depths), + "depth_mean": np.mean(episode_depths), + "depth_std": np.std(episode_depths), + "depth_min": np.min(episode_depths), + "depth_max": np.max(episode_depths), + "depth_raw": np.array(episode_depths), + } + actor.train() + return result + + +@pyrallis.wrap() +def train(config: TrainConfig): + print(f"Device: {DEVICE}") + wandb.init( + config=asdict(config), + project=config.project, + group=config.group, + name=config.name, + id=str(uuid.uuid4()), + save_code=True, + ) + if config.checkpoints_path is not None: + print(f"Checkpoints path: {config.checkpoints_path}") + os.makedirs(config.checkpoints_path, exist_ok=True) + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: + pyrallis.dump(config, f) + + set_seed(config.train_seed) + + def env_fn(): + env = NetHackChallenge( + character=config.character, + observation_keys=["tty_chars", "tty_colors", "tty_cursor"] + ) + env = OfflineNetHackChallengeWrapper(env) + return env + + tmp_env = env_fn() + eval_env = AsyncVectorEnv( + env_fns=[env_fn for _ in range(config.eval_processes)], + copy=False + ) + buffer = SequentialBuffer( + dataset=tmp_env.get_dataset(mode=config.data_mode, scale="small"), + seq_len=config.seq_len, + batch_size=config.batch_size, + seed=config.train_seed, + add_next_step=True # true as this is needed for next_obs + ) + tp = ThreadPoolExecutor(max_workers=config.render_processes) + + model = ActorCritic( + action_dim=eval_env.single_action_space.n, + use_prev_action=config.use_prev_action, + rnn_hidden_dim=config.rnn_hidden_dim, + rnn_layers=config.rnn_layers, + rnn_dropout=config.rnn_dropout, + ).to(DEVICE) + with torch.no_grad(): + target_model = deepcopy(model) + + no_decay_params, decay_params = filter_wd_params(model) + optim = torch.optim.AdamW([ + {"params": no_decay_params, "weight_decay": 0.0}, + {"params": decay_params, "weight_decay": config.weight_decay} + ], lr=config.learning_rate) + print("Number of parameters:", sum(p.numel() for p in model.parameters())) + + scaler = torch.cuda.amp.GradScaler() + rnn_state, target_rnn_state = None, None + prev_actions = torch.zeros((config.batch_size, 1), dtype=torch.long, device=DEVICE) + + # For reward normalization + reward_stats = StatMean(cumulative=True) + running_rewards = 0.0 + + for step in trange(1, config.update_steps + 1, desc="Training"): + with Timeit() as timer: + batch = buffer.sample() + screen_image = render_screen_image( + tty_chars=batch["tty_chars"], + tty_colors=batch["tty_colors"], + tty_cursor=batch["tty_cursor"], + threadpool=tp, + ) + batch["screen_image"] = screen_image + + # Update reward statistics (as in the original nle implementation) + running_rewards *= config.gamma + running_rewards += batch["rewards"] + reward_stats += running_rewards ** 2 + running_rewards *= (~batch["dones"]).astype(float) + # Normalize the reward + reward_std = reward_stats.mean() ** 0.5 + batch["rewards"] = batch["rewards"] / max(0.01, reward_std) + batch["rewards"] = np.clip(batch["rewards"], -config.clip_range, config.clip_range) + + batch = dict_to_tensor(batch, device=DEVICE) + + wandb.log( + { + "times/batch_loading_cpu": timer.elapsed_time_cpu, + "times/batch_loading_gpu": timer.elapsed_time_gpu, + }, + step=step, + ) + + with Timeit() as timer: + with torch.cuda.amp.autocast(): + obs = { + "screen_image": batch["screen_image"][:, :-1].contiguous(), + "tty_chars": batch["tty_chars"][:, :-1].contiguous(), + "prev_actions": torch.cat([prev_actions.long(), batch["actions"][:, :-2].long()], dim=1) + } + next_obs = { + "screen_image": batch["screen_image"][:, 1:].contiguous(), + "tty_chars": batch["tty_chars"][:, 1:].contiguous(), + "prev_actions": batch["actions"][:, :-1].long() + } + loss, rnn_state, target_rnn_state, loss_info = awac_loss( + model=model, + target_model=target_model, + obs=obs, + next_obs=next_obs, + actions=batch["actions"][:, :-1], + rewards=batch["rewards"][:, :-1], + dones=batch["dones"][:, :-1], + rnn_states=rnn_state, + target_rnn_states=target_rnn_state, + temperature=config.temperature, + gamma=config.gamma + ) + # detaching rnn hidden states for the next iteration + rnn_state = [a.detach() for a in rnn_state] + target_rnn_state = [a.detach() for a in target_rnn_state] + + # update prev_actions for next iteration (-1 is seq_len + 1, so -2) + prev_actions = batch["actions"][:, -2].unsqueeze(-1) + + wandb.log({"times/forward_pass": timer.elapsed_time_gpu}, step=step) + + with Timeit() as timer: + scaler.scale(loss).backward() + if config.clip_grad_norm is not None: + scaler.unscale_(optim) + torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad_norm) + scaler.step(optim) + scaler.update() + optim.zero_grad(set_to_none=True) + soft_update(target_model, model, tau=config.tau) + + wandb.log({"times/backward_pass": timer.elapsed_time_gpu}, step=step) + wandb.log({"transitions": config.batch_size * config.seq_len * step, **loss_info}, step=step) + + if step % config.eval_every == 0: + with Timeit() as timer: + eval_stats = vec_evaluate( + eval_env, model, config.eval_episodes, config.eval_seed, device=DEVICE + ) + raw_returns = eval_stats.pop("reward_raw") + raw_depths = eval_stats.pop("depth_raw") + normalized_scores = tmp_env.get_normalized_score(raw_returns) + + wandb.log({ + "times/evaluation_gpu": timer.elapsed_time_gpu, + "times/evaluation_cpu": timer.elapsed_time_cpu, + }, step=step) + wandb.log({"transitions": config.batch_size * config.seq_len * step, **eval_stats}, step=step) + + if config.checkpoints_path is not None: + torch.save(model.state_dict(), os.path.join(config.checkpoints_path, f"{step}.pt")) + # saving raw logs + np.save(os.path.join(config.checkpoints_path, f"{step}_returns.npy"), raw_returns) + np.save(os.path.join(config.checkpoints_path, f"{step}_depths.npy"), raw_depths) + np.save(os.path.join(config.checkpoints_path, f"{step}_normalized_scores.npy"), normalized_scores) + + # also saving to wandb files for easier use in the future + np.save(os.path.join(wandb.run.dir, f"{step}_returns.npy"), raw_returns) + np.save(os.path.join(wandb.run.dir, f"{step}_depths.npy"), raw_depths) + np.save(os.path.join(wandb.run.dir, f"{step}_normalized_scores.npy"), normalized_scores) + + buffer.close() + + +if __name__ == "__main__": + set_start_method("spawn") + train() + diff --git a/algorithms/small_scale/bc_chaotic_lstm.py b/algorithms/small_scale/bc_chaotic_lstm.py new file mode 100644 index 0000000..c94c6f5 --- /dev/null +++ b/algorithms/small_scale/bc_chaotic_lstm.py @@ -0,0 +1,405 @@ +import pyrallis +from dataclasses import dataclass, asdict +import random +import wandb +import os +import uuid +import torch +import torch.nn as nn + +from gym.vector import AsyncVectorEnv +from concurrent.futures import ThreadPoolExecutor +import torch.nn.functional as F +from tqdm.auto import tqdm, trange +from torch.distributions import Categorical +import numpy as np + +from multiprocessing import set_start_method +from katakomba.env import NetHackChallenge, OfflineNetHackChallengeWrapper +from katakomba.nn.chaotic_dwarf import TopLineEncoder, BottomLinesEncoder, ScreenEncoder +from katakomba.utils.render import SCREEN_SHAPE, render_screen_image +from katakomba.utils.datasets import SequentialBuffer +from katakomba.utils.misc import Timeit +from typing import Optional, Tuple, List, Dict + +torch.backends.cudnn.benchmark = True +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +@dataclass +class TrainConfig: + character: str = "mon-hum-neu" + data_mode: str = "compressed" + # Wandb logging + project: str = "NetHack" + group: str = "small_scale_bc" + name: str = "bc" + version: int = 0 + # Model + rnn_hidden_dim: int = 2048 + rnn_layers: int = 2 + use_prev_action: bool = True + rnn_dropout: float = 0.0 + # Training + update_steps: int = 500_000 + batch_size: int = 64 + seq_len: int = 16 + learning_rate: float = 3e-4 + weight_decay: float = 0.0 + clip_grad_norm: Optional[float] = None + checkpoints_path: Optional[str] = None + eval_every: int = 10_000 + eval_episodes: int = 50 + eval_processes: int = 14 + render_processes: int = 14 + eval_seed: int = 50 + train_seed: int = 42 + + def __post_init__(self): + self.group = f"{self.group}-v{str(self.version)}" + self.name = f"{self.name}-{self.character}-{str(uuid.uuid4())[:8]}" + if self.checkpoints_path is not None: + self.checkpoints_path = os.path.join(self.checkpoints_path, self.group, self.name) + + +def set_seed(seed: int): + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + + +@torch.no_grad() +def filter_wd_params(model: nn.Module) -> Tuple[List[nn.parameter.Parameter], List[nn.parameter.Parameter]]: + no_decay, decay = [], [] + for name, param in model.named_parameters(): + if hasattr(param, 'requires_grad') and not param.requires_grad: + continue + if 'weight' in name and 'norm' not in name and 'bn' not in name: + decay.append(param) + else: + no_decay.append(param) + assert len(no_decay) + len(decay) == len(list(model.parameters())) + return no_decay, decay + + +def dict_to_tensor(data: Dict[str, np.ndarray], device: str) -> Dict[str, torch.Tensor]: + return {k: torch.as_tensor(v, device=device) for k, v in data.items()} + + +class Actor(nn.Module): + def __init__( + self, + action_dim: int, + rnn_hidden_dim: int = 512, + rnn_layers: int = 1, + rnn_dropout: float = 0.0, + use_prev_action: bool = True + ): + super().__init__() + # Action dimensions and prev actions + self.num_actions = action_dim + self.use_prev_action = use_prev_action + self.prev_actions_dim = self.num_actions if self.use_prev_action else 0 + + # Encoders + self.topline_encoder = TopLineEncoder() + self.bottomline_encoder = torch.jit.script(BottomLinesEncoder()) + + screen_shape = (SCREEN_SHAPE[1], SCREEN_SHAPE[2]) + self.screen_encoder = torch.jit.script(ScreenEncoder(screen_shape)) + + self.h_dim = sum( + [ + self.topline_encoder.hidden_dim, + self.bottomline_encoder.hidden_dim, + self.screen_encoder.hidden_dim, + self.prev_actions_dim, + ] + ) + # Policy + self.rnn = nn.LSTM( + self.h_dim, + rnn_hidden_dim, + num_layers=rnn_layers, + dropout=rnn_dropout, + batch_first=True + ) + self.head = nn.Linear(rnn_hidden_dim, self.num_actions) + + def forward(self, inputs, state=None): + B, T, C, H, W = inputs["screen_image"].shape + topline = inputs["tty_chars"][..., 0, :] + bottom_line = inputs["tty_chars"][..., -2:, :] + + encoded_state = [ + self.topline_encoder( + topline.float(memory_format=torch.contiguous_format).view(T * B, -1) + ), + self.bottomline_encoder( + bottom_line.float(memory_format=torch.contiguous_format).view(T * B, -1) + ), + self.screen_encoder( + inputs["screen_image"] + .float(memory_format=torch.contiguous_format) + .view(T * B, C, H, W) + ), + ] + if self.use_prev_action: + encoded_state.append( + F.one_hot(inputs["prev_actions"], self.num_actions).view(T * B, -1) + ) + + encoded_state = torch.cat(encoded_state, dim=1) + core_output, new_state = self.rnn(encoded_state.view(B, T, -1), state) + logits = self.head(core_output) + + return logits, new_state + + @torch.no_grad() + def vec_act(self, obs, state=None, device="cpu"): + inputs = { + "tty_chars": torch.tensor(obs["tty_chars"][:, None], device=device), + "screen_image": torch.tensor(obs["screen_image"][:, None], device=device), + "prev_actions": torch.tensor(obs["prev_actions"][:, None], dtype=torch.long, device=device) + } + logits, new_state = self(inputs, state) + actions = torch.argmax(logits.squeeze(1), dim=-1) + return actions.cpu().numpy(), new_state + + +@torch.no_grad() +def vec_evaluate( + vec_env: AsyncVectorEnv, + actor: Actor, + num_episodes: int, + seed: str = 0, + device: str = "cpu" +) -> Dict[str, np.ndarray]: + actor.eval() + # set seed for reproducibility (reseed=False by default) + vec_env.seed(seed) + # all this work is needed to mitigate bias for shorter + # episodes during vectorized evaluation, for more see: + # https://github.com/DLR-RM/stable-baselines3/issues/402 + n_envs = vec_env.num_envs + episode_rewards = [] + episode_lengths = [] + episode_depths = [] + + episode_counts = np.zeros(n_envs, dtype="int") + # Divides episodes among different sub environments in the vector as evenly as possible + episode_count_targets = np.array([(num_episodes + i) // n_envs for i in range(n_envs)], dtype="int") + + current_rewards = np.zeros(n_envs) + current_lengths = np.zeros(n_envs, dtype="int") + observations = vec_env.reset() + observations["prev_actions"] = np.zeros(n_envs, dtype=float) + + rnn_states = None + pbar = tqdm(total=num_episodes) + while (episode_counts < episode_count_targets).any(): + # faster to do this here for entire batch, than in wrappers for each env + observations["screen_image"] = render_screen_image( + tty_chars=observations["tty_chars"][:, np.newaxis, ...], + tty_colors=observations["tty_colors"][:, np.newaxis, ...], + tty_cursor=observations["tty_cursor"][:, np.newaxis, ...], + ) + observations["screen_image"] = np.squeeze(observations["screen_image"], 1) + + actions, rnn_states = actor.vec_act(observations, rnn_states, device=device) + + observations, rewards, dones, infos = vec_env.step(actions) + observations["prev_actions"] = actions + + current_rewards += rewards + current_lengths += 1 + + for i in range(n_envs): + if episode_counts[i] < episode_count_targets[i]: + if dones[i]: + episode_rewards.append(current_rewards[i]) + episode_lengths.append(current_lengths[i]) + episode_depths.append(infos[i]["current_depth"]) + episode_counts[i] += 1 + pbar.update(1) + + current_rewards[i] = 0 + current_lengths[i] = 0 + + pbar.close() + result = { + "reward_median": np.median(episode_rewards), + "reward_mean": np.mean(episode_rewards), + "reward_std": np.std(episode_rewards), + "reward_min": np.min(episode_rewards), + "reward_max": np.max(episode_rewards), + "reward_raw": np.array(episode_rewards), + # depth + "depth_median": np.median(episode_depths), + "depth_mean": np.mean(episode_depths), + "depth_std": np.std(episode_depths), + "depth_min": np.min(episode_depths), + "depth_max": np.max(episode_depths), + "depth_raw": np.array(episode_depths), + } + actor.train() + return result + + +@pyrallis.wrap() +def train(config: TrainConfig): + print(f"Device: {DEVICE}") + wandb.init( + config=asdict(config), + project=config.project, + group=config.group, + name=config.name, + id=str(uuid.uuid4()), + save_code=True, + ) + if config.checkpoints_path is not None: + print(f"Checkpoints path: {config.checkpoints_path}") + os.makedirs(config.checkpoints_path, exist_ok=True) + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: + pyrallis.dump(config, f) + + set_seed(config.train_seed) + + def env_fn(): + env = NetHackChallenge( + character=config.character, + observation_keys=["tty_chars", "tty_colors", "tty_cursor"] + ) + env = OfflineNetHackChallengeWrapper(env) + return env + + tmp_env = env_fn() + eval_env = AsyncVectorEnv( + env_fns=[env_fn for _ in range(config.eval_processes)], + copy=False + ) + buffer = SequentialBuffer( + dataset=tmp_env.get_dataset(mode=config.data_mode, scale="small"), + seq_len=config.seq_len, + batch_size=config.batch_size, + seed=config.train_seed, + add_next_step=False + ) + tp = ThreadPoolExecutor(max_workers=config.render_processes) + + actor = Actor( + action_dim=eval_env.single_action_space.n, + use_prev_action=config.use_prev_action, + rnn_hidden_dim=config.rnn_hidden_dim, + rnn_layers=config.rnn_layers, + rnn_dropout=config.rnn_dropout, + ).to(DEVICE) + + no_decay_params, decay_params = filter_wd_params(actor) + optim = torch.optim.AdamW([ + {"params": no_decay_params, "weight_decay": 0.0}, + {"params": decay_params, "weight_decay": config.weight_decay} + ], lr=config.learning_rate) + print("Number of parameters:", sum(p.numel() for p in actor.parameters())) + + scaler = torch.cuda.amp.GradScaler() + + rnn_state = None + prev_actions = torch.zeros((config.batch_size, 1), dtype=torch.long, device=DEVICE) + for step in trange(1, config.update_steps + 1, desc="Training"): + with Timeit() as timer: + batch = buffer.sample() + screen_image = render_screen_image( + tty_chars=batch["tty_chars"], + tty_colors=batch["tty_colors"], + tty_cursor=batch["tty_cursor"], + threadpool=tp, + ) + batch["screen_image"] = screen_image + batch = dict_to_tensor(batch, device=DEVICE) + + wandb.log( + { + "times/batch_loading_cpu": timer.elapsed_time_cpu, + "times/batch_loading_gpu": timer.elapsed_time_gpu, + }, + step=step, + ) + + with Timeit() as timer: + with torch.cuda.amp.autocast(): + logits, rnn_state = actor( + inputs={ + "screen_image": batch["screen_image"], + "tty_chars": batch["tty_chars"], + "prev_actions": torch.cat( + [prev_actions.long(), batch["actions"][:, :-1].long()], dim=1 + ) + }, + state=rnn_state, + ) + rnn_state = [a.detach() for a in rnn_state] + + dist = Categorical(logits=logits) + loss = -dist.log_prob(batch["actions"]).mean() + # update prev_actions for next iteration + prev_actions = batch["actions"][:, -1].unsqueeze(-1) + + wandb.log({"times/forward_pass": timer.elapsed_time_gpu}, step=step) + + with Timeit() as timer: + scaler.scale(loss).backward() + # loss.backward() + if config.clip_grad_norm is not None: + scaler.unscale_(optim) + torch.nn.utils.clip_grad_norm_(actor.parameters(), config.clip_grad_norm) + # optim.step() + scaler.step(optim) + scaler.update() + optim.zero_grad(set_to_none=True) + + wandb.log({"times/backward_pass": timer.elapsed_time_gpu}, step=step) + + wandb.log({ + "loss": loss.detach().item(), + "transitions": config.batch_size * config.seq_len * step, + }, step=step) + + if step % config.eval_every == 0: + with Timeit() as timer: + eval_stats = vec_evaluate( + eval_env, actor, config.eval_episodes, config.eval_seed, device=DEVICE + ) + raw_returns = eval_stats.pop("reward_raw") + raw_depths = eval_stats.pop("depth_raw") + normalized_scores = tmp_env.get_normalized_score(raw_returns) + + wandb.log({ + "times/evaluation_gpu": timer.elapsed_time_gpu, + "times/evaluation_cpu": timer.elapsed_time_cpu, + }, step=step) + + wandb.log(dict( + eval_stats, + **{"transitions": config.batch_size * config.seq_len * step}, + ), step=step) + + if config.checkpoints_path is not None: + torch.save(actor.state_dict(), os.path.join(config.checkpoints_path, f"{step}.pt")) + # saving raw logs + np.save(os.path.join(config.checkpoints_path, f"{step}_returns.npy"), raw_returns) + np.save(os.path.join(config.checkpoints_path, f"{step}_depths.npy"), raw_depths) + np.save(os.path.join(config.checkpoints_path, f"{step}_normalized_scores.npy"), normalized_scores) + + # also saving to wandb files for easier use in the future + np.save(os.path.join(wandb.run.dir, f"{step}_returns.npy"), raw_returns) + np.save(os.path.join(wandb.run.dir, f"{step}_depths.npy"), raw_depths) + np.save(os.path.join(wandb.run.dir, f"{step}_normalized_scores.npy"), normalized_scores) + + buffer.close() + + +if __name__ == "__main__": + set_start_method("spawn") + train() diff --git a/algorithms/small_scale/cql_chaotic_lstm.py b/algorithms/small_scale/cql_chaotic_lstm.py new file mode 100644 index 0000000..082395e --- /dev/null +++ b/algorithms/small_scale/cql_chaotic_lstm.py @@ -0,0 +1,476 @@ +import pyrallis +from dataclasses import dataclass, asdict + +import random +import wandb +import os +import uuid +import torch +import torch.nn as nn +import torch.nn.functional as F + +from gym.vector import AsyncVectorEnv +from concurrent.futures import ThreadPoolExecutor +from tqdm.auto import tqdm, trange +import numpy as np + +from copy import deepcopy +from typing import Optional, Dict, Tuple, Any, List + +from multiprocessing import set_start_method +from katakomba.env import NetHackChallenge, OfflineNetHackChallengeWrapper +from katakomba.nn.chaotic_dwarf import TopLineEncoder, BottomLinesEncoder, ScreenEncoder +from katakomba.utils.render import SCREEN_SHAPE, render_screen_image +from katakomba.utils.datasets import SequentialBuffer +from katakomba.utils.misc import Timeit, StatMean + +LSTM_HIDDEN = Tuple[torch.Tensor, torch.Tensor] +UPDATE_INFO = Dict[str, Any] + +torch.backends.cudnn.benchmark = True +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +@dataclass +class TrainConfig: + character: str = "mon-hum-neu" + data_mode: str = "compressed" + # Wandb logging + project: str = "NetHack" + group: str = "small_scale_cql" + name: str = "cql" + version: int = 0 + # Model + rnn_hidden_dim: int = 2048 + rnn_layers: int = 2 + use_prev_action: bool = True + rnn_dropout: float = 0.0 + clip_range: float = 10.0 + tau: float = 0.005 + gamma: float = 0.999 + alpha: float = 2.0 + # Training + update_steps: int = 500_000 + batch_size: int = 64 + seq_len: int = 16 + learning_rate: float = 3e-4 + weight_decay: float = 0.0 + clip_grad_norm: Optional[float] = None + checkpoints_path: Optional[str] = None + eval_every: int = 10_000 + eval_episodes: int = 50 + eval_processes: int = 14 + render_processes: int = 14 + eval_seed: int = 50 + train_seed: int = 42 + + def __post_init__(self): + self.group = f"{self.group}-v{str(self.version)}" + self.name = f"{self.name}-{self.character}-{str(uuid.uuid4())[:8]}" + if self.checkpoints_path is not None: + self.checkpoints_path = os.path.join(self.checkpoints_path, self.group, self.name) + + +def set_seed(seed: int): + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + + +@torch.no_grad() +def filter_wd_params(model: nn.Module) -> Tuple[List[nn.parameter.Parameter], List[nn.parameter.Parameter]]: + no_decay, decay = [], [] + for name, param in model.named_parameters(): + if hasattr(param, 'requires_grad') and not param.requires_grad: + continue + if 'weight' in name and 'norm' not in name and 'bn' not in name: + decay.append(param) + else: + no_decay.append(param) + assert len(no_decay) + len(decay) == len(list(model.parameters())) + return no_decay, decay + + +def dict_to_tensor(data: Dict[str, np.ndarray], device: str) -> Dict[str, torch.Tensor]: + return {k: torch.as_tensor(v, dtype=torch.float, device=device) for k, v in data.items()} + + +def soft_update(target: nn.Module, source: nn.Module, tau: float): + for tp, sp in zip(target.parameters(), source.parameters()): + tp.data.copy_((1 - tau) * tp.data + tau * sp.data) + + +class Critic(nn.Module): + def __init__( + self, + action_dim: int, + rnn_hidden_dim: int = 512, + rnn_layers: int = 1, + rnn_dropout: float = 0.0, + use_prev_action: bool = True + ): + super().__init__() + self.num_actions = action_dim + self.use_prev_action = use_prev_action + self.prev_actions_dim = self.num_actions if self.use_prev_action else 0 + + # Encoders + self.topline_encoder = torch.jit.script(TopLineEncoder()) + self.bottomline_encoder = torch.jit.script(BottomLinesEncoder()) + + screen_shape = (SCREEN_SHAPE[1], SCREEN_SHAPE[2]) + self.screen_encoder = torch.jit.script(ScreenEncoder(screen_shape)) + + self.h_dim = sum( + [ + self.topline_encoder.hidden_dim, + self.bottomline_encoder.hidden_dim, + self.screen_encoder.hidden_dim, + self.prev_actions_dim, + ] + ) + # Policy + self.rnn = nn.LSTM( + self.h_dim, + rnn_hidden_dim, + num_layers=rnn_layers, + dropout=rnn_dropout, + batch_first=True + ) + self.head = nn.Linear(rnn_hidden_dim, self.num_actions) + + def forward(self, inputs, state=None): + # [batch_size, seq_len, ...] + B, T, C, H, W = inputs["screen_image"].shape + topline = inputs["tty_chars"][..., 0, :] + bottom_line = inputs["tty_chars"][..., -2:, :] + + encoded_state = [ + self.topline_encoder( + topline.float(memory_format=torch.contiguous_format).view(T * B, -1) + ), + self.bottomline_encoder( + bottom_line.float(memory_format=torch.contiguous_format).view(T * B, -1) + ), + self.screen_encoder( + inputs["screen_image"] + .float(memory_format=torch.contiguous_format) + .view(T * B, C, H, W) + ), + ] + if self.use_prev_action: + encoded_state.append( + F.one_hot(inputs["prev_actions"], self.num_actions).view(T * B, -1) + ) + + encoded_state = torch.cat(encoded_state, dim=1) + core_output, new_state = self.rnn(encoded_state.view(B, T, -1), state) + q_values = self.head(core_output).view(B, T, self.num_actions) + + return q_values, new_state + + @torch.no_grad() + def vec_act(self, obs, state=None, device="cpu"): + inputs = { + "tty_chars": torch.tensor(obs["tty_chars"][:, None], device=device), + "screen_image": torch.tensor(obs["screen_image"][:, None], device=device), + "prev_actions": torch.tensor(obs["prev_actions"][:, None], dtype=torch.long, device=device) + } + q_values, new_state = self(inputs, state) + actions = torch.argmax(q_values.squeeze(1), dim=-1) + return actions.cpu().numpy(), new_state + + +def cql_loss( + critic: Critic, + target_critic: Critic, + obs: Dict[str, torch.Tensor], + next_obs: Dict[str, torch.Tensor], + actions: torch.Tensor, + rewards: torch.Tensor, + dones: torch.Tensor, + rnn_states: LSTM_HIDDEN, + target_rnn_states: LSTM_HIDDEN, + alpha: float, + gamma: float, +) -> Tuple[torch.Tensor, LSTM_HIDDEN, LSTM_HIDDEN, UPDATE_INFO]: + with torch.no_grad(): + next_q_values, next_target_rnn_states = target_critic(next_obs, state=target_rnn_states) + next_q_values = next_q_values.max(dim=-1).values + assert next_q_values.shape == rewards.shape == dones.shape + q_target = rewards + gamma * (1 - dones) * next_q_values + + assert actions.dim() == 2 + q_pred, next_rnn_states = critic(obs, state=rnn_states) + q_pred_actions = q_pred.gather(-1, actions.to(torch.long).unsqueeze(-1)).squeeze() + assert q_pred_actions.shape == q_target.shape + + td_loss = F.mse_loss(q_pred_actions, q_target) * alpha + # [batch_size, seq_len, num_actions] -> [batch_size, seq_len] -> 1 + cql_loss = (torch.logsumexp(q_pred, dim=-1) - q_pred_actions).mean() # * alpha + + loss = cql_loss + td_loss + loss_info = { + "td_loss": td_loss.item(), + "cql_loss": cql_loss, + "loss": loss, + "q_target": q_target.mean().item(), + } + return loss, next_rnn_states, next_target_rnn_states, loss_info + + +@torch.no_grad() +def vec_evaluate( + vec_env: AsyncVectorEnv, + actor: Critic, + num_episodes: int, + seed: int = 0, + device: str = "cpu" +) -> Dict[str, np.ndarray]: + actor.eval() + # set seed for reproducibility (reseed=False by default) + vec_env.seed(seed) + # all this work is needed to mitigate bias for shorter + # episodes during vectorized evaluation, for more see: + # https://github.com/DLR-RM/stable-baselines3/issues/402 + n_envs = vec_env.num_envs + episode_rewards = [] + episode_lengths = [] + episode_depths = [] + + episode_counts = np.zeros(n_envs, dtype="int") + # Divides episodes among different sub environments in the vector as evenly as possible + episode_count_targets = np.array([(num_episodes + i) // n_envs for i in range(n_envs)], dtype="int") + + current_rewards = np.zeros(n_envs) + current_lengths = np.zeros(n_envs, dtype="int") + observations = vec_env.reset() + observations["prev_actions"] = np.zeros(n_envs, dtype=float) + + rnn_states = None + pbar = tqdm(total=num_episodes) + while (episode_counts < episode_count_targets).any(): + # faster to do this here for entire batch, than in wrappers for each env + observations["screen_image"] = render_screen_image( + tty_chars=observations["tty_chars"][:, np.newaxis, ...], + tty_colors=observations["tty_colors"][:, np.newaxis, ...], + tty_cursor=observations["tty_cursor"][:, np.newaxis, ...], + ) + observations["screen_image"] = np.squeeze(observations["screen_image"], 1) + + actions, rnn_states = actor.vec_act(observations, rnn_states, device=device) + + observations, rewards, dones, infos = vec_env.step(actions) + observations["prev_actions"] = actions + + current_rewards += rewards + current_lengths += 1 + + for i in range(n_envs): + if episode_counts[i] < episode_count_targets[i]: + if dones[i]: + episode_rewards.append(current_rewards[i]) + episode_lengths.append(current_lengths[i]) + episode_depths.append(infos[i]["current_depth"]) + episode_counts[i] += 1 + pbar.update(1) + + current_rewards[i] = 0 + current_lengths[i] = 0 + + pbar.close() + result = { + "reward_median": np.median(episode_rewards), + "reward_mean": np.mean(episode_rewards), + "reward_std": np.std(episode_rewards), + "reward_min": np.min(episode_rewards), + "reward_max": np.max(episode_rewards), + "reward_raw": np.array(episode_rewards), + # depth + "depth_median": np.median(episode_depths), + "depth_mean": np.mean(episode_depths), + "depth_std": np.std(episode_depths), + "depth_min": np.min(episode_depths), + "depth_max": np.max(episode_depths), + "depth_raw": np.array(episode_depths), + } + actor.train() + return result + + +@pyrallis.wrap() +def train(config: TrainConfig): + print(f"Device: {DEVICE}") + wandb.init( + config=asdict(config), + project=config.project, + group=config.group, + name=config.name, + id=str(uuid.uuid4()), + save_code=True, + ) + if config.checkpoints_path is not None: + print(f"Checkpoints path: {config.checkpoints_path}") + os.makedirs(config.checkpoints_path, exist_ok=True) + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: + pyrallis.dump(config, f) + + set_seed(config.train_seed) + + def env_fn(): + env = NetHackChallenge( + character=config.character, + observation_keys=["tty_chars", "tty_colors", "tty_cursor"] + ) + env = OfflineNetHackChallengeWrapper(env) + return env + + tmp_env = env_fn() + eval_env = AsyncVectorEnv( + env_fns=[env_fn for _ in range(config.eval_processes)], + copy=False + ) + buffer = SequentialBuffer( + dataset=tmp_env.get_dataset(mode=config.data_mode, scale="small"), + seq_len=config.seq_len, + batch_size=config.batch_size, + seed=config.train_seed, + add_next_step=True # true as this is needed for next_obs + ) + tp = ThreadPoolExecutor(max_workers=config.render_processes) + + critic = Critic( + action_dim=eval_env.single_action_space.n, + use_prev_action=config.use_prev_action, + rnn_hidden_dim=config.rnn_hidden_dim, + rnn_layers=config.rnn_layers, + rnn_dropout=config.rnn_dropout, + ).to(DEVICE) + with torch.no_grad(): + target_critic = deepcopy(critic) + + no_decay_params, decay_params = filter_wd_params(critic) + optim = torch.optim.AdamW([ + {"params": no_decay_params, "weight_decay": 0.0}, + {"params": decay_params, "weight_decay": config.weight_decay} + ], lr=config.learning_rate) + print("Number of parameters:", sum(p.numel() for p in critic.parameters())) + + scaler = torch.cuda.amp.GradScaler() + rnn_state, target_rnn_state = None, None + prev_actions = torch.zeros((config.batch_size, 1), dtype=torch.long, device=DEVICE) + + # For reward normalization + reward_stats = StatMean(cumulative=True) + running_rewards = 0.0 + + for step in trange(1, config.update_steps + 1, desc="Training"): + with Timeit() as timer: + batch = buffer.sample() + screen_image = render_screen_image( + tty_chars=batch["tty_chars"], + tty_colors=batch["tty_colors"], + tty_cursor=batch["tty_cursor"], + threadpool=tp, + ) + batch["screen_image"] = screen_image + + # Update reward statistics (as in the original nle implementation) + running_rewards *= config.gamma + running_rewards += batch["rewards"] + reward_stats += running_rewards ** 2 + running_rewards *= (~batch["dones"]).astype(float) + # Normalize the reward + reward_std = reward_stats.mean() ** 0.5 + batch["rewards"] = batch["rewards"] / max(0.01, reward_std) + batch["rewards"] = np.clip(batch["rewards"], -config.clip_range, config.clip_range) + + batch = dict_to_tensor(batch, device=DEVICE) + + wandb.log( + { + "times/batch_loading_cpu": timer.elapsed_time_cpu, + "times/batch_loading_gpu": timer.elapsed_time_gpu, + }, + step=step, + ) + + with Timeit() as timer: + with torch.cuda.amp.autocast(): + obs = { + "screen_image": batch["screen_image"][:, :-1].contiguous(), + "tty_chars": batch["tty_chars"][:, :-1].contiguous(), + "prev_actions": torch.cat([prev_actions.long(), batch["actions"][:, :-2].long()], dim=1) + } + next_obs = { + "screen_image": batch["screen_image"][:, 1:].contiguous(), + "tty_chars": batch["tty_chars"][:, 1:].contiguous(), + "prev_actions": batch["actions"][:, :-1].long() + } + loss, rnn_state, target_rnn_state, loss_info = cql_loss( + critic=critic, + target_critic=target_critic, + obs=obs, + next_obs=next_obs, + actions=batch["actions"][:, :-1], + rewards=batch["rewards"][:, :-1], + dones=batch["dones"][:, :-1], + rnn_states=rnn_state, + target_rnn_states=target_rnn_state, + alpha=config.alpha, + gamma=config.gamma + ) + rnn_state = [a.detach() for a in rnn_state] + target_rnn_state = [a.detach() for a in target_rnn_state] + # update prev_actions for next iteration (-1 is seq_len + 1, so -2) + prev_actions = batch["actions"][:, -2].unsqueeze(-1) + + wandb.log({"times/forward_pass": timer.elapsed_time_gpu}, step=step) + + with Timeit() as timer: + scaler.scale(loss).backward() + if config.clip_grad_norm is not None: + scaler.unscale_(optim) + torch.nn.utils.clip_grad_norm_(critic.parameters(), config.clip_grad_norm) + scaler.step(optim) + scaler.update() + optim.zero_grad(set_to_none=True) + soft_update(target_critic, critic, tau=config.tau) + + wandb.log({"times/backward_pass": timer.elapsed_time_gpu}, step=step) + wandb.log({"transitions": config.batch_size * config.seq_len * step, **loss_info}, step=step) + + if step % config.eval_every == 0: + with Timeit() as timer: + eval_stats = vec_evaluate( + eval_env, critic, config.eval_episodes, config.eval_seed, device=DEVICE + ) + raw_returns = eval_stats.pop("reward_raw") + raw_depths = eval_stats.pop("depth_raw") + normalized_scores = tmp_env.get_normalized_score(raw_returns) + + wandb.log({ + "times/evaluation_gpu": timer.elapsed_time_gpu, + "times/evaluation_cpu": timer.elapsed_time_cpu, + }, step=step) + wandb.log({"transitions": config.batch_size * config.seq_len * step, **eval_stats}, step=step) + + if config.checkpoints_path is not None: + torch.save(critic.state_dict(), os.path.join(config.checkpoints_path, f"{step}.pt")) + # saving raw logs + np.save(os.path.join(config.checkpoints_path, f"{step}_returns.npy"), raw_returns) + np.save(os.path.join(config.checkpoints_path, f"{step}_depths.npy"), raw_depths) + np.save(os.path.join(config.checkpoints_path, f"{step}_normalized_scores.npy"), normalized_scores) + + # also saving to wandb files for easier use in the future + np.save(os.path.join(wandb.run.dir, f"{step}_returns.npy"), raw_returns) + np.save(os.path.join(wandb.run.dir, f"{step}_depths.npy"), raw_depths) + np.save(os.path.join(wandb.run.dir, f"{step}_normalized_scores.npy"), normalized_scores) + + buffer.close() + + +if __name__ == "__main__": + set_start_method("spawn") + train() diff --git a/algorithms/small_scale/iql_chaotic_lstm.py b/algorithms/small_scale/iql_chaotic_lstm.py new file mode 100644 index 0000000..30c95c9 --- /dev/null +++ b/algorithms/small_scale/iql_chaotic_lstm.py @@ -0,0 +1,517 @@ +import pyrallis +from dataclasses import dataclass, asdict + +import random +import wandb +import os +import uuid +import torch +import torch.nn as nn +import torch.nn.functional as F + +from gym.vector import AsyncVectorEnv +from concurrent.futures import ThreadPoolExecutor +from tqdm.auto import tqdm, trange +import numpy as np + +from copy import deepcopy +from typing import Optional, Dict, Tuple, Any, List + +from multiprocessing import set_start_method +from katakomba.env import NetHackChallenge, OfflineNetHackChallengeWrapper +from katakomba.nn.chaotic_dwarf import TopLineEncoder, BottomLinesEncoder, ScreenEncoder +from katakomba.utils.render import SCREEN_SHAPE, render_screen_image +from katakomba.utils.datasets import SequentialBuffer +from katakomba.utils.misc import Timeit, StatMean + +LSTM_HIDDEN = Tuple[torch.Tensor, torch.Tensor] +UPDATE_INFO = Dict[str, Any] + +torch.backends.cudnn.benchmark = True +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +@dataclass +class TrainConfig: + character: str = "mon-hum-neu" + data_mode: str = "compressed" + # Wandb logging + project: str = "NetHack" + group: str = "small_scale_iql" + name: str = "iql" + version: int = 0 + # Model + rnn_hidden_dim: int = 2048 + rnn_layers: int = 2 + use_prev_action: bool = True + rnn_dropout: float = 0.0 + clip_range: float = 10.0 + tau: float = 0.005 + gamma: float = 0.999 + expectile_tau: float = 0.8 + temperature: float = 1.0 + # Training + update_steps: int = 500_000 + batch_size: int = 64 + seq_len: int = 16 + learning_rate: float = 3e-4 + weight_decay: float = 0.0 + clip_grad_norm: Optional[float] = None + checkpoints_path: Optional[str] = None + eval_every: int = 10_000 + eval_episodes: int = 50 + eval_processes: int = 14 + render_processes: int = 14 + eval_seed: int = 50 + train_seed: int = 42 + + def __post_init__(self): + self.group = f"{self.group}-v{str(self.version)}" + self.name = f"{self.name}-{self.character}-{str(uuid.uuid4())[:8]}" + if self.checkpoints_path is not None: + self.checkpoints_path = os.path.join(self.checkpoints_path, self.group, self.name) + + +def set_seed(seed: int): + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + + +@torch.no_grad() +def filter_wd_params(model: nn.Module) -> Tuple[List[nn.parameter.Parameter], List[nn.parameter.Parameter]]: + no_decay, decay = [], [] + for name, param in model.named_parameters(): + if hasattr(param, 'requires_grad') and not param.requires_grad: + continue + if 'weight' in name and 'norm' not in name and 'bn' not in name: + decay.append(param) + else: + no_decay.append(param) + assert len(no_decay) + len(decay) == len(list(model.parameters())) + return no_decay, decay + + +def dict_to_tensor(data: Dict[str, np.ndarray], device: str) -> Dict[str, torch.Tensor]: + return {k: torch.as_tensor(v, dtype=torch.float, device=device) for k, v in data.items()} + + +def soft_update(target: nn.Module, source: nn.Module, tau: float): + for tp, sp in zip(target.parameters(), source.parameters()): + tp.data.copy_((1 - tau) * tp.data + tau * sp.data) + + +def asymmetric_l2_loss(u: torch.Tensor, tau: float) -> torch.Tensor: + return torch.mean(torch.abs(tau - (u < 0).float()) * u ** 2) + + +class Critic(nn.Module): + def __init__( + self, + action_dim: int, + rnn_hidden_dim: int = 512, + rnn_layers: int = 1, + rnn_dropout: float = 0.0, + use_prev_action: bool = True + ): + super().__init__() + self.num_actions = action_dim + self.use_prev_action = use_prev_action + self.prev_actions_dim = self.num_actions if self.use_prev_action else 0 + + # Encoders + self.topline_encoder = torch.jit.script(TopLineEncoder()) + self.bottomline_encoder = torch.jit.script(BottomLinesEncoder()) + + screen_shape = (SCREEN_SHAPE[1], SCREEN_SHAPE[2]) + self.screen_encoder = torch.jit.script(ScreenEncoder(screen_shape)) + + self.h_dim = sum( + [ + self.topline_encoder.hidden_dim, + self.bottomline_encoder.hidden_dim, + self.screen_encoder.hidden_dim, + self.prev_actions_dim, + ] + ) + # networks + self.rnn = nn.LSTM( + self.h_dim, + rnn_hidden_dim, + num_layers=rnn_layers, + dropout=rnn_dropout, + batch_first=True + ) + # in similar style to original: + # https://github.com/dungeonsdatasubmission/dungeonsdata-neurips2022/blob/ee72d6aac9df00a4a6ab1f501db37a632a75b952/experiment_code/hackrl/models/offline_chaotic_dwarf.py#L538 + self.qf1 = nn.Linear(rnn_hidden_dim + self.num_actions, 1) + self.qf2 = nn.Linear(rnn_hidden_dim + self.num_actions, 1) + self.vf = nn.Linear(rnn_hidden_dim, 1) + self.policy = nn.Linear(rnn_hidden_dim, self.num_actions) + + def forward(self, obs, state=None, actions=None): + # [batch_size, seq_len, ...] + B, T, C, H, W = obs["screen_image"].shape + topline = obs["tty_chars"][..., 0, :] + bottom_line = obs["tty_chars"][..., -2:, :] + + encoded_state = [ + self.topline_encoder( + topline.float(memory_format=torch.contiguous_format).view(T * B, -1) + ), + self.bottomline_encoder( + bottom_line.float(memory_format=torch.contiguous_format).view(T * B, -1) + ), + self.screen_encoder( + obs["screen_image"] + .float(memory_format=torch.contiguous_format) + .view(T * B, C, H, W) + ), + ] + if self.use_prev_action: + encoded_state.append( + F.one_hot(obs["prev_actions"], self.num_actions).view(T * B, -1) + ) + encoded_state = torch.cat(encoded_state, dim=1) + core_output, new_state = self.rnn(encoded_state.view(B, T, -1), state) + # policy + logits = self.policy(core_output) + vf = self.vf(core_output).squeeze(-1) + + if actions is not None: + # state action value function + core_output_actions = torch.cat([ + core_output, F.one_hot(actions, self.num_actions) + ], dim=-1) + q1 = self.qf1(core_output_actions).squeeze(-1) + q2 = self.qf2(core_output_actions).squeeze(-1) + + return logits, vf, q1, q2, new_state + + return logits, vf, None, None, new_state + + @torch.no_grad() + def vec_act(self, obs, state=None, device="cpu"): + inputs = { + "tty_chars": torch.tensor(obs["tty_chars"][:, None], device=device), + "screen_image": torch.tensor(obs["screen_image"][:, None], device=device), + "prev_actions": torch.tensor(obs["prev_actions"][:, None], dtype=torch.long, device=device) + } + logits, *_, new_state = self(inputs, state=state) + actions = torch.argmax(logits.squeeze(1), dim=-1) + return actions.cpu().numpy(), new_state + + +def iql_loss( + critic: Critic, + target_critic: Critic, + obs: Dict[str, torch.Tensor], + next_obs: Dict[str, torch.Tensor], + actions: torch.Tensor, + rewards: torch.Tensor, + dones: torch.Tensor, + rnn_states: LSTM_HIDDEN, + target_rnn_states: LSTM_HIDDEN, + next_rnn_states: LSTM_HIDDEN, + gamma: float, + expectile_tau: float, + temperature: float +) -> Tuple[torch.Tensor, LSTM_HIDDEN, LSTM_HIDDEN, LSTM_HIDDEN, UPDATE_INFO]: + # state value function loss + with torch.no_grad(): + _, _, target_q1, target_q2, new_target_rnn_states = target_critic(obs, actions=actions.long(), state=target_rnn_states) + target_q = torch.minimum(target_q1, target_q2) + + logits, v_pred, q1_pred, q2_pred, new_rnn_states = critic(obs, actions=actions.long(), state=rnn_states) + assert target_q.shape == v_pred.shape + advantage = target_q - v_pred + + value_loss = asymmetric_l2_loss(advantage, expectile_tau) + + # state action value function loss + with torch.no_grad(): + _, next_v, _, _, new_next_rnn_states = critic(next_obs, state=next_rnn_states) + next_q = rewards + (1 - dones) * gamma * next_v + + assert q1_pred.shape == q2_pred.shape == next_q.shape + td_loss = (F.mse_loss(q1_pred, next_q) + F.mse_loss(q2_pred, next_q)) / 2 + + # actor loss + weights = torch.exp(temperature * advantage.clamp(max=100.0)) + log_probs = torch.distributions.Categorical(logits=logits).log_prob(actions) + + actor_loss = torch.mean(-log_probs * weights.detach()) + + loss = value_loss + td_loss + actor_loss + loss_info = { + "td_loss": td_loss.item(), + "value_loss": value_loss.item(), + "actor_loss": actor_loss.item(), + "loss": loss, + "next_v": next_v.mean().item(), + "q_target": next_q.mean().item() + } + return loss, new_rnn_states, new_target_rnn_states, new_next_rnn_states, loss_info + + +@torch.no_grad() +def vec_evaluate( + vec_env: AsyncVectorEnv, + actor: Critic, + num_episodes: int, + seed: int = 0, + device: str = "cpu" +) -> Dict[str, np.ndarray]: + actor.eval() + # set seed for reproducibility (reseed=False by default) + vec_env.seed(seed) + # all this work is needed to mitigate bias for shorter + # episodes during vectorized evaluation, for more see: + # https://github.com/DLR-RM/stable-baselines3/issues/402 + n_envs = vec_env.num_envs + episode_rewards = [] + episode_lengths = [] + episode_depths = [] + + episode_counts = np.zeros(n_envs, dtype="int") + # Divides episodes among different sub environments in the vector as evenly as possible + episode_count_targets = np.array([(num_episodes + i) // n_envs for i in range(n_envs)], dtype="int") + + current_rewards = np.zeros(n_envs) + current_lengths = np.zeros(n_envs, dtype="int") + observations = vec_env.reset() + observations["prev_actions"] = np.zeros(n_envs, dtype=float) + + rnn_states = None + pbar = tqdm(total=num_episodes) + while (episode_counts < episode_count_targets).any(): + # faster to do this here for entire batch, than in wrappers for each env + observations["screen_image"] = render_screen_image( + tty_chars=observations["tty_chars"][:, np.newaxis, ...], + tty_colors=observations["tty_colors"][:, np.newaxis, ...], + tty_cursor=observations["tty_cursor"][:, np.newaxis, ...], + ) + observations["screen_image"] = np.squeeze(observations["screen_image"], 1) + + actions, rnn_states = actor.vec_act(observations, rnn_states, device=device) + + observations, rewards, dones, infos = vec_env.step(actions) + observations["prev_actions"] = actions + + current_rewards += rewards + current_lengths += 1 + + for i in range(n_envs): + if episode_counts[i] < episode_count_targets[i]: + if dones[i]: + episode_rewards.append(current_rewards[i]) + episode_lengths.append(current_lengths[i]) + episode_depths.append(infos[i]["current_depth"]) + episode_counts[i] += 1 + pbar.update(1) + + current_rewards[i] = 0 + current_lengths[i] = 0 + + pbar.close() + result = { + "reward_median": np.median(episode_rewards), + "reward_mean": np.mean(episode_rewards), + "reward_std": np.std(episode_rewards), + "reward_min": np.min(episode_rewards), + "reward_max": np.max(episode_rewards), + "reward_raw": np.array(episode_rewards), + # depth + "depth_median": np.median(episode_depths), + "depth_mean": np.mean(episode_depths), + "depth_std": np.std(episode_depths), + "depth_min": np.min(episode_depths), + "depth_max": np.max(episode_depths), + "depth_raw": np.array(episode_depths), + } + actor.train() + return result + + +@pyrallis.wrap() +def train(config: TrainConfig): + print(f"Device: {DEVICE}") + wandb.init( + config=asdict(config), + project=config.project, + group=config.group, + name=config.name, + id=str(uuid.uuid4()), + save_code=True, + ) + if config.checkpoints_path is not None: + print(f"Checkpoints path: {config.checkpoints_path}") + os.makedirs(config.checkpoints_path, exist_ok=True) + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: + pyrallis.dump(config, f) + + set_seed(config.train_seed) + + def env_fn(): + env = NetHackChallenge( + character=config.character, + observation_keys=["tty_chars", "tty_colors", "tty_cursor"] + ) + env = OfflineNetHackChallengeWrapper(env) + return env + + tmp_env = env_fn() + eval_env = AsyncVectorEnv( + env_fns=[env_fn for _ in range(config.eval_processes)], + copy=False + ) + buffer = SequentialBuffer( + dataset=tmp_env.get_dataset(mode=config.data_mode, scale="small"), + seq_len=config.seq_len, + batch_size=config.batch_size, + seed=config.train_seed, + add_next_step=True # true as this is needed for next_obs + ) + tp = ThreadPoolExecutor(max_workers=config.render_processes) + + critic = Critic( + action_dim=eval_env.single_action_space.n, + use_prev_action=config.use_prev_action, + rnn_hidden_dim=config.rnn_hidden_dim, + rnn_layers=config.rnn_layers, + rnn_dropout=config.rnn_dropout, + ).to(DEVICE) + with torch.no_grad(): + target_critic = deepcopy(critic) + + no_decay_params, decay_params = filter_wd_params(critic) + optim = torch.optim.AdamW([ + {"params": no_decay_params, "weight_decay": 0.0}, + {"params": decay_params, "weight_decay": config.weight_decay} + ], lr=config.learning_rate) + print("Number of parameters:", sum(p.numel() for p in critic.parameters())) + + scaler = torch.cuda.amp.GradScaler() + rnn_state, target_rnn_state, next_rnn_state = None, None, None + prev_actions = torch.zeros((config.batch_size, 1), dtype=torch.long, device=DEVICE) + + # For reward normalization + reward_stats = StatMean(cumulative=True) + running_rewards = 0.0 + + for step in trange(1, config.update_steps + 1, desc="Training"): + with Timeit() as timer: + batch = buffer.sample() + screen_image = render_screen_image( + tty_chars=batch["tty_chars"], + tty_colors=batch["tty_colors"], + tty_cursor=batch["tty_cursor"], + threadpool=tp, + ) + batch["screen_image"] = screen_image + + # Update reward statistics (as in the original nle implementation) + running_rewards *= config.gamma + running_rewards += batch["rewards"] + reward_stats += running_rewards ** 2 + running_rewards *= (~batch["dones"]).astype(float) + # Normalize the reward + reward_std = reward_stats.mean() ** 0.5 + batch["rewards"] = batch["rewards"] / max(0.01, reward_std) + batch["rewards"] = np.clip(batch["rewards"], -config.clip_range, config.clip_range) + + batch = dict_to_tensor(batch, device=DEVICE) + + wandb.log( + { + "times/batch_loading_cpu": timer.elapsed_time_cpu, + "times/batch_loading_gpu": timer.elapsed_time_gpu, + }, + step=step, + ) + + with Timeit() as timer: + with torch.cuda.amp.autocast(): + obs = { + "screen_image": batch["screen_image"][:, :-1].contiguous(), + "tty_chars": batch["tty_chars"][:, :-1].contiguous(), + "prev_actions": torch.cat([prev_actions.long(), batch["actions"][:, :-2].long()], dim=1) + } + next_obs = { + "screen_image": batch["screen_image"][:, 1:].contiguous(), + "tty_chars": batch["tty_chars"][:, 1:].contiguous(), + "prev_actions": batch["actions"][:, :-1].long() + } + loss, rnn_state, target_rnn_state, next_rnn_state, loss_info = iql_loss( + critic=critic, + target_critic=target_critic, + obs=obs, + next_obs=next_obs, + actions=batch["actions"][:, :-1], + rewards=batch["rewards"][:, :-1], + dones=batch["dones"][:, :-1], + rnn_states=rnn_state, + target_rnn_states=target_rnn_state, + next_rnn_states=next_rnn_state, + expectile_tau=config.expectile_tau, + temperature=config.temperature, + gamma=config.gamma + ) + # detaching rnn hidden states for the next iteration + rnn_state = [a.detach() for a in rnn_state] + target_rnn_state = [a.detach() for a in target_rnn_state] + next_rnn_state = [a.detach() for a in next_rnn_state] + + # update prev_actions for next iteration (-1 is seq_len + 1, so -2) + prev_actions = batch["actions"][:, -2].unsqueeze(-1) + + wandb.log({"times/forward_pass": timer.elapsed_time_gpu}, step=step) + + with Timeit() as timer: + scaler.scale(loss).backward() + if config.clip_grad_norm is not None: + scaler.unscale_(optim) + torch.nn.utils.clip_grad_norm_(critic.parameters(), config.clip_grad_norm) + scaler.step(optim) + scaler.update() + optim.zero_grad(set_to_none=True) + soft_update(target_critic, critic, tau=config.tau) + + wandb.log({"times/backward_pass": timer.elapsed_time_gpu}, step=step) + wandb.log({"transitions": config.batch_size * config.seq_len * step, **loss_info}, step=step) + + if step % config.eval_every == 0: + with Timeit() as timer: + eval_stats = vec_evaluate( + eval_env, critic, config.eval_episodes, config.eval_seed, device=DEVICE + ) + raw_returns = eval_stats.pop("reward_raw") + raw_depths = eval_stats.pop("depth_raw") + normalized_scores = tmp_env.get_normalized_score(raw_returns) + + wandb.log({ + "times/evaluation_gpu": timer.elapsed_time_gpu, + "times/evaluation_cpu": timer.elapsed_time_cpu, + }, step=step) + wandb.log({"transitions": config.batch_size * config.seq_len * step, **eval_stats}, step=step) + + if config.checkpoints_path is not None: + torch.save(critic.state_dict(), os.path.join(config.checkpoints_path, f"{step}.pt")) + # saving raw logs + np.save(os.path.join(config.checkpoints_path, f"{step}_returns.npy"), raw_returns) + np.save(os.path.join(config.checkpoints_path, f"{step}_depths.npy"), raw_depths) + np.save(os.path.join(config.checkpoints_path, f"{step}_normalized_scores.npy"), normalized_scores) + + # also saving to wandb files for easier use in the future + np.save(os.path.join(wandb.run.dir, f"{step}_returns.npy"), raw_returns) + np.save(os.path.join(wandb.run.dir, f"{step}_depths.npy"), raw_depths) + np.save(os.path.join(wandb.run.dir, f"{step}_normalized_scores.npy"), normalized_scores) + + buffer.close() + + +if __name__ == "__main__": + set_start_method("spawn") + train() + diff --git a/algorithms/small_scale/rem_chaotic_lstm.py b/algorithms/small_scale/rem_chaotic_lstm.py new file mode 100644 index 0000000..908388e --- /dev/null +++ b/algorithms/small_scale/rem_chaotic_lstm.py @@ -0,0 +1,485 @@ +import pyrallis +from dataclasses import dataclass, asdict + +import random +import wandb +import os +import uuid +import torch +import torch.nn as nn +import torch.nn.functional as F + +from gym.vector import AsyncVectorEnv +from concurrent.futures import ThreadPoolExecutor +from tqdm.auto import tqdm, trange +import numpy as np + +from copy import deepcopy +from typing import Optional, Dict, Tuple, Any, List + +from multiprocessing import set_start_method +from katakomba.env import NetHackChallenge, OfflineNetHackChallengeWrapper +from katakomba.nn.chaotic_dwarf import TopLineEncoder, BottomLinesEncoder, ScreenEncoder +from katakomba.utils.render import SCREEN_SHAPE, render_screen_image +from katakomba.utils.datasets import SequentialBuffer +from katakomba.utils.misc import Timeit, StatMean + +LSTM_HIDDEN = Tuple[torch.Tensor, torch.Tensor] +UPDATE_INFO = Dict[str, Any] + +torch.backends.cudnn.benchmark = True +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +@dataclass +class TrainConfig: + character: str = "mon-hum-neu" + data_mode: str = "compressed" + # Wandb logging + project: str = "NetHack" + group: str = "small_scale_rem" + name: str = "rem" + version: int = 0 + # Model + rnn_hidden_dim: int = 2048 + rnn_layers: int = 2 + use_prev_action: bool = True + rnn_dropout: float = 0.0 + num_heads: int = 200 + clip_range: float = 10.0 + tau: float = 0.005 + gamma: float = 0.999 + # Training + update_steps: int = 500_000 + batch_size: int = 64 + seq_len: int = 16 + learning_rate: float = 3e-4 + weight_decay: float = 0.0 + clip_grad_norm: Optional[float] = None + checkpoints_path: Optional[str] = None + eval_every: int = 10_000 + eval_episodes: int = 50 + eval_processes: int = 14 + render_processes: int = 14 + eval_seed: int = 50 + train_seed: int = 42 + + def __post_init__(self): + self.group = f"{self.group}-v{str(self.version)}" + self.name = f"{self.name}-{self.character}-{str(uuid.uuid4())[:8]}" + if self.checkpoints_path is not None: + self.checkpoints_path = os.path.join(self.checkpoints_path, self.group, self.name) + + +def set_seed(seed: int): + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + + +@torch.no_grad() +def filter_wd_params(model: nn.Module) -> Tuple[List[nn.parameter.Parameter], List[nn.parameter.Parameter]]: + no_decay, decay = [], [] + for name, param in model.named_parameters(): + if hasattr(param, 'requires_grad') and not param.requires_grad: + continue + if 'weight' in name and 'norm' not in name and 'bn' not in name: + decay.append(param) + else: + no_decay.append(param) + assert len(no_decay) + len(decay) == len(list(model.parameters())) + return no_decay, decay + + +def dict_to_tensor(data: Dict[str, np.ndarray], device: str) -> Dict[str, torch.Tensor]: + return {k: torch.as_tensor(v, dtype=torch.float, device=device) for k, v in data.items()} + + +def soft_update(target: nn.Module, source: nn.Module, tau: float): + for tp, sp in zip(target.parameters(), source.parameters()): + tp.data.copy_((1 - tau) * tp.data + tau * sp.data) + + +def sample_convex_combination(size: int, device="cpu") -> torch.Tensor: + weights = torch.rand(size, device=device) + weights = weights / weights.sum() + assert torch.isclose(weights.sum(), torch.tensor([1.0], device=device)) + return weights.view(1, 1, -1, 1) + + +class Critic(nn.Module): + def __init__( + self, + action_dim: int, + num_heads: int, + rnn_hidden_dim: int = 512, + rnn_layers: int = 1, + rnn_dropout: float = 0.0, + use_prev_action: bool = True + ): + super().__init__() + self.num_heads = num_heads + self.num_actions = action_dim + self.use_prev_action = use_prev_action + self.prev_actions_dim = self.num_actions if self.use_prev_action else 0 + + # Encoders + self.topline_encoder = torch.jit.script(TopLineEncoder()) + self.bottomline_encoder = torch.jit.script(BottomLinesEncoder()) + + screen_shape = (SCREEN_SHAPE[1], SCREEN_SHAPE[2]) + self.screen_encoder = torch.jit.script(ScreenEncoder(screen_shape)) + + self.h_dim = sum([ + self.topline_encoder.hidden_dim, + self.bottomline_encoder.hidden_dim, + self.screen_encoder.hidden_dim, + self.prev_actions_dim, + ]) + # Policy + self.rnn = nn.LSTM( + self.h_dim, + rnn_hidden_dim, + dropout=rnn_dropout, + num_layers=rnn_layers, + batch_first=True) + self.head = nn.Linear(rnn_hidden_dim, self.num_actions * num_heads) + + def forward(self, inputs, state=None): + # [batch_size, seq_len, ...] + B, T, C, H, W = inputs["screen_image"].shape + topline = inputs["tty_chars"][..., 0, :] + bottom_line = inputs["tty_chars"][..., -2:, :] + + encoded_state = [ + self.topline_encoder( + topline.float(memory_format=torch.contiguous_format).view(T * B, -1) + ), + self.bottomline_encoder( + bottom_line.float(memory_format=torch.contiguous_format).view(T * B, -1) + ), + self.screen_encoder( + inputs["screen_image"] + .float(memory_format=torch.contiguous_format) + .view(T * B, C, H, W) + ), + ] + if self.use_prev_action: + encoded_state.append( + torch.nn.functional.one_hot(inputs["prev_actions"], self.num_actions).view(T * B, -1) + ) + + encoded_state = torch.cat(encoded_state, dim=1) + core_output, new_state = self.rnn(encoded_state.view(B, T, -1), state) + q_values_ensemble = self.head(core_output).view(B, T, self.num_heads, self.num_actions) + return q_values_ensemble, new_state + + @torch.no_grad() + def vec_act(self, obs, state=None, device="cpu"): + inputs = { + "tty_chars": torch.tensor(obs["tty_chars"][:, None], device=device), + "screen_image": torch.tensor(obs["screen_image"][:, None], device=device), + "prev_actions": torch.tensor(obs["prev_actions"][:, None], dtype=torch.long, device=device) + } + q_values_ensemble, new_state = self(inputs, state) + # [batch_size, seq_len, num_heads, num_actions] + q_values = q_values_ensemble.mean(2) + assert q_values.dim() == 3 + + actions = torch.argmax(q_values.squeeze(1), dim=-1) + return actions.cpu().numpy(), new_state + + +def rem_loss( + critic: Critic, + target_critic: Critic, + obs: Dict[str, torch.Tensor], + next_obs: Dict[str, torch.Tensor], + actions: torch.Tensor, + rewards: torch.Tensor, + dones: torch.Tensor, + rnn_states: LSTM_HIDDEN, + target_rnn_states: LSTM_HIDDEN, + convex_comb_weights: torch.Tensor, + gamma: float, +) -> Tuple[torch.Tensor, LSTM_HIDDEN, LSTM_HIDDEN, UPDATE_INFO]: + with torch.no_grad(): + next_q_values, next_target_rnn_states = target_critic(next_obs, state=target_rnn_states) + next_q_values = (next_q_values * convex_comb_weights).sum(2) + next_q_values = next_q_values.max(dim=-1).values + + assert next_q_values.shape == rewards.shape == dones.shape + q_target = rewards + gamma * (1 - dones) * next_q_values + + assert actions.dim() == 2 + q_pred, next_rnn_states = critic(obs, state=rnn_states) + q_pred = (q_pred * convex_comb_weights.detach()).sum(2) + q_pred = q_pred.gather(-1, actions.to(torch.long).unsqueeze(-1)).squeeze() + assert q_pred.shape == q_target.shape + + loss = F.mse_loss(q_pred, q_target) + loss_info = { + "loss": loss.item(), + "q_target": q_target.mean().item() + } + return loss, next_rnn_states, next_target_rnn_states, loss_info + + +@torch.no_grad() +def vec_evaluate( + vec_env: AsyncVectorEnv, + actor: Critic, + num_episodes: int, + seed: int = 0, + device: str = "cpu" +) -> Dict[str, np.ndarray]: + actor.eval() + # set seed for reproducibility (reseed=False by default) + vec_env.seed(seed) + # all this work is needed to mitigate bias for shorter + # episodes during vectorized evaluation, for more see: + # https://github.com/DLR-RM/stable-baselines3/issues/402 + n_envs = vec_env.num_envs + episode_rewards = [] + episode_lengths = [] + episode_depths = [] + + episode_counts = np.zeros(n_envs, dtype="int") + # Divides episodes among different sub environments in the vector as evenly as possible + episode_count_targets = np.array([(num_episodes + i) // n_envs for i in range(n_envs)], dtype="int") + + current_rewards = np.zeros(n_envs) + current_lengths = np.zeros(n_envs, dtype="int") + observations = vec_env.reset() + observations["prev_actions"] = np.zeros(n_envs, dtype=float) + + rnn_states = None + pbar = tqdm(total=num_episodes) + while (episode_counts < episode_count_targets).any(): + # faster to do this here for entire batch, than in wrappers for each env + observations["screen_image"] = render_screen_image( + tty_chars=observations["tty_chars"][:, np.newaxis, ...], + tty_colors=observations["tty_colors"][:, np.newaxis, ...], + tty_cursor=observations["tty_cursor"][:, np.newaxis, ...], + ) + observations["screen_image"] = np.squeeze(observations["screen_image"], 1) + + actions, rnn_states = actor.vec_act(observations, rnn_states, device=device) + + observations, rewards, dones, infos = vec_env.step(actions) + observations["prev_actions"] = actions + + current_rewards += rewards + current_lengths += 1 + + for i in range(n_envs): + if episode_counts[i] < episode_count_targets[i]: + if dones[i]: + episode_rewards.append(current_rewards[i]) + episode_lengths.append(current_lengths[i]) + episode_depths.append(infos[i]["current_depth"]) + episode_counts[i] += 1 + pbar.update(1) + + current_rewards[i] = 0 + current_lengths[i] = 0 + + pbar.close() + result = { + "reward_median": np.median(episode_rewards), + "reward_mean": np.mean(episode_rewards), + "reward_std": np.std(episode_rewards), + "reward_min": np.min(episode_rewards), + "reward_max": np.max(episode_rewards), + "reward_raw": np.array(episode_rewards), + # depth + "depth_median": np.median(episode_depths), + "depth_mean": np.mean(episode_depths), + "depth_std": np.std(episode_depths), + "depth_min": np.min(episode_depths), + "depth_max": np.max(episode_depths), + "depth_raw": np.array(episode_depths), + } + actor.train() + return result + + +@pyrallis.wrap() +def train(config: TrainConfig): + print(f"Device: {DEVICE}") + wandb.init( + config=asdict(config), + project=config.project, + group=config.group, + name=config.name, + id=str(uuid.uuid4()), + save_code=True, + ) + if config.checkpoints_path is not None: + print(f"Checkpoints path: {config.checkpoints_path}") + os.makedirs(config.checkpoints_path, exist_ok=True) + with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f: + pyrallis.dump(config, f) + + set_seed(config.train_seed) + + def env_fn(): + env = NetHackChallenge( + character=config.character, + observation_keys=["tty_chars", "tty_colors", "tty_cursor"] + ) + env = OfflineNetHackChallengeWrapper(env) + return env + + tmp_env = env_fn() + eval_env = AsyncVectorEnv( + env_fns=[env_fn for _ in range(config.eval_processes)], + copy=False + ) + buffer = SequentialBuffer( + dataset=tmp_env.get_dataset(mode=config.data_mode, scale="small"), + seq_len=config.seq_len, + batch_size=config.batch_size, + seed=config.train_seed, + add_next_step=True # true as this is needed for next_obs + ) + tp = ThreadPoolExecutor(max_workers=config.render_processes) + + critic = Critic( + action_dim=eval_env.single_action_space.n, + num_heads=config.num_heads, + use_prev_action=config.use_prev_action, + rnn_hidden_dim=config.rnn_hidden_dim, + rnn_layers=config.rnn_layers, + rnn_dropout=config.rnn_dropout, + ).to(DEVICE) + with torch.no_grad(): + target_critic = deepcopy(critic) + + no_decay_params, decay_params = filter_wd_params(critic) + optim = torch.optim.AdamW([ + {"params": no_decay_params, "weight_decay": 0.0}, + {"params": decay_params, "weight_decay": config.weight_decay} + ], lr=config.learning_rate) + print("Number of parameters:", sum(p.numel() for p in critic.parameters())) + + scaler = torch.cuda.amp.GradScaler() + rnn_state, target_rnn_state = None, None + prev_actions = torch.zeros((config.batch_size, 1), dtype=torch.long, device=DEVICE) + + # For reward normalization + reward_stats = StatMean(cumulative=True) + running_rewards = 0.0 + + for step in trange(1, config.update_steps + 1, desc="Training"): + with Timeit() as timer: + batch = buffer.sample() + screen_image = render_screen_image( + tty_chars=batch["tty_chars"], + tty_colors=batch["tty_colors"], + tty_cursor=batch["tty_cursor"], + threadpool=tp, + ) + batch["screen_image"] = screen_image + + # Update reward statistics (as in the original nle implementation) + running_rewards *= config.gamma + running_rewards += batch["rewards"] + reward_stats += running_rewards ** 2 + running_rewards *= (~batch["dones"]).astype(float) + # Normalize the reward + reward_std = reward_stats.mean() ** 0.5 + batch["rewards"] = batch["rewards"] / max(0.01, reward_std) + batch["rewards"] = np.clip(batch["rewards"], -config.clip_range, config.clip_range) + + batch = dict_to_tensor(batch, device=DEVICE) + + wandb.log( + { + "times/batch_loading_cpu": timer.elapsed_time_cpu, + "times/batch_loading_gpu": timer.elapsed_time_gpu, + }, + step=step, + ) + + with Timeit() as timer: + with torch.cuda.amp.autocast(): + obs = { + "screen_image": batch["screen_image"][:, :-1].contiguous(), + "tty_chars": batch["tty_chars"][:, :-1].contiguous(), + "prev_actions": torch.cat([prev_actions.long(), batch["actions"][:, :-2].long()], dim=1) + } + next_obs = { + "screen_image": batch["screen_image"][:, 1:].contiguous(), + "tty_chars": batch["tty_chars"][:, 1:].contiguous(), + "prev_actions": batch["actions"][:, :-1].long() + } + + loss, rnn_state, target_rnn_state, loss_info = rem_loss( + critic=critic, + target_critic=target_critic, + obs=obs, + next_obs=next_obs, + actions=batch["actions"][:, :-1], + rewards=batch["rewards"][:, :-1], + dones=batch["dones"][:, :-1], + rnn_states=rnn_state, + target_rnn_states=target_rnn_state, + convex_comb_weights=sample_convex_combination(config.num_heads, device=DEVICE), + gamma=config.gamma + ) + rnn_state = [a.detach() for a in rnn_state] + target_rnn_state = [a.detach() for a in target_rnn_state] + # update prev_actions for next iteration (-1 is seq_len + 1, so -2) + prev_actions = batch["actions"][:, -2].unsqueeze(-1) + + wandb.log({"times/forward_pass": timer.elapsed_time_gpu}, step=step) + + with Timeit() as timer: + scaler.scale(loss).backward() + if config.clip_grad_norm is not None: + scaler.unscale_(optim) + torch.nn.utils.clip_grad_norm_(critic.parameters(), config.clip_grad_norm) + scaler.step(optim) + scaler.update() + optim.zero_grad(set_to_none=True) + soft_update(target_critic, critic, tau=config.tau) + + wandb.log({"times/backward_pass": timer.elapsed_time_gpu}, step=step) + wandb.log({"transitions": config.batch_size * config.seq_len * step, **loss_info}, step=step) + + if step % config.eval_every == 0: + with Timeit() as timer: + + eval_stats = vec_evaluate( + eval_env, critic, config.eval_episodes, config.eval_seed, device=DEVICE + ) + raw_returns = eval_stats.pop("reward_raw") + raw_depths = eval_stats.pop("depth_raw") + normalized_scores = tmp_env.get_normalized_score(raw_returns) + + wandb.log({ + "times/evaluation_gpu": timer.elapsed_time_gpu, + "times/evaluation_cpu": timer.elapsed_time_cpu, + }, step=step) + wandb.log({"transitions": config.batch_size * config.seq_len * step, **eval_stats}, step=step) + + if config.checkpoints_path is not None: + torch.save(critic.state_dict(), os.path.join(config.checkpoints_path, f"{step}.pt")) + # saving raw logs + np.save(os.path.join(config.checkpoints_path, f"{step}_returns.npy"), raw_returns) + np.save(os.path.join(config.checkpoints_path, f"{step}_depths.npy"), raw_depths) + np.save(os.path.join(config.checkpoints_path, f"{step}_normalized_scores.npy"), normalized_scores) + + # also saving to wandb files for easier use in the future + np.save(os.path.join(wandb.run.dir, f"{step}_returns.npy"), raw_returns) + np.save(os.path.join(wandb.run.dir, f"{step}_depths.npy"), raw_depths) + np.save(os.path.join(wandb.run.dir, f"{step}_normalized_scores.npy"), normalized_scores) + + buffer.close() + + +if __name__ == "__main__": + set_start_method("spawn") + train() diff --git a/configs/sweeps/small_scale_awac_chaotic_lstm.yaml b/configs/sweeps/small_scale_awac_chaotic_lstm.yaml new file mode 100644 index 0000000..f845a43 --- /dev/null +++ b/configs/sweeps/small_scale_awac_chaotic_lstm.yaml @@ -0,0 +1,71 @@ +entity: tlab +project: NetHack +program: algorithms/small_scale/awac_chaotic_lstm.py +method: grid +parameters: + group: + value: "small_scale_awac_chaotic_lstm_multiseed" + version: + value: 0 + data_mode: + value: "memmap" + character: + values: [ + "arc-hum-law", + "arc-hum-neu", + "arc-dwa-law", + "arc-gno-neu", + + "bar-hum-neu", + "bar-hum-cha", + "bar-orc-cha", + + "cav-hum-law", + "cav-hum-neu", + "cav-dwa-law", + "cav-gno-neu", + + "hea-hum-neu", + "hea-gno-neu", + + "kni-hum-law", + + "mon-hum-neu", + "mon-hum-law", + "mon-hum-cha", + + "pri-hum-neu", + "pri-hum-law", + "pri-hum-cha", + "pri-elf-cha", + + "ran-hum-neu", + "ran-hum-cha", + "ran-elf-cha", + "ran-gno-neu", + "ran-orc-cha", + + "rog-hum-cha", + "rog-orc-cha", + + "sam-hum-law", + + "tou-hum-neu", + + "val-hum-neu", + "val-hum-law", + "val-dwa-law", + + "wiz-hum-neu", + "wiz-hum-cha", + "wiz-elf-cha", + "wiz-gno-neu", + "wiz-orc-cha", + ] + train_seed: + values: [0, 1, 2] +command: + - ${env} + - python3 + - ${program} + - ${args} \ No newline at end of file diff --git a/configs/sweeps/small_scale_bc_chaotic_lstm.yaml b/configs/sweeps/small_scale_bc_chaotic_lstm.yaml new file mode 100644 index 0000000..0f64d59 --- /dev/null +++ b/configs/sweeps/small_scale_bc_chaotic_lstm.yaml @@ -0,0 +1,71 @@ +entity: tlab +project: NetHack +program: algorithms/small_scale/bc_chaotic_lstm.py +method: grid +parameters: + group: + value: "small_scale_bc_chaotic_lstm_multiseed" + version: + value: 0 + data_mode: + value: "memmap" + character: + values: [ + "arc-hum-law", + "arc-hum-neu", + "arc-dwa-law", + "arc-gno-neu", + + "bar-hum-neu", + "bar-hum-cha", + "bar-orc-cha", + + "cav-hum-law", + "cav-hum-neu", + "cav-dwa-law", + "cav-gno-neu", + + "hea-hum-neu", + "hea-gno-neu", + + "kni-hum-law", + + "mon-hum-neu", + "mon-hum-law", + "mon-hum-cha", + + "pri-hum-neu", + "pri-hum-law", + "pri-hum-cha", + "pri-elf-cha", + + "ran-hum-neu", + "ran-hum-cha", + "ran-elf-cha", + "ran-gno-neu", + "ran-orc-cha", + + "rog-hum-cha", + "rog-orc-cha", + + "sam-hum-law", + + "tou-hum-neu", + + "val-hum-neu", + "val-hum-law", + "val-dwa-law", + + "wiz-hum-neu", + "wiz-hum-cha", + "wiz-elf-cha", + "wiz-gno-neu", + "wiz-orc-cha", + ] + train_seed: + values: [0, 1, 2] +command: + - ${env} + - python3 + - ${program} + - ${args} \ No newline at end of file diff --git a/configs/sweeps/small_scale_cql_chaotic_lstm.yaml b/configs/sweeps/small_scale_cql_chaotic_lstm.yaml new file mode 100644 index 0000000..cb5566c --- /dev/null +++ b/configs/sweeps/small_scale_cql_chaotic_lstm.yaml @@ -0,0 +1,73 @@ +entity: tlab +project: NetHack +program: algorithms/small_scale/cql_chaotic_lstm.py +method: grid +parameters: + group: + value: "small_scale_cql_chaotic_lstm_multiseed" + version: + value: 0 + data_mode: + value: "memmap" + character: + values: [ + "arc-hum-law", + "arc-hum-neu", + "arc-dwa-law", + "arc-gno-neu", + + "bar-hum-neu", + "bar-hum-cha", + "bar-orc-cha", + + "cav-hum-law", + "cav-hum-neu", + "cav-dwa-law", + "cav-gno-neu", + + "hea-hum-neu", + "hea-gno-neu", + + "kni-hum-law", + + "mon-hum-neu", + "mon-hum-law", + "mon-hum-cha", + + "pri-hum-neu", + "pri-hum-law", + "pri-hum-cha", + "pri-elf-cha", + + "ran-hum-neu", + "ran-hum-cha", + "ran-elf-cha", + "ran-gno-neu", + "ran-orc-cha", + + "rog-hum-cha", + "rog-orc-cha", + + "sam-hum-law", + + "tou-hum-neu", + + "val-hum-neu", + "val-hum-law", + "val-dwa-law", + + "wiz-hum-neu", + "wiz-hum-cha", + "wiz-elf-cha", + "wiz-gno-neu", + "wiz-orc-cha", + ] + alpha: + value: 0.0001 + train_seed: + values: [0, 1, 2] +command: + - ${env} + - python3 + - ${program} + - ${args} \ No newline at end of file diff --git a/configs/sweeps/small_scale_cql_chaotic_lstm_sweep.yaml b/configs/sweeps/small_scale_cql_chaotic_lstm_sweep.yaml new file mode 100644 index 0000000..384b76d --- /dev/null +++ b/configs/sweeps/small_scale_cql_chaotic_lstm_sweep.yaml @@ -0,0 +1,24 @@ +entity: tlab +project: NetHack +program: algorithms/small_scale/cql_chaotic_lstm.py +method: grid +parameters: + group: + value: "small_scale_cql_chaotic_lstm_sweep" + version: + value: 1 + data_mode: + value: "memmap" + character: + value: "mon-hum-neu" + update_steps: + value: 250000 + alpha: + values: [0.0001, 0.0005, 0.001, 0.05, 0.01, 0.05, 0.1, 0.5, 1.0] + train_seed: + values: [0, 1, 2] +command: + - ${env} + - python3 + - ${program} + - ${args} \ No newline at end of file diff --git a/configs/sweeps/small_scale_iql_chaotic_lstm.yaml b/configs/sweeps/small_scale_iql_chaotic_lstm.yaml new file mode 100644 index 0000000..7e3e9c9 --- /dev/null +++ b/configs/sweeps/small_scale_iql_chaotic_lstm.yaml @@ -0,0 +1,71 @@ +entity: tlab +project: NetHack +program: algorithms/small_scale/iql_chaotic_lstm.py +method: grid +parameters: + group: + value: "small_scale_iql_chaotic_lstm_multiseed" + version: + value: 0 + data_mode: + value: "memmap" + character: + values: [ + "arc-hum-law", + "arc-hum-neu", + "arc-dwa-law", + "arc-gno-neu", + + "bar-hum-neu", + "bar-hum-cha", + "bar-orc-cha", + + "cav-hum-law", + "cav-hum-neu", + "cav-dwa-law", + "cav-gno-neu", + + "hea-hum-neu", + "hea-gno-neu", + + "kni-hum-law", + + "mon-hum-neu", + "mon-hum-law", + "mon-hum-cha", + + "pri-hum-neu", + "pri-hum-law", + "pri-hum-cha", + "pri-elf-cha", + + "ran-hum-neu", + "ran-hum-cha", + "ran-elf-cha", + "ran-gno-neu", + "ran-orc-cha", + + "rog-hum-cha", + "rog-orc-cha", + + "sam-hum-law", + + "tou-hum-neu", + + "val-hum-neu", + "val-hum-law", + "val-dwa-law", + + "wiz-hum-neu", + "wiz-hum-cha", + "wiz-elf-cha", + "wiz-gno-neu", + "wiz-orc-cha", + ] + train_seed: + values: [0, 1, 2] +command: + - ${env} + - python3 + - ${program} + - ${args} \ No newline at end of file diff --git a/configs/sweeps/small_scale_rem_chaotic_lstm.yaml b/configs/sweeps/small_scale_rem_chaotic_lstm.yaml new file mode 100644 index 0000000..1d3b66d --- /dev/null +++ b/configs/sweeps/small_scale_rem_chaotic_lstm.yaml @@ -0,0 +1,71 @@ +entity: tlab +project: NetHack +program: algorithms/small_scale/rem_chaotic_lstm.py +method: grid +parameters: + group: + value: "small_scale_rem_chaotic_lstm_multiseed" + version: + value: 0 + data_mode: + value: "memmap" + character: + values: [ + "arc-hum-law", + "arc-hum-neu", + "arc-dwa-law", + "arc-gno-neu", + + "bar-hum-neu", + "bar-hum-cha", + "bar-orc-cha", + + "cav-hum-law", + "cav-hum-neu", + "cav-dwa-law", + "cav-gno-neu", + + "hea-hum-neu", + "hea-gno-neu", + + "kni-hum-law", + + "mon-hum-neu", + "mon-hum-law", + "mon-hum-cha", + + "pri-hum-neu", + "pri-hum-law", + "pri-hum-cha", + "pri-elf-cha", + + "ran-hum-neu", + "ran-hum-cha", + "ran-elf-cha", + "ran-gno-neu", + "ran-orc-cha", + + "rog-hum-cha", + "rog-orc-cha", + + "sam-hum-law", + + "tou-hum-neu", + + "val-hum-neu", + "val-hum-law", + "val-dwa-law", + + "wiz-hum-neu", + "wiz-hum-cha", + "wiz-elf-cha", + "wiz-gno-neu", + "wiz-orc-cha", + ] + train_seed: + values: [0, 1, 2] +command: + - ${env} + - python3 + - ${program} + - ${args} \ No newline at end of file diff --git a/katakomba/__init__.py b/katakomba/__init__.py deleted file mode 100644 index af45787..0000000 --- a/katakomba/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Katakomba: Deep Data-Driven Reinforcement Learning for NetHack -""" diff --git a/katakomba/datasets/__init__.py b/katakomba/datasets/__init__.py deleted file mode 100644 index 9496077..0000000 --- a/katakomba/datasets/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from katakomba.datasets.base import BaseAutoAscend -from katakomba.datasets.builder import AutoAscendDatasetBuilder -from katakomba.datasets.sars_autoascend import SARSAutoAscendTTYDataset -from katakomba.datasets.state_autoascend import StateAutoAscendTTYDataset -from katakomba.datasets.sars_chaotic_autoascend import SARSChaoticAutoAscendTTYDataset diff --git a/katakomba/datasets/base.py b/katakomba/datasets/base.py deleted file mode 100644 index 0dd409b..0000000 --- a/katakomba/datasets/base.py +++ /dev/null @@ -1,23 +0,0 @@ -from nle.dataset.dataset import TtyrecDataset -from torch.utils.data import IterableDataset - - -class BaseAutoAscend(IterableDataset): - def __init__( - self, - autoascend_iterator_cls, - ttyrecdata: TtyrecDataset, - batch_size: int, - **kwargs - ): - self._autoascend_iterator_cls = autoascend_iterator_cls - self._ttyrecdata = ttyrecdata - self._batch_size = batch_size - self._kwargs = kwargs - - def __iter__(self): - return iter( - self._autoascend_iterator_cls( - ttyrecdata=self._ttyrecdata, batch_size=self._batch_size, **self._kwargs - ) - ) diff --git a/katakomba/datasets/builder.py b/katakomba/datasets/builder.py deleted file mode 100644 index 4026b07..0000000 --- a/katakomba/datasets/builder.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Memes - - Screen may not contain the map on the screen (there can be just a menu, or inventory) - - Default TTYRec dataset fetches datapoints sequentially (i.e. each sample goes one after another within a game) - -What to keep in mind: - - Alignment between dataset actions and environment actions - - Match terminal sizes (seems that the dataset uses 80x24, but original NLE 79x21) -""" -from __future__ import annotations - -import logging -from concurrent.futures import ThreadPoolExecutor -from typing import List, Optional, Tuple - -import nle.dataset as nld - -from katakomba.datasets.base import BaseAutoAscend -from katakomba.datasets.sars_autoascend import SARSAutoAscendTTYDataset -from katakomba.datasets.sa_chaotic_autoascend import SAChaoticAutoAscendTTYDataset -from katakomba.utils.roles import Alignment, Race, Role, Sex - - -class AutoAscendDatasetBuilder: - """ - This is the most basic wrapper. - It obeys the original logic of the TTYRec dataset that samples data within-game-sequentially. - """ - - def __init__(self, path: str = "data/nle_data", db_path: str = "ttyrecs.db"): - # Create a sql-lite database for keeping trajectories - if not nld.db.exists(db_path): - nld.db.create(db_path) - nld.add_nledata_directory(path, "autoascend", db_path) - - # Create a connection to specify the database to use - db_conn = nld.db.connect(filename=db_path) - logging.info( - f"AutoAscend Dataset has {nld.db.count_games('autoascend', conn=db_conn)} games." - ) - - self.db_path = db_path - # Pre-init filters - # Note that all strings are further converted to be first-letter-capitalized - # This is how it's stored in dungeons data :shrug: - self._races: Optional[List[str]] = None - self._game_ids: Optional[List[int]] = None - self._alignments: Optional[List[str]] = None - self._sex: Optional[List[str]] = None - self._roles: Optional[List[str]] = None - self._game_versions: List[str] = ["3.6.6"] - - def races(self, races: List[Race]) -> AutoAscendDatasetBuilder: - self._races = [str(race.value).title() for race in races] - return self - - def roles(self, roles: List[Role]) -> AutoAscendDatasetBuilder: - self._roles = [str(role.value).title() for role in roles] - return self - - def sex(self, sex: List[Sex]) -> AutoAscendDatasetBuilder: - self._sex = [str(s.value).title() for s in sex] - return self - - def alignments(self, alignments: List[Alignment]) -> AutoAscendDatasetBuilder: - self._alignments = [str(alignment.value).title() for alignment in alignments] - return self - - def game_ids(self, game_ids: List[int]) -> AutoAscendDatasetBuilder: - self._game_ids = game_ids - return self - - def game_versions(self, versions: List[str]) -> AutoAscendDatasetBuilder: - self._game_versions = versions - return self - - def build( - self, - batch_size: int, - seq_len: int = 1, - n_workers: int = 32, - auto_ascend_cls=SARSAutoAscendTTYDataset, - **kwargs, - ) -> BaseAutoAscend: - """ - Args: - batch_size: well - n_prefetch_states: how many states will be preloaded into the device memory (CPU for now) - """ - # Build a sql query to select only filtered ones - query, query_args = self._build_sql_query() - - tp = ThreadPoolExecutor(max_workers=n_workers) - self._dataset = nld.TtyrecDataset( - dataset_name="autoascend", - dbfilename=self.db_path, - batch_size=batch_size, - seq_length=seq_len, - shuffle=True, - loop_forever=True, - subselect_sql=query, - subselect_sql_args=query_args, - threadpool=tp, - ) - print(f"Total games in the filtered dataset: {len(self._dataset._gameids)}") - - return auto_ascend_cls( - self._dataset, batch_size=batch_size, threadpool=tp, **kwargs - ) - - def _build_sql_query(self) -> Tuple[str, Tuple]: - subselect_sql = "SELECT gameid FROM games WHERE " - - # Game version (there can be potentially recordings from various NetHack versions) - subselect_sql += "version in ({seq}) AND ".format( - seq=",".join(["?"] * len(self._game_versions)) - ) - subselect_sql_args = tuple(self._game_versions) - - # If specific game ids were specified - if self._game_ids is not None: - subselect_sql += "gameid in ({seq}) AND ".format( - seq=",".join(["?"] * len(self._game_ids)) - ) - subselect_sql_args += tuple(self._game_ids) - if self._roles: - subselect_sql += "role in ({seq}) AND ".format( - seq=",".join(["?"] * len(self._roles)) - ) - subselect_sql_args += tuple(self._roles) - if self._races: - subselect_sql += "race in ({seq}) AND ".format( - seq=",".join(["?"] * len(self._races)) - ) - subselect_sql_args += tuple(self._races) - if self._alignments: - subselect_sql += "align in ({seq}) AND ".format( - seq=",".join(["?"] * len(self._alignments)) - ) - subselect_sql_args += tuple(self._alignments) - if self._sex: - subselect_sql += "gender in ({seq}) AND ".format( - seq=",".join(["?"] * len(self._sex)) - ) - subselect_sql_args += tuple(self._sex) - - # There will always be an AND at the end - subselect_sql = subselect_sql[:-5] - - return subselect_sql, subselect_sql_args diff --git a/katakomba/datasets/sa_autoascend.py b/katakomba/datasets/sa_autoascend.py deleted file mode 100644 index 7e764e9..0000000 --- a/katakomba/datasets/sa_autoascend.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -from katakomba.datasets.base import BaseAutoAscend -from katakomba.utils.actions import ascii_actions_to_gym_actions -from nle.dataset.dataset import TtyrecDataset -from nle.nethack.actions import ACTIONS - - -class _SAAutoAscendTTYIterator: - def __init__(self, ttyrecdata: TtyrecDataset, batch_size: int): - self._ttyrecdata = iter(ttyrecdata) - - # Mapping from ASCII keypresses to the gym env actions - self.action_mapping = np.zeros((256, 1)) - for i, a in enumerate(ACTIONS): - self.action_mapping[a.value][0] = i - - def __iter__(self): - while True: - batch = next(self._ttyrecdata) - # actions = ascii_actions_to_gym_actions(batch["keypresses"]) - # actions = batch["keypresses"] - actions = np.take_along_axis( - self.action_mapping, batch["keypresses"], axis=0 - ) - - yield ( - batch["tty_chars"], - batch["tty_colors"], - batch["tty_cursor"], - actions, - ) - - -class SAAutoAscendTTYDataset(BaseAutoAscend): - def __init__(self, ttyrecdata: TtyrecDataset, batch_size: int): - super().__init__(_SAAutoAscendTTYIterator, ttyrecdata, batch_size) diff --git a/katakomba/datasets/sa_chaotic_autoascend.py b/katakomba/datasets/sa_chaotic_autoascend.py deleted file mode 100644 index 40da354..0000000 --- a/katakomba/datasets/sa_chaotic_autoascend.py +++ /dev/null @@ -1,54 +0,0 @@ -import numpy as np -from katakomba.datasets.base import BaseAutoAscend -from nle.dataset.dataset import TtyrecDataset -from nle.nethack.actions import ACTIONS -from concurrent.futures import ThreadPoolExecutor -from katakomba.utils.render import render_screen_image - - -class _SAChaoticAutoAscendTTYIterator: - def __init__( - self, ttyrecdata: TtyrecDataset, batch_size: int, threadpool: ThreadPoolExecutor - ): - self._ttyrecdata = iter(ttyrecdata) - self._threadpool = threadpool - self._batch_size = batch_size - - # Mapping from ASCII keypresses to the gym env actions - self.action_mapping = np.zeros((256, 1)) - for i, a in enumerate(ACTIONS): - self.action_mapping[a.value][0] = i - - def __iter__(self): - while True: - batch = next(self._ttyrecdata) - - actions = np.take_along_axis( - self.action_mapping, batch["keypresses"], axis=0 - ) - screen_image = render_screen_image( - tty_chars=batch["tty_chars"], - tty_colors=batch["tty_colors"], - tty_cursor=batch["tty_cursor"], - threadpool=self._threadpool, - ) - - yield ( - batch["tty_chars"], - batch["tty_colors"], - batch["tty_cursor"], - screen_image, - actions, - ) - - -class SAChaoticAutoAscendTTYDataset(BaseAutoAscend): - def __init__( - self, ttyrecdata: TtyrecDataset, batch_size: int, threadpool: ThreadPoolExecutor - ): - super().__init__( - _SAChaoticAutoAscendTTYIterator, - ttyrecdata, - batch_size, - threadpool=threadpool, - ) diff --git a/katakomba/datasets/sars_autoascend.py b/katakomba/datasets/sars_autoascend.py deleted file mode 100644 index db21298..0000000 --- a/katakomba/datasets/sars_autoascend.py +++ /dev/null @@ -1,82 +0,0 @@ -import itertools -from copy import deepcopy - -import numpy as np -from nle.dataset.dataset import TtyrecDataset - -from katakomba.datasets.base import BaseAutoAscend -from katakomba.utils.actions import ascii_actions_to_gym_actions -from katakomba.utils.observations import tty_to_numpy - -from nle.dataset.dataset import TtyrecDataset - -from katakomba.datasets.base import BaseAutoAscend -from katakomba.utils.observations import tty_to_numpy - - -class _SARSAutoAscendTTYIterator: - def __init__(self, ttyrecdata: TtyrecDataset, batch_size: int): - self._ttyrecdata = iter(ttyrecdata) - - def __iter__(self): - # A note: I provided how the sequences look like below (+3 is an example) - # so it's easier to understand what's happening with the alignment - cur_batch = self._convert_batch(next(self._ttyrecdata)) - while True: - # [s_n, s_n+1, s_n+2, s_n+3] - # [a_n, a_n+1, a_n+2, a_n+3] - # [r_n-1, r_n, r_n+1, r_n+2] - # [d_n-1, d_n, d_n+1, d_n+2] - states, actions, rewards, dones = cur_batch - - # Alignment for rewards, dones, and next_states - # [r_n, r_n+1, r_n+2, r_n-1] - # [d_n, d_n+1, d_n+2, d_n-1] - # [s_n+1, s_n+2, s_n+3, s_n] - # TODO: gigantic 4x overhead over original loader, somehow we need to optimimze this! - rewards = np.roll(rewards, shift=-1, axis=1) - dones = np.roll(dones, shift=-1, axis=1) - next_states = np.roll(deepcopy(states), shift=-1, axis=1) - - # Replace the last element using the information from the next batch - # [r_n, r_n+1, r_n+2, r_n+3] - # [d_n, d_n+1, d_n+2, d_n+3] - # [s_n+1, s_n+2, s_n+3, s_n+4] - next_batch = self._convert_batch(next(self._ttyrecdata)) - rewards[:, -1] = next_batch[2][:, 0] - dones[:, -1] = next_batch[3][:, 0] - next_states[:, -1] = next_batch[0][:, 0] - - # Move on - cur_batch = next_batch - - # states: [batch_size, seq_len, 24, 80, 3] - # actions: [batch_size, seq_len] - # rewards: [batch_size, seq_len] - # dones: [batch_size, seq_len] - # next_states: [batch_size, seq_len, 24, 80, 3] - yield states, actions, rewards, next_states, dones - - def _convert_batch(self, batch): - # [batch_size, seq_len, 24, 80, 3] - states = tty_to_numpy( - tty_chars=batch["tty_chars"], - tty_colors=batch["tty_colors"], - tty_cursor=batch["tty_cursor"], - ) - - # [batch_size, seq_len] - actions = ascii_actions_to_gym_actions(batch["keypresses"]) - - # [batch_size, seq_len] - rewards = deepcopy(batch["scores"]) - - # [batch_size, seq_len] - dones = deepcopy(batch["done"]) - - return states, actions, rewards, dones - - -class SARSAutoAscendTTYDataset(BaseAutoAscend): - def __init__(self, ttyrecdata: TtyrecDataset, batch_size: int): - super().__init__(_SARSAutoAscendTTYIterator, ttyrecdata, batch_size) diff --git a/katakomba/datasets/sars_chaotic_autoascend.py b/katakomba/datasets/sars_chaotic_autoascend.py deleted file mode 100644 index 7a7b092..0000000 --- a/katakomba/datasets/sars_chaotic_autoascend.py +++ /dev/null @@ -1,86 +0,0 @@ -import numpy as np -from copy import deepcopy - -from katakomba.datasets.base import BaseAutoAscend -from nle.dataset.dataset import TtyrecDataset -from nle.nethack.actions import ACTIONS -from concurrent.futures import ThreadPoolExecutor -from katakomba.utils.render import render_screen_image - - -class _SARSChaoticAutoAscendTTYIterator: - def __init__( - self, ttyrecdata: TtyrecDataset, batch_size: int, threadpool: ThreadPoolExecutor - ): - self._ttyrecdata = iter(ttyrecdata) - self._threadpool = threadpool - self._batch_size = batch_size - - # Mapping from ASCII keypresses to the gym env actions - self.action_mapping = np.zeros((256, 1)) - for i, a in enumerate(ACTIONS): - self.action_mapping[a.value][0] = i - - def __iter__(self): - # A note: I provided how the sequences look like below (+3 is an example) - # so it's easier to understand what's happening with the alignment - cur_batch = self._convert_batch(next(self._ttyrecdata)) - while True: - # [s_n, s_n+1, s_n+2, s_n+3] - # [a_n, a_n+1, a_n+2, a_n+3] - # [r_n-1, r_n, r_n+1, r_n+2] - # [d_n-1, d_n, d_n+1, d_n+2] - screen_image, tty_chars, actions, rewards, dones = cur_batch - - # Alignment for rewards, dones, and next_states - # [r_n, r_n+1, r_n+2, r_n-1] - # [d_n, d_n+1, d_n+2, d_n-1] - # [s_n+1, s_n+2, s_n+3, s_n] - # TODO: gigantic 4x overhead over original loader, somehow we need to optimimze this! - rewards = np.roll(rewards, shift=-1, axis=1) - dones = np.roll(dones, shift=-1, axis=1) - next_screen_image = np.roll(screen_image, shift=-1, axis=1) - next_tty_chars = np.roll(tty_chars, shift=-1, axis=1) - - # Replace the last element using the information from the next batch - # [r_n, r_n+1, r_n+2, r_n+3] - # [d_n, d_n+1, d_n+2, d_n+3] - # [s_n+1, s_n+2, s_n+3, s_n+4] - next_batch = self._convert_batch(next(self._ttyrecdata)) - - rewards[:, -1] = next_batch[3][:, 0] - dones[:, -1] = next_batch[4][:, 0] - next_screen_image[:, -1] = next_batch[0][:, 0] - next_tty_chars[:, -1] = next_batch[1][:, 0] - - # Move on - cur_batch = next_batch - - yield screen_image, tty_chars, actions, rewards, next_screen_image, next_tty_chars, dones - - def _convert_batch(self, batch): - screen_image = render_screen_image( - tty_chars=batch["tty_chars"], - tty_colors=batch["tty_colors"], - tty_cursor=batch["tty_cursor"], - threadpool=self._threadpool, - ) - tty_chars = deepcopy(batch["tty_chars"]) - - actions = np.take_along_axis(self.action_mapping, batch["keypresses"], axis=0) - # TODO: score difference as reward - rewards = deepcopy(batch["scores"]) - dones = deepcopy(batch["done"]) - return screen_image, tty_chars, actions, rewards, dones - - -class SARSChaoticAutoAscendTTYDataset(BaseAutoAscend): - def __init__( - self, ttyrecdata: TtyrecDataset, batch_size: int, threadpool: ThreadPoolExecutor - ): - super().__init__( - _SARSChaoticAutoAscendTTYIterator, - ttyrecdata, - batch_size, - threadpool=threadpool, - ) diff --git a/katakomba/datasets/state_autoascend.py b/katakomba/datasets/state_autoascend.py deleted file mode 100644 index 2d9095d..0000000 --- a/katakomba/datasets/state_autoascend.py +++ /dev/null @@ -1,27 +0,0 @@ -from nle.dataset.dataset import TtyrecDataset - -from katakomba.datasets.base import BaseAutoAscend -from katakomba.utils.observations import tty_to_numpy - - -class _StateAutoAscendIterator: - def __init__(self, ttyrecdata: TtyrecDataset, batch_size: int, yield_freq: float): - self._ttyrecdata = ttyrecdata - self._batch_size = batch_size - self._yield_freq = yield_freq - - def __iter__(self): - for i, batch in enumerate(self._ttyrecdata): - if i % self._yield_freq == 0: - yield tty_to_numpy( - tty_chars=batch["tty_chars"][:, -1], - tty_colors=batch["tty_colors"][:, -1], - tty_cursor=batch["tty_cursor"][:, -1], - ) - - -class StateAutoAscendTTYDataset(BaseAutoAscend): - def __init__(self, ttyrecdata: TtyrecDataset, batch_size: int, yield_freq: float): - super().__init__( - _StateAutoAscendIterator, ttyrecdata, batch_size, yield_freq=yield_freq - ) diff --git a/katakomba/env.py b/katakomba/env.py new file mode 100644 index 0000000..0af071d --- /dev/null +++ b/katakomba/env.py @@ -0,0 +1,193 @@ +""" +Adopted from here: https://github.com/facebookresearch/nle/blob/4f5b57ea0e18f80e40fbdc33c1dbf94bbd265e42/nle/env/tasks.py#L9 + +Changes: +1. Removed raise on the seed setting for NetHackChallenge. +2. Added get_normalized_score method +3. Added get_dataset method +4. Added get_depth method +""" +import gym + +import nle +import numpy as np +from nle import nethack +from nle.env.tasks import NetHackScore + +from katakomba.utils.scores import MEAN_SCORES_AUTOASCEND +from katakomba.utils.datasets.small_scale import NLDSmallDataset +from katakomba.utils.datasets.large_scale import load_nld_aa_large_dataset +from katakomba.utils.roles import Role, Race, Alignment + +from typing import Optional, Tuple, Union + + +class NetHackChallenge(NetHackScore): + """Environment for the NetHack Challenge. + + The task is an augmentation of the standard NLE task. This is the NLE Score Task + but with some subtle differences: + * the action space is fixed to include the full keyboard + * menus and "" tokens are not skipped + * starting character is randomly assigned + """ + + def __init__( + self, + *args, + character="@", + allow_all_yn_questions=True, + allow_all_modes=True, + penalty_mode="constant", + penalty_step: float = -0.00, + penalty_time: float = -0.0, + max_episode_steps: int = 1e6, + observation_keys=( + "glyphs", + "chars", + "colors", + "specials", + "blstats", + "message", + "inv_glyphs", + "inv_strs", + "inv_letters", + "inv_oclasses", + "tty_chars", + "tty_colors", + "tty_cursor", + "misc", + ), + no_progress_timeout: int = 1000, + **kwargs, + ): + actions = nethack.ACTIONS + kwargs["wizard"] = False + super().__init__( + *args, + actions=actions, + character=character, + allow_all_yn_questions=allow_all_yn_questions, + allow_all_modes=allow_all_modes, + penalty_mode=penalty_mode, + penalty_step=penalty_step, + penalty_time=penalty_time, + max_episode_steps=max_episode_steps, + observation_keys=observation_keys, + **kwargs, + ) + # If the in-game turn count doesn't change for N steps, we abort + self.no_progress_timeout = no_progress_timeout + + def reset(self, *args, **kwargs): + self._turns = None + self._no_progress_count = 0 + return super().reset(*args, **kwargs) + + def _check_abort(self, observation): + """Check if time has stopped and no observations has changed long enough + to trigger an abort.""" + + turns = observation[self._blstats_index][nethack.NLE_BL_TIME] + if self._turns == turns: + self._no_progress_count += 1 + else: + self._turns = turns + self._no_progress_count = 0 + return ( + self._steps >= self._max_episode_steps + or self._no_progress_count >= self.no_progress_timeout + ) + + +class OfflineNetHackChallengeWrapper(gym.Wrapper): + """ + Offline NetHackChallenge wrappers. Adds normalized scores and dataset loading. + """ + def __init__(self, env: nle.env.NLE): + super().__init__(env) + self.env = env + + def seed( + self, + core: Optional[int] = None, + disp: Optional[int] = None, + reseed: bool = False, + ) -> Tuple[int, int, bool]: + """ + Sets the state of the NetHack RNGs after the next reset. + + NetHack 3.6 uses two RNGs, core and disp. This is to prevent + RNG-manipulation by e.g. running into walls or other no-ops on the + actual game state. This is a measure against "tool-assisted + speedruns" (TAS). NLE can run in both NetHack's default mode and in + TAS-friendly "no reseeding" if `reseed` is set to False, see below. + + Arguments: + core [int or None]: Seed for the core RNG. If None, chose a random + value. + disp [int or None]: Seed for the disp (anti-TAS) RNG. If None, chose + a random value. + reseed [boolean]: As an Anti-TAS (automation) measure, + NetHack 3.6 reseeds with true randomness every now and then. This + flag enables or disables this behavior. If set to True, trajectories + won't be reproducible. + + Returns: + [tuple] The seeds supplied, in the form (core, disp, reseed). + """ + return self.env.seed(core, disp, reseed) + + def get_normalized_score(self, score: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + """ + Returns score normalized against AutoAscend bot scores achieved for this exact character. + """ + if self.character.count("-") != 2: + raise ValueError("Reference score is not provided for this character.") + + role, race, align = self.character.split("-") + role, race, align = Role(role), Race(race), Alignment(align) + + ref_mean_score = MEAN_SCORES_AUTOASCEND[(role, race, align)] + + return score / ref_mean_score + + def get_current_depth(self): + """ + This returns the depth your agent is at. Note that it's not the same as the dungeon level. + https://nethackwiki.com/wiki/Dungeon_level + + Also note that this is not representative of how well your agent's doing after you descended to the amulet of yendor. + But for current state-of-the-art this is good enough. + We do not use dungeon's level as in some cases it can be biased by the agent's experience. + """ + return int( + self.env.last_observation[self.env._blstats_index][nethack.NLE_BL_DEPTH] + ) + + def get_dataset(self, scale: str = "small", **kwargs): + if self.character.count("-") != 2: + raise ValueError("Reference score is not provided for this character.") + + role, race, align = self.character.split("-") + role, race, align = Role(role), Race(race), Alignment(align) + + if scale == "small": + return NLDSmallDataset(role, race, align, **kwargs) + elif scale == "big": + return load_nld_aa_large_dataset( + role=role, + race=race, + align=align, + **kwargs + ) + else: + raise RuntimeError( + "Unknown dataset scale. Please specify 'small' for small" + " scale dataset and 'big' for NLD-AA full dataset." + ) + + def step(self, action): + obs, reward, done, info = self.env.step(action) + info["current_depth"] = self.get_current_depth() + return obs, reward, done, info diff --git a/katakomba/envs/__init__.py b/katakomba/envs/__init__.py deleted file mode 100644 index 0d30ba8..0000000 --- a/katakomba/envs/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from katakomba.envs.envs import NetHackChallenge diff --git a/katakomba/envs/builder.py b/katakomba/envs/builder.py deleted file mode 100644 index cc580fb..0000000 --- a/katakomba/envs/builder.py +++ /dev/null @@ -1,109 +0,0 @@ -from __future__ import annotations - -from copy import deepcopy -from itertools import product -from typing import List, Optional - -from nle.env.base import NLE - -from katakomba.utils.roles import ALLOWED_COMBOS, Alignment, Race, Role, Sex -from katakomba.wrappers import NetHackWrapper - - -class NetHackEnvBuilder: - def __init__(self, nethack_env_fn: NLE, wrapper: Optional[NetHackWrapper]): - """ - To keep in mind: - - Not all combinations of character options are allowed by NetHack, we will filter it for you. - - Default behavior is to build environments for each allowed combination, no seeds fixed (i.e. complete evaluation) - """ - self._env_fn = nethack_env_fn - self._env_wrapper = wrapper - - self._races = [race for race in Race] - self._roles = [role for role in Role] - self._sex = [sex for sex in Sex] - self._alignments = [alignment for alignment in Alignment] - self._eval_seeds = None - self._train_seeds = None - - def races(self, races: List[Race]) -> NetHackEnvBuilder: - self._races = races - return self - - def roles(self, roles: List[Role]) -> NetHackEnvBuilder: - self._roles = roles - return self - - def sex(self, sex: List[Sex]) -> NetHackEnvBuilder: - self._sex = sex - return self - - def alignments(self, alignments: List[Alignment]) -> NetHackEnvBuilder: - self._alignments = alignments - return self - - def eval_seeds(self, seeds: List[int]) -> NetHackEnvBuilder: - self._eval_seeds = seeds - return self - - def train_seeds(self, seeds: List[int]) -> NetHackEnvBuilder: - self._train_seeds = seeds - - def evaluate(self): - """ - An iterator over the NLE settings to evaluate against. - """ - - all_valid_combinations = deepcopy(ALLOWED_COMBOS) - valid_combinations = set() - - # Filter only allowed game settings - for role, race, alignment, sex in product( - self._roles, self._races, self._alignments, self._sex - ): - if (role, race, alignment) in all_valid_combinations: - # Valkyries do not have sex - if role is Role.VALKYRIE: - valid_combinations.add((role, race, alignment, None)) - else: - valid_combinations.add((role, race, alignment, sex)) - - # Generate character descriptions for underlying NetHack engine - eval_characters = [] - for (role, race, alignment, sex) in valid_combinations: - if sex is not None: - eval_characters.append( - f"{role.value}-{race.value}-{alignment.value}-{sex.value}" - ) - else: - eval_characters.append(f"{role.value}-{race.value}-{alignment.value}") - - # Environment and its wrapper are dataset-dependent (wrappers are needed for producing images of tty) - if self._env_wrapper: - env_fn = lambda char: self._env_wrapper( - self._env_fn(character=char, savedir=False) - ) - else: - env_fn = lambda char: self._env_fn(character=char, savedir=False) - - # Generate nethack challenges - for character in sorted(eval_characters): - if self._eval_seeds is None: - yield character, env_fn(character), None - else: - for seed in self._eval_seeds: - yield character, env_fn(character), seed - - def get_action_dim(self) -> int: - if self._env_wrapper: - env_fn = lambda char: self._env_wrapper( - self._env_fn(character=char, savedir=False) - ) - else: - env_fn = self._env_fn - - # Environment with a random character (action space does not depent on the character) - dummy_env = env_fn("@") - - return dummy_env.action_space.n diff --git a/katakomba/envs/envs.py b/katakomba/envs/envs.py deleted file mode 100644 index ba7f098..0000000 --- a/katakomba/envs/envs.py +++ /dev/null @@ -1,367 +0,0 @@ -""" -Our versions of different environments - - Not all is used and adopted at the moment - -Adopted from here: https://github.com/facebookresearch/nle/blob/4f5b57ea0e18f80e40fbdc33c1dbf94bbd265e42/nle/env/tasks.py#L9 -""" - -import enum - -import numpy as np -from nle import nethack -from nle.env import base - -TASK_ACTIONS = tuple( - [nethack.MiscAction.MORE] - + list(nethack.CompassDirection) - + list(nethack.CompassDirectionLonger) - + list(nethack.MiscDirection) - + [nethack.Command.KICK, nethack.Command.EAT, nethack.Command.SEARCH] -) - - -class NetHackScore(base.NLE): - """Environment for "score" task. - - The task is an augmentation of the standard NLE task. The return function is - defined as: - :math:`\text{score}_t - \text{score}_{t-1} + \text{TP}`, - where the :math:`\text{TP}` is a time penalty that grows with the amount of - environment steps that do not change the state (such as navigating menus). - - Args: - penalty_mode (str): name of the mode for calculating the time step - penalty. Can be ``constant``, ``exp``, ``square``, ``linear``, or - ``always``. Defaults to ``constant``. - penalty_step (float): constant applied to amount of frozen steps. - Defaults to -0.01. - penalty_time (float): constant applied to amount of frozen steps. - Defaults to -0.0. - - """ - - def __init__( - self, - *args, - penalty_mode="constant", - penalty_step: float = -0.01, - penalty_time: float = -0.0, - **kwargs, - ): - self.penalty_mode = penalty_mode - self.penalty_step = penalty_step - self.penalty_time = penalty_time - - self._frozen_steps = 0 - - actions = kwargs.pop("actions", TASK_ACTIONS) - super().__init__(*args, actions=actions, **kwargs) - - def _get_time_penalty(self, last_observation, observation): - blstats_old = last_observation[self._blstats_index] - blstats_new = observation[self._blstats_index] - - old_time = blstats_old[nethack.NLE_BL_TIME] - new_time = blstats_new[nethack.NLE_BL_TIME] - - if old_time == new_time: - self._frozen_steps += 1 - else: - self._frozen_steps = 0 - - penalty = 0 - if self.penalty_mode == "constant": - if self._frozen_steps > 0: - penalty += self.penalty_step - elif self.penalty_mode == "exp": - penalty += 2**self._frozen_steps * self.penalty_step - elif self.penalty_mode == "square": - penalty += self._frozen_steps**2 * self.penalty_step - elif self.penalty_mode == "linear": - penalty += self._frozen_steps * self.penalty_step - elif self.penalty_mode == "always": - penalty += self.penalty_step - else: # default - raise ValueError("Unknown penalty_mode '%s'" % self.penalty_mode) - penalty += (new_time - old_time) * self.penalty_time - return penalty - - def _reward_fn(self, last_observation, action, observation, end_status): - """Score delta, but with added a state loop penalty.""" - score_diff = super()._reward_fn( - last_observation, action, observation, end_status - ) - time_penalty = self._get_time_penalty(last_observation, observation) - return score_diff + time_penalty - - -class NetHackStaircase(NetHackScore): - """Environment for "staircase" task. - - This task requires the agent to get on top of a staircase down (>). - The reward function is :math:`I + \text{TP}`, where :math:`I` is 1 if the - task is successful, and 0 otherwise, and :math:`\text{TP}` is the time step - function as defined by `NetHackScore`. - """ - - class StepStatus(enum.IntEnum): - ABORTED = -1 - RUNNING = 0 - DEATH = 1 - TASK_SUCCESSFUL = 2 - - def _is_episode_end(self, observation): - internal = observation[self._internal_index] - stairs_down = internal[4] - if stairs_down: - return self.StepStatus.TASK_SUCCESSFUL - return self.StepStatus.RUNNING - - def _reward_fn(self, last_observation, action, observation, end_status): - del action # Unused - time_penalty = self._get_time_penalty(last_observation, observation) - if end_status == self.StepStatus.TASK_SUCCESSFUL: - reward = 1 - else: - reward = 0 - return reward + time_penalty - - -class NetHackStaircasePet(NetHackStaircase): - """Environment for "staircase-pet" task. - - This task requires the agent to get on top of a staircase down (>), while - having their pet next to it. See `NetHackStaircase` for the reward function. - """ - - def _is_episode_end(self, observation): - internal = observation[self._internal_index] - stairs_down = internal[4] - if stairs_down: - glyphs = observation[self._glyph_index] - blstats = observation[self._blstats_index] - x, y = blstats[:2] - - neighbors = glyphs[y - 1 : y + 2, x - 1 : x + 2] - if np.any(nethack.glyph_is_pet(neighbors)): - return self.StepStatus.TASK_SUCCESSFUL - return self.StepStatus.RUNNING - - -class NetHackOracle(NetHackStaircase): - """Environment for "oracle" task. - - This task requires the agent to reach the oracle (by standing next to it). - See `NetHackStaircase` for the reward function. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.oracle_glyph = None - for glyph in range(nethack.GLYPH_MON_OFF, nethack.GLYPH_PET_OFF): - if nethack.permonst(nethack.glyph_to_mon(glyph)).mname == "Oracle": - self.oracle_glyph = glyph - break - assert self.oracle_glyph is not None - - def _is_episode_end(self, observation): - glyphs = observation[self._glyph_index] - blstats = observation[self._blstats_index] - x, y = blstats[:2] - - neighbors = glyphs[y - 1 : y + 2, x - 1 : x + 2] - if np.any(neighbors == self.oracle_glyph): - return self.StepStatus.TASK_SUCCESSFUL - return self.StepStatus.RUNNING - - -class NetHackGold(NetHackScore): - """Environment for the "gold" task. - - The task is similar to the one defined by `NetHackScore`, but the reward - uses changes in the amount of gold collected by the agent, rather than the - score. - - The agent will pickup gold automatically by walking on top of it. - """ - - def __init__(self, *args, **kwargs): - options = kwargs.pop("options", None) - - if options is None: - # Copy & swap out "pickup_types". - options = [] - for option in nethack.NETHACKOPTIONS: - if option.startswith("pickup_types"): - options.append("pickup_types:$") - continue - options.append(option) - - super().__init__(*args, options=options, **kwargs) - - def _reward_fn(self, last_observation, action, observation, end_status): - """Difference between previous gold and new gold.""" - del end_status # Unused - del action # Unused - if not self.nethack.in_normal_game(): - # Before game started and after it ended blstats are zero. - return 0.0 - - old_blstats = last_observation[self._blstats_index] - blstats = observation[self._blstats_index] - - old_gold = old_blstats[nethack.NLE_BL_GOLD] - gold = blstats[nethack.NLE_BL_GOLD] - - time_penalty = self._get_time_penalty(last_observation, observation) - - return gold - old_gold + time_penalty - - -# FIXME: the way the reward function is currently structured means the -# agents gets a penalty of -1 every other step (since the -# uhunger increases by that) -# thus the step penalty becomes irrelevant -class NetHackEat(NetHackScore): - """Environment for the "eat" task. - - The task is similar to the one defined by `NetHackScore`, but the reward - uses positive changes in the character's hunger level (e.g. by consuming - comestibles or monster corpses), rather than the score. - """ - - def _reward_fn(self, last_observation, action, observation, end_status): - """Difference between previous hunger and new hunger.""" - del end_status # Unused - del action # Unused - - if not self.nethack.in_normal_game(): - # Before game started and after it ended blstats are zero. - return 0.0 - - old_internal = last_observation[self._internal_index] - internal = observation[self._internal_index] - - old_uhunger = old_internal[7] - uhunger = internal[7] - - reward = max(0, uhunger - old_uhunger) - - time_penalty = self._get_time_penalty(last_observation, observation) - - return reward + time_penalty - - -class NetHackScout(NetHackScore): - """Environment for the "scout" task. - - The task is similar to the one defined by `NetHackScore`, but the score is - defined by the changes in glyphs discovered by the agent. - """ - - def reset(self, *args, **kwargs): - self.dungeon_explored = {} - return super().reset(*args, **kwargs) - - def _reward_fn(self, last_observation, action, observation, end_status): - del end_status # Unused - del action # Unused - - if not self.nethack.in_normal_game(): - # Before game started and after it ended blstats are zero. - return 0.0 - - reward = 0 - glyphs = observation[self._glyph_index] - blstats = observation[self._blstats_index] - - dungeon_num = blstats[nethack.NLE_BL_DNUM] - dungeon_level = blstats[nethack.NLE_BL_DLEVEL] - - key = (dungeon_num, dungeon_level) - explored = np.sum(glyphs != nethack.GLYPH_CMAP_OFF) - explored_old = 0 - if key in self.dungeon_explored: - explored_old = self.dungeon_explored[key] - reward = explored - explored_old - self.dungeon_explored[key] = explored - time_penalty = self._get_time_penalty(last_observation, observation) - return reward + time_penalty - - -class NetHackChallenge(NetHackScore): - """Environment for the NetHack Challenge. - - The task is an augmentation of the standard NLE task. This is the NLE Score Task - but with some subtle differences: - * the action space is fixed to include the full keyboard - * menus and "" tokens are not skipped - * starting character is randomly assigned - """ - - def __init__( - self, - *args, - character="@", - allow_all_yn_questions=True, - allow_all_modes=True, - penalty_mode="constant", - penalty_step: float = -0.00, - penalty_time: float = -0.0, - max_episode_steps: int = 1e6, - observation_keys=( - "glyphs", - "chars", - "colors", - "specials", - "blstats", - "message", - "inv_glyphs", - "inv_strs", - "inv_letters", - "inv_oclasses", - "tty_chars", - "tty_colors", - "tty_cursor", - "misc", - ), - no_progress_timeout: int = 1000, - **kwargs, - ): - actions = nethack.ACTIONS - kwargs["wizard"] = False - super().__init__( - *args, - actions=actions, - character=character, - allow_all_yn_questions=allow_all_yn_questions, - allow_all_modes=allow_all_modes, - penalty_mode=penalty_mode, - penalty_step=penalty_step, - penalty_time=penalty_time, - max_episode_steps=max_episode_steps, - observation_keys=observation_keys, - **kwargs, - ) - # If the in-game turn count doesn't change for N steps, we abort - self.no_progress_timeout = no_progress_timeout - - def reset(self, *args, **kwargs): - self._turns = None - self._no_progress_count = 0 - return super().reset(*args, **kwargs) - - def _check_abort(self, observation): - """Check if time has stopped and no observations has changed long enough - to trigger an abort.""" - - turns = observation[self._blstats_index][nethack.NLE_BL_TIME] - if self._turns == turns: - self._no_progress_count += 1 - else: - self._turns = turns - self._no_progress_count = 0 - return ( - self._steps >= self._max_episode_steps - or self._no_progress_count >= self.no_progress_timeout - ) diff --git a/katakomba/nn/vit.py b/katakomba/nn/vit.py deleted file mode 100644 index e69de29..0000000 diff --git a/katakomba/tasks.py b/katakomba/tasks.py deleted file mode 100644 index 9d11c3f..0000000 --- a/katakomba/tasks.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -from typing import Tuple - -from katakomba.datasets.builder import AutoAscendDatasetBuilder -from katakomba.envs import NetHackChallenge -from katakomba.envs.builder import NetHackEnvBuilder -from katakomba.wrappers import TTYWrapper, CropRenderWrapper - -TASKS = { - "NetHackScore-v0-tty-bot-v0": { - "env_fn": NetHackChallenge, - "wrapper_fn": TTYWrapper, - "dataset_builder_fn": AutoAscendDatasetBuilder, - }, - "NetHackScore-v0-ttyimg-bot-v0": { - "env_fn": NetHackChallenge, - "wrapper_fn": CropRenderWrapper, - "dataset_builder_fn": AutoAscendDatasetBuilder, - }, -} - - -def make_task_builder( - task: str, data_path: str = "data/nle_data", db_path: str = "ttyrecs.db" -) -> Tuple[NetHackEnvBuilder, AutoAscendDatasetBuilder]: - """ - Creates environment and dataset builders for a task, which you can further configure for your needs. - """ - if task not in TASKS: - raise Exception(f"There is no such task: {task}") - - env_fn = TASKS[task]["env_fn"] - wrapper_fn = TASKS[task]["wrapper_fn"] - dataset_builder_fn = TASKS[task]["dataset_builder_fn"] - - return NetHackEnvBuilder(env_fn, wrapper_fn), dataset_builder_fn( - path=data_path, db_path=db_path - ) diff --git a/katakomba/utils/datasets/__init__.py b/katakomba/utils/datasets/__init__.py new file mode 100644 index 0000000..1feef4c --- /dev/null +++ b/katakomba/utils/datasets/__init__.py @@ -0,0 +1,3 @@ +from katakomba.utils.datasets.small_scale import load_nld_aa_small_dataset, NLDSmallDataset +from katakomba.utils.datasets.large_scale import load_nld_aa_large_dataset +from katakomba.utils.datasets.small_scale_buffer import SequentialBuffer \ No newline at end of file diff --git a/katakomba/utils/datasets/large_scale.py b/katakomba/utils/datasets/large_scale.py new file mode 100644 index 0000000..f308c65 --- /dev/null +++ b/katakomba/utils/datasets/large_scale.py @@ -0,0 +1,99 @@ +import shutil +import nle.dataset as nld + +from katakomba.utils.roles import Alignment, Race, Role, Sex +from typing import Tuple, Sequence, Optional +from concurrent.futures import ThreadPoolExecutor + + +def load_nld_aa_large_dataset( + data_path: str, + db_path: str, + seq_len: int, + batch_size: int, + num_workers: int = 8, + role: Optional[Role] = None, + race: Optional[Race] = None, + align: Optional[Alignment] = None, + **kwargs +) -> nld.TtyrecDataset: + if nld.db.exists(db_path): + # if the db was not properly initialized previously for some reason + # (i.e., a wrong path and then fixed) we need to delete it and recreate from scratch + shutil.rmtree(db_path) + + nld.db.create(db_path) + nld.add_nledata_directory(data_path, "autoascend", db_path) + + # how to write it more clearly? + query, query_args = build_dataset_sql_query( + roles=[str(role.value).title()] if role is not None else None, + races=[str(race.value).title()] if race is not None else None, + alignments=[str(align.value).title()] if align is not None else None, + **kwargs + ) + tp = ThreadPoolExecutor(max_workers=num_workers) + + dataset = nld.TtyrecDataset( + dataset_name="autoascend", + dbfilename=db_path, + batch_size=batch_size, + seq_length=seq_len, + shuffle=True, + loop_forever=True, + subselect_sql=query, + subselect_sql_args=query_args, + threadpool=tp, + ) + print(f"Total games in the filtered dataset: {len(dataset._gameids)}") + + return dataset + + +def build_dataset_sql_query( + roles: Optional[Sequence[str]] = None, + races: Optional[Sequence[str]] = None, + alignments: Optional[Sequence[str]] = None, + genders: Optional[Sequence[str]] = None, + game_versions: Optional[Sequence[str]] = ("3.6.6",), + game_ids: Optional[Tuple[int]] = None +) -> Tuple[str, Tuple]: + subselect_sql = "SELECT gameid FROM games WHERE " + + # Game version (there can be potentially recordings from various NetHack versions) + subselect_sql += "version in ({seq}) AND ".format( + seq=",".join(["?"] * len(game_versions)) + ) + subselect_sql_args = tuple(game_versions) + + # If specific game ids were specified + if game_ids is not None: + subselect_sql += "gameid in ({seq}) AND ".format( + seq=",".join(["?"] * len(game_ids)) + ) + subselect_sql_args += tuple(game_ids) + if roles is not None: + subselect_sql += "role in ({seq}) AND ".format( + seq=",".join(["?"] * len(roles)) + ) + subselect_sql_args += tuple(roles) + if races is not None: + subselect_sql += "race in ({seq}) AND ".format( + seq=",".join(["?"] * len(races)) + ) + subselect_sql_args += tuple(races) + if alignments is not None: + subselect_sql += "align in ({seq}) AND ".format( + seq=",".join(["?"] * len(alignments)) + ) + subselect_sql_args += tuple(alignments) + if genders is not None: + subselect_sql += "gender in ({seq}) AND ".format( + seq=",".join(["?"] * len(genders)) + ) + subselect_sql_args += tuple(genders) + + # There will always be an AND at the end + subselect_sql = subselect_sql[:-5] + + return subselect_sql, subselect_sql_args \ No newline at end of file diff --git a/katakomba/utils/datasets/small_scale.py b/katakomba/utils/datasets/small_scale.py new file mode 100644 index 0000000..9b37fce --- /dev/null +++ b/katakomba/utils/datasets/small_scale.py @@ -0,0 +1,141 @@ +import os +import h5py +import shutil +import numpy as np + +import urllib +from typing import Optional +from katakomba.utils.roles import Role, Race, Alignment, ALLOWED_COMBOS +from tqdm.auto import tqdm +from typing import Tuple, List, Dict, Any + +BASE_REPO_ID = os.environ.get('KATAKOMBA_REPO_ID', os.path.expanduser('Howuhh/katakomba')) +DATA_PATH = os.environ.get('KATAKOMBA_DATA_DIR', os.path.expanduser('~/.katakomba/datasets')) +CACHE_PATH = os.environ.get('KATAKOMBA_CACHE_DIR', os.path.expanduser('~/.katakomba/cache')) + + +# similar to huggingface_hub function hf_hub_url +def download_dataset( + repo_id: str, + filename: str, + subfolder: Optional[str] = None +): + dataset_path = os.path.join(DATA_PATH, filename) + if subfolder is not None: + filename = f"{subfolder}/{filename}" + dataset_url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{filename}" + + print(f"Downloading dataset: {dataset_url} to {DATA_PATH}") + urllib.request.urlretrieve(dataset_url, dataset_path) + + if not os.path.exists(os.path.join(DATA_PATH, filename.split("/")[-1])): + raise IOError(f"Failed to download dataset from {dataset_url}") + + +def _flush_to_memmap(filename: str, array: np.ndarray) -> np.ndarray: + if os.path.exists(filename): + mmap = np.load(filename, mmap_mode="r") + else: + mmap = np.memmap(filename, mode="w+", dtype=array.dtype, shape=array.shape) + mmap[:] = array + mmap.flush() + + return mmap + + +def load_nld_aa_small_dataset( + role: Role, + race: Race, + align: Alignment, + mode: str = "in_memory" +) -> Tuple[h5py.File, List[Dict[str, Any]]]: + os.makedirs(DATA_PATH, exist_ok=True) + if (role, race, align) not in ALLOWED_COMBOS: + raise RuntimeError( + "Invalid character combination! " + "Please see all allowed combos in the katakomba/utils/roles.py" + ) + dataset_name = f"data-{role.value}-{race.value}-{align.value}-any.hdf5" + if not os.path.exists(os.path.join(DATA_PATH, dataset_name)): + download_dataset( + repo_id=BASE_REPO_ID, + subfolder="data", + filename=dataset_name, + ) + + dataset_path = os.path.join(DATA_PATH, dataset_name) + df = h5py.File(dataset_path, "r") + + if mode == "in_memory": + trajectories = {} + for episode in tqdm(df["/"].keys(), leave=False): + episode_data = { + k: df[episode][k][()] for k in df[episode].keys() + } + trajectories[episode] = episode_data + + elif mode == "memmap": + os.makedirs(CACHE_PATH, exist_ok=True) + + trajectories = {} + for episode in tqdm(df["/"].keys(), leave=False): + cache_name = f"memmap-{dataset_name.split('.')[0]}" + episode_cache_path = os.path.join(CACHE_PATH, cache_name, str(episode)) + + os.makedirs(episode_cache_path, exist_ok=True) + episode_data = { + k: _flush_to_memmap( + filename=os.path.join(episode_cache_path, str(k)), + array=df[episode][k][()] + ) + for k in df[episode].keys() + } + trajectories[episode] = episode_data + + elif mode == "compressed": + trajectories = {} + for episode in tqdm(df["/"].keys(), leave=False): + # we do not copy data here! it will decompress it during reading or slicing + episode_data = {k: df[episode][k] for k in df[episode].keys()} + trajectories[episode] = episode_data + else: + raise RuntimeError("Unknown mode for dataset loading! Please use one of: 'compressed', 'in_memory', 'memmap'") + + # TODO: or return NLDSmallDataset here similar to nld loading? + return df, trajectories + + +class NLDSmallDataset: + def __init__( + self, + role: Role, + race: Race, + align: Alignment, + mode: str = "compressed" + ): + self.hdf5_file, self.data = load_nld_aa_small_dataset(role, race, align, mode=mode) + self.gameids = list(self.data.keys()) + + self.role = role + self.race = race + self.align = align + self.mode = mode + + def __getitem__(self, idx): + gameid = self.gameids[idx] + return self.data[gameid] + + def __len__(self): + return len(self.gameids) + + def metadata(self, idx): + gameid = self.gameids[idx] + return dict(self.hdf5_file[gameid].attrs) + + def close(self, clear_cache=True): + self.hdf5_file.close() + if self.mode == "memmap" and clear_cache: + print("Cleaning memmap cache...") + # remove memmap cache files from the disk upon closing + cache_name = f"memmap-data-{self.role.value}-{self.race.value}-{self.align.value}-any" + shutil.rmtree(os.path.join(CACHE_PATH, cache_name)) diff --git a/katakomba/utils/datasets/small_scale_buffer.py b/katakomba/utils/datasets/small_scale_buffer.py new file mode 100644 index 0000000..1e42b9e --- /dev/null +++ b/katakomba/utils/datasets/small_scale_buffer.py @@ -0,0 +1,80 @@ +import random +import numpy as np +from itertools import cycle +from katakomba.utils.datasets.small_scale import NLDSmallDataset + +from typing import Dict, List + + +# simple utility functions, you can also use map_tree analogs from dm-tree or optree +def dict_slice( + data: Dict[str, np.ndarray], + start: int, + end: int +) -> Dict[str, np.ndarray]: + return {k: v[start:end] for k, v in data.items()} + + +def dict_concat(datas: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: + return {k: np.concatenate([d[k] for d in datas]) for k in datas[0].keys()} + + +def dict_stack(datas: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: + return {k: np.stack([d[k] for d in datas]) for k in datas[0].keys()} + + +class SequentialBuffer: + def __init__( + self, + dataset: NLDSmallDataset, + batch_size: int, + seq_len: int, + add_next_step: bool = False, + seed: int = 0 + ): + self.traj = dataset + self.traj_idxs = list(range(len(self.traj))) + # shuffle starting trajectories indices + random.seed(seed) + random.shuffle(self.traj_idxs) + # iterator over next free trajectories to pick + self.free_traj = cycle(self.traj_idxs) + # index of the current trajectory for each row in batch + self.curr_traj = np.array([next(self.free_traj) for _ in range(batch_size)], dtype=int) + # index withing the current trajectory for each row in batch + self.curr_idx = np.zeros(batch_size, dtype=int) + + self.batch_size = batch_size + # it will return seq_len + 1, but also will start next traj from seq_len + 1, not seq_len + 2 as in nle + # this is very useful for DQN-like algorithms training with RNNs + self.add_next_step = add_next_step + self.seq_len = seq_len + 1 if add_next_step else seq_len + + def sample(self): + batch = [] + for i in range(self.batch_size): + traj_idx = self.curr_traj[i] + start_idx = self.curr_idx[i] + data = dict_slice(self.traj[traj_idx], start_idx, start_idx + self.seq_len) + + if len(data["actions"]) < self.seq_len: + # if next traj will have total_len < seq_len, then get next until data is seq_len + while len(data["actions"]) < self.seq_len: + traj_idx = next(self.free_traj) + len_diff = self.seq_len - len(data["actions"]) + + data = dict_concat([ + data, + dict_slice(self.traj[traj_idx], 0, len_diff), + ]) + self.curr_traj[i] = traj_idx + self.curr_idx[i] = len_diff - 1 if self.add_next_step else len_diff + else: + self.curr_idx[i] += self.seq_len - 1 if self.add_next_step else self.seq_len + + batch.append(data) + + return dict_stack(batch) + + def close(self, clear_cache=True): + return self.traj.close(clear_cache=clear_cache) diff --git a/katakomba/utils/misc.py b/katakomba/utils/misc.py new file mode 100644 index 0000000..73fb9ac --- /dev/null +++ b/katakomba/utils/misc.py @@ -0,0 +1,103 @@ +import math +import time + +import numpy as np +import torch +from dataclasses import dataclass + + +class Timeit: + def __enter__(self): + if torch.cuda.is_available(): + self.start_gpu = torch.cuda.Event(enable_timing=True) + self.end_gpu = torch.cuda.Event(enable_timing=True) + self.start_gpu.record() + self.start_cpu = time.time() + return self + + def __exit__(self, type, value, traceback): + if torch.cuda.is_available(): + self.end_gpu.record() + torch.cuda.synchronize() + self.elapsed_time_gpu = self.start_gpu.elapsed_time(self.end_gpu) / 1000 + else: + self.elapsed_time_gpu = -1.0 + self.elapsed_time_cpu = time.time() - self.start_cpu + + +# taken from: +# https://github.com/dungeonsdatasubmission/dungeonsdata-neurips2022/blob/ee72d6aac9df00a4a6ab1f501db37a632a75b952/experiment_code/hackrl/offline_experiment.py#L179 +@dataclass +class StatMean: + # Compute using Welford'd Online Algorithm + # Algo: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + # Math: https://jonisalonen.com/2013/deriving-welfords-method-for-computing-variance/ + n: int = 0 + mu: float = 0 + m2: float = 0 + cumulative: bool = False + + def result(self): + if self.n == 0: + return None + return self.mu + + def mean(self): + return self.mu + + def std(self): + if self.n < 1: + return None + return math.sqrt(self.m2 / self.n) + + def __sub__(self, other): + assert isinstance(other, StatMean) + n_new = self.n - other.n + if n_new == 0: + return StatMean(0, 0, 0) + mu_new = (self.mu * self.n - other.mu * other.n) / n_new + delta = other.mu - mu_new + m2_new = self.m2 - other.m2 - (delta**2) * n_new * other.n / self.n + return StatMean(n_new, mu_new, m2_new) + + def __iadd__(self, other): + if isinstance(other, StatMean): + other_n = other.n + other_mu = other.mu + other_m2 = other.m2 + elif isinstance(other, torch.Tensor): + other_n = other.numel() + other_mu = other.mean().item() + other_m2 = ((other - other_mu) ** 2).sum().item() + elif isinstance(other, np.ndarray): + other_n = other.size + other_mu = other.mean().item() + other_m2 = ((other - other_mu) ** 2).sum().item() + else: + other_n = 1 + other_mu = other + other_m2 = 0 + # See parallelized Welford in wiki + new_n = other_n + self.n + delta = other_mu - self.mu + self.mu += delta * (other_n / max(new_n, 1)) + delta2 = other_mu - self.mu + self.m2 += other_m2 + (delta2**2) * (self.n * other_n / max(new_n, 1)) + self.n = new_n + return self + + def reset(self): + if not self.cumulative: + self.mu = 0 + self.n = 0 + + def decay_cumulative(self, n=1e6): + """Adjust sample size downwards to upweight recent samples""" + if not self.cumulative: + return + if self.n > n: + self.m2 *= n / self.n + self.n = n + + def __repr__(self): + return repr(self.result()) \ No newline at end of file diff --git a/katakomba/utils/roles.py b/katakomba/utils/roles.py index c034bb2..9bae739 100644 --- a/katakomba/utils/roles.py +++ b/katakomba/utils/roles.py @@ -4,7 +4,6 @@ """ import enum - class Role(enum.Enum): ARCHEOLOGIST = "arc" BARBARIAN = "bar" @@ -42,45 +41,91 @@ class Sex(enum.Enum): ### These combinations are allowed by NetHack ### On sex: both are always available except Valkyrie (which has no sex) -ALLOWED_COMBOS = set( - [ - (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.LAWFUL), - (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.NEUTRAL), - (Role.ARCHEOLOGIST, Race.DWARF, Alignment.LAWFUL), - (Role.ARCHEOLOGIST, Race.GNOME, Alignment.NEUTRAL), - (Role.BARBARIAN, Race.HUMAN, Alignment.NEUTRAL), - (Role.BARBARIAN, Race.HUMAN, Alignment.CHAOTIC), - (Role.BARBARIAN, Race.ORC, Alignment.CHAOTIC), - (Role.CAVEMAN, Race.HUMAN, Alignment.LAWFUL), - (Role.CAVEMAN, Race.HUMAN, Alignment.NEUTRAL), - (Role.CAVEMAN, Race.DWARF, Alignment.LAWFUL), - (Role.CAVEMAN, Race.GNOME, Alignment.NEUTRAL), - (Role.HEALER, Race.HUMAN, Alignment.NEUTRAL), - (Role.HEALER, Race.GNOME, Alignment.NEUTRAL), - (Role.KNIGHT, Race.HUMAN, Alignment.LAWFUL), - (Role.MONK, Race.HUMAN, Alignment.NEUTRAL), - (Role.MONK, Race.HUMAN, Alignment.LAWFUL), - (Role.MONK, Race.HUMAN, Alignment.CHAOTIC), - (Role.PRIEST, Race.HUMAN, Alignment.NEUTRAL), - (Role.PRIEST, Race.HUMAN, Alignment.LAWFUL), - (Role.PRIEST, Race.HUMAN, Alignment.CHAOTIC), - (Role.PRIEST, Race.ELF, Alignment.CHAOTIC), - (Role.RANGER, Race.HUMAN, Alignment.NEUTRAL), - (Role.RANGER, Race.HUMAN, Alignment.CHAOTIC), - (Role.RANGER, Race.ELF, Alignment.CHAOTIC), - (Role.RANGER, Race.GNOME, Alignment.NEUTRAL), - (Role.RANGER, Race.ORC, Alignment.CHAOTIC), - (Role.ROGUE, Race.HUMAN, Alignment.CHAOTIC), - (Role.ROGUE, Race.ORC, Alignment.CHAOTIC), - (Role.SAMURAI, Race.HUMAN, Alignment.LAWFUL), - (Role.TOURIST, Race.HUMAN, Alignment.NEUTRAL), - (Role.VALKYRIE, Race.HUMAN, Alignment.NEUTRAL), - (Role.VALKYRIE, Race.HUMAN, Alignment.LAWFUL), - (Role.VALKYRIE, Race.DWARF, Alignment.LAWFUL), - (Role.WIZARD, Race.HUMAN, Alignment.NEUTRAL), - (Role.WIZARD, Race.HUMAN, Alignment.CHAOTIC), - (Role.WIZARD, Race.ELF, Alignment.CHAOTIC), - (Role.WIZARD, Race.GNOME, Alignment.NEUTRAL), - (Role.WIZARD, Race.ORC, Alignment.CHAOTIC), - ] -) +ALLOWED_COMBOS = set([ + (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.LAWFUL), + (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.NEUTRAL), + (Role.ARCHEOLOGIST, Race.DWARF, Alignment.LAWFUL), + (Role.ARCHEOLOGIST, Race.GNOME, Alignment.NEUTRAL), + (Role.BARBARIAN, Race.HUMAN, Alignment.NEUTRAL), + (Role.BARBARIAN, Race.HUMAN, Alignment.CHAOTIC), + (Role.BARBARIAN, Race.ORC, Alignment.CHAOTIC), + (Role.CAVEMAN, Race.HUMAN, Alignment.LAWFUL), + (Role.CAVEMAN, Race.HUMAN, Alignment.NEUTRAL), + (Role.CAVEMAN, Race.DWARF, Alignment.LAWFUL), + (Role.CAVEMAN, Race.GNOME, Alignment.NEUTRAL), + (Role.HEALER, Race.HUMAN, Alignment.NEUTRAL), + (Role.HEALER, Race.GNOME, Alignment.NEUTRAL), + (Role.KNIGHT, Race.HUMAN, Alignment.LAWFUL), + (Role.MONK, Race.HUMAN, Alignment.NEUTRAL), + (Role.MONK, Race.HUMAN, Alignment.LAWFUL), + (Role.MONK, Race.HUMAN, Alignment.CHAOTIC), + (Role.PRIEST, Race.HUMAN, Alignment.NEUTRAL), + (Role.PRIEST, Race.HUMAN, Alignment.LAWFUL), + (Role.PRIEST, Race.HUMAN, Alignment.CHAOTIC), + (Role.PRIEST, Race.ELF, Alignment.CHAOTIC), + (Role.RANGER, Race.HUMAN, Alignment.NEUTRAL), + (Role.RANGER, Race.HUMAN, Alignment.CHAOTIC), + (Role.RANGER, Race.ELF, Alignment.CHAOTIC), + (Role.RANGER, Race.GNOME, Alignment.NEUTRAL), + (Role.RANGER, Race.ORC, Alignment.CHAOTIC), + (Role.ROGUE, Race.HUMAN, Alignment.CHAOTIC), + (Role.ROGUE, Race.ORC, Alignment.CHAOTIC), + (Role.SAMURAI, Race.HUMAN, Alignment.LAWFUL), + (Role.TOURIST, Race.HUMAN, Alignment.NEUTRAL), + (Role.VALKYRIE, Race.HUMAN, Alignment.NEUTRAL), + (Role.VALKYRIE, Race.HUMAN, Alignment.LAWFUL), + (Role.VALKYRIE, Race.DWARF, Alignment.LAWFUL), + (Role.WIZARD, Race.HUMAN, Alignment.NEUTRAL), + (Role.WIZARD, Race.HUMAN, Alignment.CHAOTIC), + (Role.WIZARD, Race.ELF, Alignment.CHAOTIC), + (Role.WIZARD, Race.GNOME, Alignment.NEUTRAL), + (Role.WIZARD, Race.ORC, Alignment.CHAOTIC), +]) + +# These are combinations for the splits from the paper +BASE_COMBOS = set([ + (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.NEUTRAL), + (Role.CAVEMAN, Race.HUMAN, Alignment.NEUTRAL), + (Role.BARBARIAN, Race.HUMAN, Alignment.NEUTRAL), + (Role.HEALER, Race.HUMAN, Alignment.NEUTRAL), + (Role.KNIGHT, Race.HUMAN, Alignment.LAWFUL), + (Role.MONK, Race.HUMAN, Alignment.NEUTRAL), + (Role.PRIEST, Race.HUMAN, Alignment.NEUTRAL), + (Role.RANGER, Race.HUMAN, Alignment.NEUTRAL), + (Role.ROGUE, Race.HUMAN, Alignment.CHAOTIC), + (Role.SAMURAI, Race.HUMAN, Alignment.LAWFUL), + (Role.TOURIST, Race.HUMAN, Alignment.NEUTRAL), + (Role.VALKYRIE, Race.HUMAN, Alignment.NEUTRAL), + (Role.WIZARD, Race.HUMAN, Alignment.NEUTRAL) +]) + +EXTENDED_COMBOS = set([ + (Role.PRIEST, Race.ELF, Alignment.CHAOTIC), + (Role.RANGER, Race.ELF, Alignment.CHAOTIC), + (Role.WIZARD, Race.ELF, Alignment.CHAOTIC), + (Role.ARCHEOLOGIST, Race.DWARF, Alignment.LAWFUL), + (Role.CAVEMAN, Race.DWARF, Alignment.LAWFUL), + (Role.VALKYRIE, Race.DWARF, Alignment.LAWFUL), + (Role.ARCHEOLOGIST, Race.GNOME, Alignment.NEUTRAL), + (Role.CAVEMAN, Race.GNOME, Alignment.NEUTRAL), + (Role.HEALER, Race.GNOME, Alignment.NEUTRAL), + (Role.RANGER, Race.GNOME, Alignment.NEUTRAL), + (Role.WIZARD, Race.GNOME, Alignment.NEUTRAL), + (Role.BARBARIAN, Race.ORC, Alignment.CHAOTIC), + (Role.RANGER, Race.ORC, Alignment.CHAOTIC), + (Role.ROGUE, Race.ORC, Alignment.CHAOTIC), + (Role.WIZARD, Race.ORC, Alignment.CHAOTIC) +]) + +COMPLETE_COMBOS = set([ + (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.LAWFUL), + (Role.CAVEMAN, Race.HUMAN, Alignment.LAWFUL), + (Role.MONK, Race.HUMAN, Alignment.LAWFUL), + (Role.PRIEST, Race.HUMAN, Alignment.LAWFUL), + (Role.VALKYRIE, Race.HUMAN, Alignment.LAWFUL), + (Role.BARBARIAN, Race.HUMAN, Alignment.CHAOTIC), + (Role.MONK, Race.HUMAN, Alignment.CHAOTIC), + (Role.PRIEST, Race.HUMAN, Alignment.CHAOTIC), + (Role.RANGER, Race.HUMAN, Alignment.CHAOTIC), + (Role.WIZARD, Race.HUMAN, Alignment.CHAOTIC) +]) diff --git a/katakomba/utils/scores.py b/katakomba/utils/scores.py new file mode 100644 index 0000000..080206e --- /dev/null +++ b/katakomba/utils/scores.py @@ -0,0 +1,127 @@ +""" +Here are the statistics for results and the normalization all around. +""" +from katakomba.utils.roles import Role, Race, Alignment + +MEAN_SCORES_AUTOASCEND = { + (Role.ARCHEOLOGIST, Race.DWARF, Alignment.LAWFUL): 5445.69, + (Role.ARCHEOLOGIST, Race.GNOME, Alignment.NEUTRAL): 5316.57, + (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.LAWFUL): 5826.35, + (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.NEUTRAL): 6636.44, + (Role.BARBARIAN, Race.HUMAN, Alignment.CHAOTIC): 18228.11, + (Role.BARBARIAN, Race.HUMAN, Alignment.NEUTRAL): 17836.68, + (Role.BARBARIAN, Race.ORC, Alignment.CHAOTIC): 17594.38, + (Role.CAVEMAN, Race.DWARF, Alignment.LAWFUL): 11893.48, + (Role.CAVEMAN, Race.GNOME, Alignment.NEUTRAL): 10083.06, + (Role.CAVEMAN, Race.HUMAN, Alignment.LAWFUL): 12462.82, + (Role.CAVEMAN, Race.HUMAN, Alignment.NEUTRAL): 12113.87, + (Role.HEALER, Race.GNOME, Alignment.NEUTRAL): 3783.93, + (Role.HEALER, Race.HUMAN, Alignment.NEUTRAL): 4068.27, + (Role.KNIGHT, Race.HUMAN, Alignment.LAWFUL): 14137.06, + (Role.MONK, Race.HUMAN, Alignment.CHAOTIC): 18353.30, + (Role.MONK, Race.HUMAN, Alignment.LAWFUL): 16091.57, + (Role.MONK, Race.HUMAN, Alignment.NEUTRAL): 17456.05, + (Role.PRIEST, Race.ELF, Alignment.CHAOTIC): 7109.35, + (Role.PRIEST, Race.HUMAN, Alignment.CHAOTIC): 8262.56, + (Role.PRIEST, Race.HUMAN, Alignment.LAWFUL): 6847.99, + (Role.PRIEST, Race.HUMAN, Alignment.NEUTRAL): 7732.69, + (Role.RANGER, Race.ELF, Alignment.CHAOTIC): 9014.18, + (Role.RANGER, Race.GNOME, Alignment.NEUTRAL): 6965.04, + (Role.RANGER, Race.HUMAN, Alignment.CHAOTIC): 8378.50, + (Role.RANGER, Race.HUMAN, Alignment.NEUTRAL): 8067.99, + (Role.RANGER, Race.ORC, Alignment.CHAOTIC): 7608.48, + (Role.ROGUE, Race.HUMAN, Alignment.CHAOTIC): 4818.20, + (Role.ROGUE, Race.ORC, Alignment.CHAOTIC): 4897.69, + (Role.SAMURAI, Race.HUMAN, Alignment.LAWFUL): 11009.36, + (Role.TOURIST, Race.HUMAN, Alignment.NEUTRAL): 4211.47, + (Role.VALKYRIE, Race.DWARF, Alignment.LAWFUL): 23473.61, + (Role.VALKYRIE, Race.HUMAN, Alignment.LAWFUL): 26103.03, + (Role.VALKYRIE, Race.HUMAN, Alignment.NEUTRAL): 18624.77, + (Role.WIZARD, Race.ELF, Alignment.CHAOTIC): 5005.16, + (Role.WIZARD, Race.GNOME, Alignment.NEUTRAL): 4317.51, + (Role.WIZARD, Race.HUMAN, Alignment.CHAOTIC): 5316.82, + (Role.WIZARD, Race.HUMAN, Alignment.NEUTRAL): 5323.48, + (Role.WIZARD, Race.ORC, Alignment.CHAOTIC): 5016.74, +} + +MIN_SCORES_AUTOASCEND = { + (Role.ARCHEOLOGIST, Race.DWARF, Alignment.LAWFUL): 0.00, + (Role.ARCHEOLOGIST, Race.GNOME, Alignment.NEUTRAL): 0.00, + (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.LAWFUL): 2.00, + (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.NEUTRAL): 0.00, + (Role.BARBARIAN, Race.HUMAN, Alignment.CHAOTIC): 0.00, + (Role.BARBARIAN, Race.HUMAN, Alignment.NEUTRAL): 0.00, + (Role.BARBARIAN, Race.ORC, Alignment.CHAOTIC): 0.00, + (Role.CAVEMAN, Race.DWARF, Alignment.LAWFUL): 0.00, + (Role.CAVEMAN, Race.GNOME, Alignment.NEUTRAL): 0.00, + (Role.CAVEMAN, Race.HUMAN, Alignment.LAWFUL): 0.00, + (Role.CAVEMAN, Race.HUMAN, Alignment.NEUTRAL): 0.00, + (Role.HEALER, Race.GNOME, Alignment.NEUTRAL): 0.00, + (Role.HEALER, Race.HUMAN, Alignment.NEUTRAL): 0.00, + (Role.KNIGHT, Race.HUMAN, Alignment.LAWFUL): 0.00, + (Role.MONK, Race.HUMAN, Alignment.CHAOTIC): 0.00, + (Role.MONK, Race.HUMAN, Alignment.LAWFUL): 7.00, + (Role.MONK, Race.HUMAN, Alignment.NEUTRAL): 0.00, + (Role.PRIEST, Race.ELF, Alignment.CHAOTIC): 0.00, + (Role.PRIEST, Race.HUMAN, Alignment.CHAOTIC): 0.00, + (Role.PRIEST, Race.HUMAN, Alignment.LAWFUL): 0.00, + (Role.PRIEST, Race.HUMAN, Alignment.NEUTRAL): 0.00, + (Role.RANGER, Race.ELF, Alignment.CHAOTIC): 0.00, + (Role.RANGER, Race.GNOME, Alignment.NEUTRAL): 0.00, + (Role.RANGER, Race.HUMAN, Alignment.CHAOTIC): 3.00, + (Role.RANGER, Race.HUMAN, Alignment.NEUTRAL): 0.00, + (Role.RANGER, Race.ORC, Alignment.CHAOTIC): 3.00, + (Role.ROGUE, Race.HUMAN, Alignment.CHAOTIC): 0.00, + (Role.ROGUE, Race.ORC, Alignment.CHAOTIC): 0.00, + (Role.SAMURAI, Race.HUMAN, Alignment.LAWFUL): 0.00, + (Role.TOURIST, Race.HUMAN, Alignment.NEUTRAL): 0.00, + (Role.VALKYRIE, Race.DWARF, Alignment.LAWFUL): 0.00, + (Role.VALKYRIE, Race.HUMAN, Alignment.LAWFUL): 0.00, + (Role.VALKYRIE, Race.HUMAN, Alignment.NEUTRAL): 16.00, + (Role.WIZARD, Race.ELF, Alignment.CHAOTIC): 0.00, + (Role.WIZARD, Race.GNOME, Alignment.NEUTRAL): 0.00, + (Role.WIZARD, Race.HUMAN, Alignment.CHAOTIC): 0.00, + (Role.WIZARD, Race.HUMAN, Alignment.NEUTRAL): 0.00, + (Role.WIZARD, Race.ORC, Alignment.CHAOTIC): 0.00, +} + +MAX_SCORES_AUTOASCEND = { + (Role.ARCHEOLOGIST, Race.DWARF, Alignment.LAWFUL): 83496.00, + (Role.ARCHEOLOGIST, Race.GNOME, Alignment.NEUTRAL): 110054.00, + (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.LAWFUL): 84823.00, + (Role.ARCHEOLOGIST, Race.HUMAN, Alignment.NEUTRAL): 138103.00, + (Role.BARBARIAN, Race.HUMAN, Alignment.CHAOTIC): 164446.00, + (Role.BARBARIAN, Race.HUMAN, Alignment.NEUTRAL): 292342.00, + (Role.BARBARIAN, Race.ORC, Alignment.CHAOTIC): 164296.00, + (Role.CAVEMAN, Race.DWARF, Alignment.LAWFUL): 161682.00, + (Role.CAVEMAN, Race.GNOME, Alignment.NEUTRAL): 142460.00, + (Role.CAVEMAN, Race.HUMAN, Alignment.LAWFUL): 156966.00, + (Role.CAVEMAN, Race.HUMAN, Alignment.NEUTRAL): 258978.00, + (Role.HEALER, Race.GNOME, Alignment.NEUTRAL): 69566.00, + (Role.HEALER, Race.HUMAN, Alignment.NEUTRAL): 64337.00, + (Role.KNIGHT, Race.HUMAN, Alignment.LAWFUL): 419154.00, + (Role.MONK, Race.HUMAN, Alignment.CHAOTIC): 223997.00, + (Role.MONK, Race.HUMAN, Alignment.LAWFUL): 190783.00, + (Role.MONK, Race.HUMAN, Alignment.NEUTRAL): 171224.00, + (Role.PRIEST, Race.ELF, Alignment.CHAOTIC): 83744.00, + (Role.PRIEST, Race.HUMAN, Alignment.CHAOTIC): 58367.00, + (Role.PRIEST, Race.HUMAN, Alignment.LAWFUL): 99250.00, + (Role.PRIEST, Race.HUMAN, Alignment.NEUTRAL): 114269.00, + (Role.RANGER, Race.ELF, Alignment.CHAOTIC): 66690.00, + (Role.RANGER, Race.GNOME, Alignment.NEUTRAL): 58137.00, + (Role.RANGER, Race.HUMAN, Alignment.CHAOTIC): 62599.00, + (Role.RANGER, Race.HUMAN, Alignment.NEUTRAL): 54874.00, + (Role.RANGER, Race.ORC, Alignment.CHAOTIC): 69244.00, + (Role.ROGUE, Race.HUMAN, Alignment.CHAOTIC): 68628.00, + (Role.ROGUE, Race.ORC, Alignment.CHAOTIC): 54892.00, + (Role.SAMURAI, Race.HUMAN, Alignment.LAWFUL): 155163.00, + (Role.TOURIST, Race.HUMAN, Alignment.NEUTRAL): 59484.00, + (Role.VALKYRIE, Race.DWARF, Alignment.LAWFUL): 1136591.00, + (Role.VALKYRIE, Race.HUMAN, Alignment.LAWFUL): 428274.00, + (Role.VALKYRIE, Race.HUMAN, Alignment.NEUTRAL): 313858.00, + (Role.WIZARD, Race.ELF, Alignment.CHAOTIC): 71664.00, + (Role.WIZARD, Race.GNOME, Alignment.NEUTRAL): 37376.00, + (Role.WIZARD, Race.HUMAN, Alignment.CHAOTIC): 55185.00, + (Role.WIZARD, Race.HUMAN, Alignment.NEUTRAL): 71709.00, + (Role.WIZARD, Race.ORC, Alignment.CHAOTIC): 40871.00, +} \ No newline at end of file diff --git a/katakomba/wrappers/__init__.py b/katakomba/wrappers/__init__.py index 3d077be..07d95bd 100644 --- a/katakomba/wrappers/__init__.py +++ b/katakomba/wrappers/__init__.py @@ -1,3 +1,2 @@ -from katakomba.wrappers.base import NetHackWrapper from katakomba.wrappers.tty import TTYWrapper from katakomba.wrappers.render import CropRenderWrapper diff --git a/katakomba/wrappers/base.py b/katakomba/wrappers/base.py deleted file mode 100644 index b7445d5..0000000 --- a/katakomba/wrappers/base.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Optional, Tuple - -import gym -from nle.env.base import NLE - - -class NetHackWrapper(gym.Wrapper): - """ - NetHack needs a modified gym-wrapper due to its multiple-seeding strategy. - """ - - def __init__(self, env: NLE): - super().__init__(env) - - self.env: NLE = env - - def seed( - self, - core: Optional[int] = None, - disp: Optional[int] = None, - reseed: bool = False, - ) -> Tuple[int, int, bool]: - """ - Sets the state of the NetHack RNGs after the next reset. - - NetHack 3.6 uses two RNGs, core and disp. This is to prevent - RNG-manipulation by e.g. running into walls or other no-ops on the - actual game state. This is a measure against "tool-assisted - speedruns" (TAS). NLE can run in both NetHack's default mode and in - TAS-friendly "no reseeding" if `reseed` is set to False, see below. - - Arguments: - core [int or None]: Seed for the core RNG. If None, chose a random - value. - disp [int or None]: Seed for the disp (anti-TAS) RNG. If None, chose - a random value. - reseed [boolean]: As an Anti-TAS (automation) measure, - NetHack 3.6 reseeds with true randomness every now and then. This - flag enables or disables this behavior. If set to True, trajectories - won't be reproducible. - - Returns: - [tuple] The seeds supplied, in the form (core, disp, reseed). - """ - return self.env.seed(core, disp, reseed) diff --git a/katakomba/wrappers/render.py b/katakomba/wrappers/render.py index 4bd8d92..305708b 100644 --- a/katakomba/wrappers/render.py +++ b/katakomba/wrappers/render.py @@ -2,10 +2,10 @@ import numpy as np from katakomba.utils.render import SCREEN_SHAPE, render_screen_image -from katakomba.wrappers.base import NetHackWrapper +from katakomba.env import OfflineNetHackChallengeWrapper -class CropRenderWrapper(NetHackWrapper): +class CropRenderWrapper(OfflineNetHackChallengeWrapper): """ Populates observation with: - screen_image: [3, crop_width, crop_height]. For specific values see d5rl/utils/render.py diff --git a/katakomba/wrappers/tty.py b/katakomba/wrappers/tty.py index bd4df86..d8a5966 100644 --- a/katakomba/wrappers/tty.py +++ b/katakomba/wrappers/tty.py @@ -3,10 +3,10 @@ from nle import nethack from nle.env import base -from katakomba.wrappers.base import NetHackWrapper +from katakomba.env import OfflineNetHackChallengeWrapper -class TTYWrapper(NetHackWrapper): +class TTYWrapper(OfflineNetHackChallengeWrapper): """ An observation wrapper that converts tty_* to a numpy array. """ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..825824e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,18 @@ +cython +psutil +opencv-python +scikit-build +pybind11==2.10.3 +numba==0.56.4 +tqdm==4.64.0 +wandb==0.12.21 +numpy==1.23.1 +gym==0.23.0 +pyrallis==0.3.1 +h5py==3.8.0 +nle==0.9.0 +rliable==1.0.8 +matplotlib==3.6.2 +seaborn==0.12.1 +Pillow==9.2.0 +torch diff --git a/requirements/requirements.txt b/requirements/requirements.txt deleted file mode 100644 index e9c585d..0000000 --- a/requirements/requirements.txt +++ /dev/null @@ -1,21 +0,0 @@ -# Our typical stuff for offline-rl and jax -git+https://github.com/tinkoff-ai/d4rl@master#egg=d4rl -tqdm==4.64.0 -wandb==0.12.21 -mujoco-py==2.1.2.14 -numpy==1.23.1 -gym[mujoco_py]==0.23.0 - -sortedcontainers==2.4.0 -pyrallis==0.3.1 - -# Some useful stuff seems to be -cython -scikit-build -pybind11 -numba -opencv-python - -transformers==4.25.1 -datasets==2.8.0 -accelerate==0.15.0 diff --git a/scripts/generate.sh b/scripts/generate.sh new file mode 100644 index 0000000..a066d1a --- /dev/null +++ b/scripts/generate.sh @@ -0,0 +1,67 @@ +#!/bin/bash +export DATA_PATH="../data/nle_data" +export SAVE_PATH="../data/nle_small_scale_data" + +# All allowed role-race-align combos. See also: katakomba/utils/roles.py +combos=( + "arc hum law" + "arc hum neu" + "arc dwa law" + "arc gno neu" + + "bar hum neu" + "bar hum cha" + "bar orc cha" + + "cav hum law" + "cav hum neu" + "cav dwa law" + "cav gno neu" + + "hea hum neu" + "hea gno neu" + + "kni hum law" + + "mon hum neu" + "mon hum law" + "mon hum cha" + + "pri hum neu" + "pri hum law" + "pri hum cha" + "pri elf cha" + + "ran hum neu" + "ran hum cha" + "ran elf cha" + "ran gno neu" + "ran orc cha" + + "rog hum cha" + "rog orc cha" + + "sam hum law" + + "tou hum neu" + + "val hum neu" + "val hum law" + "val dwa law" + + "wiz hum neu" + "wiz hum cha" + "wiz elf cha" + "wiz gno neu" + "wiz orc cha" +) + +for tup in "${combos[@]}" +do + set -- $tup + python3 generate_small_dataset.py \ + --data_path=$DATA_PATH \ + --save_path=$SAVE_PATH \ + --role="$1" --race="$2" --alignment="$3" \ + --num_episodes=700 +done \ No newline at end of file diff --git a/scripts/generate_small_dataset.py b/scripts/generate_small_dataset.py new file mode 100644 index 0000000..c8c9bf1 --- /dev/null +++ b/scripts/generate_small_dataset.py @@ -0,0 +1,176 @@ +import os +import h5py +import numpy as np +import nle.dataset as nld + +import zipfile +import random +import pyrallis +from dataclasses import dataclass +from typing import Optional + +from tqdm.auto import tqdm +from collections import defaultdict +from nle.nethack.actions import ACTIONS + +ACTION_MAPPING = np.zeros(256) +for i, a in enumerate(ACTIONS): + ACTION_MAPPING[a.value] = i + + +@dataclass +class Config: + data_path: str = "data/nle_data" + save_path: str = "data/nle_data_converted" + race: Optional[str] = None + role: Optional[str] = None + alignment: Optional[str] = None + gender: Optional[str] = None + sampling: Optional[str] = None # "sort" or "stratify" + num_episodes: Optional[int] = None + num_bins: int = 50 + random_seed: int = 32 + clean_db_after: bool = False + compress_after: bool = False + + +def stratified_sample(x, scores, num_samples, num_bins=100): + num_total = len(x) + + bins, edges = np.histogram(scores, bins=num_bins) + assert sum(bins) == num_total, "change number of bins" + n_strat_samples = [int(num_samples * (num_bin / num_total)) for num_bin in bins] + + bin_ids = np.digitize(scores, edges) + + sampled_ids = [] + for sample_size, bin_id in zip(n_strat_samples, range(1, num_bins + 1)): + sample = np.random.choice(x[bin_ids == bin_id], size=sample_size, replace=False) + assert sample.shape[0] == sample_size + + sampled_ids.extend(sample.tolist()) + + return np.array(sampled_ids) + + +def reward_as_score_diff(scores): + rewards = np.zeros(len(scores)) + for i in range(len(scores) - 1): + # score at step i: the in-game score at this timestep (the result of the action at the previous timestep) + rewards[i] = scores[i + 1] - scores[i] + # last reward will be repeated (it is not defined, as we don't have t + 1 score for last state) + rewards[-1] = rewards[-2] + # clip as for some reason last steps after death can have zero scores + return rewards.clip(0) + + +def load_game(dataset, game_id): + raw_data = defaultdict(list) + for step in dataset.get_ttyrec(game_id, 1)[:-1]: + # check that this step is not padding + assert step["gameids"][0, 0] != 0 + raw_data["tty_chars"].append(step["tty_chars"].squeeze()) + raw_data["tty_colors"].append(step["tty_colors"].squeeze()) + raw_data["tty_cursor"].append(step["tty_cursor"].squeeze()) + raw_data["actions"].append(ACTION_MAPPING[step["keypresses"].item()]) + raw_data["scores"].append(step["scores"].item()) + + data = { + "tty_chars": np.stack(raw_data["tty_chars"]), + "tty_colors": np.stack(raw_data["tty_colors"]), + "tty_cursor": np.stack(raw_data["tty_cursor"]), + "actions": np.array(raw_data["actions"]).astype(np.int16), + "rewards": reward_as_score_diff(raw_data["scores"]).astype(np.int32), + # dones are broken in NLD-AA, so we just rewrite them with always done at last step + # see: https://github.com/facebookresearch/nle/issues/355 + "dones": np.zeros(len(raw_data["actions"]), dtype=bool) + } + data["dones"][-1] = True + return data + + +def optional_eq(x, cond): + if cond is not None: + return x == cond + return True + + +def name(role, race, align, gender): + return f"{role or 'any'}-{race or 'any'}-{align or 'any'}-{gender or 'any'}" + + +@pyrallis.wrap() +def main(config: Config): + os.makedirs(config.save_path, exist_ok=True) + + dbfilename = "tmp_ttyrecs.db" + if not nld.db.exists(dbfilename): + nld.db.create(dbfilename) + nld.add_nledata_directory(config.data_path, "autoascend", dbfilename) + + dataset = nld.TtyrecDataset( + "autoascend", + batch_size=1, + seq_length=1, + dbfilename=dbfilename, + ) + # retrieving and filtering metadata from the dataset + metadata = {game_id: dict(dataset.get_meta(game_id)) for game_id in dataset._gameids} + metadata = { + k: v for k, v in metadata.items() if ( + optional_eq(v["role"].lower(), config.role) and + optional_eq(v["race"].lower(), config.race) and + optional_eq(v["align"].lower(), config.alignment) and + optional_eq(v["gender"].lower(), config.gender) + ) + } + file_name = name(config.role, config.race, config.alignment, config.gender) + + game_ids = np.array(list(metadata.keys())) + assert len(game_ids) != 0, "dataset does not have episodes with such configuration" + if config.sampling is not None: + scores = np.array([metadata[game_id]["points"] for game_id in game_ids]) + + if config.sampling == "stratify": + random.seed(config.random_seed) + np.random.seed(config.random_seed) + + game_ids = stratified_sample(game_ids, scores, config.num_episodes, num_bins=config.num_bins) + print(f"Sampled {len(game_ids)} episodes with stratified sampling!") + elif config.sampling == "sort": + game_ids = game_ids[np.argsort(scores)][-config.num_episodes:] + mean_score = np.mean(np.sort(scores)[-config.num_episodes:]) + print(f"Sampled episodes with top {config.num_episodes} scores. Mean score: {mean_score}") + else: + raise RuntimeError("Unknown sampling type") + + # saving episodes data as uncompressed hdf5 + with h5py.File(os.path.join(config.save_path, f"data-{file_name}.hdf5"), "w", track_order=True) as df: + for ep_id in tqdm(game_ids): + data = load_game(dataset, game_id=ep_id) + + g = df.create_group(str(ep_id)) + g.create_dataset("tty_chars", data=data["tty_chars"], compression="gzip") + g.create_dataset("tty_colors", data=data["tty_colors"], compression="gzip") + g.create_dataset("tty_cursor", data=data["tty_cursor"], compression="gzip") + g.create_dataset("actions", data=data["actions"], compression="gzip") + g.create_dataset("rewards", data=data["rewards"], compression="gzip") + g.create_dataset("dones", data=data["dones"], compression="gzip") + # also save metadata as attrs + for key, value in metadata[ep_id].items(): + g.attrs[key] = value + + # clearing and compressing at the end + if config.compress_after: + hdf5_path = os.path.join(config.save_path, f"data-{file_name}.hdf5") + + with zipfile.ZipFile(f"{hdf5_path}.zip", "w", zipfile.ZIP_DEFLATED) as z: + z.write(os.path.join(config.save_path, f"data-{file_name}.hdf5")) + os.remove(hdf5_path) + + if nld.db.exists(dbfilename) and config.clean_db_after: + os.remove(dbfilename) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/guide.py b/scripts/guide.py deleted file mode 100644 index faf06a4..0000000 --- a/scripts/guide.py +++ /dev/null @@ -1,113 +0,0 @@ -from torch.utils.data import DataLoader - -from katakomba.tasks import make_task_builder -from katakomba.utils.roles import Alignment, Race, Role, Sex - - -""" -Task Builder. - -For each task (now it's just NetHackScore-v0-tty-bot-v0), make_task_builder -will output both environment builder and dataset builder, which could -further be used for precise evaluation and the dataset you use for -your training purposes. -""" -env_builder, dataset_builder = make_task_builder("NetHackScore-v0-tty-bot-v0") - - -""" -Environment Builder. - -- This class allows you to specify exactly which character traits combination you evaluate against. -- You can specify all that's typically allowed by NetHack game: [roles, races, alignments, sex]. -- Note that not all of the combinations are allowed by the NetHack, we filter out them for you. -- If you do not provide a specification, we assume that you evaluate against all possbile settings. -- You can also specify the seeds which are used for evaluation - - This way you can make sure that you evaluate against the same dungeons - - Not specifying anything would result in random dungeons at each evaluation - (This is what should be done for reporting scores the ultimate true score) -""" -env_builder = ( - env_builder.roles([Role.MONK]) - .races([Race.HUMAN]) - .alignments([Alignment.NEUTRAL]) - .sex([Sex.MALE]) - .eval_seeds([1, 2, 3]) -) - - -""" -An evaluation example. - -- character is a short description that stands for 'role-race-alignment-sex' - - Note that some characters do not possess sex. - - Note that these are short labels (see utils/roles.py) for actual values. -- env is a typical gym env (with a reseed flag exception in reset) -- seed is the evaluation seed - - In case none were provided, it will be None. - - Note that you should specify the seed yourself. -""" -for character, env, seed in env_builder.evaluate(): - """ - In case you specified a certain set of seeds. - Make sure you pass reseed=False in order to get the same dungeon for the same seed. - """ - env.seed(seed, reseed=False) - # your_super_eval_function(your_super_agent, env) - - -""" -Dataset Builder. - -- This class allows you to specify which games will be in your training dataset. -- You can specify all that's typically allowed by NetHack game: [roles, races, alignments, sex]. -- You can also specify concrete game_ids (but you better know what you're up to). -- You can also specify NetHack game versions to filter for (but you really better know what you're up to) - - By default, we rely only on NetHack 3.6.6 trajectories. -- Note that not all of the combinations are allowed by the NetHack, we filter out them for you. -- If you do not provide a specification, we assume that you want the whole dataset for game_version=3.6.6. -- (!) You need to call build to get a dataset - - batch_size is well, you know - - seq_len is well, you also know - - (!) sequences move by the seq_len, e.g., seq_len=4 - 1st batch sequence timesteps = [1, 2, 3, 4] - 2nd batch sequence timesteps = [5, 6, 7, 9] -""" -dataset = ( - dataset_builder.roles([Role.MONK]) - .races([Race.HUMAN]) - .alignments([Alignment.NEUTRAL]) - .sex([Sex.MALE]) - .build(batch_size=4, seq_len=100) -) - - -""" -PyTorch DataLoader. - -- As the dataset is already batched, we need to disable automatic batching. -""" -loader = DataLoader( - dataset=dataset, - # Disable automatic batching - batch_sampler=None, - batch_size=None, -) - - -""" -An iterator example. This will run indefinitely. -""" -for batch in loader: - states, actions, rewards, dones, next_states = batch - - # [4, 100, 24, 80, 3] - print(states.size()) - # [4, 100] - print(actions.size()) - # [4, 100] - print(rewards.size()) - # [4, 100] - print(dones.size()) - # [4, 100, 24, 80, 3] - print(next_states.size()) diff --git a/scripts/loader_benchmark.py b/scripts/loader_benchmark.py deleted file mode 100644 index 308525f..0000000 --- a/scripts/loader_benchmark.py +++ /dev/null @@ -1,46 +0,0 @@ -import time - -from torch.utils.data import DataLoader - -from katakomba.tasks import make_task_builder -from katakomba.datasets.sa_autoascend import SAAutoAscendTTYDataset - - -NUM_BATCHES = 200 -BATCH_SIZE = 256 -SEQ_LEN = 32 -N_WORKERS = 8 -DEVICE = "cpu" - -env_builder, dataset_builder = make_task_builder( - "NetHackScore-v0-tty-bot-v0", data_path="../nethack/nle_data" -) - -dataset = dataset_builder.build( - batch_size=BATCH_SIZE, - seq_len=SEQ_LEN, - n_workers=N_WORKERS, - auto_ascend_cls=SAAutoAscendTTYDataset, -) - -loader = DataLoader( - dataset=dataset, - # Disable automatic batching - batch_sampler=None, - batch_size=None, -) - -start = time.time() -for ind, batch in enumerate(loader): - device_batch = [t.to(DEVICE) for t in batch] - - if (ind + 1) == NUM_BATCHES: - break -end = time.time() -elapsed = end - start -print( - f"Fetching {NUM_BATCHES} batches of [batch_size={BATCH_SIZE}, seq_len={SEQ_LEN}] took: {elapsed} seconds." -) -print(f"1 batch takes around {elapsed / NUM_BATCHES} seconds.") -print(f"Total frames fetched: {NUM_BATCHES * BATCH_SIZE * SEQ_LEN}") -print(f"Frames / s: {NUM_BATCHES * BATCH_SIZE * SEQ_LEN / elapsed}") diff --git a/scripts/rliable_report.py b/scripts/rliable_report.py new file mode 100644 index 0000000..049fbf4 --- /dev/null +++ b/scripts/rliable_report.py @@ -0,0 +1,145 @@ +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt + +import pickle +import pyrallis +from dataclasses import dataclass +from itertools import product +from rliable import library as rly +from rliable import metrics +from rliable import plot_utils + +from katakomba.utils.roles import Role, Race, Alignment +from katakomba.utils.roles import ( + ALLOWED_COMBOS, BASE_COMBOS, EXTENDED_COMBOS, COMPLETE_COMBOS +) + +@dataclass +class Config: + scores_path: str = "cached_algo_stats.pkl" + metric_name: str = "normalized_scores" # normalized_scores | returns | depths + setting: str = "full" + + +@pyrallis.wrap() +def main(config: Config): + setting_combos = { + "full": ALLOWED_COMBOS, + "base": BASE_COMBOS, + "extended": EXTENDED_COMBOS, + "complete": COMPLETE_COMBOS, + } + + with open(config.scores_path, "rb") as f: + cached_stats = pickle.load(f) + + algorithms_scores = {} + for algo in cached_stats: + all_metrics = [] + for character in cached_stats[algo]: + role, race, align = character.split("-") + role, race, align = Role(role), Race(race), Alignment(align) + + if (role, race, align) not in setting_combos[config.setting]: + continue + + character_metrics = cached_stats[algo][character][config.metric_name].ravel() + if config.metric_name == "depths": + # levels start from 1, not 0 + character_metrics = character_metrics + 1 + elif config.metric_name == "normalized_scores": + character_metrics = character_metrics * 100.0 + + if algo == "AUTOASCEND": + # sample min trajectories, to align shapes + np.random.seed(32) + character_metrics = np.random.choice(character_metrics, size=675, replace=False) + + all_metrics.append(character_metrics) + + algorithms_scores[algo] = np.stack(all_metrics, axis=0).T + + # exclude it for now + algorithms_scores.pop("AUTOASCEND") + + xlabels = { + "normalized_scores": "Normalized Score", + "returns": "Score", + "depths": "Death Level" + } + metrics_thresholds = { + "normalized_scores": np.linspace(0.0, 100.0, 16), + "returns": np.linspace(0.0, 5000.0, 16), + "depths": np.linspace(1.0, 3.0, 6) + } + metrics_ticks = { + "normalized_scores": {"xticks": None, "yticks": [0.0, 0.10, 0.25, 0.5, 0.75, 1.0]}, + "returns": {"xticks": None, "yticks": None}, + "depths": {"xticks": None, "yticks": np.linspace(0.0, 0.1, 5)}, + } + + # plotting aggregate metrics with 95% stratified bootstrap CIs + aggregate_func = lambda x: np.array([ + metrics.aggregate_median(x), + metrics.aggregate_iqm(x), + metrics.aggregate_mean(x), + metrics.aggregate_optimality_gap(x) + ]) + aggregate_scores, aggregate_score_cis = rly.get_interval_estimates( + algorithms_scores, aggregate_func, reps=10000 + ) + fig, axes = plot_utils.plot_interval_estimates( + aggregate_scores, aggregate_score_cis, + metric_names=['Median', 'IQM', 'Mean', 'Optimality Gap'], + algorithms=list(algorithms_scores.keys()), + xlabel=xlabels[config.metric_name], + xlabel_y_coordinate=-0.25, + ) + plt.savefig(f"reliable_metrics_{config.metric_name}.pdf", bbox_inches="tight", format="pdf") + + # plotting performance profiles + thresholds = metrics_thresholds[config.metric_name] + score_distributions, score_distributions_cis = rly.create_performance_profile( + algorithms_scores, thresholds + ) + fig, ax = plt.subplots(ncols=1, figsize=(7, 5)) + + plot_utils.plot_performance_profiles( + score_distributions, thresholds, + performance_profile_cis=score_distributions_cis, + colors=dict(zip( + list(algorithms_scores.keys()), sns.color_palette('colorblind') + )), + xlabel=fr"{xlabels[config.metric_name]} $(\tau)$", + ax=ax, + yticks=metrics_ticks[config.metric_name]["yticks"], + xticks=metrics_ticks[config.metric_name]["xticks"] + ) + plt.legend() + plt.savefig(f"reliable_performance_profile_{config.metric_name}.pdf", bbox_inches="tight", format="pdf") + + # plotting probability of improvement + paired_scores = {} + for algo_y in algorithms_scores.keys(): + if algo_y != "BC": + paired_scores[f"BC,{algo_y}"] = ( + algorithms_scores["BC"], + algorithms_scores[algo_y] + ) + # for algo_x, algo_y in product(algorithms_scores.keys(), algorithms_scores.keys()): + # if algo_x != algo_y: + # paired_scores[f"{algo_x},{algo_y}"] = ( + # algorithms_scores[algo_x], + # algorithms_scores[algo_y] + # ) + + average_probabilities, average_prob_cis = rly.get_interval_estimates( + paired_scores, metrics.probability_of_improvement, reps=500 + ) + plot_utils.plot_probability_of_improvement(average_probabilities, average_prob_cis, xticks=[0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]) + plt.savefig(f"reliable_probability_of_improvement_{config.metric_name}.pdf", bbox_inches="tight", format="pdf") + + +if __name__ == "__main__": + main() diff --git a/scripts/stats/algorithms_stats.py b/scripts/stats/algorithms_stats.py new file mode 100644 index 0000000..4c604c0 --- /dev/null +++ b/scripts/stats/algorithms_stats.py @@ -0,0 +1,88 @@ +import os +import wandb +import numpy as np + +import pickle +import pyrallis +from typing import Optional +from dataclasses import dataclass +from tqdm.auto import tqdm + +from katakomba.utils.roles import ALLOWED_COMBOS +from katakomba.utils.datasets.small_scale import load_nld_aa_small_dataset +from katakomba.utils.scores import MEAN_SCORES_AUTOASCEND + + +@dataclass +class Config: + bc_wandb_group: Optional[str] = "small_scale_bc_chaotic_lstm_multiseed-v0" + cql_wandb_group: Optional[str] = "small_scale_cql_chaotic_lstm_multiseed-v0" + awac_wandb_group: Optional[str] = "small_scale_awac_chaotic_lstm_multiseed-v0" + iql_wandb_group: Optional[str] = "small_scale_iql_chaotic_lstm_multiseed-v0" + rem_wandb_group: Optional[str] = "small_scale_rem_chaotic_lstm_multiseed-v0" + checkpoint: int = 500000 + cache_path: str = "cached_algo_stats.pkl" + + +def get_character_scores(runs, character, filename): + multiseed_scores = [] + + runs = [run for run in runs if run.config["character"] == character] + for run in runs: + run.file(filename).download(replace=True) + multiseed_scores.append(np.load(filename)) + + os.remove(filename) + return np.array(multiseed_scores) + + +def get_autoascend_scores(): + characters_metrics = {} + for role, race, align in ALLOWED_COMBOS: + df, traj = load_nld_aa_small_dataset(role=role, race=race, align=align, mode="compressed") + returns = np.array([df[gameid].attrs["points"] for gameid in list(traj.keys())]) + norm_scores = returns / MEAN_SCORES_AUTOASCEND[(role, race, align)] + depths = np.array([df[gameid].attrs["deathlev"] for gameid in list(traj.keys())]) + + characters_metrics[f"{role.value}-{race.value}-{align.value}"] = { + "normalized_scores": norm_scores, + "returns": returns, + "depths": depths, + } + df.close() + + return characters_metrics + + +@pyrallis.wrap() +def main(config: Config): + algo_groups = { + "BC": config.bc_wandb_group, + "CQL": config.cql_wandb_group, + "AWAC": config.awac_wandb_group, + "IQL": config.iql_wandb_group, + "REM": config.rem_wandb_group, + } + + if not os.path.exists(config.cache_path): + algorithms_scores = {algo_name: {} for algo_name in algo_groups.keys()} + algorithms_scores["AUTOASCEND"] = get_autoascend_scores() + + api = wandb.Api() + for algo, group in tqdm(algo_groups.items(), desc="Downloading algorithms scores"): + algo_runs = [run for run in api.runs("tlab/Nethack") if run.group == group] + + for role, race, align in tqdm(ALLOWED_COMBOS, desc="Downloading character scores", leave=False): + character = f"{role.value}-{race.value}-{align.value}" + algorithms_scores[algo][character] = { + "normalized_scores": get_character_scores(algo_runs, character, f"{config.checkpoint}_normalized_scores.npy"), + "returns": get_character_scores(algo_runs, character, f"{config.checkpoint}_returns.npy"), + "depths": get_character_scores(algo_runs, character, f"{config.checkpoint}_depths.npy"), + } + + with open(config.cache_path, "wb") as f: + pickle.dump(algorithms_scores, f) + + +if __name__ == "__main__": + main() diff --git a/scripts/stats/depth.py b/scripts/stats/depth.py new file mode 100644 index 0000000..796a21d --- /dev/null +++ b/scripts/stats/depth.py @@ -0,0 +1,39 @@ +""" +This is a script for extracting the scores (used for populating values in utils/scores.py) +""" +import nle.dataset as nld +import numpy as np + +from nle.dataset.db import db as nld_database + +data_path = "data/nle_data" +db_path = "ttyrecs.db" + +if not nld.db.exists(db_path): + nld.db.create(db_path) + nld.add_nledata_directory(data_path, "autoascend", db_path) + +db_connection = nld.db.connect(filename=db_path) + +with nld_database(conn=db_connection) as connection: + c = connection.execute( + "SELECT games.role, games.race, games.align0, games.deathlev " + "FROM games " + "JOIN datasets ON games.gameid=datasets.gameid " + "WHERE datasets.dataset_name='autoascend' " + ) + +all_levels = {} +global_levels = [] +for row in c: + role, race, alignment, deathlev = row + key = (role, race, alignment) + if key not in all_levels: + all_levels[key] = [deathlev] + else: + all_levels[key].append(deathlev) + global_levels.append(deathlev) + +print("All dataset: ", np.median(global_levels), np.mean(global_levels)) +for key in all_levels: + print(key, np.median(all_levels[key]), np.mean(all_levels[key])) diff --git a/scripts/stats/scores.py b/scripts/stats/scores.py new file mode 100644 index 0000000..5bc3af8 --- /dev/null +++ b/scripts/stats/scores.py @@ -0,0 +1,33 @@ +""" +This is a script for extracting the scores (used for populating values in utils/scores.py) +""" +import nle.dataset as nld + +from nle.dataset.db import db as nld_database +from katakomba.utils.roles import Role, Race, Alignment + +data_path = "data/nle_data" +db_path = "ttyrecs.db" + +if not nld.db.exists(db_path): + nld.db.create(db_path) + nld.add_nledata_directory(data_path, "autoascend", db_path) + +db_connection = nld.db.connect(filename=db_path) + +with nld_database(conn=db_connection) as connection: + c = connection.execute( + "SELECT games.role, games.race, games.align0, AVG(games.points) " + "FROM games " + "JOIN datasets ON games.gameid=datasets.gameid " + "WHERE datasets.dataset_name='autoascend' " + "GROUP BY games.role, games.race, games.align0;", + ) + + for row in c: + role, race, alignment, avg_score = row + copypaste_string = f"({Role._value2member_map_[str.lower(role)]}, " + copypaste_string += f"{Race._value2member_map_[str.lower(race)]}, " + copypaste_string += f"{Alignment._value2member_map_[str.lower(alignment)]})" + copypaste_string += f": {avg_score:.2f}," + print(copypaste_string) \ No newline at end of file diff --git a/scripts/stats/small_scale_stats.py b/scripts/stats/small_scale_stats.py new file mode 100644 index 0000000..a5b5e7b --- /dev/null +++ b/scripts/stats/small_scale_stats.py @@ -0,0 +1,33 @@ +import numpy as np + +from katakomba.utils.roles import ALLOWED_COMBOS +from katakomba.utils.datasets.small_scale import load_nld_aa_small_dataset + + +def main(): + for role, race, align in ALLOWED_COMBOS: + df, traj = load_nld_aa_small_dataset(role=role, race=race, align=align, mode="compressed") + transitions = [t["actions"].shape[0] for t in traj.values()] + median_score = np.median([df[gameid].attrs["points"] for gameid in list(traj.keys())]) + median_depth = np.median([df[gameid].attrs["deathlev"] for gameid in list(traj.keys())]) + median_length = np.median(transitions) + + total_transitions = np.sum(transitions) + total_trajectories = len(traj.keys()) + + total_bytes = 0 + for episode in traj.values(): + total_bytes += sum([arr.nbytes for arr in episode.values()]) + + total_gb = round(total_bytes / 1e+9, 1) + + print(f"{role.value}-{race.value}-{align.value}", total_gb) + + print(f"{role.value}-{race.value}-{align.value}") + print(f"& {total_transitions} & {median_length} & {median_score} & {median_depth} & {total_gb} \\") + + df.close() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/test_chaotic_loader.py b/scripts/test_chaotic_loader.py deleted file mode 100644 index 9bb44d9..0000000 --- a/scripts/test_chaotic_loader.py +++ /dev/null @@ -1,48 +0,0 @@ -import time - -from torch.utils.data import DataLoader - -from katakomba.tasks import make_task_builder -from katakomba.datasets.sa_chaotic_autoascend import SAChaoticAutoAscendTTYDataset - - -NUM_BATCHES = 200 -BATCH_SIZE = 256 -SEQ_LEN = 32 -N_WORKERS = 8 -DEVICE = "cpu" - -env_builder, dataset_builder = make_task_builder( - "NetHackScore-v0-tty-bot-v0", data_path="../nethack/nle_data" -) - -dataset = dataset_builder.build( - batch_size=BATCH_SIZE, - seq_len=SEQ_LEN, - n_workers=N_WORKERS, - auto_ascend_cls=SAChaoticAutoAscendTTYDataset, -) - -loader = DataLoader( - dataset=dataset, - # Disable automatic batching - batch_sampler=None, - batch_size=None, -) - -start = time.time() -for ind, batch in enumerate(loader): - device_batch = [t.to(DEVICE) for t in batch] - print(device_batch[-2].size()) - break - - if (ind + 1) == NUM_BATCHES: - break -end = time.time() -elapsed = end - start -print( - f"Fetching {NUM_BATCHES} batches of [batch_size={BATCH_SIZE}, seq_len={SEQ_LEN}] took: {elapsed} seconds." -) -print(f"1 batch takes around {elapsed / NUM_BATCHES} seconds.") -print(f"Total frames fetched: {NUM_BATCHES * BATCH_SIZE * SEQ_LEN}") -print(f"Frames / s: {NUM_BATCHES * BATCH_SIZE * SEQ_LEN / elapsed}")