Skip to content

hls4ml fpga_backend fails when using a Resize node with no ROI using #1266 fix (Brevitas -> QONNX) #1268

@The-Padi

Description

@The-Padi

Prerequisites

Please make sure to check off these prerequisites before submitting a bug report.

  • Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
  • Check that the issue hasn't already been reported, by checking the currently open issues.
  • If there are steps to reproduce the problem, make sure to write them down below.
  • If relevant, please include the hls4ml project files, which were created directly before and/or after the bug.

Quick summary

hls4ml fpga_backend fails when using a Resize node with no ROI but using #1266 fix.

Details

hls4ml fpga_backend fails when using a Resize node with no ROI and using #1266 fix comming from the conversion of a QuantUpsample from Brevitas.

Steps to Reproduce

Add what needs to be done to reproduce the bug. Add commented code examples and make sure to include the original model files / code, and the commit hash you are working on.

  1. Clone the hls4ml repository
  2. Checkout the master branch, with commit hash: [77b8331]
  3. Run the code below
import torch
import torch.nn as nn
import brevitas.nn as qnn
from brevitas.export import export_qonnx
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.channels_last import ConvertToChannelsLastAndClean
from qonnx.transformation.gemm_to_matmul import GemmToMatMul
import qonnx.util.cleanup as qonnx_cleanup
import hls4ml
import numpy as np
import onnx
import onnx.helper as helper
from qonnx.transformation.base import Transformation


class FillEmptyRoI(Transformation):
    "Fill empty RoI input tensor of Resize node if is empty to avoid issues during shape inference"

    def apply(self, model):
        graph_modified = False
        for i, node in enumerate(model.graph.node):
            if node.op_type == 'Resize':
                # Assuming 'roi' is the second input 
                if len(node.input) > 2 and node.input[1] == '':
                    roi = onnx.numpy_helper.from_array(np.empty([0], dtype=np.float32), node.name + "_roi")
                    model.graph.initializer.append(roi)
                    roi_value_info = helper.make_tensor_value_info(node.name + "_roi", onnx.TensorProto.FLOAT, [0])
                    model.graph.value_info.append(roi_value_info)
                    inputs = [node.input[0], node.name + "_roi", node.input[2]]
                    mode_string = ''
                    for attr in model.graph.node[i].attribute:
                        if attr.name == 'mode':
                            mode_string = attr.s
                    new_node = onnx.helper.make_node(
                        "Resize",
                        name=node.name,
                        coordinate_transformation_mode="asymmetric",
                        cubic_coeff_a=-0.75,
                        mode=mode_string,
                        nearest_mode="floor",
                        inputs=inputs,
                        outputs=node.output
                    )
                    model.graph.node.remove(node)
                    model.graph.node.insert(i, new_node)
                    graph_modified = True

        return (model, graph_modified)

class TestModel(nn.Module):
    def __init__(self, *args, **kwargs):
        
        super().__init__(*args, **kwargs)
        
        self.quant_input = qnn.QuantIdentity(
            bit_width=8, return_quant_tensor=True
        )
        
        self.upsample = qnn.QuantUpsample(scale_factor=2)
        
    def forward(self, input):
        
        out = self.quant_input(input)
        out1 = self.upsample(out)
        
        return out1


model = TestModel()
model.eval()

dummy_input = torch.randn(1, 5, 25, 25)
export_qonnx(model, input_shape=dummy_input.shape, export_path="test_model.onnx")

# Nettoyage avec QONNX
model_wrapped = ModelWrapper("test_model.onnx")
model_wrapped = qonnx_cleanup.cleanup_model(model_wrapped)
model_wrapped = model_wrapped.transform(
    ConvertToChannelsLastAndClean(make_input_channels_last=True)
)
model_wrapped = model_wrapped.transform(GemmToMatMul())
model_wrapped = qonnx_cleanup.cleanup_model(model_wrapped)
model_wrapped = model_wrapped.transform(FillEmptyRoI())
model_wrapped.save("test_model_clean.onnx")

# Conversion vers hls4ml
config = hls4ml.utils.config_from_onnx_model(
    model_wrapped,
    granularity="name",
    backend="Vivado",
    default_precision="ap_fixed<16,6>",
)

# Convertir
hls_model = hls4ml.converters.convert_from_onnx_model(
    model_wrapped,
    hls_config=config,
    io_type="io_stream",
    output_dir="my-hls4ml",
    backend="Vivado",
)

# Compiler le modèle
hls_model.compile()

Expected behavior

