Skip to content

Commit 6d10db0

Browse files
committed
Added tips for conversion to TensorRT format
1 parent 177bcce commit 6d10db0

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,14 @@ To run with TensorRT, it is necessary to install it properly. Please, follow the
8585
```
8686
4. [Install](https://github.com/NVIDIA-AI-IOT/torch2trt) `torch2trt`.
8787
88-
8988
Convert checkpoint to TensorRT format:
9089
```
9190
python scripts/convert_to_trt.py --checkpoint-path human-pose-estimation-3d.pth
9291
```
92+
> TensorRT does not support dynamic network input size reshape.
93+
Make sure you have set proper network input height, width with `--height` and `--width` options during conversion (if not, there will be no detections).
94+
Default values work for a usual video with 16:9 aspect ratio (1280x720, 1920x1080).
95+
You can check the network input size with `print(scaled_img.shape)` in the demo.py
9396
9497
To run the demo with TensorRT inference, pass `--use-tensorrt` option:
9598
```

models/with_mobilenet.py

-1
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,3 @@ def forward(self, x):
192192
out = self.Pose3D(backbone_features, torch.cat([stages_output[-2], stages_output[-1]], dim=1))
193193

194194
return out, keypoints2d_maps, paf_maps
195-

scripts/convert_to_trt.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,28 @@
77
from modules.load_state import load_state
88

99

10-
def convert_to_trt(net, output_name):
10+
def convert_to_trt(net, output_name, height, width):
1111
net.eval()
12-
input = torch.randn(1, 3, 256, 448).cuda()
12+
input = torch.randn(1, 3, height, width).cuda()
1313
net_trt = torch2trt(net, [input])
1414
torch.save(net_trt.state_dict(), output_name)
1515

1616

1717
if __name__ == '__main__':
1818
parser = argparse.ArgumentParser()
1919
parser.add_argument('--checkpoint-path', type=str, required=True, help='path to the checkpoint')
20+
parser.add_argument('--height', type=int, default=256, help='network input height')
21+
parser.add_argument('--width', type=int, default=448, help='network input width')
2022
parser.add_argument('--output-name', type=str, default='human-pose-estimation-3d-trt.pth',
2123
help='name of output model in TensorRT format')
2224
args = parser.parse_args()
25+
print('TensorRT does not support dynamic network input size reshape.\n'
26+
'Make sure you have set proper network input height, width. If not, there will be no detections.\n'
27+
'Default values work for a usual video with 16:9 aspect ratio (1280x720, 1920x1080).\n'
28+
'You can check the network input size with \'print(scaled_img.shape)\' in demo.py')
2329

2430
net = PoseEstimationWithMobileNet().cuda()
2531
checkpoint = torch.load(args.checkpoint_path)
2632
load_state(net, checkpoint)
2733

28-
convert_to_trt(net, args.output_name)
34+
convert_to_trt(net, args.output_name, args.height, args.width)

0 commit comments

Comments
 (0)