2
2
import os
3
3
import operator
4
4
from datetime import datetime
5
- import distiller
6
- from distiller .data_loggers import *
5
+ # import distiller
6
+ # from distiller.data_loggers import *
7
7
from config import opt
8
8
import torch as t
9
9
import models
10
10
from data .dataset import DatasetFromFilename
11
11
from torch .utils .data import DataLoader
12
12
from utils .image_loader import image_loader
13
13
from utils .utils import AverageMeter , accuracy , write_err_img , config_pylogger , check_date , get_scheduler
14
- from utils .sensitivity import sensitivity_analysis , val
14
+ from utils .sensitivity import val
15
15
from utils .progress_bar import ProgressBar
16
16
from tqdm import tqdm
17
17
import numpy as np
18
- import distiller .quantization as quantization
18
+ # import distiller.quantization as quantization
19
19
from torch .utils .tensorboard import SummaryWriter
20
20
from torchvision .utils import make_grid
21
21
22
22
t .backends .cudnn .benchmark = True
23
- # 如果我们的输入在每一次的iterate的时候都进行变化,那么benchmark就会在每次iterate的时候重新选择最优算法,当选选择是需要花费时间的,
23
+ # 如果我们的输入在每一次的iterate的时候都进行变化,那么benchmark就会在每次iterate的时候重新选择最优算法,当选选择是需要花费时间的,pip freeze > requirements.txt
24
24
# 反而速度会变慢,也就是说,如果我们每次训练的输入数据的size不变,那么开启这个就会加快我们的训练速度:
25
25
seed = 1000
26
26
t .manual_seed (seed ) # 随机数种子,当使用随机数时,关闭进程后再次生成和上次得一样
@@ -44,11 +44,11 @@ def test(**kwargs):
44
44
total = 0
45
45
msglogger .info ('测试数据集大小%s' , len (test_dataloader ))
46
46
# 量化
47
- if opt .quantize_eval :
48
- model .cpu ()
49
- quantizer = quantization .PostTrainLinearQuantizer .from_args (model , opt ) # 量化模型
50
- quantizer .prepare_model ()
51
- model .to (opt .device )
47
+ # if opt.quantize_eval:
48
+ # model.cpu()
49
+ # quantizer = quantization.PostTrainLinearQuantizer.from_args(model, opt) # 量化模型
50
+ # quantizer.prepare_model()
51
+ # model.to(opt.device)
52
52
model .eval () # 把module设成测试模式,对Dropout和BatchNorm有影响
53
53
err_img = [('img_path' , 'result' , 'label' )]
54
54
for ii , (data , labels , img_path , tag ) in tqdm (enumerate (test_dataloader )):
@@ -125,7 +125,7 @@ def train(**kwargs):
125
125
pylogger = PythonLogger (msglogger )
126
126
# step3: configure model
127
127
model = getattr (models , opt .model )() # 获得网络结构
128
- compression_scheduler = distiller .CompressionScheduler (model )
128
+ # compression_scheduler = distiller.CompressionScheduler(model)
129
129
optimizer = model .get_optimizer (lr , opt .weight_decay ) # 优化器
130
130
if opt .load_model_path :
131
131
# # 把所有的张量加载到CPU中
@@ -143,10 +143,10 @@ def train(**kwargs):
143
143
optimizer = checkpoint ['optimizer' ]
144
144
model .to (opt .device ) # 加载模型到 GPU
145
145
146
- if opt .compress :
147
- compression_scheduler = distiller .file_config (model , optimizer , opt .compress ,
148
- compression_scheduler ) # 加载模型修剪计划表
149
- model .to (opt .device )
146
+ # if opt.compress:
147
+ # compression_scheduler = distiller.file_config(model, optimizer, opt.compress,
148
+ # compression_scheduler) # 加载模型修剪计划表
149
+ # model.to(opt.device)
150
150
# 学习速率调整器
151
151
lr_scheduler = get_scheduler (optimizer , opt )
152
152
# step4: data_image
@@ -157,8 +157,8 @@ def train(**kwargs):
157
157
# train
158
158
for epoch in range (start_epoch , opt .max_epoch ):
159
159
model .train ()
160
- if opt .pruning :
161
- compression_scheduler .on_epoch_begin (epoch ) # epoch 开始修剪
160
+ # if opt.pruning:
161
+ # compression_scheduler.on_epoch_begin(epoch) # epoch 开始修剪
162
162
train_losses .reset () # 重置仪表
163
163
train_top1 .reset () # 重置仪表
164
164
# print('训练数据集大小', len(train_dataloader))
@@ -170,8 +170,8 @@ def train(**kwargs):
170
170
lr = lr_scheduler .get_lr ()[0 ]
171
171
for ii , (data , labels , img_path , tag ) in enumerate (train_dataloader ):
172
172
if not check_date (img_path , tag , msglogger ): return
173
- if opt .pruning :
174
- compression_scheduler .on_minibatch_begin (epoch , ii , steps_per_epoch , optimizer ) # batch 开始修剪
173
+ # if opt.pruning:
174
+ # compression_scheduler.on_minibatch_begin(epoch, ii, steps_per_epoch, optimizer) # batch 开始修剪
175
175
train_progressor .current = ii + 1 # 训练集当前进度
176
176
# train model
177
177
input = data .to (opt .device )
@@ -182,21 +182,21 @@ def train(**kwargs):
182
182
score = model (input ) # 网络结构返回值
183
183
# 计算损失
184
184
loss = criterion (score , target )
185
- if opt .pruning :
186
- # Before running the backward phase, we allow the scheduler to modify the loss
187
- # (e.g. add regularization loss)
188
- agg_loss = compression_scheduler .before_backward_pass (epoch , ii , steps_per_epoch , loss ,
189
- optimizer = optimizer ,
190
- return_loss_components = True ) # 模型修建误差
191
- loss = agg_loss .overall_loss
185
+ # if opt.pruning:
186
+ # # Before running the backward phase, we allow the scheduler to modify the loss
187
+ # # (e.g. add regularization loss)
188
+ # agg_loss = compression_scheduler.before_backward_pass(epoch, ii, steps_per_epoch, loss,
189
+ # optimizer=optimizer,
190
+ # return_loss_components=True) # 模型修建误差
191
+ # loss = agg_loss.overall_loss
192
192
train_losses .update (loss .item (), input .size (0 ))
193
193
# loss = criterion(score[0], target) # 计算损失 Inception3网络
194
194
optimizer .zero_grad () # 参数梯度设成0
195
195
loss .backward () # 反向传播
196
196
optimizer .step () # 更新参数
197
197
198
- if opt .pruning :
199
- compression_scheduler .on_minibatch_end (epoch , ii , steps_per_epoch , optimizer ) # batch 结束修剪
198
+ # if opt.pruning:
199
+ # compression_scheduler.on_minibatch_end(epoch, ii, steps_per_epoch, optimizer) # batch 结束修剪
200
200
201
201
precision1_train , precision5_train = accuracy (score , target , topk = (1 , 5 )) # top1 和 top5 的准确率
202
202
@@ -219,14 +219,14 @@ def train(**kwargs):
219
219
'loss' : train_losses .avg }, ii * (epoch + 1 ))
220
220
# train_progressor.done() # 保存训练结果为txt
221
221
# validate and visualize
222
- if opt .pruning :
223
- distiller .log_weights_sparsity (model , epoch , loggers = [pylogger ]) # 打印模型修剪结果
224
- compression_scheduler .on_epoch_end (epoch , optimizer ) # epoch 结束修剪
222
+ # if opt.pruning:
223
+ # distiller.log_weights_sparsity(model, epoch, loggers=[pylogger]) # 打印模型修剪结果
224
+ # compression_scheduler.on_epoch_end(epoch, optimizer) # epoch 结束修剪
225
225
val_loss , val_top1 , val_top5 = val (model , criterion , val_dataloader , epoch , value_writer , lr ) # 校验模型
226
- sparsity = distiller .model_sparsity (model )
227
- perf_scores_history .append (distiller .MutableNamedTuple ({'sparsity' : sparsity , 'top1' : val_top1 ,
228
- 'top5' : val_top5 , 'epoch' : epoch + 1 , 'lr' : lr ,
229
- 'loss' : val_loss }, ))
226
+ # sparsity = distiller.model_sparsity(model)
227
+ # perf_scores_history.append(distiller.MutableNamedTuple({'sparsity': sparsity, 'top1': val_top1,
228
+ # 'top5': val_top5, 'epoch': epoch + 1, 'lr': lr,
229
+ # 'loss': val_loss}, ))
230
230
# 保持绩效分数历史记录从最好到最差的排序
231
231
# 按稀疏度排序为主排序键,然后按top1、top5、epoch排序
232
232
perf_scores_history .sort (key = operator .attrgetter ('sparsity' , 'top1' , 'top5' , 'epoch' ), reverse = True )
@@ -244,7 +244,7 @@ def train(**kwargs):
244
244
"best_precision" : best_precision ,
245
245
"optimizer" : optimizer ,
246
246
"valid_loss" : [val_loss , val_top1 , val_top5 ],
247
- 'compression_scheduler' : compression_scheduler .state_dict (),
247
+ # 'compression_scheduler': compression_scheduler.state_dict(),
248
248
}) # 保存模型
249
249
# update learning rate
250
250
lr_scheduler .step (epoch ) # 更新学习效率
@@ -260,18 +260,18 @@ def train(**kwargs):
260
260
261
261
262
262
# 模型敏感性分析
263
- def sensitivity (** kwargs ):
264
- opt ._parse (kwargs )
265
- test_data = DatasetFromFilename (opt .data_root , flag = 'test' )
266
- test_dataloader = DataLoader (test_data , batch_size = opt .batch_size , shuffle = False , num_workers = opt .num_workers )
267
- criterion = t .nn .CrossEntropyLoss ().to (opt .device ) # 损失函数
268
- model = getattr (models , opt .model )()
269
- if opt .load_model_path :
270
- checkpoint = t .load (opt .load_model_path )
271
- model .load_state_dict (checkpoint ["state_dict" ])
272
- model .to (opt .device )
273
- sensitivities = np .arange (opt .sensitivity_range [0 ], opt .sensitivity_range [1 ], opt .sensitivity_range [2 ])
274
- return sensitivity_analysis (model , criterion , test_dataloader , opt , sensitivities , msglogger )
263
+ # def sensitivity(**kwargs):
264
+ # opt._parse(kwargs)
265
+ # test_data = DatasetFromFilename(opt.data_root, flag='test')
266
+ # test_dataloader = DataLoader(test_data, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers)
267
+ # criterion = t.nn.CrossEntropyLoss().to(opt.device) # 损失函数
268
+ # model = getattr(models, opt.model)()
269
+ # if opt.load_model_path:
270
+ # checkpoint = t.load(opt.load_model_path)
271
+ # model.load_state_dict(checkpoint["state_dict"])
272
+ # model.to(opt.device)
273
+ # sensitivities = np.arange(opt.sensitivity_range[0], opt.sensitivity_range[1], opt.sensitivity_range[2])
274
+ # return sensitivity_analysis(model, criterion, test_dataloader, opt, sensitivities, msglogger)
275
275
276
276
277
277
def help ():
0 commit comments