@@ -30,23 +30,17 @@ def load_model(self, model_path):
30
30
self .model = self .model .load (model_path , env = self .train_env )
31
31
32
32
def load_most_recent_model (self , model_path ):
33
- directory = os .fsencode (model_path )
34
- most_recent_model_filename = None
35
- greatest_timestamp = 0
36
-
37
- for file in os .listdir (directory ):
38
- filename = os .fsdecode (file )
39
- if ".save" in filename :
40
- timestep = int (filename .split ("." )[0 ])
41
- if timestep > greatest_timestamp :
42
- greatest_timestamp = timestep
43
- most_recent_model_filename = filename
33
+ save_files = list (filter (lambda filename : ".save" in filename ,
34
+ os .listdir (model_path )))
44
35
45
- if most_recent_model_filename is None :
36
+ if len ( save_files ) == 0 :
46
37
raise Exception (f"No models found in { model_path } " )
47
-
48
- self .load_model (f"{ model_path } /{ most_recent_model_filename } " )
49
-
38
+ else :
39
+ def only_digits (filename ):
40
+ return '' .join (c for c in filename if c .isdigit ())
41
+ save_files .sort (reverse = True , key = only_digits )
42
+ self .load_model (f"{ model_path } /{ save_files [0 ]} " )
43
+
50
44
def train_model (
51
45
self , timesteps : int = 10000 , log_interval : int = 1 ,
52
46
reset_num_timesteps : bool = False , progress_bar : bool = False ,
0 commit comments