Skip to content

Commit d4c0691

Browse files
authored
Merge pull request #1 from tinkoff-ai/encoder
encoder dataloader
2 parents 8269131 + 4e5ca92 commit d4c0691

24 files changed

+516
-393
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ share/python-wheels/
2929
.installed.cfg
3030
*.egg
3131
MANIFEST
32+
.idea/
3233

3334
# PyInstaller
3435
# Usually these files are written by a python script from a template

bc_dummy.py

+60-57
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,57 @@
1-
from typing import Any, Dict, List, Optional, Tuple, Union
2-
from dataclasses import asdict, dataclass
3-
from collections import defaultdict
41
import os
5-
from pathlib import Path
62
import random
73
import uuid
4+
from collections import defaultdict
5+
from dataclasses import asdict, dataclass
6+
from pathlib import Path
7+
from typing import Any, Dict, List, Optional, Tuple
88

9-
import gym
109
import numpy as np
1110
import pyrallis
1211
import torch
1312
import torch.nn as nn
1413
import torch.nn.functional as F
1514
import wandb
16-
17-
from d5rl.tasks import make_task_builder, NetHackEnvBuilder
18-
from d5rl.utils.roles import Role, Alignment, Race, Sex
1915
from torch.utils.data import DataLoader
2016

17+
from d5rl.tasks import NetHackEnvBuilder, make_task_builder
18+
from d5rl.utils.roles import Alignment, Race, Role, Sex
19+
2120
TensorBatch = List[torch.Tensor]
2221

2322

2423
@dataclass
2524
class TrainConfig:
2625
# NetHack
27-
env : str = "NetHackScore-v0-tty-bot-v0"
28-
character : str = "mon-hum-neutral-male"
26+
env: str = "NetHackScore-v0-tty-bot-v0"
27+
character: str = "mon-hum-neutral-male"
2928
eval_seeds: Optional[Tuple[int]] = (228, 1337, 1307, 2, 10000)
3029

3130
# Training
32-
device : str = "cpu"
33-
seed : int = 0
34-
eval_freq : int = int(1000)
35-
n_episodes : int = 10
36-
max_timesteps : int = int(1e6)
31+
device: str = "cpu"
32+
seed: int = 0
33+
eval_freq: int = int(1000)
34+
n_episodes: int = 10
35+
max_timesteps: int = int(1e6)
3736
checkpoints_path: Optional[str] = None
38-
load_model : str = ""
39-
batch_size : int = 512
37+
load_model: str = ""
38+
batch_size: int = 512
4039

4140
# Wandb logging
4241
project: str = "NeuralNetHack"
43-
group : str = "DummyBC"
44-
name : str = "DummyBC"
42+
group: str = "DummyBC"
43+
name: str = "DummyBC"
4544
version: str = "v0"
4645

4746
def __post_init__(self):
4847
self.group = f"{self.env}-{self.name}-{self.version}"
49-
self.name = f"{self.group}-{str(uuid.uuid4())[:8]}"
48+
self.name = f"{self.group}-{str(uuid.uuid4())[:8]}"
5049

5150
if self.checkpoints_path is not None:
5251
self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)
5352

5453