Sucessfull creation of config from model.

Actual behavior

Output layers:  ['Resize_0']
Input shape: [25, 25, 5]
Topology:
Looking for : ['global_in']
Layer name: Transpose_0, layer type: Transpose, current shape: [[1, 25, 25, 5]]
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Layer name: Quant_0, layer type: Quant, current shape: [[1, 5, 25, 25]]
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Layer name: Resize_0, layer type: Resize, current shape: [[1, 5, 25, 25], [0], [4]]
Interpreting Model ...
Output layers:  ['Resize_0']
Input shape: [25, 25, 5]
Topology:
Looking for : ['global_in']
Layer name: Transpose_0, layer type: Transpose, current shape: [[1, 25, 25, 5]]
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Looking for : ['Transpose_0_out0', 'global_out_1', 'global_out_2', 'global_out_3']
Layer name: Quant_0, layer type: Quant, current shape: [[1, 5, 25, 25]]
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Looking for : ['Quant_0_out0', 'Resize_0_roi', 'Resize_0_param0']
Layer name: Resize_0, layer type: Resize, current shape: [[1, 5, 25, 25], [0], [4]]
Creating HLS model
WARNING: Config parameter "algorithm" overwrites an existing attribute in layer "Resize_0" (Resize)
Writing HLS project
Done
firmware/myproject.cpp: In function ‘void myproject(hls::stream<nnet::array<ap_fixed<16, 6>, 5> >&, hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 5> >&)’:
firmware/myproject.cpp:38:45: error: no matching function for call to ‘resize_nearest<layer10_t, config9>(hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25> >&, hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 5> >&)’
   38 |     nnet::resize_nearest<layer10_t, config9>(layer10_out, layer9_out); // Resize_0
      |     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
In file included from firmware/parameters.h:12,
                 from firmware/myproject.cpp:4:
firmware/nnet_utils/nnet_image.h:19:6: note: candidate: ‘void nnet::resize_nearest(data_T*, data_T*) [with data_T = nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25>; CONFIG_T = config9]’
   19 | void resize_nearest(data_T image[CONFIG_T::height * CONFIG_T::width * CONFIG_T::n_chan],
      |      ^~~~~~~~~~~~~~
firmware/nnet_utils/nnet_image.h:19:28: note:   no known conversion for argument 1 from ‘hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25> >’ to ‘nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25>*’
   19 | void resize_nearest(data_T image[CONFIG_T::height * CONFIG_T::width * CONFIG_T::n_chan],
      |                     ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In file included from firmware/parameters.h:13,
                 from firmware/myproject.cpp:4:
firmware/nnet_utils/nnet_image_stream.h:9:49: note: candidate: ‘void nnet::resize_nearest(hls::stream<srcType>&, hls::stream<srcType>&) [with data_T = nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25>; CONFIG_T = config9]’
    9 | template <class data_T, typename CONFIG_T> void resize_nearest(hls::stream<data_T> &image, hls::stream<data_T> &resized) {
      |                                                 ^~~~~~~~~~~~~~
firmware/nnet_utils/nnet_image_stream.h:9:113: note:   no known conversion for argument 2 from ‘hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 5> >’ to ‘hls::stream<nnet::array<ap_fixed<8, 1, AP_RND_CONV, AP_SAT, 0>, 25> >&’
    9 | template <class data_T, typename CONFIG_T> void resize_nearest(hls::stream<data_T> &image, hls::stream<data_T> &resized) {
      |                                                                                            ~~~~~~~~~~~~~~~~~~~~~^~~~~~~

Traceback (most recent call last):
  File "/home/padi/Nextcloud/2 - École/1 - HEPIA/Année 3/Projet de Semestre/Source/upsamp.py", line 104, in <module>
    hls_model.compile()
  File "/home/padi/Nextcloud/2 - École/1 - HEPIA/Année 3/Projet de Semestre/Source/hls4ml/hls4ml/model/graph.py", line 691, in compile
    self._compile()
  File "/home/padi/Nextcloud/2 - École/1 - HEPIA/Année 3/Projet de Semestre/Source/hls4ml/hls4ml/model/graph.py", line 694, in _compile
    lib_name = self.config.backend.compile(self)
  File "/home/padi/Nextcloud/2 - École/1 - HEPIA/Année 3/Projet de Semestre/Source/hls4ml/hls4ml/backends/fpga/fpga_backend.py", line 174, in compile
    raise Exception(f'Failed to compile project "{model.config.get_project_name()}"')
Exception: Failed to compile project "myproject"

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions