Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d90969b

Browse files
Annarinemomo609
authored andcommittedJun 13, 2024
chamfer_distance fp16->fp32
1 parent 737b5b4 commit d90969b

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed
 

‎mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp

+17-3
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,38 @@
1-
21
#include "pytorch_npu_helper.hpp"
32

43
using namespace NPU_NAME_SPACE;
54
using namespace std;
65

76
void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
87
Tensor dist2, Tensor idx1, Tensor idx2) {
8+
bool is_half = input.scalar_type() == at::kHalf;
99
at::Tensor xyz1 = at::ones_like(XYZ1);
1010
at::Tensor xyz2 = at::ones_like(XYZ2);
11+
at::Tensor distf1 = at::ones_like(dist1);
12+
at::Tensor distf2 = at::ones_like(dist2);
1113
xyz1 = XYZ1.transpose(1, 2).transpose(0, 1);
1214
xyz2 = XYZ2.transpose(1, 2).transpose(0, 1);
15+
if (is_half) {
16+
xyz1 = xyz1.to(at::kFloat);
17+
xyz2 = xyz2.to(at::kFloat);
18+
distf1 = dist1.to(at::kFloat);
19+
distf2 = dist2.to(at::kFloat);
20+
}
1321
OpCommand cmd;
1422
cmd.Name("ChamferDistance")
1523
.Input(xyz1)
1624
.Input(xyz2)
17-
.Output(dist1)
18-
.Output(dist2)
25+
.Output(distf1)
26+
.Output(distf2)
1927
.Output(idx1)
2028
.Output(idx2)
2129
.Run();
30+
if (is_half) {
31+
distf1 = distf1.to(at::kHalf);
32+
distf2 = distf2.to(at::kHalf);
33+
}
34+
dist1.copy_(distf1);
35+
dist2.copy_(distf2);
2236
}
2337

2438
void chamfer_distance_backward_npu(Tensor xyz1, Tensor xyz2, Tensor idx1,

0 commit comments

Comments
 (0)
Please sign in to comment.