-
Notifications
You must be signed in to change notification settings - Fork 801
Open
Labels
module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/partner: qualcommFor backend delegation, kernels, demo, etc. from the 3rd-party partner, QualcommFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm
Description
🐛 Describe the bug
import torch
from torch import nn as nn
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
from executorch.exir import to_edge_transform_and_lower
from executorch.backends.qualcomm.utils.utils import (
generate_qnn_executorch_compiler_spec,
generate_htp_compiler_spec,
QcomChipset,
)
class TestModel(nn.Module):
def __init__(self,
num_in_ch,
num_out_ch,
feature_channels=32,
):
super(TestModel, self).__init__()
in_channels = num_in_ch
out_channels = num_out_ch
self.conv_act = nn.Sequential(
nn.Conv2d(in_channels, feature_channels, kernel_size=3, padding=1),
nn.ReLU()
)
self.conv_act2 = nn.Sequential(
nn.Conv2d(feature_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU()
)
def forward(self, x):
out_feature = self.conv_act(x)
output = self.conv_act2(out_feature)
return output
# HTP Compiler Configuration
backend_options = generate_htp_compiler_spec(
use_fp16=True, # False for quantized models
)
# QNN Compiler Spec
compile_spec = generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.SM8550, # Your target SoC
backend_options=backend_options,
)
model = TestModel(num_in_ch=3, num_out_ch=3, feature_channels=32).eval()
sample_inputs = (torch.randn(1, 3, 100, 200), )
exported_program = torch.export.export(model, sample_inputs)
executorch_program = to_edge_transform_and_lower(
exported_program,
partitioner=[QnnPartitioner(compile_spec),
XnnpackPartitioner()]
).to_executorch()
with open("/home/jiang/extorch/TestModel.pte", "wb") as f:
f.write(executorch_program.buffer)[INFO] [Qnn ExecuTorch]: create QNN Logger with log_level 1
[INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2
[INFO] [Qnn ExecuTorch]: Caching: Caching is in SAVE MODE.
[INFO] [Qnn ExecuTorch]: Running level=3 optimization.
[QNN Partitioner Op Support]: aten.relu.default | True
[ERROR] [Qnn ExecuTorch]: Filters in[1] dimension 32 at index 2 not equal to channel_in 200 / groups 1.
[ERROR] [Qnn ExecuTorch]: Op specific validation failed.
[ERROR] [Qnn ExecuTorch]: <E> validateNativeOps master op validator aten_convolution_default_1:qti.aisw:Conv2d failed 3110
[ERROR] [Qnn ExecuTorch]: <E> QnnBackend_validateOpConfig failed 3110
[ERROR] [Qnn ExecuTorch]: <E> Failed to validate op aten_convolution_default_1 with error 0xc26
[WARNING] [Qnn ExecuTorch]: Qnn Backend op validation failed with error: 3110
[QNN Partitioner Op Support]: aten.convolution.default | False
[QNN Partitioner Op Support]: aten.relu.default | True
[ERROR] [Qnn ExecuTorch]: Filters in[1] dimension 3 at index 2 not equal to channel_in 200 / groups 1.
[ERROR] [Qnn ExecuTorch]: Op specific validation failed.
[ERROR] [Qnn ExecuTorch]: <E> validateNativeOps master op validator aten_convolution_default:qti.aisw:Conv2d failed 3110
[ERROR] [Qnn ExecuTorch]: <E> QnnBackend_validateOpConfig failed 3110
[ERROR] [Qnn ExecuTorch]: <E> Failed to validate op aten_convolution_default with error 0xc26
[WARNING] [Qnn ExecuTorch]: Qnn Backend op validation failed with error: 3110
[QNN Partitioner Op Support]: aten.convolution.default | False
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context
[INFO] [Qnn ExecuTorch]: Destroy Qnn device
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
INFO:executorch.backends.qualcomm.partition.qnn_partitioner:Qnn partitioner will delegate torch mutable buffer with the same I/O address during the runtime, so if your model contains mutable buffer, then you can get the better performance with skip_mutable_buffer=False. If you encounter accuracy issue during the runtime, then please set `skip_mutable_buffer=True` and try again.
[INFO] [Qnn ExecuTorch]: create QNN Logger with log_level 1
[INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2
[INFO] [Qnn ExecuTorch]: Caching: Caching is in SAVE MODE.
[INFO] [Qnn ExecuTorch]: Running level=3 optimization.
INFO:executorch.backends.qualcomm.qnn_preprocess:Processing Method(0): (1/2)
INFO:executorch.backends.qualcomm.qnn_preprocess:Visiting: aten_relu_default, aten.relu.default
Versions
executorch 1.0.1
ubuntu22.04 x86 wsl
cc @cccclai @winskuo-quic @shewu-quic @haowhsu-quic @DannyYuyang-quic @cbilgin
Metadata
Metadata
Assignees
Labels
module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/Issues related to Qualcomm's QNN delegate and code under backends/qualcomm/partner: qualcommFor backend delegation, kernels, demo, etc. from the 3rd-party partner, QualcommFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm