Skip to content

Commit 3492f44

Browse files
authored
Fix snapshotter ordering (#2327)
1 parent 6461a07 commit 3492f44

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/garage/experiment/snapshotter.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def load(self, load_dir, itr='last'):
159159
raise FileNotFoundError(errno.ENOENT,
160160
os.strerror(errno.ENOENT),
161161
'*.pkl file in', load_dir)
162-
files.sort()
162+
files.sort(key=_extract_snapshot_itr)
163163
load_from_file = files[0] if itr == 'first' else files[-1]
164164
load_from_file = os.path.join(load_dir, load_from_file)
165165

@@ -170,5 +170,20 @@ def load(self, load_dir, itr='last'):
170170
return cloudpickle.load(file)
171171

172172

173+
def _extract_snapshot_itr(filename: str) -> int:
174+
"""Extracts the integer itr from a filename.
175+
176+
Args:
177+
filename(str): The snapshot filename.
178+
179+
Returns:
180+
int: The snapshot as an integer.
181+
182+
"""
183+
base = os.path.splitext(filename)[0]
184+
digits = base.split('itr_')[1]
185+
return int(digits)
186+
187+
173188
class NotAFileError(Exception):
174189
"""Raise when the snapshot is not a file."""

tests/garage/experiment/test_snapshotter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
class TestSnapshotter:
2323

2424
def setup_method(self):
25+
# pylint: disable=consider-using-with
2526
self.temp_dir = tempfile.TemporaryDirectory()
2627

2728
def teardown_method(self):
@@ -78,3 +79,10 @@ def test_conflicting_params(self):
7879
Snapshotter(snapshot_dir=self.temp_dir.name,
7980
snapshot_mode='gap_overwrite',
8081
snapshot_gap=1)
82+
83+
def test_sorts_correctly(self):
84+
snapshotter = Snapshotter(self.temp_dir.name, 'all', 2)
85+
snapshotter.save_itr_params(80, {'test_itr': 80})
86+
snapshotter.save_itr_params(120, {'test_itr': 120})
87+
last = snapshotter.load(self.temp_dir.name)
88+
assert last['test_itr'] == 120

0 commit comments

Comments
 (0)