diff --git a/test_models.py b/test_models.py index 52b0e940..eb877770 100755 --- a/test_models.py +++ b/test_models.py @@ -191,11 +191,10 @@ def parse_shift_option_from_log_name(log_name): ) if args.gpus is not None: - devices = [args.gpus[i] for i in range(args.workers)] + devices = args.gpus + net = torch.nn.DataParallel(net.cuda(), devices) else: - devices = list(range(args.workers)) - - net = torch.nn.DataParallel(net.cuda()) + net = torch.nn.DataParallel(net.cuda()) net.eval() data_gen = enumerate(data_loader)