@@ -4,6 +4,21 @@ using namespace std;
4
4
5
5
void sigmoid_focal_loss_forward_npu (Tensor input, Tensor target, Tensor weight,
6
6
Tensor output, float gamma, float alpha) {
7
+ at::Tensor input_y = input;
8
+ at::Tensor output_y = output;
9
+ bool is_half = input.scalar_type () == at::kHalf ;
10
+ if (is_half) {
11
+ input_y = input.to (at::kFloat );
12
+ output_y = output.to (at::kFloat );
13
+ }
14
+ int64_t weight_size = weight.size (0 );
15
+ at::Tensor weight_y = at::ones_like (input_y);
16
+ if (weight_size > 0 ) {
17
+ weight_y = at::broadcast_to (weight, input.sizes ());
18
+ if (is_half) {
19
+ weight_y = weight_y.to (at::kFloat );
20
+ }
21
+ }
7
22
int64_t n_class = input.size (1 );
8
23
at::Tensor target_y = at::ones_like (input);
9
24
if (n_class == 1 ) {
@@ -12,24 +27,26 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
12
27
target_y = at::add (target_y, 1.0 );
13
28
} else {
14
29
target_y = at::one_hot (target, n_class);
30
+ weight_y = at::mul (weight_y, target_y);
31
+ weight_y = at::sum (weight_y, 1 , true );
32
+ weight_y = at::broadcast_to (weight_y, input.sizes ());
15
33
}
16
34
target_y = target_y.to (at::kInt );
17
- int64_t weight_size = weight.size (0 );
18
- at::Tensor weight_y = at::ones_like (input);
19
- if (weight_size > 0 ) {
20
- weight_y = at::broadcast_to (weight, input.sizes ());
21
- }
22
35
OpCommand cmd;
23
36
string reduction = " none" ;
24
37
cmd.Name (" SigmoidFocalLoss" )
25
- .Input (input )
38
+ .Input (input_y )
26
39
.Input (target_y)
27
40
.Input (weight_y)
28
- .Output (output )
41
+ .Output (output_y )
29
42
.Attr (" gamma" , gamma )
30
43
.Attr (" alpha" , alpha)
31
44
.Attr (" reduction" , reduction)
32
45
.Run ();
46
+ if (is_half) {
47
+ output_y = output_y.to (at::kHalf );
48
+ }
49
+ output.copy_ (output_y);
33
50
}
34
51
35
52
void sigmoid_focal_loss_forward_impl (Tensor input, Tensor target, Tensor weight,
@@ -38,34 +55,51 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
38
55
void sigmoid_focal_loss_backward_npu (Tensor input, Tensor target, Tensor weight,
39
56
Tensor grad_input, float gamma,
40
57
float alpha) {
58
+ at::Tensor input_y = input;
59
+ at::Tensor grad_input_y = grad_input;
60
+ bool is_half = input.scalar_type () == at::kHalf ;
61
+ if (is_half) {
62
+ input_y = input.to (at::kFloat );
63
+ grad_input_y = grad_input.to (at::kFloat );
64
+ }
65
+ int64_t weight_size = weight.size (0 );
66
+ at::Tensor weight_y = at::ones_like (input_y);
67
+ if (weight_size > 0 ) {
68
+ weight_y = at::broadcast_to (weight, input.sizes ());
69
+ if (is_half) {
70
+ weight_y = weight_y.to (at::kFloat );
71
+ }
72
+ }
41
73
int64_t n_class = input.size (1 );
42
74
at::Tensor target_y = at::ones_like (input);
43
75
if (n_class == 1 ) {
44
76
target_y = at::reshape (target, input.sizes ());
45
77
} else {
46
78
target_y = at::one_hot (target, n_class);
79
+ weight_y = at::mul (weight_y, target_y);
80
+ weight_y = at::sum (weight_y, 1 , true );
81
+ weight_y = at::broadcast_to (weight_y, input.sizes ());
47
82
target_y = at::mul (target_y, -1.0 );
48
83
target_y = at::add (target_y, 1.0 );
49
84
}
50
85
target_y = target_y.to (at::kInt );
51
86
at::Tensor grad_up = at::ones_like (input);
52
- int64_t weight_size = weight.size (0 );
53
- at::Tensor weight_y = at::ones_like (input);
54
- if (weight_size > 0 ) {
55
- weight_y = at::broadcast_to (weight, input.sizes ());
56
- }
57
87
OpCommand cmd;
58
88
string reduction = " none" ;
59
89
cmd.Name (" SigmoidFocalLossGrad" )
60
- .Input (input )
90
+ .Input (input_y )
61
91
.Input (target_y)
62
92
.Input (grad_up)
63
93
.Input (weight_y)
64
- .Output (grad_input )
94
+ .Output (grad_input_y )
65
95
.Attr (" gamma" , gamma )
66
96
.Attr (" alpha" , alpha)
67
97
.Attr (" reduction" , reduction)
68
98
.Run ();
99
+ if (is_half) {
100
+ grad_input_y = grad_input_y.to (at::kHalf );
101
+ }
102
+ grad_input.copy_ (grad_input_y);
69
103
}
70
104
71
105
void sigmoid_focal_loss_backward_impl (Tensor input, Tensor target,
@@ -74,26 +108,40 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
74
108
75
109
void softmax_focal_loss_forward_npu (Tensor input, Tensor target, Tensor weight,
76
110
Tensor output, float gamma, float alpha) {
111
+ at::Tensor input_y = input;
112
+ bool is_half = input.scalar_type () == at::kHalf ;
113
+ if (is_half) {
114
+ input_y = input.to (at::kFloat );
115
+ }
77
116
int64_t n_class = input.size (1 );
78
117
at::Tensor target_y = at::one_hot (target, n_class);
79
118
target_y = target_y.to (at::kInt );
80
119
int64_t weight_size = weight.size (0 );
81
- at::Tensor weight_y = at::ones_like (input );
120
+ at::Tensor weight_y = at::ones_like (input_y );
82
121
if (weight_size > 0 ) {
83
122
weight_y = at::broadcast_to (weight, input.sizes ());
123
+ if (is_half) {
124
+ weight_y = weight_y.to (at::kFloat );
125
+ }
126
+ weight_y = at::mul (weight_y, target_y);
127
+ weight_y = at::sum (weight_y, 1 , true );
128
+ weight_y = at::broadcast_to (weight_y, input.sizes ());
84
129
}
85
- at::Tensor op_output = at::ones_like (input );
130
+ at::Tensor op_output = at::ones_like (input_y );
86
131
OpCommand cmd;
87
132
string reduction = " none" ;
88
133
cmd.Name (" SoftmaxFocalLoss" )
89
- .Input (input )
134
+ .Input (input_y )
90
135
.Input (target_y)
91
136
.Input (weight_y)
92
137
.Output (op_output)
93
138
.Attr (" gamma" , gamma )
94
139
.Attr (" alpha" , alpha)
95
140
.Attr (" reduction" , reduction)
96
141
.Run ();
142
+ if (is_half) {
143
+ op_output = op_output.to (at::kHalf );
144
+ }
97
145
int64_t n_batch = input.size (0 );
98
146
c10::SmallVector<int64_t , 2 > offsets = {0 , 0 };
99
147
c10::SmallVector<int64_t , 2 > sizes = {n_batch, 1 };
@@ -124,27 +172,44 @@ void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
124
172
void softmax_focal_loss_backward_npu (Tensor input, Tensor target, Tensor weight,
125
173
Tensor buff, Tensor grad_input,
126
174
float gamma, float alpha) {
175
+ at::Tensor input_y = input;
176
+ at::Tensor grad_input_y = grad_input;
177
+ bool is_half = input.scalar_type () == at::kHalf ;
178
+ if (is_half) {
179
+ input_y = input.to (at::kFloat );
180
+ grad_input_y = grad_input.to (at::kFloat );
181
+ }
127
182
int64_t n_class = input.size (1 );
128
183
at::Tensor target_y = at::one_hot (target, n_class);
129
184
target_y = target_y.to (at::kInt );
130
185
at::Tensor grad_up = at::ones_like (input);
131
186
int64_t weight_size = weight.size (0 );
132
- at::Tensor weight_y = at::ones_like (input );
187
+ at::Tensor weight_y = at::ones_like (input_y );
133
188
if (weight_size > 0 ) {
134
189
weight_y = at::broadcast_to (weight, input.sizes ());
190
+ if (is_half) {
191
+ weight_y = weight_y.to (at::kFloat );
192
+ }
193
+ weight_y = at::mul (weight_y, target_y);
194
+ weight_y = at::sum (weight_y, 1 , true );
195
+ weight_y = at::broadcast_to (weight_y, input.sizes ());
135
196
}
136
197
OpCommand cmd;
137
198
string reduction = " none" ;
138
199
cmd.Name (" SoftmaxFocalLossGrad" )
139
- .Input (input )
200
+ .Input (input_y )
140
201
.Input (target_y)
141
202
.Input (grad_up)
142
203
.Input (weight_y)
143
- .Output (grad_input )
204
+ .Output (grad_input_y )
144
205
.Attr (" gamma" , gamma )
145
206
.Attr (" alpha" , alpha)
146
207
.Attr (" reduction" , reduction)
147
208
.Run ();
209
+ if (is_half) {
210
+ grad_input_y = grad_input_y.to (at::kHalf );
211
+ }
212
+ grad_input.copy_ (grad_input_y);
148
213
}
149
214
150
215
void softmax_focal_loss_backward_impl (Tensor input, Tensor target,
0 commit comments