-
Notifications
You must be signed in to change notification settings - Fork 974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance support for Paddle-TensorRT #2817
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
|
||
import os | ||
from abc import abstractmethod | ||
from pathlib import Path | ||
import lazy_paddle as paddle | ||
import numpy as np | ||
|
||
|
@@ -22,27 +23,50 @@ | |
from ...utils.pp_option import PaddlePredictorOption | ||
from ..base import BaseComponent | ||
|
||
CACHE_DIR = ".cache" | ||
|
||
def collect_trt_shapes( | ||
model_file, model_params, gpu_id, shape_range_info_path, trt_dynamic_shapes | ||
|
||
def collect_trt_shape_range_info( | ||
model_file, | ||
model_params, | ||
gpu_id, | ||
shape_range_info_path, | ||
dynamic_shapes, | ||
dynamic_shape_input_data, | ||
): | ||
dynamic_shape_input_data = dynamic_shape_input_data or {} | ||
|
||
config = paddle.inference.Config(model_file, model_params) | ||
config.enable_use_gpu(100, gpu_id) | ||
config.collect_shape_range_info(shape_range_info_path) | ||
config.disable_glog_info() | ||
predictor = paddle.inference.create_predictor(config) | ||
|
||
min_arrs, opt_arrs, max_arrs = {}, {}, {} | ||
for name, candidate_shapes in trt_dynamic_shapes.items(): | ||
for name, candidate_shapes in dynamic_shapes.items(): | ||
min_shape, opt_shape, max_shape = candidate_shapes | ||
min_arrs[name] = np.ones(min_shape, dtype=np.float32) | ||
opt_arrs[name] = np.ones(opt_shape, dtype=np.float32) | ||
max_arrs[name] = np.ones(max_shape, dtype=np.float32) | ||
# HACK: Currently the data type is hard-coded | ||
if name in dynamic_shape_input_data: | ||
min_arrs[name] = np.array( | ||
dynamic_shape_input_data[name][0], dtype=np.float32 | ||
).reshape(min_shape) | ||
opt_arrs[name] = np.array( | ||
dynamic_shape_input_data[name][1], dtype=np.float32 | ||
).reshape(opt_shape) | ||
max_arrs[name] = np.array( | ||
dynamic_shape_input_data[name][2], dtype=np.float32 | ||
).reshape(max_shape) | ||
else: | ||
min_arrs[name] = np.ones(min_shape, dtype=np.float32) | ||
opt_arrs[name] = np.ones(opt_shape, dtype=np.float32) | ||
max_arrs[name] = np.ones(max_shape, dtype=np.float32) | ||
|
||
config.collect_shape_range_info(shape_range_info_path) | ||
predictor = paddle.inference.create_predictor(config) | ||
# opt_arrs would be used twice to simulate the most common situations | ||
for arrs in [min_arrs, opt_arrs, opt_arrs, max_arrs]: | ||
for name, arr in arrs.items(): | ||
input_handler = predictor.get_input_handle(name) | ||
input_handler.reshape(arr.shape) | ||
input_handler.copy_from_cpu(arr) | ||
handle = predictor.get_input_handle(name) | ||
handle.reshape(arr.shape) | ||
handle.copy_from_cpu(arr) | ||
predictor.run() | ||
|
||
|
||
|
@@ -146,33 +170,78 @@ def _create(self): | |
"trt_fp16": Config.Precision.Half, | ||
} | ||
if self.option.run_mode in precision_map.keys(): | ||
config.set_optim_cache_dir(str(self.model_dir / CACHE_DIR)) | ||
|
||
config.enable_tensorrt_engine( | ||
workspace_size=(1 << 25) * self.option.batch_size, | ||
max_batch_size=self.option.batch_size, | ||
min_subgraph_size=self.option.min_subgraph_size, | ||
workspace_size=self.option.trt_max_workspace_size, | ||
max_batch_size=self.option.trt_max_batch_size, | ||
min_subgraph_size=self.option.trt_min_subgraph_size, | ||
precision_mode=precision_map[self.option.run_mode], | ||
use_static=self.option.trt_use_static, | ||
use_calib_mode=self.option.trt_calib_mode, | ||
use_calib_mode=self.option.trt_use_calib_mode, | ||
) | ||
|
||
if not os.path.exists(self.option.shape_info_filename): | ||
logging.info( | ||
f"Dynamic shape info is collected into: {self.option.shape_info_filename}" | ||
) | ||
collect_trt_shapes( | ||
model_file, | ||
params_file, | ||
self.option.device_id, | ||
self.option.shape_info_filename, | ||
self.option.trt_dynamic_shapes, | ||
) | ||
else: | ||
logging.info( | ||
f"A dynamic shape info file ( {self.option.shape_info_filename} ) already exists. No need to collect again." | ||
) | ||
config.enable_tuned_tensorrt_dynamic_shape( | ||
self.option.shape_info_filename, True | ||
) | ||
if self.option.trt_use_dynamic_shapes: | ||
if self.option.trt_collect_shape_range_info: | ||
# NOTE: We always use a shape range info file. | ||
if self.option.trt_shape_range_info_path is not None: | ||
trt_shape_range_info_path = Path( | ||
self.option.trt_shape_range_info_path | ||
) | ||
else: | ||
trt_shape_range_info_path = ( | ||
self.model_dir | ||
/ CACHE_DIR | ||
/ "shape_range_info.pbtxt" | ||
) | ||
should_collect_shape_range_info = True | ||
if not trt_shape_range_info_path.exists(): | ||
trt_shape_range_info_path.parent.mkdir( | ||
parents=True, exist_ok=True | ||
) | ||
logging.info( | ||
f"Shape range info will be collected into {trt_shape_range_info_path}" | ||
) | ||
elif self.option.trt_discard_cached_shape_range_info: | ||
trt_shape_range_info_path.unlink() | ||
logging.info( | ||
f"The shape range info file ({trt_shape_range_info_path}) has been removed, and the shape range info will be re-collected." | ||
) | ||
else: | ||
logging.info( | ||
f"A shape range info file ({trt_shape_range_info_path}) already exists. There is no need to collect the info again." | ||
) | ||
should_collect_shape_range_info = False | ||
if should_collect_shape_range_info: | ||
collect_trt_shape_range_info( | ||
model_file, | ||
params_file, | ||
self.option.device_id, | ||
str(trt_shape_range_info_path), | ||
self.option.trt_dynamic_shapes, | ||
self.option.trt_dynamic_shape_input_data, | ||
) | ||
config.enable_tuned_tensorrt_dynamic_shape( | ||
str(trt_shape_range_info_path), | ||
self.option.trt_allow_build_at_runtime, | ||
) | ||
else: | ||
if self.option.trt_dynamic_shapes is not None: | ||
min_shapes, opt_shapes, max_shapes = {}, {}, {} | ||
for ( | ||
key, | ||
shapes, | ||
) in self.option.trt_dynamic_shapes.items(): | ||
min_shapes[key] = shapes[0] | ||
opt_shapes[key] = shapes[1] | ||
max_shapes[key] = shapes[2] | ||
config.set_trt_dynamic_shape_info( | ||
min_shapes, max_shapes, opt_shapes | ||
) | ||
else: | ||
raise RuntimeError( | ||
"No dynamic shape information provided" | ||
) | ||
|
||
elif self.option.device == "npu": | ||
config.enable_custom_device("npu") | ||
|
@@ -217,8 +286,8 @@ def _create(self): | |
if hasattr(config, "disable_mkldnn"): | ||
config.disable_mkldnn() | ||
|
||
# Disable paddle inference logging | ||
config.disable_glog_info() | ||
if self.option.disable_glog_info: | ||
config.disable_glog_info() | ||
Comment on lines
+289
to
+290
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里我改成通过PADDLE_PDX_DEBUG来控制了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这是环境变量还是? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 嗯是的,PADDLE_PDX_DEBUG=1时,不再disable |
||
|
||
config.set_cpu_math_library_num_threads(self.option.cpu_threads) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里因为涉及要支持pir trt,所以可能还得改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以的~ 那样的话我理解我们是不是可以把trt相关的逻辑拆到一个单独的比如
_configure_trt
方法里,现在这块有点儿太复杂了There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,trt的相关设置也非常多,要单独封装起来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,那要不等你的PR先合入后,我这边再统一适配调整一下?