11
11
import numpy as np
12
12
13
13
from backbone import Network_D
14
- from sphere_loss import SphereLoss
14
+ from sphere_loss import SphereLoss , OhemSphereLoss
15
15
from market1501 import Market1501
16
16
from balanced_sampler import BalancedSampler
17
17
@@ -66,17 +66,17 @@ def train():
66
66
67
67
## network and loss
68
68
logger .info ('setup model and loss' )
69
- sphereloss = SphereLoss (1024 , num_classes )
70
- sphereloss .cuda ()
69
+ # criteria = SphereLoss(1024, num_classes)
70
+ criteria = OhemSphereLoss (1024 , num_classes , thresh = 0.8 )
71
+ criteria .cuda ()
71
72
net = Network_D ()
72
- net = nn .DataParallel (net )
73
73
net .train ()
74
74
net .cuda ()
75
75
76
76
## optimizer
77
77
logger .info ('creating optimizer' )
78
78
params = list (net .parameters ())
79
- params += list (sphereloss .parameters ())
79
+ params += list (criteria .parameters ())
80
80
optim = torch .optim .Adam (params , lr = 1e-3 )
81
81
82
82
## training
@@ -90,24 +90,24 @@ def train():
90
90
lbs = lbs .cuda ()
91
91
92
92
embs = net (imgs )
93
- loss = sphereloss (embs , lbs )
93
+ loss = criteria (embs , lbs )
94
94
optim .zero_grad ()
95
95
loss .backward ()
96
96
optim .step ()
97
97
98
98
loss_it .append (loss .detach ().cpu ().numpy ())
99
- if it % 10 == 0 and it != 0 :
100
- t_end = time .time ()
101
- t_interval = t_end - t_start
102
- log_loss = sum (loss_it ) / len (loss_it )
103
- msg = 'epoch: {}, iter: {}, loss: {:4f}, lr: {}, time: {:4f}' .format (ep ,
104
- it , log_loss , lrs , t_interval )
105
- logger .info (msg )
106
- loss_it = []
107
- t_start = t_end
99
+ ## print log
100
+ t_end = time .time ()
101
+ t_interval = t_end - t_start
102
+ log_loss = sum (loss_it ) / len (loss_it )
103
+ msg = 'epoch: {}, iter: {}, loss: {:. 4f}, lr: {}, time: {:. 4f}' .format (ep ,
104
+ it , log_loss , lrs , t_interval )
105
+ logger .info (msg )
106
+ loss_it = []
107
+ t_start = t_end
108
108
109
109
## save model
110
- torch .save (net .module . state_dict (), './res/model_final.pkl' )
110
+ torch .save (net .state_dict (), './res/model_final.pkl' )
111
111
logger .info ('\n Training done, model saved to {}\n \n ' .format ('./res/model_final.pkl' ))
112
112
113
113
0 commit comments