Skip to content

[torch_mlir] missing match the the avg_pool1d operator #4349

@vfdff

Description

@vfdff
  • I try to generate a torch.aten.avg_pool1d test case using the fx.export_and_import, so I create a torch file avg_pool1d.py , which use F.avg_pool1d(x, kernel_size=3, stride=2)
def forward(self, x):
            return F.avg_pool1d(x, kernel_size=3, stride=2)
  • real output of the touch.mlir: using a torch.aten.avg_pool2d operator
(py311-source) root@998ee80b761b:/home/zhongyunde/torch-mlir/test/python/fx_importer# python avg_pool1d.py 
test_import_frozen_exported_program
-----------------------------------
module {
  func.func @main(%arg0: !torch.vtensor<[2,3,10],f32>) -> !torch.vtensor<[2,3,4],f32> {
    %int-2 = torch.constant.int -2
    %0 = torch.aten.unsqueeze %arg0, %int-2 : !torch.vtensor<[2,3,10],f32>, !torch.int -> !torch.vtensor<[2,3,1,10],f32>
    %int1 = torch.constant.int 1
    %int3 = torch.constant.int 3
    %1 = torch.prim.ListConstruct %int1, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %int1_0 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %2 = torch.prim.ListConstruct %int1_0, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %int0 = torch.constant.int 0
    %int0_1 = torch.constant.int 0
    %3 = torch.prim.ListConstruct %int0, %int0_1 : (!torch.int, !torch.int) -> !torch.list<int>
    %false = torch.constant.bool false
    %true = torch.constant.bool true
    %none = torch.constant.none
    %4 = torch.aten.avg_pool2d %0, %1, %2, %3, %false, %true, %none : !torch.vtensor<[2,3,1,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2,3,1,4],f32>
    %int-2_2 = torch.constant.int -2
    %5 = torch.aten.squeeze.dim %4, %int-2_2 : !torch.vtensor<[2,3,1,4],f32>, !torch.int -> !torch.vtensor<[2,3,4],f32>
    return %5 : !torch.vtensor<[2,3,4],f32>
  }
}
  • expected output : using torch.aten.avg_pool1d
module {
  func.func @main(%arg0: !torch.vtensor<[2,3,10],f32>) -> !torch.vtensor<[2,3,4],f32> {
    %true = torch.constant.bool true
    %false = torch.constant.bool false
    %int0 = torch.constant.int 0
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
    %1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list<int>
    %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>
    %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %true : !torch.vtensor<[2,3,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[2,3,4],f32>
    return %3 : !torch.vtensor<[2,3,4],f32>
  }
}
  • torch-mlir version
(py311-source) root@998ee80b761b:/home/zhongyunde/torch-mlir/test/python/fx_importer# pip list | grep mlir  
torch-mlir               20241002.240

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions