Skip to content

Commit e02c7b9

Browse files
committed
refactoring
1 parent 87dd857 commit e02c7b9

File tree

3 files changed

+44
-16
lines changed

3 files changed

+44
-16
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,6 @@ venv.bak/
106106

107107
## Coin
108108
dataset/
109+
dataset
109110
res/
111+
adj.md

sphere_loss.py

+26
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,32 @@
66
import torch.nn as nn
77

88

9+
class OhemSphereLoss(nn.Module):
10+
def __init__(self, in_feats, n_classes, thresh=0.7, scale=14, *args, **kwargs):
11+
super(OhemSphereLoss, self).__init__(*args, **kwargs)
12+
self.thresh = thresh
13+
self.scale = scale
14+
self.cross_entropy = nn.CrossEntropyLoss(reduction='none')
15+
self.W = torch.nn.Parameter(torch.randn(in_feats, n_classes),
16+
requires_grad = True)
17+
# nn.init.kaiming_normal_(self.W, a=1)
18+
nn.init.xavier_normal_(self.W, gain=1)
19+
20+
def forward(self, x, label):
21+
n_examples = x.size()[0]
22+
n_pick = int(n_examples*self.thresh)
23+
x_norm = torch.norm(x, 2, 1, True).clamp(min = 1e-12).expand_as(x)
24+
x_norm = x / x_norm
25+
w_norm = torch.norm(self.W, 2, 0, True).clamp(min = 1e-12).expand_as(self.W)
26+
w_norm = self.W / w_norm
27+
cos_th = torch.mm(x_norm, w_norm)
28+
s_cos_th = self.scale * cos_th
29+
loss = self.cross_entropy(s_cos_th, label)
30+
loss, _ = torch.sort(loss, descending=True)
31+
loss = torch.mean(loss[:n_pick])
32+
return loss
33+
34+
935
class SphereLoss(nn.Module):
1036
def __init__(self, in_feats, n_classes, scale = 14, *args, **kwargs):
1137
super(SphereLoss, self).__init__(*args, **kwargs)

train.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212

1313
from backbone import Network_D
14-
from sphere_loss import SphereLoss
14+
from sphere_loss import SphereLoss, OhemSphereLoss
1515
from market1501 import Market1501
1616
from balanced_sampler import BalancedSampler
1717

@@ -66,17 +66,17 @@ def train():
6666

6767
## network and loss
6868
logger.info('setup model and loss')
69-
sphereloss = SphereLoss(1024, num_classes)
70-
sphereloss.cuda()
69+
# criteria = SphereLoss(1024, num_classes)
70+
criteria = OhemSphereLoss(1024, num_classes, thresh=0.8)
71+
criteria.cuda()
7172
net = Network_D()
72-
net = nn.DataParallel(net)
7373
net.train()
7474
net.cuda()
7575

7676
## optimizer
7777
logger.info('creating optimizer')
7878
params = list(net.parameters())
79-
params += list(sphereloss.parameters())
79+
params += list(criteria.parameters())
8080
optim = torch.optim.Adam(params, lr = 1e-3)
8181

8282
## training
@@ -90,24 +90,24 @@ def train():
9090
lbs = lbs.cuda()
9191

9292
embs = net(imgs)
93-
loss = sphereloss(embs, lbs)
93+
loss = criteria(embs, lbs)
9494
optim.zero_grad()
9595
loss.backward()
9696
optim.step()
9797

9898
loss_it.append(loss.detach().cpu().numpy())
99-
if it % 10 == 0 and it != 0:
100-
t_end = time.time()
101-
t_interval = t_end - t_start
102-
log_loss = sum(loss_it) / len(loss_it)
103-
msg = 'epoch: {}, iter: {}, loss: {:4f}, lr: {}, time: {:4f}'.format(ep,
104-
it, log_loss, lrs, t_interval)
105-
logger.info(msg)
106-
loss_it = []
107-
t_start = t_end
99+
## print log
100+
t_end = time.time()
101+
t_interval = t_end - t_start
102+
log_loss = sum(loss_it) / len(loss_it)
103+
msg = 'epoch: {}, iter: {}, loss: {:.4f}, lr: {}, time: {:.4f}'.format(ep,
104+
it, log_loss, lrs, t_interval)
105+
logger.info(msg)
106+
loss_it = []
107+
t_start = t_end
108108

109109
## save model
110-
torch.save(net.module.state_dict(), './res/model_final.pkl')
110+
torch.save(net.state_dict(), './res/model_final.pkl')
111111
logger.info('\nTraining done, model saved to {}\n\n'.format('./res/model_final.pkl'))
112112

113113

0 commit comments

Comments
 (0)