Skip to content

encoder dataloader #1

Merged
merged 7 commits into from
Jan 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
.idea/

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
117 changes: 60 additions & 57 deletions bc_dummy.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,57 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from dataclasses import asdict, dataclass
from collections import defaultdict
import os
from pathlib import Path
import random
import uuid
from collections import defaultdict
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import gym
import numpy as np
import pyrallis
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb

from d5rl.tasks import make_task_builder, NetHackEnvBuilder
from d5rl.utils.roles import Role, Alignment, Race, Sex
from torch.utils.data import DataLoader

from d5rl.tasks import NetHackEnvBuilder, make_task_builder
from d5rl.utils.roles import Alignment, Race, Role, Sex

TensorBatch = List[torch.Tensor]


@dataclass
class TrainConfig:
# NetHack
env : str = "NetHackScore-v0-tty-bot-v0"
character : str = "mon-hum-neutral-male"
env: str = "NetHackScore-v0-tty-bot-v0"
character: str = "mon-hum-neutral-male"
eval_seeds: Optional[Tuple[int]] = (228, 1337, 1307, 2, 10000)

# Training
device : str = "cpu"
seed : int = 0
eval_freq : int = int(1000)
n_episodes : int = 10
max_timesteps : int = int(1e6)
device: str = "cpu"
seed: int = 0
eval_freq: int = int(1000)
n_episodes: int = 10
max_timesteps: int = int(1e6)
checkpoints_path: Optional[str] = None
load_model : str = ""
batch_size : int = 512
load_model: str = ""
batch_size: int = 512

# Wandb logging
project: str = "NeuralNetHack"
group : str = "DummyBC"
name : str = "DummyBC"
group: str = "DummyBC"
name: str = "DummyBC"
version: str = "v0"

def __post_init__(self):
self.group = f"{self.env}-{self.name}-{self.version}"
self.name = f"{self.group}-{str(uuid.uuid4())[:8]}"
self.name = f"{self.group}-{str(uuid.uuid4())[:8]}"

if self.checkpoints_path is not None:
self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)


def set_seed(
seed: int, deterministic_torch: bool = False
):
def set_seed(seed: int, deterministic_torch: bool = False):
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
random.seed(seed)
Expand All @@ -64,21 +61,21 @@ def set_seed(

def wandb_init(config: dict) -> None:
wandb.init(
config = config,
project = config["project"],
group = config["group"],
name = config["name"],
id = str(uuid.uuid4()),
config=config,
project=config["project"],
group=config["group"],
name=config["name"],
id=str(uuid.uuid4()),
)
wandb.run.save()


@torch.no_grad()
def eval_actor(
env_builder: NetHackEnvBuilder,
actor : nn.Module,
device : str,
n_episodes : int,
actor: nn.Module,
device: str,
n_episodes: int,
) -> Dict[str, Dict[int, float]]:
actor.eval()
eval_stats = defaultdict(dict)
Expand All @@ -96,7 +93,6 @@ def eval_actor(
episode_rewards.append(episode_reward)
eval_stats[character][seed] = np.mean(episode_rewards)


actor.train()

return eval_stats
Expand All @@ -112,55 +108,59 @@ def __init__(self, action_dim: int):
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU()
nn.ReLU(),
)
self.colors_encoder = nn.Sequential(
nn.Linear(24 * 80, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU()
nn.ReLU(),
)
self.cursor_encoder = nn.Sequential(
nn.Linear(24 * 80, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU()
nn.ReLU(),
)
self.head = nn.Sequential(
nn.Linear(256*3, 256),
nn.Linear(256 * 3, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, action_dim)
nn.Linear(128, action_dim),
)

def forward(self, state: torch.Tensor) -> torch.Tensor:
batch_size = state.shape[0]
state = state.view(batch_size, -1, 3) / 255.0
state = state.view(batch_size, -1, 3) / 255.0

chars_encoded = self.chars_encoder(state[:, :, 0])
chars_encoded = self.chars_encoder(state[:, :, 0])
colors_encoded = self.colors_encoder(state[:, :, 1])
cursor_encoded = self.cursor_encoder(state[:, :, 2])

return self.head(torch.concat([chars_encoded, colors_encoded, cursor_encoded], dim=-1))
return self.head(
torch.concat([chars_encoded, colors_encoded, cursor_encoded], dim=-1)
)

@torch.no_grad()
def act(self, state: np.ndarray, device: str = "cpu") -> np.ndarray:
state = torch.tensor(np.expand_dims(state, axis=0), device=device, dtype=torch.float32)
state = torch.tensor(
np.expand_dims(state, axis=0), device=device, dtype=torch.float32
)
logits = self(state)
return torch.argmax(logits).cpu().item()


class BC: # noqa
def __init__(
self,
actor : nn.Module,
actor: nn.Module,
actor_optimizer: torch.optim.Optimizer,
device : str = "cpu",
device: str = "cpu",
):
self.actor = actor
self.actor_optimizer = actor_optimizer
Expand All @@ -176,7 +176,12 @@ def train(self, batch: TensorBatch) -> Dict[str, float]:

# Compute actor loss
pi = self.actor(state.squeeze())
actor_loss = F.cross_entropy(pi, action.view(-1,))
actor_loss = F.cross_entropy(
pi,
action.view(
-1,
),
)
log_dict["actor_loss"] = actor_loss.item()
# Optimize the actor
self.actor_optimizer.zero_grad()
Expand All @@ -203,26 +208,24 @@ def train(config: TrainConfig):
# NetHack builders
env_builder, dataset_builder = make_task_builder(config.env)
env_builder = (
env_builder
.roles([Role.MONK])
env_builder.roles([Role.MONK])
.races([Race.HUMAN])
.alignments([Alignment.NEUTRAL])
.sex([Sex.MALE])
.eval_seeds(list(config.eval_seeds))
)
dataset = (
dataset_builder
.roles([Role.MONK])
dataset_builder.roles([Role.MONK])
.races([Race.HUMAN])
.alignments([Alignment.NEUTRAL])
.sex([Sex.MALE])
.build(batch_size=config.batch_size, seq_len=1, n_prefetched_batches=100)
)
loader = DataLoader(
dataset = dataset,
dataset=dataset,
# Disable automatic batching
batch_sampler = None,
batch_size = None
batch_sampler=None,
batch_size=None,
)

# Get number of actions for the task of interest
Expand Down Expand Up @@ -263,7 +266,7 @@ def train(config: TrainConfig):

evaluations = []
for t, batch in enumerate(loader):
batch = [b.to(config.device) for b in batch]
batch = [b.to(config.device) for b in batch]
log_dict = trainer.train(batch)

# Log train
Expand All @@ -274,10 +277,10 @@ def train(config: TrainConfig):
print(f"Time steps: {t + 1}")

eval_stats = eval_actor(
env_builder = env_builder,
actor = actor,
device = config.device,
n_episodes = config.n_episodes,
env_builder=env_builder,
actor=actor,
device=config.device,
n_episodes=config.n_episodes,
)

print(eval_stats)
Expand All @@ -299,4 +302,4 @@ def train(config: TrainConfig):


if __name__ == "__main__":
train()
train()
2 changes: 1 addition & 1 deletion d5rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""
D5RL: Dungeon Datasets for Deep Data-Driven Reinforcement Learning
"""
"""
5 changes: 3 additions & 2 deletions d5rl/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from d5rl.datasets.autoascend import AutoAscendTTYDataset
from d5rl.datasets.builder import AutoAscendDatasetBuilder
from d5rl.datasets.base import BaseAutoAscend
from d5rl.datasets.builder import AutoAscendDatasetBuilder
from d5rl.datasets.sars_autoascend import SARSAutoAscendTTYDataset
Loading