@@ -29,7 +29,7 @@ def parse_config():
29
29
parser .add_argument ('--pretrained_model' , type = str , default = None , help = 'pretrained_model' )
30
30
parser .add_argument ('--launcher' , choices = ['none' , 'pytorch' , 'slurm' ], default = 'none' )
31
31
parser .add_argument ('--tcp_port' , type = int , default = 18888 , help = 'tcp port for distrbuted training' )
32
- parser .add_argument ('--local_rank' , type = int , default = 0 , help = 'local rank for distributed training' )
32
+ parser .add_argument ('--local_rank' , type = int , default = None , help = 'local rank for distributed training' )
33
33
parser .add_argument ('--set' , dest = 'set_cfgs' , default = None , nargs = argparse .REMAINDER ,
34
34
help = 'set extra config keys if needed' )
35
35
@@ -145,6 +145,9 @@ def main():
145
145
dist_test = False
146
146
total_gpus = 1
147
147
else :
148
+ if args .local_rank is None :
149
+ args .local_rank = int (os .environ .get ('LOCAL_RANK' , '0' ))
150
+
148
151
total_gpus , cfg .LOCAL_RANK = getattr (common_utils , 'init_dist_%s' % args .launcher )(
149
152
args .tcp_port , args .local_rank , backend = 'nccl'
150
153
)
0 commit comments