-
Notifications
You must be signed in to change notification settings - Fork 2
Closed
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
e93b862
first draft of bc with resnet and lstm
Howuhh 55c7f56
added evaluation and logging
Howuhh 0632340
Merge branch 'main' into howuhh/bc
Howuhh 4151e81
removed diffs
Howuhh dc30e73
tqdm added
Howuhh e65d18b
small comment
Howuhh c8b469f
added path to the builder and timers for profiling
Howuhh f813bd3
added db path to the dataloader too
Howuhh 1718cd5
amp testing
Howuhh ddf8ae2
run on full dataset
Howuhh f7fa328
update gitignore
Howuhh 2726557
merged main
Howuhh 1b49525
resolve conflict
Howuhh 4dc4b60
eval seeds as range
Howuhh f9ad576
fix warn with clip norm, add global mean return
Howuhh 1663f48
testing dataloader for BC
Howuhh bfa22ca
added numba conver variant also
Howuhh 1103146
bc with new dataloader
Howuhh a5fe887
fix resnet
Howuhh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
# Data | ||
data/ | ||
ttyrecs.db | ||
test.py | ||
test.ipynb | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import pyrallis | ||
from dataclasses import dataclass | ||
|
||
import sys | ||
import os | ||
import uuid | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.utils.data import DataLoader | ||
|
||
from torch.distributions import Categorical | ||
import numpy as np | ||
|
||
from typing import Optional, Tuple | ||
from d5rl.tasks import NetHackEnvBuilder, make_task_builder | ||
from d5rl.utils.roles import Alignment, Race, Role, Sex | ||
from d5rl.nn.resnet import ResNet11, ResNet20, ResNet38, ResNet56, ResNet110 | ||
|
||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
|
||
@dataclass | ||
class TrainConfig: | ||
env: str = "NetHackScore-v0-tty-bot-v0" | ||
# Wandb logging | ||
project: str = "NeuralNetHack" | ||
group: str = "DummyBC" | ||
name: str = "DummyBC" | ||
version: str = "v0" | ||
# Model | ||
resnet_type: str = "ResNet11" | ||
lstm_layers: int = 2 | ||
hidden_dim: int = 512 | ||
width_k: int = 1 | ||
# Training | ||
update_steps: int = 100_000 | ||
batch_size: int = 256 | ||
seq_len: int = 64 | ||
learning_rate: float = 3e-4 | ||
clip_grad: Optional[float] = None | ||
checkpoints_path: Optional[str] = None | ||
eval_every: int = 1000 | ||
eval_episodes: int = 10 | ||
eval_seeds: Tuple[int] = (228, 1337, 1307, 2, 10000) | ||
train_seed: int = 42 | ||
|
||
def __post_init__(self): | ||
self.group = f"{self.env}-{self.name}-{self.version}" | ||
self.name = f"{self.group}-{str(uuid.uuid4())[:8]}" | ||
if self.checkpoints_path is not None: | ||
self.checkpoints_path = os.path.join(self.checkpoints_path, self.name) | ||
|
||
|
||
class Actor(nn.Module): | ||
def __init__(self, action_dim, hidden_dim, lstm_layers, width_k, resnet_type): | ||
super().__init__() | ||
resnet = getattr(sys.modules[__name__], resnet_type) | ||
self.state_encoder = resnet(img_channels=3, out_dim=hidden_dim, k=width_k) | ||
self.rnn = nn.LSTM( | ||
input_size=hidden_dim, | ||
hidden_size=hidden_dim, | ||
num_layers=lstm_layers, | ||
batch_first=True | ||
) | ||
self.head = nn.Linear(hidden_dim, action_dim) | ||
|
||
def forward(self, obs, state=None): | ||
# [batch_size, seq_len, ...] | ||
batch_size, seq_len, *_ = obs.shape | ||
|
||
out = self.state_encoder(obs.flatten(0, 1)).view(batch_size, seq_len, -1) | ||
out, new_state = self.rnn(out, state) | ||
logits = self.head(out) | ||
|
||
return logits, new_state | ||
|
||
@torch.no_grad() | ||
def act(self, obs, state=None, device="cpu"): | ||
assert obs.ndim == 3, "act only for single obs" | ||
obs = torch.tensor(state[None, None, ...], device=device, dtype=torch.float32) | ||
logits, new_state = self(obs, state) | ||
return torch.argmax(logits).cpu().item(), new_state | ||
|
||
|
||
@pyrallis.wrap() | ||
def train(config: TrainConfig): | ||
env_builder, dataset_builder = make_task_builder(config.env) | ||
|
||
env_builder = env_builder.eval_seeds(config.eval_seeds) | ||
dataset = dataset_builder.build( | ||
batch_size=config.batch_size, | ||
seq_len=config.seq_len | ||
) | ||
actor = Actor( | ||
resnet_type=config.resnet_type, | ||
action_dim=env_builder.get_action_dim(), | ||
hidden_dim=config.hidden_dim, | ||
lstm_layers=config.lstm_layers, | ||
width_k=config.width_k | ||
) | ||
optim = torch.optim.AdamW(actor.parameters(), lr=config.learning_rate) | ||
|
||
loader = DataLoader( | ||
dataset=dataset, | ||
# Disable automatic batching | ||
batch_sampler=None, | ||
batch_size=None, | ||
) | ||
|
||
rnn_state = None | ||
for idx, batch in enumerate(loader): | ||
if idx >= config.update_steps: | ||
break | ||
|
||
states, actions, *_ = batch | ||
logits, rnn_state = actor( | ||
states.permute(0, 1, 4, 2, 3).to(DEVICE).to(torch.float32), | ||
state=rnn_state | ||
) | ||
rnn_state = [a.detach() for a in rnn_state] | ||
|
||
dist = Categorical(logits=logits) | ||
loss = -dist.log_prob(actions.to(DEVICE)).mean() | ||
|
||
optim.zero_grad() | ||
loss.backward() | ||
if config.clip_grad is not None: | ||
torch.nn.utils.clip_grad_norm(actor.parameters(), config.clip_grad) | ||
optim.step() | ||
|
||
print(loss) | ||
|
||
|
||
if __name__ == "__main__": | ||
train() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
мб лучше einops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
хз, я если честно его не очень понимаю, но если все за, то мб
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
у него просто более интуитивная форма, тут получается было бы "b t x y c -> b t c x y", что более интуитивно выглядит
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
то есть, понимать, что у эйнопса под коробкой происходит необязательно, чтобы input->output семантику улавливать
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
а он точно не дает оверхеда?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
не уверен, надо чекать