55-
def set_seed(
56-
seed: int, deterministic_torch: bool = False
57-
):
54+
def set_seed(seed: int, deterministic_torch: bool = False):
5855
os.environ["PYTHONHASHSEED"] = str(seed)
5956
np.random.seed(seed)
6057
random.seed(seed)
@@ -64,21 +61,21 @@ def set_seed(
6461

6562
def wandb_init(config: dict) -> None:
6663
wandb.init(
67-
config = config,
68-
project = config["project"],
69-
group = config["group"],
70-
name = config["name"],
71-
id = str(uuid.uuid4()),
64+
config=config,
65+
project=config["project"],
66+
group=config["group"],
67+
name=config["name"],
68+
id=str(uuid.uuid4()),
7269
)
7370
wandb.run.save()
7471

7572

7673
@torch.no_grad()
7774
def eval_actor(
7875
env_builder: NetHackEnvBuilder,
79-
actor : nn.Module,
80-
device : str,
81-
n_episodes : int,
76+
actor: nn.Module,
77+
device: str,
78+
n_episodes: int,
8279
) -> Dict[str, Dict[int, float]]:
8380
actor.eval()
8481
eval_stats = defaultdict(dict)
@@ -96,7 +93,6 @@ def eval_actor(
9693
episode_rewards.append(episode_reward)
9794
eval_stats[character][seed] = np.mean(episode_rewards)
9895

99-
10096
actor.train()
10197

10298
return eval_stats
@@ -112,55 +108,59 @@ def __init__(self, action_dim: int):
112108
nn.Linear(256, 256),
113109
nn.ReLU(),
114110
nn.Linear(256, 256),
115-
nn.ReLU()
111+
nn.ReLU(),
116112
)
117113
self.colors_encoder = nn.Sequential(
118114
nn.Linear(24 * 80, 256),
119115
nn.ReLU(),
120116
nn.Linear(256, 256),
121117
nn.ReLU(),
122118
nn.Linear(256, 256),
123-
nn.ReLU()
119+
nn.ReLU(),
124120
)
125121
self.cursor_encoder = nn.Sequential(
126122
nn.Linear(24 * 80, 256),
127123
nn.ReLU(),
128124
nn.Linear(256, 256),
129125
nn.ReLU(),
130126
nn.Linear(256, 256),
131-
nn.ReLU()
127+
nn.ReLU(),
132128
)
133129
self.head = nn.Sequential(
134-
nn.Linear(256*3, 256),
130+
nn.Linear(256 * 3, 256),
135131
nn.ReLU(),
136132
nn.Linear(256, 128),
137133
nn.ReLU(),
138-
nn.Linear(128, action_dim)
134+
nn.Linear(128, action_dim),
139135
)
140136

141137
def forward(self, state: torch.Tensor) -> torch.Tensor:
142138
batch_size = state.shape[0]
143-
state = state.view(batch_size, -1, 3) / 255.0
139+
state = state.view(batch_size, -1, 3) / 255.0
144140

145-
chars_encoded = self.chars_encoder(state[:, :, 0])
141+
chars_encoded = self.chars_encoder(state[:, :, 0])
146142
colors_encoded = self.colors_encoder(state[:, :, 1])
147143
cursor_encoded = self.cursor_encoder(state[:, :, 2])
148144

149-
return self.head(torch.concat([chars_encoded, colors_encoded, cursor_encoded], dim=-1))
145+
return self.head(
146+
torch.concat([chars_encoded, colors_encoded, cursor_encoded], dim=-1)
147+
)
150148

151149
@torch.no_grad()
152150
def act(self, state: np.ndarray, device: str = "cpu") -> np.ndarray:
153-
state = torch.tensor(np.expand_dims(state, axis=0), device=device, dtype=torch.float32)
151+
state = torch.tensor(
152+
np.expand_dims(state, axis=0), device=device, dtype=torch.float32
153+
)
154154
logits = self(state)
155155
return torch.argmax(logits).cpu().item()
156156

157157

158158
class BC: # noqa
159159
def __init__(
160160
self,
161-
actor : nn.Module,
161+
actor: nn.Module,
162162
actor_optimizer: torch.optim.Optimizer,
163-
device : str = "cpu",
163+
device: str = "cpu",
164164
):
165165
self.actor = actor
166166
self.actor_optimizer = actor_optimizer
@@ -176,7 +176,12 @@ def train(self, batch: TensorBatch) -> Dict[str, float]:
176176

177177
# Compute actor loss
178178
pi = self.actor(state.squeeze())
179-
actor_loss = F.cross_entropy(pi, action.view(-1,))
179+
actor_loss = F.cross_entropy(
180+
pi,
181+
action.view(
182+
-1,
183+
),
184+
)
180185
log_dict["actor_loss"] = actor_loss.item()
181186
# Optimize the actor
182187
self.actor_optimizer.zero_grad()
@@ -203,26 +208,24 @@ def train(config: TrainConfig):
203208
# NetHack builders
204209
env_builder, dataset_builder = make_task_builder(config.env)
205210
env_builder = (
206-
env_builder
207-
.roles([Role.MONK])
211+
env_builder.roles([Role.MONK])
208212
.races([Race.HUMAN])
209213
.alignments([Alignment.NEUTRAL])
210214
.sex([Sex.MALE])
211215
.eval_seeds(list(config.eval_seeds))
212216
)
213217
dataset = (
214-
dataset_builder
215-
.roles([Role.MONK])
218+
dataset_builder.roles([Role.MONK])
216219
.races([Race.HUMAN])
217220
.alignments([Alignment.NEUTRAL])
218221
.sex([Sex.MALE])
219222
.build(batch_size=config.batch_size, seq_len=1, n_prefetched_batches=100)
220223
)
221224
loader = DataLoader(
222-
dataset = dataset,
225+
dataset=dataset,
223226
# Disable automatic batching
224-
batch_sampler = None,
225-
batch_size = None
227+
batch_sampler=None,
228+
batch_size=None,
226229
)
227230

228231
# Get number of actions for the task of interest
@@ -263,7 +266,7 @@ def train(config: TrainConfig):
263266

264267
evaluations = []
265268
for t, batch in enumerate(loader):
266-
batch = [b.to(config.device) for b in batch]
269+
batch = [b.to(config.device) for b in batch]
267270
log_dict = trainer.train(batch)
268271

269272
# Log train
@@ -274,10 +277,10 @@ def train(config: TrainConfig):
274277
print(f"Time steps: {t + 1}")
275278

276279
eval_stats = eval_actor(
277-
env_builder = env_builder,
278-
actor = actor,
279-
device = config.device,
280-
n_episodes = config.n_episodes,
280+
env_builder=env_builder,
281+
actor=actor,
282+
device=config.device,
283+
n_episodes=config.n_episodes,
281284
)
282285

283286
print(eval_stats)
@@ -299,4 +302,4 @@ def train(config: TrainConfig):
299302

300303

301304
if __name__ == "__main__":
302-
train()
305+
train()

d5rl/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""
22
D5RL: Dungeon Datasets for Deep Data-Driven Reinforcement Learning
3-
"""
3+
"""

d5rl/datasets/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from d5rl.datasets.autoascend import AutoAscendTTYDataset
2-
from d5rl.datasets.builder import AutoAscendDatasetBuilder
1+
from d5rl.datasets.base import BaseAutoAscend
2+
from d5rl.datasets.builder import AutoAscendDatasetBuilder
3+
from d5rl.datasets.sars_autoascend import SARSAutoAscendTTYDataset

0 commit comments

Comments
 (0)