Skip to content

Commit e4e8f50

Browse files
Annarinemomo609
authored andcommitted
chamfer_distance fp16->fp32
1 parent d90969b commit e4e8f50

File tree

3 files changed

+89
-23
lines changed

3 files changed

+89
-23
lines changed

mmcv/ops/csrc/pytorch/npu/chamfer_distance_npu.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using namespace std;
55

66
void chamfer_distance_forward_npu(Tensor XYZ1, Tensor XYZ2, Tensor dist1,
77
Tensor dist2, Tensor idx1, Tensor idx2) {
8-
bool is_half = input.scalar_type() == at::kHalf;
8+
bool is_half = XYZ1.scalar_type() == at::kHalf;
99
at::Tensor xyz1 = at::ones_like(XYZ1);
1010
at::Tensor xyz2 = at::ones_like(XYZ2);
1111
at::Tensor distf1 = at::ones_like(dist1);

mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp

+85-20
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@ using namespace std;
44

55
void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
66
Tensor output, float gamma, float alpha) {
7+
at::Tensor input_y = input;
8+
at::Tensor output_y = output;
9+
bool is_half = input.scalar_type() == at::kHalf;
10+
if (is_half) {
11+
input_y = input.to(at::kFloat);
12+
output_y = output.to(at::kFloat);
13+
}
14+
int64_t weight_size = weight.size(0);
15+
at::Tensor weight_y = at::ones_like(input_y);
16+
if (weight_size > 0) {
17+
weight_y = at::broadcast_to(weight, input.sizes());
18+
if (is_half) {
19+
weight_y = weight_y.to(at::kFloat);
20+
}
21+
}
722
int64_t n_class = input.size(1);
823
at::Tensor target_y = at::ones_like(input);
924
if (n_class == 1) {
@@ -12,24 +27,26 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
1227
target_y = at::add(target_y, 1.0);
1328
} else {
1429
target_y = at::one_hot(target, n_class);
30+
weight_y = at::mul(weight_y, target_y);
31+
weight_y = at::sum(weight_y, 1, true);
32+
weight_y = at::broadcast_to(weight_y, input.sizes());
1533
}
1634
target_y = target_y.to(at::kInt);
17-
int64_t weight_size = weight.size(0);
18-
at::Tensor weight_y = at::ones_like(input);
19-
if (weight_size > 0) {
20-
weight_y = at::broadcast_to(weight, input.sizes());
21-
}
2235
OpCommand cmd;
2336
string reduction = "none";
2437
cmd.Name("SigmoidFocalLoss")
25-
.Input(input)
38+
.Input(input_y)
2639
.Input(target_y)
2740
.Input(weight_y)
28-
.Output(output)
41+
.Output(output_y)
2942
.Attr("gamma", gamma)
3043
.Attr("alpha", alpha)
3144
.Attr("reduction", reduction)
3245
.Run();
46+
if (is_half) {
47+
output_y = output_y.to(at::kHalf);
48+
}
49+
output.copy_(output_y);
3350
}
3451

3552
void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
@@ -38,34 +55,51 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
3855
void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
3956
Tensor grad_input, float gamma,
4057
float alpha) {
58+
at::Tensor input_y = input;
59+
at::Tensor grad_input_y = grad_input;
60+
bool is_half = input.scalar_type() == at::kHalf;
61+
if (is_half) {
62+
input_y = input.to(at::kFloat);
63+
grad_input_y = grad_input.to(at::kFloat);
64+
}
65+
int64_t weight_size = weight.size(0);
66+
at::Tensor weight_y = at::ones_like(input_y);
67+
if (weight_size > 0) {
68+
weight_y = at::broadcast_to(weight, input.sizes());
69+
if (is_half) {
70+
weight_y = weight_y.to(at::kFloat);
71+
}
72+
}
4173
int64_t n_class = input.size(1);
4274
at::Tensor target_y = at::ones_like(input);
4375
if (n_class == 1) {
4476
target_y = at::reshape(target, input.sizes());
4577
} else {
4678
target_y = at::one_hot(target, n_class);
79+
weight_y = at::mul(weight_y, target_y);
80+
weight_y = at::sum(weight_y, 1, true);
81+
weight_y = at::broadcast_to(weight_y, input.sizes());
4782
target_y = at::mul(target_y, -1.0);
4883
target_y = at::add(target_y, 1.0);
4984
}
5085
target_y = target_y.to(at::kInt);
5186
at::Tensor grad_up = at::ones_like(input);
52-
int64_t weight_size = weight.size(0);
53-
at::Tensor weight_y = at::ones_like(input);
54-
if (weight_size > 0) {
55-
weight_y = at::broadcast_to(weight, input.sizes());
56-
}
5787
OpCommand cmd;
5888
string reduction = "none";
5989
cmd.Name("SigmoidFocalLossGrad")
60-
.Input(input)
90+
.Input(input_y)
6191
.Input(target_y)
6292
.Input(grad_up)
6393
.Input(weight_y)
64-
.Output(grad_input)
94+
.Output(grad_input_y)
6595
.Attr("gamma", gamma)
6696
.Attr("alpha", alpha)
6797
.Attr("reduction", reduction)
6898
.Run();
99+
if (is_half) {
100+
grad_input_y = grad_input_y.to(at::kHalf);
101+
}
102+
grad_input.copy_(grad_input_y);
69103
}
70104

