|
7 | 7 | from modules.load_state import load_state
|
8 | 8 |
|
9 | 9 |
|
10 |
| -def convert_to_trt(net, output_name): |
| 10 | +def convert_to_trt(net, output_name, height, width): |
11 | 11 | net.eval()
|
12 |
| - input = torch.randn(1, 3, 256, 448).cuda() |
| 12 | + input = torch.randn(1, 3, height, width).cuda() |
13 | 13 | net_trt = torch2trt(net, [input])
|
14 | 14 | torch.save(net_trt.state_dict(), output_name)
|
15 | 15 |
|
16 | 16 |
|
17 | 17 | if __name__ == '__main__':
|
18 | 18 | parser = argparse.ArgumentParser()
|
19 | 19 | parser.add_argument('--checkpoint-path', type=str, required=True, help='path to the checkpoint')
|
| 20 | + parser.add_argument('--height', type=int, default=256, help='network input height') |
| 21 | + parser.add_argument('--width', type=int, default=448, help='network input width') |
20 | 22 | parser.add_argument('--output-name', type=str, default='human-pose-estimation-3d-trt.pth',
|
21 | 23 | help='name of output model in TensorRT format')
|
22 | 24 | args = parser.parse_args()
|
| 25 | + print('TensorRT does not support dynamic network input size reshape.\n' |
| 26 | + 'Make sure you have set proper network input height, width. If not, there will be no detections.\n' |
| 27 | + 'Default values work for a usual video with 16:9 aspect ratio (1280x720, 1920x1080).\n' |
| 28 | + 'You can check the network input size with \'print(scaled_img.shape)\' in demo.py') |
23 | 29 |
|
24 | 30 | net = PoseEstimationWithMobileNet().cuda()
|
25 | 31 | checkpoint = torch.load(args.checkpoint_path)
|
26 | 32 | load_state(net, checkpoint)
|
27 | 33 |
|
28 |
| - convert_to_trt(net, args.output_name) |
| 34 | + convert_to_trt(net, args.output_name, args.height, args.width) |
0 commit comments