Skip to content

Commit a3d5274

Browse files
authored
[NPU] fix the npu code for knn and three nn ops (#3269)
1 parent 6d33b9f commit a3d5274

File tree

4 files changed

+29
-12
lines changed

4 files changed

+29
-12
lines changed

mmcv/ops/csrc/pytorch/npu/knn_npu.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using namespace std;
88
void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
99
const Tensor new_xyz, Tensor idx, Tensor dist2) {
1010
// transpose known from [B, N, 3] to [B, 3, N]
11-
at::Tensor source = xyz.transpose(1, 2).contiguous();
11+
at::Tensor source = xyz.transpose(2, 1).contiguous();
1212
at::Tensor target = new_xyz.contiguous();
1313

1414
bool is_from_knn = true;

mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp

+2-11
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,12 @@ using namespace std;
77

88
void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
99
const Tensor known, Tensor dist2, Tensor idx) {
10-
// transpose known [B, N, 3] -> [B, 3, N]
11-
at::Tensor source = known.transpose(1, 2).contiguous();
10+
at::Tensor source = known.contiguous();
1211
at::Tensor target = unknown.contiguous();
13-
auto originDtype = source.scalar_type();
14-
if (originDtype == at::kHalf) {
15-
source = source.to(at::kFloat);
16-
target = target.to(at::kFloat);
17-
}
1812

1913
bool is_from_knn = false;
20-
uint32_t nsample = 3;
14+
int nsample = 3;
2115
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, nsample, dist2, idx);
22-
if (originDtype == at::kHalf) {
23-
dist2 = dist2.to(at::kHalf);
24-
}
2516
}
2617

2718
void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,

mmcv/ops/knn.py

+11
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ def forward(ctx,
6666
B, npoint, _ = center_xyz.shape
6767
N = xyz.shape[1]
6868

69+
if xyz.device.type == 'npu':
70+
dist2 = center_xyz.new_zeros((B, npoint, k)).float()
71+
idx = center_xyz.new_zeros((B, npoint, k)).int()
72+
ext_module.knn_forward(
73+
xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
74+
zeros_idx = torch.zeros(
75+
xyz.shape[0], center_xyz.shape[1], k, dtype=torch.int32).npu()
76+
idx.where(dist2 >= 1e10, zeros_idx)
77+
idx = idx.transpose(2, 1).contiguous() # [B, k, npoint]
78+
return idx.int()
79+
6980
idx = center_xyz.new_zeros((B, npoint, k)).int()
7081
dist2 = center_xyz.new_zeros((B, npoint, k)).float()
7182

mmcv/ops/three_nn.py

+15
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ def forward(ctx: Any, target: torch.Tensor,
3434

3535
B, N, _ = target.size()
3636
m = source.size(1)
37+
if source.device.type == 'npu':
38+
# strict to fp32
39+
source = source.transpose(2, 1).contiguous()
40+
dtype_ = source.dtype
41+
if dtype_ == torch.float16:
42+
target = target.float()
43+
source = source.float()
44+
dist2 = target.new_empty(B, N, 3)
45+
idx = target.new_empty(B, N, 3, dtype=torch.int32)
46+
ext_module.three_nn_forward(
47+
target, source, dist2, idx, b=B, n=N, m=m)
48+
dist2 = torch.sqrt(dist2)
49+
if dtype_ == torch.float16:
50+
dist2 = dist2.half()
51+
return dist2, idx.int()
3752
dist2 = target.new_empty(B, N, 3)
3853
idx = target.new_empty(B, N, 3, dtype=torch.int32)
3954

0 commit comments

Comments
 (0)