Skip to content

Commit dfabcec

Browse files
committed
three_interpolate_npu_init
1 parent df2dadb commit dfabcec

File tree

2 files changed

+139
-4
lines changed

2 files changed

+139
-4
lines changed

mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp

+52-4
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,72 @@ using namespace std;
66
void three_interpolate_forward_npu(int b, int c, int m, int n,
77
const Tensor points, const Tensor idx,
88
const Tensor weight, Tensor out) {
9-
auto point_c_trans = points.transpose(1, 2);
9+
auto originDtype = points.scalar_type();
10+
at::Tensor pointsCast = points;
11+
at::Tensor weightCast = weight;
12+
at::Tensor outCast = out;
13+
14+
TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
15+
"three_interpolate_forward ascend only support fp32 and fp16.");
16+
17+
if (originDtype == at::ScalarType::Half) {
18+
pointsCast = points.to(at::kFloat);
19+
weightCast = weight.to(at::kFloat);
20+
outCast = out.to(at::kFloat);
21+
}
22+
23+
auto point_c_trans = pointsCast.transpose(1, 2);
1024

1125
OpCommand cmd;
1226
cmd.Name("ThreeInterpolate")
1327
.Input(point_c_trans)
1428
.Input(idx)
15-
.Input(weight)
16-
.Output(out)
29+
.Input(weightCast)
30+
.Output(outCast)
1731
.Run();
1832

19-
auto output = out.view({b, n, c}).transpose(1, 2);
33+
auto output = outCast.view({b, n, c}).transpose(1, 2);
2034
auto res = NpuUtils::format_contiguous(output);
2135
out.copy_(res);
2236
}
2337

38+
void three_interpolate_backward_npu(int b, int c, int n, int m,
39+
const Tensor grad_out, const Tensor idx,
40+
const Tensor weight, Tensor grad_points) {
41+
auto originDtype = grad_out.scalar_type();
42+
at::Tensor gradOutCast = grad_out;
43+
at::Tensor weightCast = weight;
44+
at::Tensor gradPointsCast = grad_points;
45+
46+
TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
47+
"three_interpolate_backward ascend only support fp32 and fp16.");
48+
49+
if (originDtype == at::ScalarType::Half) {
50+
gradOutCast = grad_out.to(at::kFloat);
51+
weightCast = weight.to(at::kFloat);
52+
gradPointsCast = grad_points.to(at::kFloat);
53+
}
54+
55+
OpCommand cmd;
56+
cmd.Name("ThreeInterpolateBackward")
57+
.Input(gradOutCast)
58+
.Input(idx)
59+
.Input(weightCast)
60+
.Output(gradPointsCast)
61+
.Attr("m", m)
62+
.Run();
63+
}
64+
2465
void three_interpolate_forward_impl(int b, int c, int m, int n,
2566
const Tensor points, const Tensor idx,
2667
const Tensor weight, Tensor out);
2768

69+
void three_interpolate_backward_impl(int b, int c, int n, int m,
70+
const Tensor grad_out, const Tensor idx,
71+
const Tensor weight, Tensor grad_points);
72+
2873
REGISTER_NPU_IMPL(three_interpolate_forward_impl,
2974
three_interpolate_forward_npu);
75+
76+
REGISTER_NPU_IMPL(three_interpolate_backward_impl,
77+
three_interpolate_backward_npu);

tests/test_ops/test_three_interpolate.py

+87
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import pytest
33
import torch
4+
import numpy as np
45

56
from mmcv.ops import three_interpolate
67
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE
@@ -95,3 +96,89 @@ def test_three_interpolate(dtype, device):
9596
device=device)
9697

9798
assert torch.allclose(output, expected_output, 1e-3, 1e-4)
99+
100+
101+
def three_interpolate_forward_gloden(features, idx, weight):
102+
bs, cs, ms = features.shape
103+
ns = idx.shape[1]
104+
105+
dtype = features.dtype
106+
if dtype == np.float16:
107+
features = features.astype(np.float32)
108+
weight = weight.astype(np.float32)
109+
110+
output = np.zeros((bs, cs, ns), dtype=np.float)
111+
for b in range(bs):
112+
for c in range(cs):
113+
for n in range(ns):
114+
output[b][c][n] = \
115+
features[b][c][idx[b][n][0]] * weight[b][n][0] \
116+
+ features[b][c][idx[b][n][1]] * weight[b][n][1] \
117+
+ features[b][c][idx[b][n][2]] * weight[b][n][2]
118+
return output
119+
120+
121+
def three_interpolate_backward_gloden(grad_output, idx, weight, features):
122+
bs, cs, ns = grad_output.shape
123+
ms = features.shape[2]
124+
125+
dtype = features.dtype
126+
if dtype == np.float16:
127+
features = features.astype(np.float32)
128+
weight = weight.astype(np.float32)
129+
130+
grad_point = np.zeros((bs, cs, ms), dtype=np.float)
131+
for b in range(bs):
132+
for c in range(cs):
133+
for n in range(ns):
134+
grad_point[b][c][idx[b][n][0]] = \
135+
grad_point[b][c][idx[b][n][0]] + \
136+
grad_output[b][c][n] * weight[b][n][0]
137+
grad_point[b][c][idx[b][n][1]] = \
138+
grad_point[b][c][idx[b][n][1]] + \
139+
grad_output[b][c][n] * weight[b][n][1]
140+
grad_point[b][c][idx[b][n][2]] = \
141+
grad_point[b][c][idx[b][n][2]] + \
142+
grad_output[b][c][n] * weight[b][n][2]
143+
return grad_point
144+
145+
146+
def torch_type_trans(dtype):
147+
if dtype == torch.half:
148+
return np.float16
149+
elif dtype == torch.float:
150+
return np.float32
151+
else:
152+
return np.float64
153+
154+
155+
@pytest.mark.parametrize('dtype', [torch.half, torch.float])
156+
@pytest.mark.parametrize('device', [
157+
pytest.param(
158+
'npu',
159+
marks=pytest.mark.skipif(
160+
not IS_NPU_AVAILABLE, reason='requires NPU support'))
161+
])
162+
@pytest.mark.parametrize('shape', [(2, 5, 6, 6), (10, 10, 10, 10),
163+
(20, 21, 13, 4), (2, 10, 2, 18),
164+
(10, 602, 910, 200), (600, 100, 300, 101)])
165+
def test_three_interpolate_npu_dynamic_shape(dtype, device, shape):
166+
bs = shape[0]
167+
cs = shape[1]
168+
ms = shape[2]
169+
ns = shape[3]
170+
171+
features = np.random.uniform(-10.0, 10.0,
172+
(bs, cs, ms)).astype(torch_type_trans(dtype))
173+
idx = np.random.uniform(0, ms, size=(bs, ns, 3), dtype=np.int32)
174+
weight = np.random.uniform(-10.0,
175+
10.0 (bs, ns,
176+
3)).astype(torch_type_trans(dtype))
177+
178+
features_npu = torch.tensor(features, dtype=dtype).to(device)
179+
idx_npu = torch.tensor(idx, dtype=torch.int32).to(device)
180+
weight_npu = torch.tensor(weight, dtype=dtype).to(device)
181+
182+
expected_output = three_interpolate_forward_gloden(features, idx, weight)
183+
output = three_interpolate(features_npu, idx_npu, weight_npu)
184+
assert np.allclose(output.cpu().numpy(), expected_output, 1e-3, 1e-4)

0 commit comments

Comments
 (0)