Skip to content

Refactoring #19

Merged
merged 21 commits into from
Jun 14, 2023
Merged

Refactoring #19

merged 21 commits into from
Jun 14, 2023

Conversation

Howuhh
Copy link
Contributor

@Howuhh Howuhh commented May 23, 2023

No description provided.


set_seed(config.train_seed)

def env_fn():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move it to a standalone function out of main

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then we have to add character as an argument, it's not very nice

tmp_env = env_fn()
eval_env = AsyncVectorEnv(
env_fns=[env_fn for _ in range(config.eval_processes)],
shared_memory=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment explaining why this is needed

seed=config.train_seed,
add_next_step=False
)
tp = ThreadPoolExecutor(max_workers=14)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should either be a config value or something automatic based on the number of cores/processors we have

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, now as config value

katakomba/env.py Outdated
Returns score normalized against AutoAscend bot scores achieved for this exact character.
"""
if self.character.count("-") != 2:
raise ValueError("Reference score not provided for this character.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Reference score is not provided...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

katakomba/env.py Outdated

def get_dataset(self, scale: str = "small", **kwargs):
if self.character.count("-") != 2:
raise ValueError("Reference score not provided for this character.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Contributor

@vkurenkov vkurenkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks ok



@torch.no_grad()
def filter_wd_params(model: nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add type for return value

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return no_decay, decay


def dict_to_tensor(data, device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



@torch.no_grad()
def vec_evaluate(vec_env, actor, num_episodes, seed=0, device="cpu"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed



class Actor(nn.Module):
def __init__(self, action_dim, rnn_hidden_dim=512, rnn_layers=1, rnn_dropout=0.0, use_prev_action=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


pbar.close()
result = {
"reward_median": np.median(episode_rewards),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rename to returns to be consistent

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return_median, return_mean, etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to late, we have all logs in wandb in this format....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is consistent across algorithms tho

align: Optional[Alignment] = None,
**kwargs
) -> nld.TtyrecDataset:
if not nld.db.exists(db_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this original solution from DD is actually a bit problematic
if the db was not properly initiialized for some reason (i.e., a wrong path and then fixed) this will silently re-use db
i think it's better to initialize the DB each time as it does not take much time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

CACHE_PATH = os.environ.get('KATAKOMBA_CACHE_DIR', os.path.expanduser('~/.katakomba/cache'))


def _flush_to_memmap(filename: str, array: np.ndarray):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return type is missing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

gameid = self.gameids[idx]
return dict(self.hdf5_file[gameid].attrs)

def close(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a flag for cleaning the memmap
sometimes people would like to work with just one dataset and rebuilding it every time is not desirable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

def close(self):
self.hdf5_file.close()
# remove memmap files from the disk upon closing
if self.mode == "memmap":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add logging that this is happening

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@vkurenkov vkurenkov merged commit 61a7c77 into main Jun 14, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants