Skip to content

Commit 41aaec5

Browse files
gcunhasekevalmorabia97
authored andcommitted
[5763448][ONNX][Autocast] Fix Resize input type mismatch error (#757)
## What does this PR do? **Type of change:** Bug fix **Overview:** This PR fixes an input type mismatch in Resize layers when being converted to FP16. ## Usage ```python $ python -m modelopt.onnx.autocast --onnx_path=$MODEL_NAME.onnx ``` ## Testing Added unittest. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No ## Additional Information This issue is also fixed by using the standalone type inference logic from #719. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **Improvements** * Enhanced the graph sanitization process to automatically duplicate shared constants during optimization, ensuring improved model handling and consistency. * **Tests** * Added test coverage for mixed precision conversion of Conv-Resize model architectures. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: gcunhase <[email protected]>
1 parent 4c499ab commit 41aaec5

File tree

5 files changed

+124
-3
lines changed

5 files changed

+124
-3
lines changed

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def sanitize(self) -> None:
6767
self.convert_opset()
6868
self.replace_layernorm_pattern()
6969
self.ensure_graph_name_exists()
70+
self.duplicate_shared_constants()
7071
onnx_utils.name_onnx_nodes(self.model.graph)
7172
self.replace_custom_domain_nodes()
7273
self.sanitize_io_casts()
@@ -254,6 +255,12 @@ def ensure_graph_name_exists(self) -> None:
254255
if not self.model.graph.name:
255256
self.model.graph.name = "model"
256257

258+
def duplicate_shared_constants(self) -> None:
259+
"""Duplicate constant tensors if they are shared."""
260+
self.model, is_duplicated_constant = onnx_utils.duplicate_shared_constants(self.model)
261+
if is_duplicated_constant:
262+
logger.warning("Shared constants were detected and duplicated accordingly.")
263+
257264
def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
258265
"""Match the sequence of operations that constitute a LayerNorm.
259266

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,6 +1419,11 @@ def _sanitize_model(self):
14191419
graph_sanitizer.sanitize()
14201420
self.model = graph_sanitizer.model
14211421

1422+
# Update value_info_map and initializer_map after sanitizing model
1423+
self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings(
1424+
self.model
1425+
)
1426+
14221427
def _create_skip_inputs_mapping(self, tensor_block_dict: dict[str, dict[str, list[int]]] = {}):
14231428
"""Create mapping of op types to indices of inputs that should not be converted to low precision."""
14241429
skip_inputs_map = {}

modelopt/onnx/quantization/fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _convert(node: onnx.NodeProto):
102102
)
103103
zero_point = initializers[zero_point_idx]
104104
dtype = onnx.helper.tensor_dtype_to_np_dtype(zero_point.data_type)
105-
vals = np.array(zero_point.int32_data, dtype=dtype).tobytes()
105+
vals = np.array(zero_point.int32_data, dtype=dtype).tobytes() or zero_point.raw_data
106106

107107
np_zero_point = onnx.helper.make_tensor(
108108
zero_point_name, onnx.TensorProto.FLOAT8E4M3FN, zero_point.dims, vals, raw=True

tests/_test_utils/onnx/lib_test_models.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,3 +924,88 @@ def build_conv_isinf_model(opset_version=13):
924924
onnx.checker.check_model(model_inferred)
925925

926926
return model_inferred
927+
928+
929+
def build_conv_resize_model():
930+
# Define your model inputs and outputs
931+
input_names = ["input_0"]
932+
output_names = ["output_0"]
933+
input_shapes = [(1, 288, 32, 32)]
934+
output_shapes = [(1, 16, 64, 64)]
935+
936+
inputs = [
937+
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
938+
for input_name, input_shape in zip(input_names, input_shapes)
939+
]
940+
outputs = [
941+
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
942+
for output_name, output_shape in zip(output_names, output_shapes)
943+
]
944+
945+
# Create the ONNX graph with the nodes
946+
nodes = [
947+
helper.make_node(
948+
op_type="Conv",
949+
inputs=["input_0", "weights_1"],
950+
outputs=["conv1_conv/Conv2D:0"],
951+
name="conv1_conv/Conv2D",
952+
dilations=[1, 1],
953+
group=1,
954+
kernel_shape=[1, 1],
955+
pads=[0, 0, 0, 0],
956+
strides=[1, 1],
957+
),
958+
# Note: resize_roi_scales is intentionally used for both roi and scales inputs
959+
# to test the shared constant duplication fix (PR #757)
960+
helper.make_node(
961+
op_type="Resize",
962+
inputs=[
963+
"conv1_conv/Conv2D:0",
964+
"resize_roi_scales",
965+
"resize_roi_scales",
966+
"resize_sizes",
967+
],
968+
outputs=["output_0"],
969+
name="resize1_resize/Resize",
970+
coordinate_transformation_mode="asymmetric",
971+
cubic_coeff_a=-0.75,
972+
mode="nearest",
973+
nearest_mode="floor",
974+
),
975+
]
976+
977+
# Create the ONNX initializers
978+
initializers = [
979+
helper.make_tensor(
980+
name="weights_1",
981+
data_type=onnx.TensorProto.FLOAT,
982+
dims=(16, 288, 1, 1),
983+
vals=np.random.uniform(low=0.5, high=1.0, size=16 * 288 * 1 * 1),
984+
),
985+
helper.make_tensor(
986+
name="resize_roi_scales",
987+
data_type=onnx.TensorProto.FLOAT,
988+
dims=(0,),
989+
vals=[],
990+
),
991+
helper.make_tensor(
992+
name="resize_sizes",
993+
data_type=onnx.TensorProto.INT64,
994+
dims=(4,),
995+
vals=[1, 16, 64, 64],
996+
),
997+
]
998+
999+
# Create the ONNX graph with the nodes and initializers
1000+
graph = helper.make_graph(nodes, "conv_resize", inputs, outputs, initializer=initializers)
1001+
1002+
# Create the ONNX model
1003+
model = helper.make_model(graph)
1004+
model.opset_import[0].version = 13
1005+
model.ir_version = 10
1006+
1007+
# Check the ONNX model
1008+
model_inferred = onnx.shape_inference.infer_shapes(model)
1009+
onnx.checker.check_model(model_inferred)
1010+
1011+
return model_inferred

tests/unit/onnx/autocast/test_autocast.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import onnx
2121
import onnx_graphsurgeon as gs
2222
import pytest
23-
from _test_utils.onnx.lib_test_models import build_conv_isinf_model
23+
from _test_utils.onnx.lib_test_models import build_conv_isinf_model, build_conv_resize_model
2424

2525
import modelopt.onnx.autocast.utils as utils
2626
import modelopt.onnx.utils as onnx_utils
@@ -174,7 +174,7 @@ def test_conv_isinf_conversion(tmp_path, opset_version):
174174
output_onnx_path = onnx_path.replace(".onnx", ".fp16.onnx")
175175
onnx.save(converted_model, output_onnx_path)
176176

177-
# Load the output model and check QDQ node placements
177+
# Load the output model
178178
graph = gs.import_onnx(converted_model)
179179

180180
# Check that Conv is converted
@@ -190,6 +190,30 @@ def test_conv_isinf_conversion(tmp_path, opset_version):
190190
assert assert_input_precision(isinf_nodes, dtype=supported_dtype)
191191

192192

193+
def test_conv_resize_conversion(tmp_path):
194+
onnx_model = build_conv_resize_model()
195+
onnx_path = os.path.join(tmp_path, "conv_resize_model.onnx")
196+
onnx.save(onnx_model, onnx_path)
197+
198+
# Convert the model
199+
converted_model = convert_to_mixed_precision(onnx_path=onnx_path)
200+
201+
# Output model should be produced in the same tmp_path
202+
output_onnx_path = onnx_path.replace(".onnx", ".fp16.onnx")
203+
onnx.save(converted_model, output_onnx_path)
204+
205+
# Load the output model
206+
graph = gs.import_onnx(converted_model)
207+
208+
# Check that Resize is correctly converted:
209+
# - Data and ROI inputs (indices 0 and 1) should be FP16
210+
# - The remaining inputs (scales/sizes) should be kept in their original precisions
211+
resize_node = next(n for n in graph.nodes if n.op == "Resize")
212+
assert all(inp.dtype == np.float16 for inp in resize_node.inputs[0:2]), (
213+
"Resize data and ROI inputs should be FP16"
214+
)
215+
216+
193217
@pytest.mark.parametrize("target_opset", [13, 17, 19, 21])
194218
def test_opset_parameter(temp_model_path, target_opset):
195219
"""Test that the opset parameter correctly sets the output model's opset version."""

0 commit comments

Comments
 (0)