Skip to content

Commit

Permalink
add_npu_border_align
Browse files Browse the repository at this point in the history
  • Loading branch information
frh23333 committed Oct 30, 2024
1 parent c46684c commit 9622034
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/border_align_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

void border_align_forward_impl(const Tensor &input, const Tensor &boxes, Tensor output,
Tensor argmax_idx, const int pool_size);

void border_align_forward_npu(const Tensor &input, const Tensor &boxes, Tensor output,
Tensor argmax_idx, const int pool_size){
TORCH_CHECK(input.size(0) == boxes.size(0), "The batch sizes of feature map and rois must be the same.");
TORCH_CHECK(input.size(1) % 4 == 0, "The number of channels must be divisible by 4.");
TORCH_CHECK(pool_size >= 2, "The pool size should be larger than 2.");
int32_t batch_size = input.size(0);
int32_t channels = input.size(1);
int32_t height = input.size(2);
int32_t width = input.size(3);
at::Tensor feature_map = input.permute({0, 2, 3, 1}).contiguous();
at::Tensor rois_map = boxes.contiguous();
at::Tensor temp_tensor = at::zeros({batch_size, height * width, pool_size + 1, channels}, input.options());
EXEC_NPU_CMD(aclnnBorderAlign, feature_map, rois_map, pool_size, temp_tensor);

auto max_result = temp_tensor.max(-2);
at::Tensor output_ = std::get<0>(max_result).to(at::kFloat);
output_ = output_.reshape({batch_size, height * width, 4, channels / 4}).permute({0, 3, 1, 2}).contiguous();
output.copy_(output_);

at::Tensor argmax_idx_ = std::get<1>(max_result).to(at::kInt);
argmax_idx_ = argmax_idx_.reshape({batch_size, height * width, 4, channels / 4}).permute({0, 3, 1, 2}).contiguous();
argmax_idx.copy_(argmax_idx_);
}


REGISTER_NPU_IMPL(border_align_forward_impl, border_align_forward_npu);

0 comments on commit 9622034

Please sign in to comment.