Skip to content

Commit 411912e

Browse files
shengliangxukevalmorabia97
authored andcommitted
[NVBug 5659126] dummy inputs are kwargs maps (#676)
## What does this PR do? bug fix **Overview:** [NVBug 5659126] dummy inputs are kwargs maps, and rename the generate_... function to more explicitly reflect the return value type ## Testing `python diffusion_trt.py --model flux-dev --override-model-path /models/FLUX.1-dev --torch --benchmark --skip-image ` Signed-off-by: Shengliang Xu <[email protected]>
1 parent 818f17c commit 411912e

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

examples/diffusers/quantization/diffusion_trt.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from onnx_utils.export import (
3434
_create_trt_dynamic_shapes,
35-
generate_dummy_inputs_and_dynamic_axes_and_shapes,
35+
generate_dummy_kwargs_and_dynamic_axes_and_shapes,
3636
get_io_shapes,
3737
remove_nesting,
3838
update_dynamic_axes,
@@ -92,18 +92,18 @@ def benchmark_backbone_standalone(
9292
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
9393

9494
# Generate dummy inputs for the backbone
95-
dummy_inputs, _, _ = generate_dummy_inputs_and_dynamic_axes_and_shapes(model_name, backbone)
95+
dummy_kwargs, _, _ = generate_dummy_kwargs_and_dynamic_axes_and_shapes(model_name, backbone)
9696

9797
# Extract the dict from the tuple and move to cuda
98-
dummy_inputs_dict = {
99-
k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in dummy_inputs[0].items()
98+
dummy_kwargs_cuda = {
99+
k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in dummy_kwargs.items()
100100
}
101101

102102
# Warmup
103103
print(f"Warming up: {num_warmup} iterations")
104104
for _ in tqdm(range(num_warmup), desc="Warmup"):
105105
with context:
106-
_ = backbone(**dummy_inputs_dict)
106+
_ = backbone(**dummy_kwargs_cuda)
107107

108108
# Benchmark
109109
torch.cuda.synchronize()
@@ -116,7 +116,7 @@ def benchmark_backbone_standalone(
116116
with context:
117117
torch.cuda.profiler.cudart().cudaProfilerStart()
118118
start_event.record()
119-
_ = backbone(**dummy_inputs_dict)
119+
_ = backbone(**dummy_kwargs_cuda)
120120
end_event.record()
121121
torch.cuda.synchronize()
122122
torch.cuda.profiler.cudart().cudaProfilerStop()
@@ -241,7 +241,7 @@ def main():
241241
backbone.to("cuda")
242242

243243
# Generate dummy inputs for the backbone
244-
dummy_inputs, dynamic_axes, dynamic_shapes = generate_dummy_inputs_and_dynamic_axes_and_shapes(
244+
dummy_inputs, dynamic_axes, dynamic_shapes = generate_dummy_kwargs_and_dynamic_axes_and_shapes(
245245
args.model, backbone
246246
)
247247

examples/diffusers/quantization/onnx_utils/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def _create_trt_dynamic_shapes(dynamic_shapes):
381381
}
382382

383383

384-
def generate_dummy_inputs_and_dynamic_axes_and_shapes(model_id, backbone):
384+
def generate_dummy_kwargs_and_dynamic_axes_and_shapes(model_id, backbone):
385385
"""Generate dummy inputs, dynamic axes, and dynamic shapes for the given model."""
386386
if model_id in ["sdxl-1.0", "sdxl-turbo"]:
387387
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sdxl(
@@ -474,7 +474,7 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision):
474474
configure_linear_module_onnx_quantizers(backbone) if precision == "fp4" else nullcontext()
475475
)
476476

477-
dummy_kwargs, dynamic_axes, _ = generate_dummy_inputs_and_dynamic_axes_and_shapes(
477+
dummy_kwargs, dynamic_axes, _ = generate_dummy_kwargs_and_dynamic_axes_and_shapes(
478478
model_name, backbone
479479
)
480480

0 commit comments

Comments
 (0)