Skip to content

[ERROR] [Qnn ExecuTorch]: Filters in[1] dimension 3 at index 2 not equal to channel_in 200 / groups 1. #16619

@U-c207Pr4f57t9

Description

@U-c207Pr4f57t9

🐛 Describe the bug

import torch
from torch import nn as nn
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
from executorch.exir import to_edge_transform_and_lower

from executorch.backends.qualcomm.utils.utils import (
    generate_qnn_executorch_compiler_spec,
    generate_htp_compiler_spec,
    QcomChipset,
)

class TestModel(nn.Module):
    def __init__(self,
                 num_in_ch,
                 num_out_ch,
                 feature_channels=32,
                 ):
        super(TestModel, self).__init__()
        in_channels = num_in_ch
        out_channels = num_out_ch

        self.conv_act = nn.Sequential(
            nn.Conv2d(in_channels, feature_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.conv_act2 = nn.Sequential(
            nn.Conv2d(feature_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        out_feature = self.conv_act(x)
        output = self.conv_act2(out_feature)
        return output

# HTP Compiler Configuration
backend_options = generate_htp_compiler_spec(
    use_fp16=True,  # False for quantized models
)

# QNN Compiler Spec
compile_spec = generate_qnn_executorch_compiler_spec(
    soc_model=QcomChipset.SM8550,  # Your target SoC
    backend_options=backend_options,
)


model = TestModel(num_in_ch=3, num_out_ch=3, feature_channels=32).eval()
sample_inputs = (torch.randn(1, 3, 100, 200), )
exported_program = torch.export.export(model, sample_inputs)

executorch_program = to_edge_transform_and_lower(
    exported_program,
    partitioner=[QnnPartitioner(compile_spec),
                 XnnpackPartitioner()]
).to_executorch()

with open("/home/jiang/extorch/TestModel.pte", "wb") as f:
    f.write(executorch_program.buffer)
[INFO] [Qnn ExecuTorch]: create QNN Logger with log_level 1
[INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2
[INFO] [Qnn ExecuTorch]: Caching: Caching is in SAVE MODE.
[INFO] [Qnn ExecuTorch]: Running level=3 optimization.
[QNN Partitioner Op Support]: aten.relu.default | True
[ERROR] [Qnn ExecuTorch]: Filters in[1] dimension 32 at index 2 not equal to channel_in 200 / groups 1.

[ERROR] [Qnn ExecuTorch]: Op specific validation failed.

[ERROR] [Qnn ExecuTorch]:  <E> validateNativeOps master op validator aten_convolution_default_1:qti.aisw:Conv2d failed 3110

[ERROR] [Qnn ExecuTorch]:  <E> QnnBackend_validateOpConfig failed 3110

[ERROR] [Qnn ExecuTorch]:  <E> Failed to validate op aten_convolution_default_1 with error 0xc26

[WARNING] [Qnn ExecuTorch]: Qnn Backend op validation failed with error: 3110
[QNN Partitioner Op Support]: aten.convolution.default | False
[QNN Partitioner Op Support]: aten.relu.default | True
[ERROR] [Qnn ExecuTorch]: Filters in[1] dimension 3 at index 2 not equal to channel_in 200 / groups 1.

[ERROR] [Qnn ExecuTorch]: Op specific validation failed.

[ERROR] [Qnn ExecuTorch]:  <E> validateNativeOps master op validator aten_convolution_default:qti.aisw:Conv2d failed 3110

[ERROR] [Qnn ExecuTorch]:  <E> QnnBackend_validateOpConfig failed 3110

[ERROR] [Qnn ExecuTorch]:  <E> Failed to validate op aten_convolution_default with error 0xc26

[WARNING] [Qnn ExecuTorch]: Qnn Backend op validation failed with error: 3110
[QNN Partitioner Op Support]: aten.convolution.default | False
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context
[INFO] [Qnn ExecuTorch]: Destroy Qnn device
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
INFO:executorch.backends.qualcomm.partition.qnn_partitioner:Qnn partitioner will delegate torch mutable buffer with the same I/O address during the runtime, so if your model contains mutable buffer, then you can get the better performance with skip_mutable_buffer=False. If you encounter accuracy issue during the runtime, then please set `skip_mutable_buffer=True` and try again.
[INFO] [Qnn ExecuTorch]: create QNN Logger with log_level 1
[INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2
[INFO] [Qnn ExecuTorch]: Caching: Caching is in SAVE MODE.
[INFO] [Qnn ExecuTorch]: Running level=3 optimization.
INFO:executorch.backends.qualcomm.qnn_preprocess:Processing Method(0): (1/2)
INFO:executorch.backends.qualcomm.qnn_preprocess:Visiting: aten_relu_default, aten.relu.default

Versions

executorch 1.0.1
ubuntu22.04 x86 wsl

cc @cccclai @winskuo-quic @shewu-quic @haowhsu-quic @DannyYuyang-quic @cbilgin

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: qnnIssues related to Qualcomm's QNN delegate and code under backends/qualcomm/partner: qualcommFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions