-
Notifications
You must be signed in to change notification settings - Fork 473
Open
Labels
Description
Prerequisites
Please make sure to check off these prerequisites before submitting a bug report.
- Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
- Check that the issue hasn't already been reported, by checking the currently open issues.
- If there are steps to reproduce the problem, make sure to write them down below.
- If relevant, please include the hls4ml project files, which were created directly before and/or after the bug.
Quick summary
hls4ml fpga_backend fails when using a Resize node with no ROI but using #1266 fix.
Details
hls4ml fpga_backend fails when using a Resize node with no ROI and using #1266 fix comming from the conversion of a QuantUpsample from Brevitas.
Steps to Reproduce
Add what needs to be done to reproduce the bug. Add commented code examples and make sure to include the original model files / code, and the commit hash you are working on.
- Clone the hls4ml repository
- Checkout the master branch, with commit hash: [77b8331]
- Run the code below
import torch
import torch.nn as nn
import brevitas.nn as qnn
from brevitas.export import export_qonnx
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
from qonnx.transformation.gemm_to_matmul import GemmToMatMul
import qonnx.util.cleanup as qonnx_cleanup
import hls4ml
import numpy as np
import onnx
import onnx.helper as helper
from qonnx.transformation.base import Transformation
class FillEmptyRoI(Transformation):
"Fill empty RoI input tensor of Resize node if is empty to avoid issues during shape inference"
def apply(self, model):
graph_modified = False
for i, node in enumerate(model.graph.node):
if node.op_type == 'Resize':
# Assuming 'roi' is the second input
if len(node.input) > 2 and node.input[1] == '':
roi = onnx.numpy_helper.from_array(np.empty([0], dtype=np.float32), node.name + "_roi")
model.graph.initializer.append(roi)
roi_value_info = helper.make_tensor_value_info(node.name + "_roi", onnx.TensorProto.FLOAT, [0])
model.graph.value_info.append(roi_value_info)
inputs = [node.input[0], node.name + "_roi", node.input[2]]
mode_string = ''
for attr in model.graph.node[i].attribute:
if attr.name == 'mode':
mode_string = attr.s
new_node = onnx.helper.make_node(
"Resize",
name=node.name,
coordinate_transformation_mode="asymmetric",
cubic_coeff_a=-0.75,
mode=mode_string,
nearest_mode="floor",
inputs=inputs,
outputs=node.output
)
model.graph.node.remove(node)
model.graph.node.insert(i, new_node)
graph_modified = True
return (model, graph_modified)
class TestModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.quant_input = qnn.QuantIdentity(
bit_width=8, return_quant_tensor=True
)
self.upsample = qnn.QuantUpsample(scale_factor=2)
def forward(self, input):
out = self.quant_input(input)
out1 = self.upsample(out)
return out1
model = TestModel()
model.eval()
dummy_input = torch.randn(1, 5, 25, 25)
export_qonnx(model, input_shape=dummy_input.shape, export_path="test_model.onnx")
# Nettoyage avec QONNX
model_wrapped = ModelWrapper("test_model.onnx")
model_wrapped = qonnx_cleanup.cleanup_model(model_wrapped)
model_wrapped = model_wrapped.transform(
ConvertToChannelsLastAndClean(make_input_channels_last=True)
)
model_wrapped = model_wrapped.transform(GemmToMatMul())
model_wrapped = qonnx_cleanup.cleanup_model(model_wrapped)
model_wrapped = model_wrapped.transform(FillEmptyRoI())
model_wrapped.save("test_model_clean.onnx")
# Conversion vers hls4ml
config = hls4ml.utils.config_from_onnx_model(
model_wrapped,
granularity="name",
backend="Vivado",
default_precision="ap_fixed<16,6>",
)
# Convertir
hls_model = hls4ml.converters.convert_from_onnx_model(
model_wrapped,
hls_config=config,
io_type="io_stream",
output_dir="my-hls4ml",
backend="Vivado",
)
# Compiler le modèle
hls_model.compile()
Expected behavior
Sucessfull creation of config from model.
Actual behavior
Output layers: ['Resize_0']
Input shape: [25, 25, 5]
Topology:
Looking for : ['global_in']
Layer name: Transpose_0, layer type: Transpose, current shape: [[1, 25, 25, 5]]
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Layer name: Quant_0, layer type: Quant, current shape: [[1, 5, 25, 25]]
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Layer name: Resize_0, layer type: Resize, current shape: [[1, 5, 25, 25], [0], [4]]
Interpreting Model ...
Output layers: ['Resize_0']
Input shape: [25, 25, 5]
Topology:
Looking for : ['global_in']
Layer name: Transpose_0, layer type: Transpose, current shape: [[1, 25, 25, 5]]
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Layer name: Quant_0, layer type: Quant, current shape: [[1, 5, 25, 25]]
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Layer name: Resize_0, layer type: Resize, current shape: [[1, 5, 25, 25], [0], [4]]
Creating HLS model
WARNING: Config parameter "algorithm" overwrites an existing attribute in layer "Resize_0" (Resize)
Writing HLS project
Done
firmware/myproject.cpp: In function ‘void myproject(hls::stream<nnet::array<ap_fixed<16, 6>, 5> >&, hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 5> >&)’:
firmware/myproject.cpp:38:45: error: no matching function for call to ‘resize_nearest<layer10_t, config9>(hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25> >&, hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 5> >&)’
38 | nnet::resize_nearest<layer10_t, config9>(layer10_out, layer9_out); // Resize_0
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
In file included from firmware/parameters.h:12,
from firmware/myproject.cpp:4:
firmware/nnet_utils/nnet_image.h:19:6: note: candidate: ‘void nnet::resize_nearest(data_T*, data_T*) [with data_T = nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25>; CONFIG_T = config9]’
19 | void resize_nearest(data_T image[CONFIG_T::height * CONFIG_T::width * CONFIG_T::n_chan],
| ^~~~~~~~~~~~~~
firmware/nnet_utils/nnet_image.h:19:28: note: no known conversion for argument 1 from ‘hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25> >’ to ‘nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25>*’
19 | void resize_nearest(data_T image[CONFIG_T::height * CONFIG_T::width * CONFIG_T::n_chan],
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from firmware/parameters.h:13,
from firmware/myproject.cpp:4:
firmware/nnet_utils/nnet_image_stream.h:9:49: note: candidate: ‘void nnet::resize_nearest(hls::stream<srcType>&, hls::stream<srcType>&) [with data_T = nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25>; CONFIG_T = config9]’
9 | template <class data_T, typename CONFIG_T> void resize_nearest(hls::stream<data_T> &image, hls::stream<data_T> &resized) {
| ^~~~~~~~~~~~~~
firmware/nnet_utils/nnet_image_stream.h:9:113: note: no known conversion for argument 2 from ‘hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 5> >’ to ‘hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25> >&’
9 | template <class data_T, typename CONFIG_T> void resize_nearest(hls::stream<data_T> &image, hls::stream<data_T> &resized) {
| ~~~~~~~~~~~~~~~~~~~~~^~~~~~~
Traceback (most recent call last):
File "/home/padi/Nextcloud/2 - École/1 - HEPIA/Année 3/Projet de Semestre/Source/upsamp.py", line 104, in <module>
hls_model.compile()
File "/home/padi/Nextcloud/2 - École/1 - HEPIA/Année 3/Projet de Semestre/Source/hls4ml/hls4ml/model/graph.py", line 691, in compile
self._compile()
File "/home/padi/Nextcloud/2 - École/1 - HEPIA/Année 3/Projet de Semestre/Source/hls4ml/hls4ml/model/graph.py", line 694, in _compile
lib_name = self.config.backend.compile(self)
File "/home/padi/Nextcloud/2 - École/1 - HEPIA/Année 3/Projet de Semestre/Source/hls4ml/hls4ml/backends/fpga/fpga_backend.py", line 174, in compile
raise Exception(f'Failed to compile project "{model.config.get_project_name()}"')
Exception: Failed to compile project "myproject"