Skip to content

Commit 01c58ac

Browse files
author
Jianfeng Wang
authored
fix(detection): fix learning rate schedule (#114)
1 parent deecc27 commit 01c58ac

File tree

7 files changed

+6
-10
lines changed

7 files changed

+6
-10
lines changed

official/vision/classification/resnet/train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,11 @@ def valid_step(image, label):
165165

166166
# multi-step learning rate scheduler with warmup
167167
def adjust_learning_rate(step):
168-
lr = args.lr * 0.1 ** bisect.bisect_right(
168+
lr = args.lr * dist.get_world_size() * 0.1 ** bisect.bisect_right(
169169
[30 * steps_per_epoch, 60 * steps_per_epoch, 80 * steps_per_epoch], step
170170
)
171171
if step < 5 * steps_per_epoch: # warmup
172-
lr = args.lr * (step / (5 * steps_per_epoch))
172+
lr = args.lr * dist.get_world_size() * (step / (5 * steps_per_epoch))
173173
for param_group in opt.param_groups:
174174
param_group["lr"] = lr
175175
return lr

official/vision/classification/shufflenet/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def valid_step(image, label):
176176

177177
# linear learning rate scheduler
178178
def adjust_learning_rate(step):
179-
lr = args.lr * (1 - step / (args.epochs * steps_per_epoch))
179+
lr = args.lr * dist.get_world_size() * (1 - step / (args.epochs * steps_per_epoch))
180180
for param_group in opt.param_groups:
181181
param_group["lr"] = lr
182182
return lr

official/vision/detection/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
## 安装和环境配置
4545

46-
本目录下代码基于MegEngine v1.2,在开始运行本目录下的代码之前,请确保按照[README](../../../README.md)进行了正确的环境配置。
46+
本目录下代码基于MegEngine v1.6,在开始运行本目录下的代码之前,请确保按照[README](../../../README.md)进行了正确的环境配置。
4747

4848
## 如何使用
4949

official/vision/detection/tools/test.py

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
logger = mge.get_logger(__name__)
2222
logger.setLevel("INFO")
23-
mge.device.set_prealloc_config(1024, 1024, 256 * 1024 * 1024, 4.0)
2423

2524

2625
def make_parser():

official/vision/detection/tools/test_random.py

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
logger = mge.get_logger(__name__)
2626
logger.setLevel("INFO")
27-
mge.device.set_prealloc_config(1024, 1024, 256 * 1024 * 1024, 4.0)
2827

2928

3029
def make_parser():

official/vision/detection/tools/train.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
logger = mge.get_logger(__name__)
3232
logger.setLevel("INFO")
33-
mge.device.set_prealloc_config(1024, 1024, 256 * 1024 * 1024, 4.0)
3433

3534

3635
def make_parser():
@@ -183,7 +182,7 @@ def train_func(image, im_info, gt_boxes):
183182

184183
def adjust_learning_rate(optimizer, epoch, step, cfg, args):
185184
base_lr = (
186-
cfg.basic_lr * args.batch_size * (
185+
cfg.basic_lr * args.batch_size * dist.get_world_size() * (
187186
cfg.lr_decay_rate
188187
** bisect.bisect_right(cfg.lr_decay_stages, epoch)
189188
)

official/vision/detection/tools/train_random.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
logger = mge.get_logger(__name__)
3131
logger.setLevel("INFO")
32-
mge.device.set_prealloc_config(1024, 1024, 256 * 1024 * 1024, 4.0)
3332

3433

3534
def make_parser():
@@ -182,7 +181,7 @@ def train_func(image, im_info, gt_boxes):
182181

183182
def adjust_learning_rate(optimizer, epoch, step, cfg, args):
184183
base_lr = (
185-
cfg.basic_lr * args.batch_size * (
184+
cfg.basic_lr * args.batch_size * dist.get_world_size() * (
186185
cfg.lr_decay_rate
187186
** bisect.bisect_right(cfg.lr_decay_stages, epoch)
188187
)

0 commit comments

Comments
 (0)