-
Notifications
You must be signed in to change notification settings - Fork 801
Description
Hi.
I want to create a YOLO model for the QCM6490.
However, I'm encountering an error and am stuck because I don't know how to resolve it. The model I want to use is the following YOLOv9.
https://github.com/MultimediaTechLab/YOLO
If it's a basic mistake, I apologize. I'm having trouble, so any support would be greatly appreciated.
Environment
- Ubuntu 22.04.5 LTS
- Docker Engine 29.1.4
Step
Since the QCM6490 is not currently available in the execuTorch released via pip, I built it following the instructions.
git clone --depth 1 https://github.com/pytorch/executorch.git
cd /executorch
git fetch --depth 1 origin 8e8d97eb3802f8ec0261684048f18001b3a3e668
sh ./install_executorch.shWe are currently in the verification phase, so we are using Docker. Below is the Dockerfile we are using.
# syntax=docker/dockerfile:1
FROM pytorch/pytorch:2.9.1-cuda12.6-cudnn9-devel AS build
WORKDIR /repo
ARG CMAKE_ARGS="-DEXECUTORCH_BUILD_QNN=ON -DCMAKE_BUILD_TYPE=Release"
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt update && apt --no-install-recommends install -y git
RUN git clone --depth 1 https://github.com/pytorch/executorch.git
WORKDIR /repo/executorch
RUN git fetch --depth 1 origin 8e8d97eb3802f8ec0261684048f18001b3a3e668
RUN ./install_executorch.sh
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=requirements.txt,target=requirements.txt \
pip install -r requirements.txtThe build and installation were successful, and the following sample code also runs correctly, outputting the converted .pte file.
from typing import Tuple
import executorch
import torch
import torchvision
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
from torchao.quantization.pt2e.quantize_pt2e import (
prepare_pt2e,
convert_pt2e,
)
from executorch.backends.qualcomm.utils.utils import (
generate_qnn_executorch_compiler_spec,
generate_htp_compiler_spec,
QcomChipset,
to_edge_transform_and_lower_to_qnn,
)
def main():
inputs = (torch.randn(1, 3, 224, 224),)
model = torchvision.models.convnext_small().eval()
quantizer = QnnQuantizer()
m = torch.export.export(model, inputs, strict=True).module()
# PTQ (Post-Training Quantization)
prepared_model = prepare_pt2e(m, quantizer)
# Calibration loop would go here
prepared_model(*inputs)
# Convert to quantized model
quantized_model = convert_pt2e(prepared_model)
backend_options = generate_htp_compiler_spec(
use_fp16=False,
)
compile_spec = generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.QCM6490,
backend_options=backend_options,
)
model = to_edge_transform_and_lower_to_qnn(quantized_model, inputs, compile_spec).to_executorch()
# Save the compiled model
model_name = "convext_qnn.pte"
with open(model_name, "wb") as f:
f.write(model.buffer)
print(f"Model successfully exported to {model_name}")
if __name__ == "__main__":
main()However, running it with YOLOv9 results in an error.
I created code to convert it to QNN, referencing deploy_utils.py. Below is a sample.
https://github.com/MultimediaTechLab/YOLO/blob/main/yolo/utils/deploy_utils.py
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
from torchao.quantization.pt2e.quantize_pt2e import (
prepare_pt2e,
convert_pt2e,
)
from executorch.backends.qualcomm.utils.utils import (
generate_qnn_executorch_compiler_spec,
generate_htp_compiler_spec,
QcomChipset,
to_edge_transform_and_lower_to_qnn,
)
def _load_qnn_model(self):
model = create_model(
self.cfg.model, class_num=self.class_num, weight_path=self.cfg.weight
).eval()
example_inputs = (torch.randn(1, 3, *self.cfg.image_size),)
quantizer = QnnQuantizer()
m = torch.export.export(model, example_inputs, strict=True).module()
# PTQ (Post-Training Quantization)
prepared_model = prepare_pt2e(m, quantizer)
# Calibration loop would go here
prepared_model(*example_inputs)
# Convert to quantized model
quantized_model = convert_pt2e(prepared_model)
# HTP Compiler Configuration
backend_options = generate_htp_compiler_spec(
use_fp16=False # False for quantized models
)
# QNN Compiler Spec
compile_spec = generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.QCM6490, # Your target SoC
backend_options=backend_options,
)
# Lower to QNN backend
delegated_program = to_edge_transform_and_lower_to_qnn(
quantized_model, example_inputs, compile_spec
)
# Export to ExecuTorch format
executorch_program = delegated_program.to_executorch()
# Save the compiled model
model_name = "qnn.pte"
with open(model_name, "wb") as f:
f.write(executorch_program.buffer)
print(f"Model successfully exported to {model_name}")Error Message
File "/app/src/sackville/yolo/utils/deploy_utils.py", line 171, in _load_qnn_model
delegated_program = to_edge_transform_and_lower_to_qnn(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/backends/qualcomm/utils/utils.py", line 442, in to_edge_transform_and_lower_to_qnn
return to_edge_transform_and_lower(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/exir/program/_program.py", line 115, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/exir/program/_program.py", line 1379, in to_edge_transform_and_lower
edge_manager = edge_manager.to_backend(method_to_partitioner)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/exir/program/_program.py", line 115, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/exir/program/_program.py", line 1681, in to_backend
new_edge_programs = to_backend(method_to_programs_and_partitioners)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/functools.py", line 909, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/exir/backend/backend_api.py", line 721, in _
partitioner_result = partitioner_instance(fake_edge_program)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/exir/backend/partitioner.py", line 66, in __call__
return self.partition(exported_program)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/backends/qualcomm/partition/qnn_partitioner.py", line 194, in partition
partitions = self.generate_partitions(edge_program)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/backends/qualcomm/partition/qnn_partitioner.py", line 159, in generate_partitions
return generate_partitions_from_list_of_nodes(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/exir/backend/canonical_partitioners/pattern_op_partitioner.py", line 54, in generate_partitions_from_list_of_nodes
partition_list = capability_partitioner.propose_partitions()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py", line 226, in propose_partitions
if self._is_node_supported(node) and node not in assignment:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/fx/passes/infra/partitioner.py", line 87, in _is_node_supported
return self.operator_support.is_node_supported(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/backends/qualcomm/partition/qnn_partitioner.py", line 98, in is_node_supported
op_wrapper = self.node_visitors[node.target.__name__].define_node(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/backends/qualcomm/builders/op_stack.py", line 34, in define_node
input_tensor = self.get_tensor(self.get_node(input_node), node)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/executorch/backends/qualcomm/builders/node_visitor.py", line 147, in get_tensor
tensor = tensor.permute(dims=op_node.meta[QCOM_AXIS_ORDER]).contiguous()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/utils/_stats.py", line 29, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 899, in __torch_dispatch__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 850, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_compile.py", line 54, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1209, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/utils/_stats.py", line 29, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1397, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2155, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1544, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2823, in _dispatch_impl
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 850, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 4 is not equal to len(dims) = 5Setting other QcomChipset values will result in the same error.
cc @cccclai @winskuo-quic @shewu-quic @haowhsu-quic @DannyYuyang-quic @cbilgin