Skip to content

Commit 0675ed6

Browse files
refactor: Improve load most recent model method
1 parent 8ecf485 commit 0675ed6

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

urnai/trainers/stablebaselines3_trainer.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,17 @@ def load_model(self, model_path):
3030
self.model = self.model.load(model_path, env = self.train_env)
3131

3232
def load_most_recent_model(self, model_path):
33-
directory = os.fsencode(model_path)
34-
most_recent_model_filename = None
35-
greatest_timestamp = 0
36-
37-
for file in os.listdir(directory):
38-
filename = os.fsdecode(file)
39-
if ".save" in filename:
40-
timestep = int(filename.split(".")[0])
41-
if timestep > greatest_timestamp:
42-
greatest_timestamp = timestep
43-
most_recent_model_filename = filename
33+
save_files = list(filter(lambda filename : ".save" in filename,
34+
os.listdir(model_path)))
4435

45-
if most_recent_model_filename is None:
36+
if len(save_files) == 0:
4637
raise Exception(f"No models found in {model_path}")
47-
48-
self.load_model(f"{model_path}/{most_recent_model_filename}")
49-
38+
else:
39+
def only_digits(filename):
40+
return ''.join(c for c in filename if c.isdigit())
41+
save_files.sort(reverse=True, key=only_digits)
42+
self.load_model(f"{model_path}/{save_files[0]}")
43+
5044
def train_model(
5145
self, timesteps: int = 10000, log_interval: int = 1,
5246
reset_num_timesteps: bool = False, progress_bar: bool = False,

0 commit comments

Comments
 (0)