1
- from typing import Any , Dict , List , Optional , Tuple , Union
2
- from dataclasses import asdict , dataclass
3
- from collections import defaultdict
4
1
import os
5
- from pathlib import Path
6
2
import random
7
3
import uuid
4
+ from collections import defaultdict
5
+ from dataclasses import asdict , dataclass
6
+ from pathlib import Path
7
+ from typing import Any , Dict , List , Optional , Tuple
8
8
9
- import gym
10
9
import numpy as np
11
10
import pyrallis
12
11
import torch
13
12
import torch .nn as nn
14
13
import torch .nn .functional as F
15
14
import wandb
16
-
17
- from d5rl .tasks import make_task_builder , NetHackEnvBuilder
18
- from d5rl .utils .roles import Role , Alignment , Race , Sex
19
15
from torch .utils .data import DataLoader
20
16
17
+ from d5rl .tasks import NetHackEnvBuilder , make_task_builder
18
+ from d5rl .utils .roles import Alignment , Race , Role , Sex
19
+
21
20
TensorBatch = List [torch .Tensor ]
22
21
23
22
24
23
@dataclass
25
24
class TrainConfig :
26
25
# NetHack
27
- env : str = "NetHackScore-v0-tty-bot-v0"
28
- character : str = "mon-hum-neutral-male"
26
+ env : str = "NetHackScore-v0-tty-bot-v0"
27
+ character : str = "mon-hum-neutral-male"
29
28
eval_seeds : Optional [Tuple [int ]] = (228 , 1337 , 1307 , 2 , 10000 )
30
29
31
30
# Training
32
- device : str = "cpu"
33
- seed : int = 0
34
- eval_freq : int = int (1000 )
35
- n_episodes : int = 10
36
- max_timesteps : int = int (1e6 )
31
+ device : str = "cpu"
32
+ seed : int = 0
33
+ eval_freq : int = int (1000 )
34
+ n_episodes : int = 10
35
+ max_timesteps : int = int (1e6 )
37
36
checkpoints_path : Optional [str ] = None
38
- load_model : str = ""
39
- batch_size : int = 512
37
+ load_model : str = ""
38
+ batch_size : int = 512
40
39
41
40
# Wandb logging
42
41
project : str = "NeuralNetHack"
43
- group : str = "DummyBC"
44
- name : str = "DummyBC"
42
+ group : str = "DummyBC"
43
+ name : str = "DummyBC"
45
44
version : str = "v0"
46
45
47
46
def __post_init__ (self ):
48
47
self .group = f"{ self .env } -{ self .name } -{ self .version } "
49
- self .name = f"{ self .group } -{ str (uuid .uuid4 ())[:8 ]} "
48
+ self .name = f"{ self .group } -{ str (uuid .uuid4 ())[:8 ]} "
50
49
51
50
if self .checkpoints_path is not None :
52
51
self .checkpoints_path = os .path .join (self .checkpoints_path , self .name )
53
52
54
53
55
- def set_seed (
56
- seed : int , deterministic_torch : bool = False
57
- ):
54
+ def set_seed (seed : int , deterministic_torch : bool = False ):
58
55
os .environ ["PYTHONHASHSEED" ] = str (seed )
59
56
np .random .seed (seed )
60
57
random .seed (seed )
@@ -64,21 +61,21 @@ def set_seed(
64
61
65
62
def wandb_init (config : dict ) -> None :
66
63
wandb .init (
67
- config = config ,
68
- project = config ["project" ],
69
- group = config ["group" ],
70
- name = config ["name" ],
71
- id = str (uuid .uuid4 ()),
64
+ config = config ,
65
+ project = config ["project" ],
66
+ group = config ["group" ],
67
+ name = config ["name" ],
68
+ id = str (uuid .uuid4 ()),
72
69
)
73
70
wandb .run .save ()
74
71
75
72
76
73
@torch .no_grad ()
77
74
def eval_actor (
78
75
env_builder : NetHackEnvBuilder ,
79
- actor : nn .Module ,
80
- device : str ,
81
- n_episodes : int ,
76
+ actor : nn .Module ,
77
+ device : str ,
78
+ n_episodes : int ,
82
79
) -> Dict [str , Dict [int , float ]]:
83
80
actor .eval ()
84
81
eval_stats = defaultdict (dict )
@@ -96,7 +93,6 @@ def eval_actor(
96
93
episode_rewards .append (episode_reward )
97
94
eval_stats [character ][seed ] = np .mean (episode_rewards )
98
95
99
-
100
96
actor .train ()
101
97
102
98
return eval_stats
@@ -112,55 +108,59 @@ def __init__(self, action_dim: int):
112
108
nn .Linear (256 , 256 ),
113
109
nn .ReLU (),
114
110
nn .Linear (256 , 256 ),
115
- nn .ReLU ()
111
+ nn .ReLU (),
116
112
)
117
113
self .colors_encoder = nn .Sequential (
118
114
nn .Linear (24 * 80 , 256 ),
119
115
nn .ReLU (),
120
116
nn .Linear (256 , 256 ),
121
117
nn .ReLU (),
122
118
nn .Linear (256 , 256 ),
123
- nn .ReLU ()
119
+ nn .ReLU (),
124
120
)
125
121
self .cursor_encoder = nn .Sequential (
126
122
nn .Linear (24 * 80 , 256 ),
127
123
nn .ReLU (),
128
124
nn .Linear (256 , 256 ),
129
125
nn .ReLU (),
130
126
nn .Linear (256 , 256 ),
131
- nn .ReLU ()
127
+ nn .ReLU (),
132
128
)
133
129
self .head = nn .Sequential (
134
- nn .Linear (256 * 3 , 256 ),
130
+ nn .Linear (256 * 3 , 256 ),
135
131
nn .ReLU (),
136
132
nn .Linear (256 , 128 ),
137
133
nn .ReLU (),
138
- nn .Linear (128 , action_dim )
134
+ nn .Linear (128 , action_dim ),
139
135
)
140
136
141
137
def forward (self , state : torch .Tensor ) -> torch .Tensor :
142
138
batch_size = state .shape [0 ]
143
- state = state .view (batch_size , - 1 , 3 ) / 255.0
139
+ state = state .view (batch_size , - 1 , 3 ) / 255.0
144
140
145
- chars_encoded = self .chars_encoder (state [:, :, 0 ])
141
+ chars_encoded = self .chars_encoder (state [:, :, 0 ])
146
142
colors_encoded = self .colors_encoder (state [:, :, 1 ])
147
143
cursor_encoded = self .cursor_encoder (state [:, :, 2 ])
148
144
149
- return self .head (torch .concat ([chars_encoded , colors_encoded , cursor_encoded ], dim = - 1 ))
145
+ return self .head (
146
+ torch .concat ([chars_encoded , colors_encoded , cursor_encoded ], dim = - 1 )
147
+ )
150
148
151
149
@torch .no_grad ()
152
150
def act (self , state : np .ndarray , device : str = "cpu" ) -> np .ndarray :
153
- state = torch .tensor (np .expand_dims (state , axis = 0 ), device = device , dtype = torch .float32 )
151
+ state = torch .tensor (
152
+ np .expand_dims (state , axis = 0 ), device = device , dtype = torch .float32
153
+ )
154
154
logits = self (state )
155
155
return torch .argmax (logits ).cpu ().item ()
156
156
157
157
158
158
class BC : # noqa
159
159
def __init__ (
160
160
self ,
161
- actor : nn .Module ,
161
+ actor : nn .Module ,
162
162
actor_optimizer : torch .optim .Optimizer ,
163
- device : str = "cpu" ,
163
+ device : str = "cpu" ,
164
164
):
165
165
self .actor = actor
166
166
self .actor_optimizer = actor_optimizer
@@ -176,7 +176,12 @@ def train(self, batch: TensorBatch) -> Dict[str, float]:
176
176
177
177
# Compute actor loss
178
178
pi = self .actor (state .squeeze ())
179
- actor_loss = F .cross_entropy (pi , action .view (- 1 ,))
179
+ actor_loss = F .cross_entropy (
180
+ pi ,
181
+ action .view (
182
+ - 1 ,
183
+ ),
184
+ )
180
185
log_dict ["actor_loss" ] = actor_loss .item ()
181
186
# Optimize the actor
182
187
self .actor_optimizer .zero_grad ()
@@ -203,26 +208,24 @@ def train(config: TrainConfig):
203
208
# NetHack builders
204
209
env_builder , dataset_builder = make_task_builder (config .env )
205
210
env_builder = (
206
- env_builder
207
- .roles ([Role .MONK ])
211
+ env_builder .roles ([Role .MONK ])
208
212
.races ([Race .HUMAN ])
209
213
.alignments ([Alignment .NEUTRAL ])
210
214
.sex ([Sex .MALE ])
211
215
.eval_seeds (list (config .eval_seeds ))
212
216
)
213
217
dataset = (
214
- dataset_builder
215
- .roles ([Role .MONK ])
218
+ dataset_builder .roles ([Role .MONK ])
216
219
.races ([Race .HUMAN ])
217
220
.alignments ([Alignment .NEUTRAL ])
218
221
.sex ([Sex .MALE ])
219
222
.build (batch_size = config .batch_size , seq_len = 1 , n_prefetched_batches = 100 )
220
223
)
221
224
loader = DataLoader (
222
- dataset = dataset ,
225
+ dataset = dataset ,
223
226
# Disable automatic batching
224
- batch_sampler = None ,
225
- batch_size = None
227
+ batch_sampler = None ,
228
+ batch_size = None ,
226
229
)
227
230
228
231
# Get number of actions for the task of interest
@@ -263,7 +266,7 @@ def train(config: TrainConfig):
263
266
264
267
evaluations = []
265
268
for t , batch in enumerate (loader ):
266
- batch = [b .to (config .device ) for b in batch ]
269
+ batch = [b .to (config .device ) for b in batch ]
267
270
log_dict = trainer .train (batch )
268
271
269
272
# Log train
@@ -274,10 +277,10 @@ def train(config: TrainConfig):
274
277
print (f"Time steps: { t + 1 } " )
275
278
276
279
eval_stats = eval_actor (
277
- env_builder = env_builder ,
278
- actor = actor ,
279
- device = config .device ,
280
- n_episodes = config .n_episodes ,
280
+ env_builder = env_builder ,
281
+ actor = actor ,
282
+ device = config .device ,
283
+ n_episodes = config .n_episodes ,
281
284
)
282
285
283
286
print (eval_stats )
@@ -299,4 +302,4 @@ def train(config: TrainConfig):
299
302
300
303
301
304
if __name__ == "__main__" :
302
- train ()
305
+ train ()
0 commit comments