@@ -33,7 +33,7 @@ def parse_config():
33
33
parser .add_argument ('--sync_bn' , action = 'store_true' , default = False , help = 'whether to use sync bn' )
34
34
parser .add_argument ('--fix_random_seed' , action = 'store_true' , default = False , help = '' )
35
35
parser .add_argument ('--ckpt_save_interval' , type = int , default = 1 , help = 'number of training epochs' )
36
- parser .add_argument ('--local_rank' , type = int , default = 0 , help = 'local rank for distributed training' )
36
+ parser .add_argument ('--local_rank' , type = int , default = None , help = 'local rank for distributed training' )
37
37
parser .add_argument ('--max_ckpt_save_num' , type = int , default = 30 , help = 'max number of saved checkpoint' )
38
38
parser .add_argument ('--merge_all_iters_to_one_epoch' , action = 'store_true' , default = False , help = '' )
39
39
parser .add_argument ('--set' , dest = 'set_cfgs' , default = None , nargs = argparse .REMAINDER ,
@@ -71,6 +71,9 @@ def main():
71
71
dist_train = False
72
72
total_gpus = 1
73
73
else :
74
+ if args .local_rank is None :
75
+ args .local_rank = int (os .environ .get ('LOCAL_RANK' , '0' ))
76
+
74
77
total_gpus , cfg .LOCAL_RANK = getattr (common_utils , 'init_dist_%s' % args .launcher )(
75
78
args .tcp_port , args .local_rank , backend = 'nccl'
76
79
)
0 commit comments