Skip to content

[DRAFT] BC + ResNet + LSTM #5

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Data
data/
ttyrecs.db
test.py
test.ipynb

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
136 changes: 136 additions & 0 deletions algoritms/bc_resnet_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import pyrallis
from dataclasses import dataclass

import sys
import os
import uuid
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torch.distributions import Categorical
import numpy as np

from typing import Optional, Tuple
from d5rl.tasks import NetHackEnvBuilder, make_task_builder
from d5rl.utils.roles import Alignment, Race, Role, Sex
from d5rl.nn.resnet import ResNet11, ResNet20, ResNet38, ResNet56, ResNet110

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


@dataclass
class TrainConfig:
env: str = "NetHackScore-v0-tty-bot-v0"
# Wandb logging
project: str = "NeuralNetHack"
group: str = "DummyBC"
name: str = "DummyBC"
version: str = "v0"
# Model
resnet_type: str = "ResNet11"
lstm_layers: int = 2
hidden_dim: int = 512
width_k: int = 1
# Training
update_steps: int = 100_000
batch_size: int = 256
seq_len: int = 64
learning_rate: float = 3e-4
clip_grad: Optional[float] = None
checkpoints_path: Optional[str] = None
eval_every: int = 1000
eval_episodes: int = 10
eval_seeds: Tuple[int] = (228, 1337, 1307, 2, 10000)
train_seed: int = 42

def __post_init__(self):
self.group = f"{self.env}-{self.name}-{self.version}"
self.name = f"{self.group}-{str(uuid.uuid4())[:8]}"
if self.checkpoints_path is not None:
self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)


class Actor(nn.Module):
def __init__(self, action_dim, hidden_dim, lstm_layers, width_k, resnet_type):
super().__init__()
resnet = getattr(sys.modules[__name__], resnet_type)
self.state_encoder = resnet(img_channels=3, out_dim=hidden_dim, k=width_k)
self.rnn = nn.LSTM(
input_size=hidden_dim,
hidden_size=hidden_dim,
num_layers=lstm_layers,
batch_first=True
)
self.head = nn.Linear(hidden_dim, action_dim)

def forward(self, obs, state=None):
# [batch_size, seq_len, ...]
batch_size, seq_len, *_ = obs.shape

out = self.state_encoder(obs.flatten(0, 1)).view(batch_size, seq_len, -1)
out, new_state = self.rnn(out, state)
logits = self.head(out)

return logits, new_state

@torch.no_grad()
def act(self, obs, state=None, device="cpu"):
assert obs.ndim == 3, "act only for single obs"
obs = torch.tensor(state[None, None, ...], device=device, dtype=torch.float32)
logits, new_state = self(obs, state)
return torch.argmax(logits).cpu().item(), new_state


@pyrallis.wrap()
def train(config: TrainConfig):
env_builder, dataset_builder = make_task_builder(config.env)

env_builder = env_builder.eval_seeds(config.eval_seeds)
dataset = dataset_builder.build(
batch_size=config.batch_size,
seq_len=config.seq_len
)
actor = Actor(
resnet_type=config.resnet_type,
action_dim=env_builder.get_action_dim(),
hidden_dim=config.hidden_dim,
lstm_layers=config.lstm_layers,
width_k=config.width_k
)
optim = torch.optim.AdamW(actor.parameters(), lr=config.learning_rate)

loader = DataLoader(
dataset=dataset,
# Disable automatic batching
batch_sampler=None,
batch_size=None,
)

rnn_state = None
for idx, batch in enumerate(loader):
if idx >= config.update_steps:
break

states, actions, *_ = batch
logits, rnn_state = actor(
states.permute(0, 1, 4, 2, 3).to(DEVICE).to(torch.float32),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

мб лучше einops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

хз, я если честно его не очень понимаю, но если все за, то мб

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

у него просто более интуитивная форма, тут получается было бы "b t x y c -> b t c x y", что более интуитивно выглядит

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

то есть, понимать, что у эйнопса под коробкой происходит необязательно, чтобы input->output семантику улавливать

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

а он точно не дает оверхеда?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

не уверен, надо чекать

state=rnn_state
)
rnn_state = [a.detach() for a in rnn_state]

dist = Categorical(logits=logits)
loss = -dist.log_prob(actions.to(DEVICE)).mean()

optim.zero_grad()
loss.backward()
if config.clip_grad is not None:
torch.nn.utils.clip_grad_norm(actor.parameters(), config.clip_grad)
optim.step()

print(loss)


if __name__ == "__main__":
train()
8 changes: 4 additions & 4 deletions d5rl/datasets/sars_autoascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __iter__(self):
# [r_n, r_n+1, r_n+2, r_n-1]
# [d_n, d_n+1, d_n+2, d_n-1]
# [s_n+1, s_n+2, s_n+3, s_n]
# TODO: gigantic overhead over original loader, somehow we need to optimimze this!
rewards = np.roll(rewards, shift=-1, axis=1)
dones = np.roll(dones, shift=-1, axis=1)
next_states = np.roll(deepcopy(states), shift=-1, axis=1)
Expand All @@ -59,11 +60,10 @@ def __iter__(self):
def _convert_batch(self, batch):
# [batch_size, seq_len, 24, 80, 3]
states = tty_to_numpy(
tty_chars=batch["tty_chars"].squeeze(),
tty_colors=batch["tty_colors"].squeeze(),
tty_cursor=batch["tty_cursor"].squeeze(),
tty_chars=batch["tty_chars"],
tty_colors=batch["tty_colors"],
tty_cursor=batch["tty_cursor"],
)

# [batch_size, seq_len]
actions = ascii_actions_to_gym_actions(batch["keypresses"])

Expand Down
1 change: 0 additions & 1 deletion d5rl/envs/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def evaluate(self):
"""
An iterator over the NLE settings to evaluate against.
"""

all_valid_combinations = deepcopy(ALLOWED_COMBOS)
valid_combinations = set()

Expand Down
1 change: 0 additions & 1 deletion d5rl/utils/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def tty_to_numpy(tty_chars, tty_colors, tty_cursor) -> np.ndarray:
shape=(batch_size, seq_len, TERMINAL_SHAPE[0], TERMINAL_SHAPE[1], 3),
dtype=np.uint8,
)

obs[:, :, :, :, 0] = tty_chars
obs[:, :, :, :, 1] = tty_colors

Expand Down