-
Notifications
You must be signed in to change notification settings - Fork 2
Conversation
|
||
set_seed(config.train_seed) | ||
|
||
def env_fn(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move it to a standalone function out of main
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then we have to add character as an argument, it's not very nice
tmp_env = env_fn() | ||
eval_env = AsyncVectorEnv( | ||
env_fns=[env_fn for _ in range(config.eval_processes)], | ||
shared_memory=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment explaining why this is needed
seed=config.train_seed, | ||
add_next_step=False | ||
) | ||
tp = ThreadPoolExecutor(max_workers=14) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should either be a config value or something automatic based on the number of cores/processors we have
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed, now as config value
katakomba/env.py
Outdated
Returns score normalized against AutoAscend bot scores achieved for this exact character. | ||
""" | ||
if self.character.count("-") != 2: | ||
raise ValueError("Reference score not provided for this character.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: Reference score is not provided...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
katakomba/env.py
Outdated
|
||
def get_dataset(self, scale: str = "small", **kwargs): | ||
if self.character.count("-") != 2: | ||
raise ValueError("Reference score not provided for this character.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks ok
|
||
|
||
@torch.no_grad() | ||
def filter_wd_params(model: nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add type for return value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
return no_decay, decay | ||
|
||
|
||
def dict_to_tensor(data, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
|
||
@torch.no_grad() | ||
def vec_evaluate(vec_env, actor, num_episodes, seed=0, device="cpu"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
|
||
|
||
class Actor(nn.Module): | ||
def __init__(self, action_dim, rnn_hidden_dim=512, rnn_layers=1, rnn_dropout=0.0, use_prev_action=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
pbar.close() | ||
result = { | ||
"reward_median": np.median(episode_rewards), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's rename to returns to be consistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return_median, return_mean, etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to late, we have all logs in wandb in this format....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is consistent across algorithms tho
align: Optional[Alignment] = None, | ||
**kwargs | ||
) -> nld.TtyrecDataset: | ||
if not nld.db.exists(db_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this original solution from DD is actually a bit problematic
if the db was not properly initiialized for some reason (i.e., a wrong path and then fixed) this will silently re-use db
i think it's better to initialize the DB each time as it does not take much time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree
CACHE_PATH = os.environ.get('KATAKOMBA_CACHE_DIR', os.path.expanduser('~/.katakomba/cache')) | ||
|
||
|
||
def _flush_to_memmap(filename: str, array: np.ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return type is missing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
gameid = self.gameids[idx] | ||
return dict(self.hdf5_file[gameid].attrs) | ||
|
||
def close(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a flag for cleaning the memmap
sometimes people would like to work with just one dataset and rebuilding it every time is not desirable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
def close(self): | ||
self.hdf5_file.close() | ||
# remove memmap files from the disk upon closing | ||
if self.mode == "memmap": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add logging that this is happening
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
No description provided.