-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutil.py
59 lines (53 loc) · 1.97 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from collections import defaultdict, deque
import pickle
from attrdict import AttrDict
import os
import numpy as np
import torch
from torch import nn
from torch import optim
from tensorboardX import SummaryWriter
class Checkpointer:
def __init__(self, path, max_num=3):
self.max_num = max_num
self.path = path
if not os.path.exists(path):
os.makedirs(path)
self.listfile = os.path.join(path, 'model_list.pkl')
if not os.path.exists(self.listfile):
with open(self.listfile, 'wb') as f:
model_list = []
pickle.dump(model_list, f)
def save(self, model, optimizer, epoch):
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch
}
filename = os.path.join(self.path, 'model_{:05}.pth'.format(epoch))
with open(self.listfile, 'rb+') as f:
model_list = pickle.load(f)
if len(model_list) >= self.max_num:
if os.path.exists(model_list[0]):
os.remove(model_list[0])
del model_list[0]
model_list.append(filename)
with open(self.listfile, 'rb+') as f:
pickle.dump(model_list, f)
with open(filename, 'wb') as f:
torch.save(checkpoint, f)
def load(self, model, optimizer):
"""
Return starting epoch
"""
with open(self.listfile, 'rb') as f:
model_list = pickle.load(f)
if len(model_list) == 0:
print('No checkpoint found. Starting from scratch')
return 0
else:
checkpoint = torch.load(model_list[-1])
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print('Load checkpoint from {}.'.format(model_list[-1]))
return checkpoint['epoch']