71105
void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
@@ -74,26 +108,40 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
74108

75109
void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
76110
Tensor output, float gamma, float alpha) {
111+
at::Tensor input_y = input;
112+
bool is_half = input.scalar_type() == at::kHalf;
113+
if (is_half) {
114+
input_y = input.to(at::kFloat);
115+
}
77116
int64_t n_class = input.size(1);
78117
at::Tensor target_y = at::one_hot(target, n_class);
79118
target_y = target_y.to(at::kInt);
80119
int64_t weight_size = weight.size(0);
81-
at::Tensor weight_y = at::ones_like(input);
120+
at::Tensor weight_y = at::ones_like(input_y);
82121
if (weight_size > 0) {
83122
weight_y = at::broadcast_to(weight, input.sizes());
123+
if (is_half) {
124+
weight_y = weight_y.to(at::kFloat);
125+
}
126+
weight_y = at::mul(weight_y, target_y);
127+
weight_y = at::sum(weight_y, 1, true);
128+
weight_y = at::broadcast_to(weight_y, input.sizes());
84129
}
85-
at::Tensor op_output = at::ones_like(input);
130+
at::Tensor op_output = at::ones_like(input_y);
86131
OpCommand cmd;
87132
string reduction = "none";
88133
cmd.Name("SoftmaxFocalLoss")
89-
.Input(input)
134+
.Input(input_y)
90135
.Input(target_y)
91136
.Input(weight_y)
92137
.Output(op_output)
93138
.Attr("gamma", gamma)
94139
.Attr("alpha", alpha)
95140
.Attr("reduction", reduction)
96141
.Run();
142+
if (is_half) {
143+
op_output = op_output.to(at::kHalf);
144+
}
97145
int64_t n_batch = input.size(0);
98146
c10::SmallVector<int64_t, 2> offsets = {0, 0};
99147
c10::SmallVector<int64_t, 2> sizes = {n_batch, 1};
@@ -124,27 +172,44 @@ void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
124172
void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
125173
Tensor buff, Tensor grad_input,
126174
float gamma, float alpha) {
175+
at::Tensor input_y = input;
176+
at::Tensor grad_input_y = grad_input;
177+
bool is_half = input.scalar_type() == at::kHalf;
178+
if (is_half) {
179+
input_y = input.to(at::kFloat);
180+
grad_input_y = grad_input.to(at::kFloat);
181+
}
127182
int64_t n_class = input.size(1);
128183
at::Tensor target_y = at::one_hot(target, n_class);
129184
target_y = target_y.to(at::kInt);
130185
at::Tensor grad_up = at::ones_like(input);
131186
int64_t weight_size = weight.size(0);
132-
at::Tensor weight_y = at::ones_like(input);
187+
at::Tensor weight_y = at::ones_like(input_y);
133188
if (weight_size > 0) {
134189
weight_y = at::broadcast_to(weight, input.sizes());
190+
if (is_half) {
191+
weight_y = weight_y.to(at::kFloat);
192+
}
193+
weight_y = at::mul(weight_y, target_y);
194+
weight_y = at::sum(weight_y, 1, true);
195+
weight_y = at::broadcast_to(weight_y, input.sizes());
135196
}
136197
OpCommand cmd;
137198
string reduction = "none";
138199
cmd.Name("SoftmaxFocalLossGrad")
139-
.Input(input)
200+
.Input(input_y)
140201
.Input(target_y)
141202
.Input(grad_up)
142203
.Input(weight_y)
143-
.Output(grad_input)
204+
.Output(grad_input_y)
144205
.Attr("gamma", gamma)
145206
.Attr("alpha", alpha)
146207
.Attr("reduction", reduction)
147208
.Run();
209+
if (is_half) {
210+
grad_input_y = grad_input_y.to(at::kHalf);
211+
}
212+
grad_input.copy_(grad_input_y);
148213
}
149214

150215
void softmax_focal_loss_backward_impl(Tensor input, Tensor target,

mmcv/ops/points_in_boxes.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor:
4747
points_device = points.get_device()
4848
assert points_device == boxes.get_device(), \
4949
'Points and boxes should be put on the same device'
50-
if torch.cuda.current_device() != points_device:
51-
torch.cuda.set_device(points_device)
50+
if points.device.type != 'npu':
51+
if torch.cuda.current_device() != points_device:
52+
torch.cuda.set_device(points_device)
5253

5354
ext_module.points_in_boxes_part_forward(boxes.contiguous(),
5455
points.contiguous(),

0 commit comments

Comments
 (0)