Skip to content

Commit 8caccce

Browse files
authored
fixbug: dist test in torch 2.0 (#1602)
* fixbug: torch 2.0 dist train with error local rank * fixbug: dist test in torch 2.0
1 parent 839d8dd commit 8caccce

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

tools/scripts/torch_test.sh

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/usr/bin/env bash
2+
3+
set -x
4+
NGPUS=$1
5+
PY_ARGS=${@:2}
6+
7+
while true
8+
do
9+
PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 ))
10+
status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)"
11+
if [ "${status}" != "0" ]; then
12+
break;
13+
fi
14+
done
15+
echo $PORT
16+
17+
torchrun --nproc_per_node=${NGPUS} --rdzv_endpoint=localhost:${PORT} test.py --launcher pytorch ${PY_ARGS}
18+

tools/test.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def parse_config():
2929
parser.add_argument('--pretrained_model', type=str, default=None, help='pretrained_model')
3030
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none')
3131
parser.add_argument('--tcp_port', type=int, default=18888, help='tcp port for distrbuted training')
32-
parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
32+
parser.add_argument('--local_rank', type=int, default=None, help='local rank for distributed training')
3333
parser.add_argument('--set', dest='set_cfgs', default=None, nargs=argparse.REMAINDER,
3434
help='set extra config keys if needed')
3535

@@ -145,6 +145,9 @@ def main():
145145
dist_test = False
146146
total_gpus = 1
147147
else:
148+
if args.local_rank is None:
149+
args.local_rank = int(os.environ.get('LOCAL_RANK', '0'))
150+
148151
total_gpus, cfg.LOCAL_RANK = getattr(common_utils, 'init_dist_%s' % args.launcher)(
149152
args.tcp_port, args.local_rank, backend='nccl'
150153
)

0 commit comments

Comments
 (0)