diff --git a/mmcv/ops/csrc/common/musa/points_in_boxes_musa_kernel.muh b/mmcv/ops/csrc/common/musa/points_in_boxes_musa_kernel.muh new file mode 100644 index 0000000000..e20ac68c76 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/points_in_boxes_musa_kernel.muh @@ -0,0 +1,91 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef POINT_IN_BOXES_MUSA_KERNEL_MUH +#define POINT_IN_BOXES_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +template +__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz, + T &local_x, T &local_y) { + T cosa = cos(-rz), sina = sin(-rz); + local_x = shift_x * cosa + shift_y * (-sina); + local_y = shift_x * sina + shift_y * cosa; +} + +template +__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x, + T &local_y) { + // param pt: (x, y, z) + // param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate, + // cz in the bottom center + T x = pt[0], y = pt[1], z = pt[2]; + T cx = box3d[0], cy = box3d[1], cz = box3d[2]; + T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6]; + cz += z_size / + 2.0; // shift to the center since cz in box3d is the bottom center + + if (fabsf(z - cz) > z_size / 2.0) return 0; + lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y); + float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) & + (local_y > -y_size / 2.0) & (local_y < y_size / 2.0); + return in_flag; +} + +template +__global__ void points_in_boxes_part_forward_musa_kernel( + int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts, + int *box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box DO NOT overlaps params pts: + // (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points: + // (B, npoints), default -1 + + int bs_idx = blockIdx.y; + MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) { + if (bs_idx >= batch_size) return; + + boxes += bs_idx * boxes_num * 7; + pts += bs_idx * pts_num * 3 + pt_idx * 3; + box_idx_of_points += bs_idx * pts_num + pt_idx; + + T local_x = 0, local_y = 0; + int cur_in_flag = 0; + for (int k = 0; k < boxes_num; k++) { + cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y); + if (cur_in_flag) { + box_idx_of_points[0] = k; + break; + } + } + } +} + +template +__global__ void points_in_boxes_all_forward_musa_kernel( + int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts, + int *box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box DO NOT overlaps params pts: + // (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points: + // (B, npoints), default -1 + + int bs_idx = blockIdx.y; + MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) { + if (bs_idx >= batch_size) return; + + boxes += bs_idx * boxes_num * 7; + pts += bs_idx * pts_num * 3 + pt_idx * 3; + box_idx_of_points += bs_idx * pts_num * boxes_num + pt_idx * boxes_num; + + T local_x = 0, local_y = 0; + for (int k = 0; k < boxes_num; k++) { + const int cur_in_flag = + check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y); + if (cur_in_flag) { + box_idx_of_points[k] = 1; + } + } + } +} + +#endif // POINT_IN_BOXES_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/points_in_polygons_musa_kernel.muh b/mmcv/ops/csrc/common/musa/points_in_polygons_musa_kernel.muh new file mode 100644 index 0000000000..714c889bd6 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/points_in_polygons_musa_kernel.muh @@ -0,0 +1,75 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef POINTS_IN_POLYGONS_MUSA_KERNEL_MUH +#define POINTS_IN_POLYGONS_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +struct point { + float x, y; +}; + +template +__global__ void points_in_polygons_forward_musa_kernel( + const int nthreads, const scalar_t *vertex1, const scalar_t *vertex2, + const int rows, const int cols, scalar_t *inside_flag) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + int row = index / cols; + int col = index % cols; + + const scalar_t *offset_vertex1 = vertex1 + row * 2; + const scalar_t *offset_vertex2 = vertex2 + col * 8; + + point point_[1]; + point polygon[4]; + + point_[0].x = offset_vertex1[0]; + point_[0].y = offset_vertex1[1]; + + polygon[0].x = offset_vertex2[0]; + polygon[0].y = offset_vertex2[1]; + polygon[1].x = offset_vertex2[2]; + polygon[1].y = offset_vertex2[3]; + polygon[2].x = offset_vertex2[4]; + polygon[2].y = offset_vertex2[5]; + polygon[3].x = offset_vertex2[6]; + polygon[3].y = offset_vertex2[7]; + + int nCross = 0; + int i, j; + float sx, sy, tx, ty, px, py, x; + for (i = 0, j = 3; i < 4; j = i, i++) { + sx = polygon[i].x; + sy = polygon[i].y; + tx = polygon[j].x; + ty = polygon[j].y; + + px = point_[0].x; + py = point_[0].y; + + if (py < min(sy, ty)) continue; + if (py > max(sy, ty)) continue; + + if ((sx == px && sy == py) || (tx == px && ty == py)) { + break; + } else { + if ((sy < py && ty >= py) || (sy >= py && ty < py)) { + x = sx + (py - sy) * (tx - sx) / (ty - sy); + if (x == px) { + break; + } + if (x > px) { + nCross++; + } + } + } + } + if (nCross % 2 == 1) { + inside_flag[index] = 1.0; + } else { + inside_flag[index] = 0.0; + } + return; + } +} + +#endif // POINTS_IN_POLYGONS_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/prroi_pool_musa_kernel.muh b/mmcv/ops/csrc/common/musa/prroi_pool_musa_kernel.muh new file mode 100644 index 0000000000..9394b7e89d --- /dev/null +++ b/mmcv/ops/csrc/common/musa/prroi_pool_musa_kernel.muh @@ -0,0 +1,377 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Modified from +// https://github.com/vacancy/PreciseRoIPooling/blob/master/src/prroi_pooling_gpu_impl.cu +// Distributed under terms of the MIT license. +#ifndef PRROI_POOL_MUSA_KERNEL_MUH +#define PRROI_POOL_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +template +__device__ static __forceinline__ T PrRoIPoolingGetData(const T *data, + const int h, + const int w, + const int height, + const int width) { + bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); + T retVal = overflow ? 0.0f : data[h * width + w]; + return retVal; +} + +template +__device__ static __forceinline__ T PrRoIPoolingGetCoeff(T dh, T dw) { + return (1.0f - abs(dh)) * (1.0f - abs(dw)); +} + +template +__device__ static __forceinline__ T PrRoIPoolingSingleCoorIntegral(T s, T t, + T c1, T c2) { + return 0.5 * (t * t - s * s) * (c2 - c1) + (t - s) * c1; +} + +template +__device__ static T PrRoIPoolingInterpolation(const T *data, const T h, + const T w, const int height, + const int width) { + T retVal = 0.0f; + int h1 = floorf(h); + int w1 = floorf(w); + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - T(h1), w - T(w1)); + h1 = floorf(h) + 1; + w1 = floorf(w); + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - T(h1), w - T(w1)); + h1 = floorf(h); + w1 = floorf(w) + 1; + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - T(h1), w - T(w1)); + h1 = floorf(h) + 1; + w1 = floorf(w) + 1; + retVal += PrRoIPoolingGetData(data, h1, w1, height, width) * + PrRoIPoolingGetCoeff(h - T(h1), w - T(w1)); + return retVal; +} + +template +__device__ static T PrRoIPoolingMatCalculation(const T *this_data, + const int s_h, const int s_w, + const int e_h, const int e_w, + const T y0, const T x0, + const T y1, const T x1, + const int h0, const int w0) { + T alpha, beta, lim_alpha, lim_beta, tmp; + T sum_out = 0; + + alpha = x0 - T(s_w); + beta = y0 - T(s_h); + lim_alpha = x1 - T(s_w); + lim_beta = y1 - T(s_h); + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, s_h, s_w, h0, w0) * tmp; + + alpha = T(e_w) - x1; + lim_alpha = T(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, s_h, e_w, h0, w0) * tmp; + + alpha = x0 - T(s_w); + beta = T(e_h) - y1; + lim_alpha = x1 - T(s_w); + lim_beta = T(e_h) - y0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, e_h, s_w, h0, w0) * tmp; + + alpha = T(e_w) - x1; + lim_alpha = T(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + sum_out += PrRoIPoolingGetData(this_data, e_h, e_w, h0, w0) * tmp; + + return sum_out; +} + +template +__device__ static void PrRoIPoolingDistributeDiff(T *diff, const T top_diff, + const int h, const int w, + const int height, + const int width, + const T coeff) { + bool overflow = (h < 0) || (w < 0) || (h >= height) || (w >= width); + if (!overflow) atomicAdd(diff + h * width + w, top_diff * coeff); +} + +template +__device__ static void PrRoIPoolingMatDistributeDiff( + T *diff, const T top_diff, const int s_h, const int s_w, const int e_h, + const int e_w, const T y0, const T x0, const T y1, const T x1, const int h0, + const int w0) { + T alpha, beta, lim_alpha, lim_beta, tmp; + + alpha = x0 - T(s_w); + beta = y0 - T(s_h); + lim_alpha = x1 - T(s_w); + lim_beta = y1 - T(s_h); + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, s_h, s_w, h0, w0, tmp); + + alpha = T(e_w) - x1; + lim_alpha = T(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, s_h, e_w, h0, w0, tmp); + + alpha = x0 - T(s_w); + beta = T(e_h) - y1; + lim_alpha = x1 - T(s_w); + lim_beta = T(e_h) - y0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, e_h, s_w, h0, w0, tmp); + + alpha = T(e_w) - x1; + lim_alpha = T(e_w) - x0; + tmp = (lim_alpha - 0.5f * lim_alpha * lim_alpha - alpha + + 0.5f * alpha * alpha) * + (lim_beta - 0.5f * lim_beta * lim_beta - beta + 0.5f * beta * beta); + PrRoIPoolingDistributeDiff(diff, top_diff, e_h, e_w, h0, w0, tmp); +} + +template +__global__ void prroi_pool_forward_musa_kernel( + const int nthreads, const T *input, const T *rois, T *output, + const int pooled_height, const int pooled_width, const T spatial_scale, + const int channels, const int height, const int width) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T *offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + T roi_x1 = offset_rois[1] * spatial_scale; + T roi_y1 = offset_rois[2] * spatial_scale; + T roi_x2 = offset_rois[3] * spatial_scale; + T roi_y2 = offset_rois[4] * spatial_scale; + + T roi_width = max(roi_x2 - roi_x1, ((T)0.0)); + T roi_height = max(roi_y2 - roi_y1, ((T)0.0)); + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + const T *this_data = + input + (roi_batch_ind * channels + c) * height * width; + T *this_out = output + index; + + T bin_x1 = roi_x1 + bin_size_w * pw; + T bin_y1 = roi_y1 + bin_size_h * ph; + T bin_x2 = bin_x1 + bin_size_w; + T bin_y2 = bin_y1 + bin_size_h; + + T bin_size = max(T(0.0), bin_size_w * bin_size_h); + if (bin_size == 0) { + *this_out = 0; + continue; + } + + T sum_out = 0; + + int start_x, start_y, end_x, end_y; + + start_x = floorf(bin_x1); + end_x = ceilf(bin_x2); + start_y = floorf(bin_y1); + end_y = ceilf(bin_y2); + + for (int bin_x = start_x; bin_x < end_x; ++bin_x) + for (int bin_y = start_y; bin_y < end_y; ++bin_y) + sum_out += PrRoIPoolingMatCalculation( + this_data, bin_y, bin_x, bin_y + 1, bin_x + 1, + max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)), + min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height, + width); + *this_out = sum_out / bin_size; + } +} + +template +__global__ void prroi_pool_backward_musa_kernel( + const int nthreads, const T *grad_output, const T *rois, T *grad_input, + const int pooled_height, const int pooled_width, const T spatial_scale, + const int channels, const int height, const int width) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + auto rois_cur = rois + n * 5; + + int roi_batch_ind = rois_cur[0]; + T roi_x1 = rois_cur[1] * spatial_scale; + T roi_y1 = rois_cur[2] * spatial_scale; + T roi_x2 = rois_cur[3] * spatial_scale; + T roi_y2 = rois_cur[4] * spatial_scale; + + T roi_width = max(roi_x2 - roi_x1, (T)0); + T roi_height = max(roi_y2 - roi_y1, (T)0); + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + const T *this_out_grad = grad_output + index; + T *this_data_grad = + grad_input + (roi_batch_ind * channels + c) * height * width; + + T bin_x1 = roi_x1 + bin_size_w * pw; + T bin_y1 = roi_y1 + bin_size_h * ph; + T bin_x2 = bin_x1 + bin_size_w; + T bin_y2 = bin_y1 + bin_size_h; + + T bin_size = max(T(0.0), bin_size_w * bin_size_h); + + T sum_out = bin_size == T(0) ? T(0) : *this_out_grad / bin_size; + + int start_x, start_y, end_x, end_y; + + start_x = floorf(bin_x1); + end_x = ceilf(bin_x2); + start_y = floorf(bin_y1); + end_y = ceilf(bin_y2); + + for (int bin_x = start_x; bin_x < end_x; ++bin_x) + for (int bin_y = start_y; bin_y < end_y; ++bin_y) + PrRoIPoolingMatDistributeDiff( + this_data_grad, sum_out, bin_y, bin_x, bin_y + 1, bin_x + 1, + max(bin_y1, T(bin_y)), max(bin_x1, T(bin_x)), + min(bin_y2, T(bin_y) + 1.0f), min(bin_x2, T(bin_x + 1.0f)), height, + width); + } +} + +template +__global__ void prroi_pool_coor_backward_musa_kernel( + const int nthreads, const T *output, const T *grad_output, const T *input, + const T *rois, T *grad_rois, const int pooled_height, + const int pooled_width, const T spatial_scale, const int channels, + const int height, const int width) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + auto rois_cur = rois + n * 5; + + int roi_batch_ind = rois_cur[0]; + T roi_x1 = rois_cur[1] * spatial_scale; + T roi_y1 = rois_cur[2] * spatial_scale; + T roi_x2 = rois_cur[3] * spatial_scale; + T roi_y2 = rois_cur[4] * spatial_scale; + + T roi_width = max(roi_x2 - roi_x1, (T)0); + T roi_height = max(roi_y2 - roi_y1, (T)0); + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + const T output_grad_val = grad_output[index]; + const T *this_input_data = + input + (roi_batch_ind * channels + c) * height * width; + const T output_val = output[index]; + T *this_rois_grad = grad_rois + n * 5; + + T bin_x1 = roi_x1 + bin_size_w * pw; + T bin_y1 = roi_y1 + bin_size_h * ph; + T bin_x2 = bin_x1 + bin_size_w; + T bin_y2 = bin_y1 + bin_size_h; + + T bin_size = max(T(0.0), bin_size_w * bin_size_h); + + T sum_out = bin_size == T(0) ? T(0) : output_grad_val / bin_size; + + // WARNING: to be discussed + if (sum_out == 0) continue; + + int start_x, start_y, end_x, end_y; + + start_x = floorf(bin_x1); + end_x = ceilf(bin_x2); + start_y = floorf(bin_y1); + end_y = ceilf(bin_y2); + + T grad_x1_y = 0, grad_x2_y = 0, grad_x_y1 = 0, grad_x_y2 = 0; + for (int bin_y = start_y; bin_y < end_y; ++bin_y) { + grad_x1_y += PrRoIPoolingSingleCoorIntegral( + max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y, + PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x1, + height, width), + PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x1, + height, width)); + + grad_x2_y += PrRoIPoolingSingleCoorIntegral( + max(bin_y1, T(bin_y)) - bin_y, min(bin_y2, T(bin_y + 1)) - bin_y, + PrRoIPoolingInterpolation(this_input_data, float(bin_y), bin_x2, + height, width), + PrRoIPoolingInterpolation(this_input_data, float(bin_y + 1), bin_x2, + height, width)); + } + + for (int bin_x = start_x; bin_x < end_x; ++bin_x) { + grad_x_y1 += PrRoIPoolingSingleCoorIntegral( + max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x, + PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x), + height, width), + PrRoIPoolingInterpolation(this_input_data, bin_y1, float(bin_x + 1), + height, width)); + + grad_x_y2 += PrRoIPoolingSingleCoorIntegral( + max(bin_x1, T(bin_x)) - bin_x, min(bin_x2, T(bin_x + 1)) - bin_x, + PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x), + height, width), + PrRoIPoolingInterpolation(this_input_data, bin_y2, float(bin_x + 1), + height, width)); + } + + T partial_x1 = -grad_x1_y + (bin_y2 - bin_y1) * output_val; + T partial_y1 = -grad_x_y1 + (bin_x2 - bin_x1) * output_val; + T partial_x2 = grad_x2_y - (bin_y2 - bin_y1) * output_val; + T partial_y2 = grad_x_y2 - (bin_x2 - bin_x1) * output_val; + + partial_x1 = partial_x1 / bin_size * spatial_scale; + partial_x2 = partial_x2 / bin_size * spatial_scale; + partial_y1 = partial_y1 / bin_size * spatial_scale; + partial_y2 = partial_y2 / bin_size * spatial_scale; + + // (index, x1, y1, x2, y2) + this_rois_grad[0] = 0; + atomicAdd(this_rois_grad + 1, + (partial_x1 * (1.0f - T(pw) / pooled_width) + + partial_x2 * (1.0f - T(pw + 1) / pooled_width)) * + output_grad_val); + atomicAdd(this_rois_grad + 2, + (partial_y1 * (1.0f - T(ph) / pooled_height) + + partial_y2 * (1.0f - T(ph + 1) / pooled_height)) * + output_grad_val); + atomicAdd(this_rois_grad + 3, (partial_x2 * T(pw + 1) / pooled_width + + partial_x1 * T(pw) / pooled_width) * + output_grad_val); + atomicAdd(this_rois_grad + 4, (partial_y2 * T(ph + 1) / pooled_height + + partial_y1 * T(ph) / pooled_height) * + output_grad_val); + } +} + +#endif // ROI_POOL_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/psamask_musa_kernel.muh b/mmcv/ops/csrc/common/musa/psamask_musa_kernel.muh new file mode 100644 index 0000000000..75091ea4b1 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/psamask_musa_kernel.muh @@ -0,0 +1,137 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef PSAMASK_MUSA_KERNEL_MUH +#define PSAMASK_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +// MUSA: grid stride looping +#ifndef MUSA_KERNEL_LOOP +#define MUSA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) +#endif + +template +__global__ void psamask_collect_forward_musa( + const int nthreads, const int h_feature, const int w_feature, + const int h_mask, const int w_mask, const int half_h_mask, + const int half_w_mask, const T* mask_data, T* buffer_data) { + MUSA_KERNEL_LOOP(index, nthreads) { + const int w = index % w_feature; + const int h = (index / w_feature) % h_feature; + const int n = index / w_feature / h_feature; + // effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed + const int hstart = max(0, half_h_mask - h); + const int hend = min(h_mask, h_feature + half_h_mask - h); + const int wstart = max(0, half_w_mask - w); + const int wend = min(w_mask, w_feature + half_w_mask - w); + // (hidx, widx ) with mask-indexed + // (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed + for (int hidx = hstart; hidx < hend; hidx++) { + for (int widx = wstart; widx < wend; widx++) { + buffer_data[(n * h_feature * w_feature + + (hidx + h - half_h_mask) * w_feature + + (widx + w - half_w_mask)) * + h_feature * w_feature + + h * w_feature + w] = mask_data + [((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature + h) * + w_feature + + w]; + } + } + } +} + +template +__global__ void psamask_distribute_forward_musa( + const int nthreads, const int h_feature, const int w_feature, + const int h_mask, const int w_mask, const int half_h_mask, + const int half_w_mask, const T* mask_data, T* buffer_data) { + MUSA_KERNEL_LOOP(index, nthreads) { + const int w = index % w_feature; + const int h = (index / w_feature) % h_feature; + const int n = index / w_feature / h_feature; + // effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed + const int hstart = max(0, half_h_mask - h); + const int hend = min(h_mask, h_feature + half_h_mask - h); + const int wstart = max(0, half_w_mask - w); + const int wend = min(w_mask, w_feature + half_w_mask - w); + // (hidx, widx ) with mask-indexed + // (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed + for (int hidx = hstart; hidx < hend; hidx++) { + for (int widx = wstart; widx < wend; widx++) { + buffer_data[(n * h_feature * w_feature + h * w_feature + w) * + h_feature * w_feature + + (hidx + h - half_h_mask) * w_feature + + (widx + w - half_w_mask)] = mask_data + [((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature + h) * + w_feature + + w]; + } + } + } +} + +template +__global__ void psamask_collect_backward_musa( + const int nthreads, const int h_feature, const int w_feature, + const int h_mask, const int w_mask, const int half_h_mask, + const int half_w_mask, const T* buffer_diff, T* mask_diff) { + MUSA_KERNEL_LOOP(index, nthreads) { + const int w = index % w_feature; + const int h = (index / w_feature) % h_feature; + const int n = index / w_feature / h_feature; + // effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed + const int hstart = max(0, half_h_mask - h); + const int hend = min(h_mask, h_feature + half_h_mask - h); + const int wstart = max(0, half_w_mask - w); + const int wend = min(w_mask, w_feature + half_w_mask - w); + // (hidx, widx ) with mask-indexed + // (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed + for (int hidx = hstart; hidx < hend; hidx++) { + for (int widx = wstart; widx < wend; widx++) { + mask_diff[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature + + h) * + w_feature + + w] = buffer_diff[(n * h_feature * w_feature + + (hidx + h - half_h_mask) * w_feature + + (widx + w - half_w_mask)) * + h_feature * w_feature + + h * w_feature + w]; + } + } + } +} + +template +__global__ void psamask_distribute_backward_musa( + const int nthreads, const int h_feature, const int w_feature, + const int h_mask, const int w_mask, const int half_h_mask, + const int half_w_mask, const T* buffer_diff, T* mask_diff) { + MUSA_KERNEL_LOOP(index, nthreads) { + const int w = index % w_feature; + const int h = (index / w_feature) % h_feature; + const int n = index / w_feature / h_feature; + // effective mask region : [hstart, hend) x [wstart, wend) with mask-indexed + const int hstart = max(0, half_h_mask - h); + const int hend = min(h_mask, h_feature + half_h_mask - h); + const int wstart = max(0, half_w_mask - w); + const int wend = min(w_mask, w_feature + half_w_mask - w); + // (hidx, widx ) with mask-indexed + // (hidx + h - half_h_mask, widx + w - half_w_mask) with feature-indexed + for (int hidx = hstart; hidx < hend; hidx++) { + for (int widx = wstart; widx < wend; widx++) { + mask_diff[((n * h_mask * w_mask + hidx * w_mask + widx) * h_feature + + h) * + w_feature + + w] = + buffer_diff[(n * h_feature * w_feature + h * w_feature + w) * + h_feature * w_feature + + (hidx + h - half_h_mask) * w_feature + + (widx + w - half_w_mask)]; + } + } + } +} + +#endif // PSAMASK_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/riroi_align_rotated_musa_kernel.muh b/mmcv/ops/csrc/common/musa/riroi_align_rotated_musa_kernel.muh new file mode 100644 index 0000000000..b5124798bc --- /dev/null +++ b/mmcv/ops/csrc/common/musa/riroi_align_rotated_musa_kernel.muh @@ -0,0 +1,238 @@ +// Modified from +// https://github.com/csuhan/ReDet/blob/master/mmdet/ops/riroi_align/src/riroi_align_kernel.cu +#ifndef RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH +#define RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH + +#include +#include "pytorch_musa_helper.hpp" + +/*** Forward ***/ +template +__global__ void riroi_align_rotated_forward_musa_kernel( + const int nthreads, const scalar_t *bottom_data, + const scalar_t *bottom_rois, const scalar_t spatial_scale, + const int num_samples, const bool clockwise, const int channels, + const int height, const int width, const int pooled_height, + const int pooled_width, const int num_orientations, scalar_t *top_data) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int o = (index / pooled_width / pooled_height) % num_orientations; + int c = + (index / pooled_width / pooled_height / num_orientations) % channels; + int n = index / pooled_width / pooled_height / num_orientations / channels; + + const scalar_t *offset_bottom_rois = bottom_rois + n * 6; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not using rounding; this implementation detail is critical + scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale; + scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale; + scalar_t roi_width = offset_bottom_rois[3] * spatial_scale; + scalar_t roi_height = offset_bottom_rois[4] * spatial_scale; + // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0; + scalar_t theta = offset_bottom_rois[5]; + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (scalar_t)1.); + roi_height = max(roi_height, (scalar_t)1.); + scalar_t bin_size_h = static_cast(roi_height) / + static_cast(pooled_height); + scalar_t bin_size_w = + static_cast(roi_width) / static_cast(pooled_width); + + // find aligned index + scalar_t ind_float = theta * num_orientations / (2 * M_PI); + int ind = floorf(ind_float); + scalar_t l_var = ind_float - (scalar_t)ind; + scalar_t r_var = 1.0 - l_var; + // correct start channel + ind = (ind + num_orientations) % num_orientations; + // rotated channel + int ind_rot = (o - ind + num_orientations) % num_orientations; + int ind_rot_plus = (ind_rot + 1 + num_orientations) % num_orientations; + const scalar_t *offset_bottom_data = + bottom_data + (roi_batch_ind * channels * num_orientations + + c * num_orientations + ind_rot) * + height * width; + + const scalar_t *offset_bottom_data_plus = + bottom_data + (roi_batch_ind * channels * num_orientations + + c * num_orientations + ind_rot_plus) * + height * width; + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (num_samples > 0) + ? num_samples + : ceilf(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (num_samples > 0) ? num_samples : ceilf(roi_width / pooled_width); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + if (clockwise) { + theta = -theta; // If clockwise, the angle needs to be reversed. + } + scalar_t roi_start_h = -roi_height / 2.0; + scalar_t roi_start_w = -roi_width / 2.0; + scalar_t cosscalar_theta = cos(theta); + scalar_t sinscalar_theta = sin(theta); + + // We do average (integral) pooling inside a bin + const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + scalar_t output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1 + const scalar_t yy = + roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const scalar_t xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta (counterclockwise) around the center and translate + scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h; + scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w; + + scalar_t val = bilinear_interpolate( + offset_bottom_data, height, width, y, x, index); + scalar_t val_plus = bilinear_interpolate( + offset_bottom_data_plus, height, width, y, x, index); + output_val += r_var * val + l_var * val_plus; + } + } + output_val /= count; + + top_data[index] = output_val; + } +} + +/*** Backward ***/ +template +__global__ void riroi_align_rotated_backward_musa_kernel( + const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois, + const scalar_t spatial_scale, const int num_samples, const bool clockwise, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, const int num_orientations, + scalar_t *bottom_diff) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int o = (index / pooled_width / pooled_height) % num_orientations; + int c = + (index / pooled_width / pooled_height / num_orientations) % channels; + int n = index / pooled_width / pooled_height / num_orientations / channels; + + const scalar_t *offset_bottom_rois = bottom_rois + n * 6; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not round + scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale; + scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale; + scalar_t roi_width = offset_bottom_rois[3] * spatial_scale; + scalar_t roi_height = offset_bottom_rois[4] * spatial_scale; + // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0; + scalar_t theta = offset_bottom_rois[5]; + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (scalar_t)1.); + roi_height = max(roi_height, (scalar_t)1.); + + scalar_t bin_size_h = static_cast(roi_height) / + static_cast(pooled_height); + scalar_t bin_size_w = + static_cast(roi_width) / static_cast(pooled_width); + + // find aligned index + scalar_t ind_float = theta * num_orientations / (2 * M_PI); + int ind = floorf(ind_float); + scalar_t l_var = ind_float - (scalar_t)ind; + scalar_t r_var = 1.0 - l_var; + // correct start channel + ind = (ind + num_orientations) % num_orientations; + // rotated channel + int ind_rot = (o - ind + num_orientations) % num_orientations; + int ind_rot_plus = (ind_rot + 1 + num_orientations) % num_orientations; + scalar_t *offset_bottom_diff = + bottom_diff + (roi_batch_ind * channels * num_orientations + + c * num_orientations + ind_rot) * + height * width; + scalar_t *offset_bottom_diff_plus = + bottom_diff + (roi_batch_ind * channels * num_orientations + + c * num_orientations + ind_rot_plus) * + height * width; + int top_offset = + (n * channels * num_orientations + c * num_orientations + o) * + pooled_height * pooled_width; + const scalar_t *offset_top_diff = top_diff + top_offset; + const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (num_samples > 0) + ? num_samples + : ceilf(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (num_samples > 0) ? num_samples : ceilf(roi_width / pooled_width); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + if (clockwise) { + theta = -theta; // If clockwise, the angle needs to be reversed. + } + scalar_t roi_start_h = -roi_height / 2.0; + scalar_t roi_start_w = -roi_width / 2.0; + scalar_t cosTheta = cos(theta); + scalar_t sinTheta = sin(theta); + + // We do average (integral) pooling inside a bin + const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1 + const scalar_t yy = + roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const scalar_t xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h; + scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w; + + scalar_t w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, + w4, x_low, x_high, y_low, + y_high, index); + + scalar_t g1 = top_diff_this_bin * w1 / count; + scalar_t g2 = top_diff_this_bin * w2 / count; + scalar_t g3 = top_diff_this_bin * w3 / count; + scalar_t g4 = top_diff_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(offset_bottom_diff + y_low * width + x_low, g1 * r_var); + atomicAdd(offset_bottom_diff + y_low * width + x_high, g2 * r_var); + atomicAdd(offset_bottom_diff + y_high * width + x_low, g3 * r_var); + atomicAdd(offset_bottom_diff + y_high * width + x_high, g4 * r_var); + + atomicAdd(offset_bottom_diff_plus + y_low * width + x_low, + g1 * l_var); + atomicAdd(offset_bottom_diff_plus + y_low * width + x_high, + g2 * l_var); + atomicAdd(offset_bottom_diff_plus + y_high * width + x_low, + g3 * l_var); + atomicAdd(offset_bottom_diff_plus + y_high * width + x_high, + g4 * l_var); + + } // if + } // ix + } // iy + } // MUSA_1D_KERNEL_LOOP +} // RiRoIAlignBackward + +#endif // RIROI_ALIGN_ROTATED_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/roi_align_musa_kernel.muh b/mmcv/ops/csrc/common/musa/roi_align_musa_kernel.muh new file mode 100644 index 0000000000..afbc1de686 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/roi_align_musa_kernel.muh @@ -0,0 +1,205 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ROI_ALIGN_MUSA_KERNEL_MUH +#define ROI_ALIGN_MUSA_KERNEL_MUH + +#include +#include "pytorch_musa_helper.hpp" + + +/*** Forward ***/ +template +__global__ void roi_align_forward_musa_kernel( + const int nthreads, const T* input, const T* rois, T* output, T* argmax_y, + T* argmax_x, const int pooled_height, const int pooled_width, + const T spatial_scale, const int sampling_ratio, + const int pool_mode, // 0 - max pool, 1 - avg pool + const bool aligned, const int channels, const int height, const int width) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { // for backward-compatibility only + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_w = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_width / pooled_width)); + + if (pool_mode == 0) { + // We do max pooling inside a bin + T maxval = -FLT_MAX; + T maxidx_y = -1.f, maxidx_x = -1.f; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = + bilinear_interpolate(offset_input, height, width, y, x, index); + if (val > maxval) { + maxval = val; + maxidx_y = y; + maxidx_x = x; + } + } + } + output[index] = maxval; + argmax_y[index] = maxidx_y; + argmax_x[index] = maxidx_x; + } else if (pool_mode == 1) { + // We do average pooling inside a bin + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = + bilinear_interpolate(offset_input, height, width, y, x, index); + output_val += val; + } + } + output[index] = output_val / count; + } + } +} + +/*** Backward ***/ +template +__global__ void roi_align_backward_musa_kernel( + const int nthreads, const T* grad_output, const T* rois, const T* argmax_y, + const T* argmax_x, T* grad_input, const int pooled_height, + const int pooled_width, const T spatial_scale, const int sampling_ratio, + const int pool_mode, // 0 - max pool, 1 - avg pool + const bool aligned, const int channels, const int height, const int width) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T grad_output_this_bin = grad_output[index]; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + if (pool_mode == 0) { + T y = argmax_y[index], x = argmax_x[index]; + if (y != -1.f) { + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, + x_low, x_high, y_low, y_high, index); + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(offset_grad_input + y_low * width + x_low, + grad_output_this_bin * w1); + atomicAdd(offset_grad_input + y_low * width + x_high, + grad_output_this_bin * w2); + atomicAdd(offset_grad_input + y_high * width + x_low, + grad_output_this_bin * w3); + atomicAdd(offset_grad_input + y_high * width + x_high, + grad_output_this_bin * w4); + } + } + } else if (pool_mode == 1) { + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { // for backward-compatibility only + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_w = + (sampling_ratio > 0) + ? sampling_ratio + : static_cast(ceilf(roi_width / pooled_width)); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4, + x_low, x_high, y_low, y_high, index); + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(offset_grad_input + y_low * width + x_low, + grad_output_this_bin * w1 / count); + atomicAdd(offset_grad_input + y_low * width + x_high, + grad_output_this_bin * w2 / count); + atomicAdd(offset_grad_input + y_high * width + x_low, + grad_output_this_bin * w3 / count); + atomicAdd(offset_grad_input + y_high * width + x_high, + grad_output_this_bin * w4 / count); + } + } + } + } + } +} + +#endif // ROI_ALIGN_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/roi_align_rotated_musa_kernel.muh b/mmcv/ops/csrc/common/musa/roi_align_rotated_musa_kernel.muh new file mode 100644 index 0000000000..76249a1229 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/roi_align_rotated_musa_kernel.muh @@ -0,0 +1,194 @@ +// Modified from +// https://github.com/facebookresearch/detectron2/tree/master/detectron2/layers/csrc/ROIAlignRotated +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +#ifndef ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH +#define ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH + +#include +#include "pytorch_musa_helper.hpp" + +/*** Forward ***/ +template +__global__ void roi_align_rotated_forward_musa_kernel( + const int nthreads, const scalar_t *bottom_data, + const scalar_t *bottom_rois, const scalar_t spatial_scale, + const int sampling_ratio, const bool aligned, const bool clockwise, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, scalar_t *top_data) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const scalar_t *offset_bottom_rois = bottom_rois + n * 6; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not using rounding; this implementation detail is critical + scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0; + scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset; + scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset; + scalar_t roi_width = offset_bottom_rois[3] * spatial_scale; + scalar_t roi_height = offset_bottom_rois[4] * spatial_scale; + // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0; + scalar_t theta = offset_bottom_rois[5]; + if (clockwise) { + theta = -theta; // If clockwise, the angle needs to be reversed. + } + if (!aligned) { // for backward-compatibility only + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (scalar_t)1.); + roi_height = max(roi_height, (scalar_t)1.); + } + scalar_t bin_size_h = static_cast(roi_height) / + static_cast(pooled_height); + scalar_t bin_size_w = + static_cast(roi_width) / static_cast(pooled_width); + + const scalar_t *offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceilf(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + scalar_t roi_start_h = -roi_height / 2.0; + scalar_t roi_start_w = -roi_width / 2.0; + scalar_t cosscalar_theta = cos(theta); + scalar_t sinscalar_theta = sin(theta); + + // We do average (integral) pooling inside a bin + const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + scalar_t output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1 + const scalar_t yy = + roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const scalar_t xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta (counterclockwise) around the center and translate + scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h; + scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w; + + scalar_t val = bilinear_interpolate( + offset_bottom_data, height, width, y, x, index); + output_val += val; + } + } + output_val /= count; + + top_data[index] = output_val; + } +} + +/*** Backward ***/ +template +__global__ void roi_align_rotated_backward_musa_kernel( + const int nthreads, const scalar_t *top_diff, const scalar_t *bottom_rois, + const scalar_t spatial_scale, const int sampling_ratio, const bool aligned, + const bool clockwise, const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, scalar_t *bottom_diff) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const scalar_t *offset_bottom_rois = bottom_rois + n * 6; + int roi_batch_ind = offset_bottom_rois[0]; + + // Do not round + scalar_t offset = aligned ? (scalar_t)0.5 : (scalar_t)0.0; + scalar_t roi_center_w = offset_bottom_rois[1] * spatial_scale - offset; + scalar_t roi_center_h = offset_bottom_rois[2] * spatial_scale - offset; + scalar_t roi_width = offset_bottom_rois[3] * spatial_scale; + scalar_t roi_height = offset_bottom_rois[4] * spatial_scale; + // scalar_t theta = offset_bottom_rois[5] * M_PI / 180.0; + scalar_t theta = offset_bottom_rois[5]; + if (clockwise) { + theta = -theta; // If clockwise, the angle needs to be reversed. + } + if (!aligned) { // for backward-compatibility only + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (scalar_t)1.); + roi_height = max(roi_height, (scalar_t)1.); + } + scalar_t bin_size_h = static_cast(roi_height) / + static_cast(pooled_height); + scalar_t bin_size_w = + static_cast(roi_width) / static_cast(pooled_width); + + scalar_t *offset_bottom_diff = + bottom_diff + (roi_batch_ind * channels + c) * height * width; + + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const scalar_t *offset_top_diff = top_diff + top_offset; + const scalar_t top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceilf(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + scalar_t roi_start_h = -roi_height / 2.0; + scalar_t roi_start_w = -roi_width / 2.0; + scalar_t cosTheta = cos(theta); + scalar_t sinTheta = sin(theta); + + // We do average (integral) pooling inside a bin + const scalar_t count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1 + const scalar_t yy = + roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const scalar_t xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + scalar_t y = yy * cosTheta - xx * sinTheta + roi_center_h; + scalar_t x = yy * sinTheta + xx * cosTheta + roi_center_w; + + scalar_t w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, + w4, x_low, x_high, y_low, + y_high, index); + + scalar_t g1 = top_diff_this_bin * w1 / count; + scalar_t g2 = top_diff_this_bin * w2 / count; + scalar_t g3 = top_diff_this_bin * w3 / count; + scalar_t g4 = top_diff_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(offset_bottom_diff + y_low * width + x_low, g1); + atomicAdd(offset_bottom_diff + y_low * width + x_high, g2); + atomicAdd(offset_bottom_diff + y_high * width + x_low, g3); + atomicAdd(offset_bottom_diff + y_high * width + x_high, g4); + } // if + } // ix + } // iy + } // MUSA_1D_KERNEL_LOOP +} // RoIAlignBackward + +#endif // ROI_ALIGN_ROTATED_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/roi_pool_musa_kernel.muh b/mmcv/ops/csrc/common/musa/roi_pool_musa_kernel.muh new file mode 100644 index 0000000000..ec7738d2c4 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/roi_pool_musa_kernel.muh @@ -0,0 +1,89 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ROI_POOL_MUSA_KERNEL_MUH +#define ROI_POOL_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +template +__global__ void roi_pool_forward_musa_kernel( + const int nthreads, const T* input, const T* rois, T* output, int* argmax, + const int pooled_height, const int pooled_width, const T spatial_scale, + const int channels, const int height, const int width) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + // calculate the roi region on feature maps + T roi_x1 = offset_rois[1] * spatial_scale; + T roi_y1 = offset_rois[2] * spatial_scale; + T roi_x2 = (offset_rois[3] + 1) * spatial_scale; + T roi_y2 = (offset_rois[4] + 1) * spatial_scale; + + // force malformed rois to be 1x1 + T roi_w = roi_x2 - roi_x1; + T roi_h = roi_y2 - roi_y1; + if (roi_w <= 0 || roi_h <= 0) continue; + + T bin_size_w = roi_w / static_cast(pooled_width); + T bin_size_h = roi_h / static_cast(pooled_height); + + // the corresponding bin region + int bin_x1 = floorf(static_cast(pw) * bin_size_w + roi_x1); + int bin_y1 = floorf(static_cast(ph) * bin_size_h + roi_y1); + int bin_x2 = ceilf(static_cast(pw + 1) * bin_size_w + roi_x1); + int bin_y2 = ceilf(static_cast(ph + 1) * bin_size_h + roi_y1); + + // add roi offsets and clip to input boundaries + bin_x1 = min(max(bin_x1, 0), width); + bin_y1 = min(max(bin_y1, 0), height); + bin_x2 = min(max(bin_x2, 0), width); + bin_y2 = min(max(bin_y2, 0), height); + bool is_empty = (bin_y2 <= bin_y1) || (bin_x2 <= bin_x1); + + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + // Define an empty pooling region to be zero + // If nothing is pooled, argmax = -1 causes nothing to be backprop'd + T max_val = is_empty ? 0 : -FLT_MAX; + int max_idx = -1; + for (int h = bin_y1; h < bin_y2; ++h) { + for (int w = bin_x1; w < bin_x2; ++w) { + int offset = h * width + w; + if (offset_input[offset] > max_val) { + max_val = offset_input[offset]; + max_idx = offset; + } + } + } + output[index] = max_val; + if (argmax != NULL) argmax[index] = max_idx; + } +} + +template +__global__ void roi_pool_backward_musa_kernel( + const int nthreads, const T* grad_output, const T* rois, const int* argmax, + T* grad_input, const int pooled_height, const int pooled_width, + const int channels, const int height, const int width) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c) is an element in the pooled output + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + int roi_batch_ind = rois[n * 5]; + T* grad_input_offset = + grad_input + ((roi_batch_ind * channels + c) * height * width); + int argmax_index = argmax[index]; + + if (argmax_index != -1) { + atomicAdd(grad_input_offset + argmax_index, grad_output[index]); + } + } +} + +#endif // ROI_POOL_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/roiaware_pool3d_musa_kernel.muh b/mmcv/ops/csrc/common/musa/roiaware_pool3d_musa_kernel.muh new file mode 100644 index 0000000000..d6de6a01c9 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/roiaware_pool3d_musa_kernel.muh @@ -0,0 +1,256 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ROIAWARE_POOL3D_MUSA_KERNEL_MUH +#define ROIAWARE_POOL3D_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +template +__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz, + T &local_x, T &local_y) { + T cosa = cos(-rz), sina = sin(-rz); + local_x = shift_x * cosa + shift_y * (-sina); + local_y = shift_x * sina + shift_y * cosa; +} + +template +__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x, + T &local_y) { + // param pt: (x, y, z) + // param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate, + // cz in the bottom center + T x = pt[0], y = pt[1], z = pt[2]; + T cx = box3d[0], cy = box3d[1], cz = box3d[2]; + T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6]; + cz += z_size / + 2.0; // shift to the center since cz in box3d is the bottom center + + if (fabsf(z - cz) > z_size / 2.0) return 0; + lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y); + float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) & + (local_y > -y_size / 2.0) & (local_y < y_size / 2.0); + return in_flag; +} + +template +__global__ void generate_pts_mask_for_box3d(int boxes_num, int pts_num, + int out_x, int out_y, int out_z, + const T *rois, const T *pts, + int *pts_mask) { + // params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate params pts: (npoints, 3) [x, y, z] params pts_mask: (N, + // npoints): -1 means point does not in this box, otherwise: encode (x_idxs, + // y_idxs, z_idxs) by binary bit + int box_idx = blockIdx.y; + MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) { + if (box_idx >= boxes_num) return; + + pts += pt_idx * 3; + rois += box_idx * 7; + pts_mask += box_idx * pts_num + pt_idx; + + T local_x = 0, local_y = 0; + int cur_in_flag = check_pt_in_box3d(pts, rois, local_x, local_y); + + pts_mask[0] = -1; + if (cur_in_flag > 0) { + T local_z = pts[2] - rois[2]; + T x_size = rois[3], y_size = rois[4], z_size = rois[5]; + + T x_res = x_size / out_x; + T y_res = y_size / out_y; + T z_res = z_size / out_z; + + unsigned int x_idx = int((local_x + x_size / 2) / x_res); + unsigned int y_idx = int((local_y + y_size / 2) / y_res); + unsigned int z_idx = int(local_z / z_res); + + x_idx = min(max(x_idx, 0), out_x - 1); + y_idx = min(max(y_idx, 0), out_y - 1); + z_idx = min(max(z_idx, 0), out_z - 1); + + unsigned int idx_encoding = (x_idx << 16) + (y_idx << 8) + z_idx; + + pts_mask[0] = idx_encoding; + } + } +} + +template +__global__ void collect_inside_pts_for_box3d(int boxes_num, int pts_num, + int max_pts_each_voxel, int out_x, + int out_y, int out_z, + const int *pts_mask, + T *pts_idx_of_voxels) { + // params pts_mask: (N, npoints) 0 or 1 + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + MUSA_1D_KERNEL_LOOP(box_idx, boxes_num) { + int max_num_pts = max_pts_each_voxel - 1; // index 0 is the counter + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel; + + for (int k = 0; k < pts_num; k++) { + if (pts_mask[box_idx * pts_num + k] != -1) { + unsigned int idx_encoding = pts_mask[box_idx * pts_num + k]; + unsigned int x_idx = (idx_encoding >> 16) & 0xFF; + unsigned int y_idx = (idx_encoding >> 8) & 0xFF; + unsigned int z_idx = idx_encoding & 0xFF; + unsigned int base_offset = x_idx * out_y * out_z * max_pts_each_voxel + + y_idx * out_z * max_pts_each_voxel + + z_idx * max_pts_each_voxel; + unsigned int cnt = pts_idx_of_voxels[base_offset]; + if (cnt < max_num_pts) { + pts_idx_of_voxels[base_offset + cnt + 1] = k; + pts_idx_of_voxels[base_offset]++; + } + } + } + } +} + +template +__global__ void roiaware_maxpool3d(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const T *pts_feature, + const int *pts_idx_of_voxels, + T *pooled_features, int *argmax) { + // params pts_feature: (npoints, C) + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), + // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C) + // params argmax: (N, out_x, out_y, out_z, C) + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) { + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels) return; + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + + offset_base * max_pts_each_voxel; + pooled_features += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + argmax += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + int argmax_idx = -1; + float max_val = -1e50; + + int total_pts = pts_idx_of_voxels[0]; + + for (int k = 1; k <= total_pts; k++) { + if (pts_feature[pts_idx_of_voxels[k] * channels + channel_idx] > + max_val) { + max_val = pts_feature[pts_idx_of_voxels[k] * channels + channel_idx]; + argmax_idx = pts_idx_of_voxels[k]; + } + } + + if (argmax_idx != -1) { + pooled_features[0] = max_val; + } + argmax[0] = argmax_idx; + } +} + +template +__global__ void roiaware_avgpool3d(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const T *pts_feature, + const int *pts_idx_of_voxels, + T *pooled_features) { + // params pts_feature: (npoints, C) + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel), + // index 0 is the counter params pooled_features: (N, out_x, out_y, out_z, C) + // params argmax: (N, out_x, out_y, out_z, C) + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) { + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels) return; + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + + offset_base * max_pts_each_voxel; + pooled_features += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + float sum_val = 0; + int total_pts = pts_idx_of_voxels[0]; + + for (int k = 1; k <= total_pts; k++) { + sum_val += pts_feature[pts_idx_of_voxels[k] * channels + channel_idx]; + } + + if (total_pts > 0) { + pooled_features[0] = sum_val / total_pts; + } + } +} + +template +__global__ void roiaware_maxpool3d_backward(int boxes_num, int channels, + int out_x, int out_y, int out_z, + const int *argmax, + const T *grad_out, T *grad_in) { + // params argmax: (N, out_x, out_y, out_z, C) + // params grad_out: (N, out_x, out_y, out_z, C) + // params grad_in: (npoints, C), return value + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) { + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels) return; + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + argmax += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + grad_out += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + if (argmax[0] == -1) return; + + atomicAdd(grad_in + argmax[0] * channels + channel_idx, grad_out[0] * 1); + } +} + +template +__global__ void roiaware_avgpool3d_backward(int boxes_num, int channels, + int out_x, int out_y, int out_z, + int max_pts_each_voxel, + const int *pts_idx_of_voxels, + const T *grad_out, T *grad_in) { + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + // params grad_out: (N, out_x, out_y, out_z, C) + // params grad_in: (npoints, C), return value + + int box_idx = blockIdx.z; + int channel_idx = blockIdx.y; + MUSA_1D_KERNEL_LOOP(voxel_idx_flat, out_x * out_y * out_z) { + int x_idx = voxel_idx_flat / (out_y * out_z); + int y_idx = (voxel_idx_flat - x_idx * (out_y * out_z)) / out_z; + int z_idx = voxel_idx_flat % out_z; + if (box_idx >= boxes_num || channel_idx >= channels) return; + + int offset_base = x_idx * out_y * out_z + y_idx * out_z + z_idx; + pts_idx_of_voxels += box_idx * out_x * out_y * out_z * max_pts_each_voxel + + offset_base * max_pts_each_voxel; + grad_out += box_idx * out_x * out_y * out_z * channels + + offset_base * channels + channel_idx; + + int total_pts = pts_idx_of_voxels[0]; + float cur_grad = 1 / fmaxf(float(total_pts), 1.0); + for (int k = 1; k <= total_pts; k++) { + atomicAdd(grad_in + pts_idx_of_voxels[k] * channels + channel_idx, + grad_out[0] * cur_grad); + } + } +} + +#endif // ROIAWARE_POOL3D_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/roipoint_pool3d_musa_kernel.muh b/mmcv/ops/csrc/common/musa/roipoint_pool3d_musa_kernel.muh new file mode 100644 index 0000000000..0a8d1ba69e --- /dev/null +++ b/mmcv/ops/csrc/common/musa/roipoint_pool3d_musa_kernel.muh @@ -0,0 +1,130 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef ROIPOINT_POOL3D_MUSA_KERNEL_MUH +#define ROIPOINT_POOL3D_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +template +__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz, + T &local_x, T &local_y) { + T cosa = cos(-rz), sina = sin(-rz); + local_x = shift_x * cosa + shift_y * (-sina); + local_y = shift_x * sina + shift_y * cosa; +} + +template +__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x, + T &local_y) { + // param pt: (x, y, z) + // param box3d: (cx, cy, cz, dx, dy, dz, rz) in LiDAR coordinate, cz in the + // bottom center + T x = pt[0], y = pt[1], z = pt[2]; + T cx = box3d[0], cy = box3d[1], cz = box3d[2]; + T dx = box3d[3], dy = box3d[4], dz = box3d[5], rz = box3d[6]; + cz += dz / 2.0; // shift to the center since cz in box3d is the bottom center + + if (fabsf(z - cz) > dz / 2.0) return 0; + lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y); + T in_flag = (local_x > -dx / 2.0) & (local_x < dx / 2.0) & + (local_y > -dy / 2.0) & (local_y < dy / 2.0); + return in_flag; +} + +template +__global__ void assign_pts_to_box3d(int batch_size, int pts_num, int boxes_num, + const T *xyz, const T *boxes3d, + int *pts_assign) { + // params xyz: (B, N, 3) + // params boxes3d: (B, M, 7) + // params pts_assign: (B, N, M): idx of the corresponding box3d, -1 means + // background points + int box_idx = blockIdx.y; + int bs_idx = blockIdx.z; + MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) { + if (box_idx >= boxes_num || bs_idx >= batch_size) return; + + int assign_idx = + bs_idx * pts_num * boxes_num + pt_idx * boxes_num + box_idx; + pts_assign[assign_idx] = 0; + + int box_offset = bs_idx * boxes_num * 7 + box_idx * 7; + int pt_offset = bs_idx * pts_num * 3 + pt_idx * 3; + + T local_x = 0, local_y = 0; + int cur_in_flag = check_pt_in_box3d(xyz + pt_offset, boxes3d + box_offset, + local_x, local_y); + pts_assign[assign_idx] = cur_in_flag; + } +} + +__global__ void get_pooled_idx(int batch_size, int pts_num, int boxes_num, + int sampled_pts_num, const int *pts_assign, + int *pts_idx, int *pooled_empty_flag) { + // params xyz: (B, N, 3) + // params pts_feature: (B, N, C) + // params pts_assign: (B, N) + // params pts_idx: (B, M, 512) + // params pooled_empty_flag: (B, M) + MUSA_1D_KERNEL_LOOP(boxes_idx, boxes_num) { + int bs_idx = blockIdx.y; + + int cnt = 0; + for (int k = 0; k < pts_num; k++) { + if (pts_assign[bs_idx * pts_num * boxes_num + k * boxes_num + + boxes_idx]) { + if (cnt < sampled_pts_num) { + pts_idx[bs_idx * boxes_num * sampled_pts_num + + boxes_idx * sampled_pts_num + cnt] = k; + cnt++; + } else + break; + } + } + + if (cnt == 0) { + pooled_empty_flag[bs_idx * boxes_num + boxes_idx] = 1; + } else if (cnt < sampled_pts_num) { + // duplicate same points for sampling + for (int k = cnt; k < sampled_pts_num; k++) { + int duplicate_idx = k % cnt; + int base_offset = + bs_idx * boxes_num * sampled_pts_num + boxes_idx * sampled_pts_num; + pts_idx[base_offset + k] = pts_idx[base_offset + duplicate_idx]; + } + } + } +} + +template +__global__ void roipoint_pool3d_forward( + int batch_size, int pts_num, int boxes_num, int feature_in_len, + int sampled_pts_num, const T *xyz, const int *pts_idx, const T *pts_feature, + T *pooled_features, int *pooled_empty_flag) { + // params xyz: (B, N, 3) + // params pts_idx: (B, M, 512) + // params pts_feature: (B, N, C) + // params pooled_features: (B, M, 512, 3+C) + // params pooled_empty_flag: (B, M) + int box_idx = blockIdx.y; + int bs_idx = blockIdx.z; + MUSA_1D_KERNEL_LOOP(sample_pt_idx, sampled_pts_num) { + if (box_idx >= boxes_num || bs_idx >= batch_size) return; + if (pooled_empty_flag[bs_idx * boxes_num + box_idx]) return; + + int temp_idx = bs_idx * boxes_num * sampled_pts_num + + box_idx * sampled_pts_num + sample_pt_idx; + int src_pt_idx = pts_idx[temp_idx]; + int dst_feature_offset = temp_idx * (3 + feature_in_len); + + for (int j = 0; j < 3; j++) + pooled_features[dst_feature_offset + j] = + xyz[bs_idx * pts_num * 3 + src_pt_idx * 3 + j]; + + int src_feature_offset = + bs_idx * pts_num * feature_in_len + src_pt_idx * feature_in_len; + memcpy(pooled_features + dst_feature_offset + 3, + pts_feature + src_feature_offset, feature_in_len * sizeof(T)); + } +} + +#endif // ROIPOINT_POOL3D_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/rotated_feature_align_musa_kernel.muh b/mmcv/ops/csrc/common/musa/rotated_feature_align_musa_kernel.muh new file mode 100644 index 0000000000..b1d8785ea4 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/rotated_feature_align_musa_kernel.muh @@ -0,0 +1,125 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu +#ifndef ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH +#define ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +template +__global__ void rotated_feature_align_forward_kernel( + const int nthreads, const int points, const scalar_t* bottom_data, + const scalar_t* best_bboxes, const scalar_t spatial_scale, + const int channels, const int height, const int width, scalar_t* top_data) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + int w = index % width; + int h = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + + const scalar_t* bbox_offset = + best_bboxes + ((n * height + h) * width + w) * 5; + scalar_t roi_y = bbox_offset[0] * spatial_scale; + scalar_t roi_x = bbox_offset[1] * spatial_scale; + + scalar_t px[5] = {roi_x, 0, 0, 0, 0}; + scalar_t py[5] = {roi_y, 0, 0, 0, 0}; + + if (points > 1) { + scalar_t roi_w = bbox_offset[2] * spatial_scale; + scalar_t roi_h = bbox_offset[3] * spatial_scale; + scalar_t roi_a = bbox_offset[4]; + + scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2; + scalar_t cosa = cosf(roi_a), sina = sinf(roi_a); + scalar_t wx = cosa * w_2, wy = sina * w_2; + scalar_t hx = -sina * h_2, hy = cosa * h_2; + + px[1] = roi_x + wx + hx; + py[1] = roi_y + wy + hy; + px[2] = roi_x - wx + hx; + py[2] = roi_y - wy + hy; + px[3] = roi_x - wx - hx; + py[3] = roi_y - wy - hy; + px[4] = roi_x + wx - hx; + py[4] = roi_y + wy - hy; + } + + const scalar_t* offset_bottom_data = + bottom_data + (n * channels + c) * height * width; + + scalar_t output_val = bottom_data[index]; + for (int i = 0; i < points; i++) { + output_val += bilinear_interpolate(offset_bottom_data, height, + width, py[i], px[i], i); + } + top_data[index] = output_val; + } +} + +template +__global__ void rotated_feature_align_backward_kernel( + const int nthreads, const int points, const scalar_t* top_diff, + const scalar_t* best_bboxes, const scalar_t spatial_scale, + const int channels, const int height, const int width, + scalar_t* bottom_diff) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + int w = index % width; + int h = (index / width) % height; + int c = (index / width / height) % channels; + int n = index / width / height / channels; + + const scalar_t* bbox_offset = + best_bboxes + ((n * height + h) * width + w) * 5; + scalar_t roi_y = bbox_offset[0] * spatial_scale; + scalar_t roi_x = bbox_offset[1] * spatial_scale; + + scalar_t px[5] = {roi_x, 0, 0, 0, 0}; + scalar_t py[5] = {roi_y, 0, 0, 0, 0}; + + if (points > 1) { + scalar_t roi_w = bbox_offset[2] * spatial_scale; + scalar_t roi_h = bbox_offset[3] * spatial_scale; + scalar_t roi_a = bbox_offset[4]; + + scalar_t w_2 = roi_w / 2, h_2 = roi_h / 2; + scalar_t cosa = cosf(roi_a), sina = sinf(roi_a); + scalar_t wx = cosa * w_2, wy = sina * w_2; + scalar_t hx = -sina * h_2, hy = cosa * h_2; + + px[1] = roi_x + wx + hx; + py[1] = roi_y + wy + hy; + px[2] = roi_x - wx + hx; + py[2] = roi_y - wy + hy; + px[3] = roi_x - wx - hx; + py[3] = roi_y - wy - hy; + px[4] = roi_x + wx - hx; + py[4] = roi_y + wy - hy; + } + + scalar_t* offset_bottom_diff = + bottom_diff + (n * channels + c) * height * width; + scalar_t value_top_diff = top_diff[index]; + + atomicAdd(bottom_diff + index, value_top_diff); + for (int i = 0; i < points; i++) { + scalar_t w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient(height, width, py[i], px[i], w1, + w2, w3, w4, x_low, x_high, y_low, + y_high, i); + scalar_t g1 = value_top_diff * w1; + scalar_t g2 = value_top_diff * w2; + scalar_t g3 = value_top_diff * w3; + scalar_t g4 = value_top_diff * w4; + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(offset_bottom_diff + y_low * width + x_low, g1); + atomicAdd(offset_bottom_diff + y_low * width + x_high, g2); + atomicAdd(offset_bottom_diff + y_high * width + x_low, g3); + atomicAdd(offset_bottom_diff + y_high * width + x_high, g4); + } + } + } +} +#endif // ROTATED_FEATURE_ALIGN_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/scatter_points_musa_kernel.muh b/mmcv/ops/csrc/common/musa/scatter_points_musa_kernel.muh new file mode 100644 index 0000000000..ba418eceba --- /dev/null +++ b/mmcv/ops/csrc/common/musa/scatter_points_musa_kernel.muh @@ -0,0 +1,137 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef SCATTER_POINTS_MUSA_KERNEL_MUH +#define SCATTER_POINTS_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t; +int const maxGridDim = 50000; + +__device__ __forceinline__ static void reduceMax(float *address, float val) { + int *address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = atomicCAS(address_as_i, assumed, + __float_as_int(fmaxf(val, __int_as_float(assumed)))); + } while (assumed != old || __int_as_float(old) < val); +} + +__device__ __forceinline__ static void reduceMax(double *address, double val) { + unsigned long long *address_as_ull = + reinterpret_cast(address); + unsigned long long old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS( + address_as_ull, assumed, + __double_as_longlong(fmax(val, __longlong_as_double(assumed)))); + } while (assumed != old || __longlong_as_double(old) < val); +} + +__device__ __forceinline__ static void reduceAdd(float *address, float val) { + atomicAdd(address, val); +} + +__device__ __forceinline__ static void reduceAdd(double *address, double val) { + atomicAdd(address, val); + +} + +template +__global__ void feats_reduce_kernel( + const T *feats, const int32_t *coors_map, + T *reduced_feats, // shall be 0 at initialization + const int num_input, const int num_feats, const reduce_t reduce_type) { + MUSA_1D_KERNEL_LOOP(x, num_input) { + int32_t reduce_to = coors_map[x]; + if (reduce_to == -1) continue; + + const T *feats_offset = feats + x * num_feats; + T *reduced_feats_offset = reduced_feats + reduce_to * num_feats; + if (reduce_type == reduce_t::MAX) { + for (int i = 0; i < num_feats; i++) { + reduceMax(&reduced_feats_offset[i], feats_offset[i]); + } + } else { + for (int i = 0; i < num_feats; i++) { + reduceAdd(&reduced_feats_offset[i], feats_offset[i]); + } + } + } +} + +template +__global__ void add_reduce_traceback_grad_kernel( + T *grad_feats, const T *grad_reduced_feats, const int32_t *coors_map, + const int32_t *reduce_count, const int num_input, const int num_feats, + const reduce_t reduce_type) { + MUSA_1D_KERNEL_LOOP(x, num_input) { + int32_t reduce_to = coors_map[x]; + if (reduce_to == -1) { + continue; + } + + const int input_offset = x * num_feats; + T *grad_feats_offset = grad_feats + input_offset; + const int reduced_offset = reduce_to * num_feats; + const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset; + + if (reduce_type == reduce_t::SUM) { + for (int i = 0; i < num_feats; i++) { + grad_feats_offset[i] = grad_reduced_feats_offset[i]; + } + } else if (reduce_type == reduce_t::MEAN) { + for (int i = 0; i < num_feats; i++) { + grad_feats_offset[i] = grad_reduced_feats_offset[i] / + static_cast(reduce_count[reduce_to]); + } + } + } +} + +template +__global__ void max_reduce_traceback_scatter_idx_kernel( + const T *feats, const T *reduced_feats, int32_t *reduce_from, + const int32_t *coors_map, const int num_input, const int num_feats) { + MUSA_1D_KERNEL_LOOP(x, num_input) { + int32_t reduce_to = coors_map[x]; + + const int input_offset = x * num_feats; + const T *feats_offset = feats + input_offset; + + if (reduce_to == -1) { + continue; + } + + const int reduced_offset = reduce_to * num_feats; + const T *reduced_feats_offset = reduced_feats + reduced_offset; + int32_t *reduce_from_offset = reduce_from + reduced_offset; + + for (int i = 0; i < num_feats; i++) { + if (feats_offset[i] == reduced_feats_offset[i]) { + atomicMin(&reduce_from_offset[i], static_cast(x)); + } + } + } +} + +template +__global__ void max_reduce_scatter_grad_kernel(T *grad_feats, + const T *grad_reduced_feats, + const int32_t *reduce_from, + const int num_reduced, + const int num_feats) { + MUSA_1D_KERNEL_LOOP(x, num_reduced) { + const int reduced_offset = x * num_feats; + const int32_t *scatter_to_offset = reduce_from + reduced_offset; + const T *grad_reduced_feats_offset = grad_reduced_feats + reduced_offset; + + for (int i = 0; i < num_feats; i++) { + grad_feats[scatter_to_offset[i] * num_feats + i] = + grad_reduced_feats_offset[i]; + } + } +} + +#endif // SCATTER_POINTS_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/sync_bn_musa_kernel.muh b/mmcv/ops/csrc/common/musa/sync_bn_musa_kernel.muh new file mode 100644 index 0000000000..7eb5e03826 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/sync_bn_musa_kernel.muh @@ -0,0 +1,327 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef SYNCBN_MUSA_KERNEL_MUH +#define SYNCBN_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +template +__global__ void sync_bn_forward_mean_musa_kernel(const T *input, float *mean, + int num, int channels, + int spatial) { + __shared__ float buffer[THREADS_PER_BLOCK]; + int tid = threadIdx.x; + int c = blockIdx.x; + buffer[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + buffer[tid] += input[index]; + } + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer[tid] += buffer[tid + s]; + } + __syncthreads(); + } + int total = num * spatial; + if (tid == 0) { + mean[c] = buffer[0] / total; + } +} + +template <> +__global__ void sync_bn_forward_mean_musa_kernel(const phalf *input, + float *mean, int num, + int channels, int spatial) { + __shared__ float buffer[THREADS_PER_BLOCK]; + int tid = threadIdx.x; + int c = blockIdx.x; + buffer[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + buffer[tid] += static_cast(input[index]); + } + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer[tid] += buffer[tid + s]; + } + __syncthreads(); + } + int total = num * spatial; + if (tid == 0) { + mean[c] = buffer[0] / total; + } +} + +template +__global__ void sync_bn_forward_var_musa_kernel(const T *input, + const float *mean, float *var, + int num, int channels, + int spatial) { + __shared__ float buffer[THREADS_PER_BLOCK]; + int tid = threadIdx.x; + int c = blockIdx.x; + buffer[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + float td = input[index] - mean[c]; + buffer[tid] += td * td; + } + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer[tid] += buffer[tid + s]; + } + __syncthreads(); + } + int total = num * spatial; + if (tid == 0) { + var[c] = buffer[0] / total; + } +} + +template <> +__global__ void sync_bn_forward_var_musa_kernel(const phalf *input, + const float *mean, float *var, + int num, int channels, + int spatial) { + __shared__ float buffer[THREADS_PER_BLOCK]; + int tid = threadIdx.x; + int c = blockIdx.x; + buffer[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + float td = static_cast(input[index]) - mean[c]; + buffer[tid] += td * td; + } + __syncthreads(); + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer[tid] += buffer[tid + s]; + } + __syncthreads(); + } + int total = num * spatial; + if (tid == 0) { + var[c] = buffer[0] / total; + } +} + +template +__global__ void sync_bn_forward_output_musa_kernel( + const T *input, const float *mean, const float *var, float *running_mean, + float *running_var, const float *weight, const float *bias, float *norm, + float *std, T *output, int num, int channels, int spatial, float eps, + float momentum, int group_size) { + int tid = threadIdx.x; + int c = blockIdx.x; + float mean_value = mean[c]; + float std_value = sqrt(var[c] + eps); + + if (weight != nullptr) { + float weight_value = weight[c]; + float bias_value = bias[c]; + if (norm != nullptr) { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + norm[index] = (input[index] - mean_value) / std_value; + output[index] = norm[index] * weight_value + bias_value; + } + } else { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + output[index] = + (input[index] - mean_value) / std_value * weight_value + bias_value; + } + } + } else { + if (norm != nullptr) { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + output[index] = norm[index] = (input[index] - mean_value) / std_value; + } + } else { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + output[index] = (input[index] - mean_value) / std_value; + } + } + } + if (tid == 0) { + if (std != nullptr) std[c] = std_value; + if (running_mean != nullptr) { + running_mean[c] = + momentum * mean_value + (1 - momentum) * running_mean[c]; + int count = num * spatial * group_size; + float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c]; + running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c]; + } + } +} + +template <> +__global__ void sync_bn_forward_output_musa_kernel( + const phalf *input, const float *mean, const float *var, + float *running_mean, float *running_var, const float *weight, + const float *bias, float *norm, float *std, phalf *output, int num, + int channels, int spatial, float eps, float momentum, int group_size) { + int tid = threadIdx.x; + int c = blockIdx.x; + float mean_value = mean[c]; + float std_value = sqrt(var[c] + eps); + if (weight != nullptr) { + float weight_value = weight[c]; + float bias_value = bias[c]; + if (norm != nullptr) { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + norm[index] = + (static_cast(input[index]) - mean_value) / std_value; + output[index] = + static_cast(norm[index] * weight_value + bias_value); + } + } else { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + output[index] = + static_cast((static_cast(input[index]) - mean_value) / + std_value * weight_value + + bias_value); + } + } + } else { + if (norm != nullptr) { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + norm[index] = + (static_cast(input[index]) - mean_value) / std_value; + output[index] = static_cast(norm[index]); + } + } else { + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = + (i / spatial) * channels * spatial + c * spatial + i % spatial; + output[index] = static_cast( + (static_cast(input[index]) - mean_value) / std_value); + } + } + } + if (tid == 0) { + if (std != nullptr) std[c] = std_value; + if (running_mean != nullptr) { + running_mean[c] = + momentum * mean_value + (1 - momentum) * running_mean[c]; + int count = num * spatial * group_size; + float var_unbias = count > 1 ? var[c] * count / (count - 1) : var[c]; + running_var[c] = momentum * var_unbias + (1 - momentum) * running_var[c]; + } + } +} + +template +__global__ void sync_bn_backward_param_musa_kernel(const T *grad_output, + const float *norm, + float *grad_weight, + float *grad_bias, int num, + int channels, int spatial) { + __shared__ float buffer1[THREADS_PER_BLOCK]; + __shared__ float buffer2[THREADS_PER_BLOCK]; + + int tid = threadIdx.x; + int c = blockIdx.x; + buffer1[tid] = buffer2[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + buffer1[tid] += grad_output[index] * norm[index]; + buffer2[tid] += grad_output[index]; + } + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer1[tid] += buffer1[tid + s]; + buffer2[tid] += buffer2[tid + s]; + } + __syncthreads(); + } + if (tid == 0) { + grad_weight[c] = buffer1[0]; + grad_bias[c] = buffer2[0]; + } +} + +template <> +__global__ void sync_bn_backward_param_musa_kernel(const phalf *grad_output, + const float *norm, + float *grad_weight, + float *grad_bias, int num, + int channels, int spatial) { + __shared__ float buffer1[THREADS_PER_BLOCK]; + __shared__ float buffer2[THREADS_PER_BLOCK]; + + int tid = threadIdx.x; + int c = blockIdx.x; + buffer1[tid] = buffer2[tid] = 0; + for (int i = tid; i < num * spatial; i += blockDim.x) { + int index = (i / spatial) * channels * spatial + c * spatial + i % spatial; + buffer1[tid] += static_cast(grad_output[index]) * norm[index]; + buffer2[tid] += static_cast(grad_output[index]); + } + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + buffer1[tid] += buffer1[tid + s]; + buffer2[tid] += buffer2[tid + s]; + } + __syncthreads(); + } + if (tid == 0) { + grad_weight[c] = buffer1[0]; + grad_bias[c] = buffer2[0]; + } +} + +template +__global__ void sync_bn_backward_data_musa_kernel( + int output_size, const T *grad_output, const float *weight, + const float *grad_weight, const float *grad_bias, const float *norm, + const float *std, T *grad_input, int num, int channels, int spatial) { + int factor = num * spatial; + MUSA_1D_KERNEL_LOOP(index, output_size) { + int c = (index / spatial) % channels; + grad_input[index] = + weight[c] * + (grad_output[index] - + (grad_weight[c] * norm[index] + grad_bias[c]) / factor) / + std[c]; + } +} + +template <> +__global__ void sync_bn_backward_data_musa_kernel( + int output_size, const phalf *grad_output, const float *weight, + const float *grad_weight, const float *grad_bias, const float *norm, + const float *std, phalf *grad_input, int num, int channels, int spatial) { + int factor = num * spatial; + MUSA_1D_KERNEL_LOOP(index, output_size) { + int c = (index / spatial) % channels; + grad_input[index] = static_cast( + weight[c] * + (static_cast(grad_output[index]) - + (grad_weight[c] * norm[index] + grad_bias[c]) / factor) / + std[c]); + } +} + +#endif // SYNCBN_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/three_interpolate_musa_kernel.muh b/mmcv/ops/csrc/common/musa/three_interpolate_musa_kernel.muh new file mode 100644 index 0000000000..4d5086ffda --- /dev/null +++ b/mmcv/ops/csrc/common/musa/three_interpolate_musa_kernel.muh @@ -0,0 +1,57 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef THREE_INTERPOLATE_MUSA_KERNEL_MUH +#define THREE_INTERPOLATE_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +template +__global__ void three_interpolate_forward_musa_kernel( + int b, int c, int m, int n, const T *points, const int *__restrict__ idx, + const T *weight, T *out) { + // points: (B, C, M) + // idx: (B, N, 3) + // weight: (B, N, 3) + // output: + // out: (B, C, N) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + MUSA_1D_KERNEL_LOOP(pt_idx, n) { + if (bs_idx >= b || c_idx >= c) return; + + weight += bs_idx * n * 3 + pt_idx * 3; + points += bs_idx * c * m + c_idx * m; + idx += bs_idx * n * 3 + pt_idx * 3; + out += bs_idx * c * n + c_idx * n; + + out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + + weight[2] * points[idx[2]]; + } +} + +template +__global__ void three_interpolate_backward_musa_kernel( + int b, int c, int n, int m, const T *grad_out, const int *__restrict__ idx, + const T *weight, T *grad_points) { + // grad_out: (B, C, N) + // weight: (B, N, 3) + // output: + // grad_points: (B, C, M) + + int bs_idx = blockIdx.z; + int c_idx = blockIdx.y; + MUSA_1D_KERNEL_LOOP(pt_idx, n) { + if (bs_idx >= b || c_idx >= c) return; + + grad_out += bs_idx * c * n + c_idx * n + pt_idx; + weight += bs_idx * n * 3 + pt_idx * 3; + grad_points += bs_idx * c * m + c_idx * m; + idx += bs_idx * n * 3 + pt_idx * 3; + + atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); + atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); + atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); + } +} + +#endif // THREE_INTERPOLATE_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/three_nn_musa_kernel.muh b/mmcv/ops/csrc/common/musa/three_nn_musa_kernel.muh new file mode 100644 index 0000000000..c25af06230 --- /dev/null +++ b/mmcv/ops/csrc/common/musa/three_nn_musa_kernel.muh @@ -0,0 +1,63 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef THREE_NN_MUSA_KERNEL_MUH +#define THREE_NN_MUSA_KERNEL_MUH + + +#include "pytorch_musa_helper.hpp" +template +__global__ void three_nn_forward_musa_kernel(int b, int n, int m, + const T *unknown, const T *known, + T *dist2, int *__restrict__ idx) { + // unknown: (B, N, 3) + // known: (B, M, 3) + // output: + // dist2: (B, N, 3) + // idx: (B, N, 3) + + int bs_idx = blockIdx.y; + MUSA_1D_KERNEL_LOOP(pt_idx, n) { + if (bs_idx >= b) return; + + unknown += bs_idx * n * 3 + pt_idx * 3; + known += bs_idx * m * 3; + dist2 += bs_idx * n * 3 + pt_idx * 3; + idx += bs_idx * n * 3 + pt_idx * 3; + + T ux = unknown[0]; + T uy = unknown[1]; + T uz = unknown[2]; + + double best1 = 1e40, best2 = 1e40, best3 = 1e40; + int besti1 = 0, besti2 = 0, besti3 = 0; + for (int k = 0; k < m; ++k) { + T x = known[k * 3 + 0]; + T y = known[k * 3 + 1]; + T z = known[k * 3 + 2]; + T d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + if (d < best1) { + best3 = best2; + besti3 = besti2; + best2 = best1; + besti2 = besti1; + best1 = d; + besti1 = k; + } else if (d < best2) { + best3 = best2; + besti3 = besti2; + best2 = d; + besti2 = k; + } else if (d < best3) { + best3 = d; + besti3 = k; + } + } + dist2[0] = best1; + dist2[1] = best2; + dist2[2] = best3; + idx[0] = besti1; + idx[1] = besti2; + idx[2] = besti3; + } +} + +#endif // THREE_NN_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/tin_shift_musa_kernel.muh b/mmcv/ops/csrc/common/musa/tin_shift_musa_kernel.muh new file mode 100644 index 0000000000..ba460883cb --- /dev/null +++ b/mmcv/ops/csrc/common/musa/tin_shift_musa_kernel.muh @@ -0,0 +1,57 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef TIN_SHIFT_MUSA_KERNEL_MUH +#define TIN_SHIFT_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +template +__global__ void tin_shift_forward_musa_kernel( + const int nthreads, const T* input, const int* shift, T* output, + const int batch_size, const int channels, const int t_size, + const int hw_size, const int group_size, const int group_channel) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + const int hw_index = index % hw_size; + const int j = (index / hw_size) % channels; + + const int n_index = (index / hw_size / channels) % batch_size; + int group_id = j / group_channel; + int t_shift = shift[n_index * group_size + group_id]; + int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index; + for (int i = 0; i < t_size; i++) { + int now_t = i + t_shift; + int data_id = i * hw_size * channels + offset; + if (now_t < 0 || now_t >= t_size) { + continue; + } + int out_id = now_t * hw_size * channels + offset; + output[out_id] = input[data_id]; + } + } +} + +template +__global__ void tin_shift_backward_musa_kernel( + const int nthreads, const T* input, const int* shift, T* output, + const int batch_size, const int channels, const int t_size, + const int hw_size, const int group_size, const int group_channel) { + MUSA_1D_KERNEL_LOOP(index, nthreads) { + const int hw_index = index % hw_size; + const int j = (index / hw_size) % channels; + + const int n_index = (index / hw_size / channels) % batch_size; + int group_id = j / group_channel; + int t_shift = shift[n_index * group_size + group_id]; + int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index; + for (int i = 0; i < t_size; i++) { + int now_t = i + t_shift; + int data_id = i * hw_size * channels + offset; + if (now_t < 0 || now_t >= t_size) { + continue; + } + int out_id = now_t * hw_size * channels + offset; + output[out_id] = input[data_id]; + } + } +} + +#endif // TIN_SHIFT_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/common/musa/voxelization_musa_kernel.muh b/mmcv/ops/csrc/common/musa/voxelization_musa_kernel.muh new file mode 100644 index 0000000000..24bc770f5a --- /dev/null +++ b/mmcv/ops/csrc/common/musa/voxelization_musa_kernel.muh @@ -0,0 +1,212 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef VOXELIZATION_MUSA_KERNEL_MUH +#define VOXELIZATION_MUSA_KERNEL_MUH + +#include "pytorch_musa_helper.hpp" + +typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t; + +template +__global__ void dynamic_voxelize_kernel( + const T* points, T_int* coors, const float voxel_x, const float voxel_y, + const float voxel_z, const float coors_x_min, const float coors_y_min, + const float coors_z_min, const float coors_x_max, const float coors_y_max, + const float coors_z_max, const int grid_x, const int grid_y, + const int grid_z, const int num_points, const int num_features, + const int NDim) { + // const int index = blockIdx.x * threadsPerBlock + threadIdx.x; + MUSA_1D_KERNEL_LOOP(index, num_points) { + // To save some computation + auto points_offset = points + index * num_features; + auto coors_offset = coors + index * NDim; + int c_x = floorf((points_offset[0] - coors_x_min) / voxel_x); + if (c_x < 0 || c_x >= grid_x) { + coors_offset[0] = -1; + continue; + } + + int c_y = floorf((points_offset[1] - coors_y_min) / voxel_y); + if (c_y < 0 || c_y >= grid_y) { + coors_offset[0] = -1; + coors_offset[1] = -1; + continue; + } + + int c_z = floorf((points_offset[2] - coors_z_min) / voxel_z); + if (c_z < 0 || c_z >= grid_z) { + coors_offset[0] = -1; + coors_offset[1] = -1; + coors_offset[2] = -1; + } else { + coors_offset[0] = c_z; + coors_offset[1] = c_y; + coors_offset[2] = c_x; + } + } +} + +template +__global__ void assign_point_to_voxel(const int nthreads, const T* points, + T_int* point_to_voxelidx, + T_int* coor_to_voxelidx, T* voxels, + const int max_points, + const int num_features, + const int num_points, const int NDim) { + MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) { + // const int index = blockIdx.x * threadsPerBlock + threadIdx.x; + int index = thread_idx / num_features; + + int num = point_to_voxelidx[index]; + int voxelidx = coor_to_voxelidx[index]; + if (num > -1 && voxelidx > -1) { + auto voxels_offset = + voxels + voxelidx * max_points * num_features + num * num_features; + + int k = thread_idx % num_features; + voxels_offset[k] = points[thread_idx]; + } + } +} + +template +__global__ void assign_voxel_coors(const int nthreads, T_int* coor, + T_int* point_to_voxelidx, + T_int* coor_to_voxelidx, T_int* voxel_coors, + const int num_points, const int NDim) { + MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) { + // const int index = blockIdx.x * threadsPerBlock + threadIdx.x; + // if (index >= num_points) return; + int index = thread_idx / NDim; + int num = point_to_voxelidx[index]; + int voxelidx = coor_to_voxelidx[index]; + if (num == 0 && voxelidx > -1) { + auto coors_offset = voxel_coors + voxelidx * NDim; + int k = thread_idx % NDim; + coors_offset[k] = coor[thread_idx]; + } + } +} + +template +__global__ void point_to_voxelidx_kernel(const T_int* coor, + T_int* point_to_voxelidx, + T_int* point_to_pointidx, + const int max_points, + const int max_voxels, + const int num_points, const int NDim) { + MUSA_1D_KERNEL_LOOP(index, num_points) { + auto coor_offset = coor + index * NDim; + // skip invalid points + if (coor_offset[0] == -1) continue; + + int num = 0; + int coor_x = coor_offset[0]; + int coor_y = coor_offset[1]; + int coor_z = coor_offset[2]; + // only calculate the coors before this coor[index] + for (int i = 0; i < index; ++i) { + auto prev_coor = coor + i * NDim; + if (prev_coor[0] == -1) continue; + + // Find all previous points that have the same coors + // if find the same coor, record it + if ((prev_coor[0] == coor_x) && (prev_coor[1] == coor_y) && + (prev_coor[2] == coor_z)) { + num++; + if (num == 1) { + // point to the same coor that first show up + point_to_pointidx[index] = i; + } else if (num >= max_points) { + // out of boundary + break; + } + } + } + if (num == 0) { + point_to_pointidx[index] = index; + } + if (num < max_points) { + point_to_voxelidx[index] = num; + } + } +} + +template +__global__ void determin_voxel_num( + // const T_int* coor, + T_int* num_points_per_voxel, T_int* point_to_voxelidx, + T_int* point_to_pointidx, T_int* coor_to_voxelidx, T_int* voxel_num, + const int max_points, const int max_voxels, const int num_points) { + // only calculate the coors before this coor[index] + for (int i = 0; i < num_points; ++i) { + int point_pos_in_voxel = point_to_voxelidx[i]; + // record voxel + if (point_pos_in_voxel == -1) { + // out of max_points or invalid point + continue; + } else if (point_pos_in_voxel == 0) { + // record new voxel + int voxelidx = voxel_num[0]; + if (voxel_num[0] >= max_voxels) continue; + voxel_num[0] += 1; + coor_to_voxelidx[i] = voxelidx; + num_points_per_voxel[voxelidx] = 1; + } else { + int point_idx = point_to_pointidx[i]; + int voxelidx = coor_to_voxelidx[point_idx]; + if (voxelidx != -1) { + coor_to_voxelidx[i] = voxelidx; + num_points_per_voxel[voxelidx] += 1; + } + } + } +} + +__global__ void nondeterministic_get_assign_pos( + const int nthreads, const int32_t* coors_map, int32_t* pts_id, + int32_t* coors_count, int32_t* reduce_count, int32_t* coors_order) { + MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) { + int coors_idx = coors_map[thread_idx]; + if (coors_idx > -1) { + int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1); + pts_id[thread_idx] = coors_pts_pos; + if (coors_pts_pos == 0) { + coors_order[coors_idx] = atomicAdd(coors_count, 1); + } + } + } +} + +template +__global__ void nondeterministic_assign_point_voxel( + const int nthreads, const T* points, const int32_t* coors_map, + const int32_t* pts_id, const int32_t* coors_in, const int32_t* reduce_count, + const int32_t* coors_order, T* voxels, int32_t* coors, int32_t* pts_count, + const int max_voxels, const int max_points, const int num_features, + const int NDim) { + MUSA_1D_KERNEL_LOOP(thread_idx, nthreads) { + int coors_idx = coors_map[thread_idx]; + int coors_pts_pos = pts_id[thread_idx]; + if (coors_idx > -1 && coors_pts_pos < max_points) { + int coors_pos = coors_order[coors_idx]; + if (coors_pos < max_voxels) { + auto voxels_offset = + voxels + (coors_pos * max_points + coors_pts_pos) * num_features; + auto points_offset = points + thread_idx * num_features; + for (int k = 0; k < num_features; k++) { + voxels_offset[k] = points_offset[k]; + } + if (coors_pts_pos == 0) { + pts_count[coors_pos] = min(reduce_count[coors_idx], max_points); + auto coors_offset = coors + coors_pos * NDim; + auto coors_in_offset = coors_in + coors_idx * NDim; + for (int k = 0; k < NDim; k++) { + coors_offset[k] = coors_in_offset[k]; + } + } + } + } + } +} + +#endif // VOXELIZATION_MUSA_KERNEL_MUH diff --git a/mmcv/ops/csrc/pytorch/musa/filtered_lrelu.mu b/mmcv/ops/csrc/pytorch/musa/filtered_lrelu.mu new file mode 100644 index 0000000000..b8994a7688 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/filtered_lrelu.mu @@ -0,0 +1,2052 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. +#include +#include +#include + +#include + +#include "pytorch_musa_helper.hpp" +#include "pytorch_device_registry.hpp" + +//------------------------------------------------------------------------ +// MUSA kernel parameters. + +struct filtered_lrelu_kernel_params { + // These parameters decide which kernel to use. + int up; // upsampling ratio (1, 2, 4) + int down; // downsampling ratio (1, 2, 4) + int2 fuShape; // [size, 1] | [size, size] + int2 fdShape; // [size, 1] | [size, size] + + int _dummy; // Alignment. + + // Rest of the parameters. + const void *x; // Input tensor. + void *y; // Output tensor. + const void *b; // Bias tensor. + unsigned char *s; // Sign tensor in/out. NULL if unused. + const float *fu; // Upsampling filter. + const float *fd; // Downsampling filter. + + int2 pad0; // Left/top padding. + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + int flip; // Filter kernel flip for gradient computation. + + int tilesXdim; // Original number of horizontal output tiles. + int tilesXrep; // Number of horizontal tiles per CTA. + int blockZofs; // Block z offset to support large minibatch, channel + // dimensions. + + int4 xShape; // [width, height, channel, batch] + int4 yShape; // [width, height, channel, batch] + int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if + // unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. + int swLimit; // Active width of sign tensor in bytes. + + longlong4 xStride; // Strides of all tensors except signs, same component + // order as shapes. + longlong4 yStride; // + int64_t bStride; // + longlong3 fuStride; // + longlong3 fdStride; // +}; + +struct filtered_lrelu_act_kernel_params { + void *x; // Input/output, modified in-place. + unsigned char *s; // Sign tensor in/out. NULL if unused. + + float gain; // Additional gain factor. + float slope; // Leaky ReLU slope on negative side. + float clamp; // Clamp after nonlinearity. + + int4 xShape; // [width, height, channel, batch] + longlong4 xStride; // Input/output tensor strides, same order as in shape. + int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if + // unused. + int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. +}; + +//------------------------------------------------------------------------ +// MUSA kernel specialization. + +struct filtered_lrelu_kernel_spec { + void *setup; // Function for filter kernel setup. + void *exec; // Function for main operation. + int2 tileOut; // Width/height of launch tile. + int numWarps; // Number of warps per thread block, determines launch block + // size. + int xrep; // For processing multiple horizontal tiles per thread block. + int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. +}; + +//------------------------------------------------------------------------ +// MUSA kernel selection. + +template +filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel( + const filtered_lrelu_kernel_params &p, int sharedKB); +template +void *choose_filtered_lrelu_act_kernel(void); + +//------------------------------------------------------------------------ +// Helpers. + +enum // Filter modes. +{ MODE_SUSD = 0, // Separable upsampling, separable downsampling. + MODE_FUSD = 1, // Full upsampling, separable downsampling. + MODE_SUFD = 2, // Separable upsampling, full downsampling. + MODE_FUFD = 3, // Full upsampling, full downsampling. +}; + +template +struct InternalType; +template <> +struct InternalType { + typedef double scalar_t; + typedef double2 vec2_t; + typedef double4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { + return make_double2(0, 0); + } + __device__ __forceinline__ static vec4_t zero_vec4(void) { + return make_double4(0, 0, 0, 0); + } + __device__ __forceinline__ static double clamp(double x, double c) { + return fmin(fmax(x, -c), c); + } +}; +template <> +struct InternalType { + typedef float scalar_t; + typedef float2 vec2_t; + typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { + return make_float2(0, 0); + } + __device__ __forceinline__ static vec4_t zero_vec4(void) { + return make_float4(0, 0, 0, 0); + } + __device__ __forceinline__ static float clamp(float x, float c) { + return fminf(fmaxf(x, -c), c); + } +}; +template <> +struct InternalType { + typedef float scalar_t; + typedef float2 vec2_t; + typedef float4 vec4_t; + __device__ __forceinline__ static vec2_t zero_vec2(void) { + return make_float2(0, 0); + } + __device__ __forceinline__ static vec4_t zero_vec4(void) { + return make_float4(0, 0, 0, 0); + } + __device__ __forceinline__ static float clamp(float x, float c) { + return fminf(fmaxf(x, -c), c); + } +}; + +#define MIN(A, B) ((A) < (B) ? (A) : (B)) +#define MAX(A, B) ((A) > (B) ? (A) : (B)) +#define CEIL_DIV(A, B) \ + (((B) == 1) \ + ? (A) \ + : ((B) == 2) ? ((int)((A) + 1) >> 1) \ + : ((B) == 4) ? ((int)((A) + 3) >> 2) \ + : (((A) + ((A) > 0 ? (B)-1 : 0)) / (B))) + +// This works only up to blocks of size 256 x 256 and for all N that are powers +// of two. +template +__device__ __forceinline__ void fast_div_mod(int &x, int &y, unsigned int i) { + if ((N & (N - 1)) && N <= 256) + y = (i * ((1 << 24) / N + 1)) >> 24; // Assumes N <= 256, i < N*256. + else + y = i / N; + + x = i - y * N; +} + +// Type cast stride before reading it. +template +__device__ __forceinline__ T get_stride(const int64_t &x) { + return *reinterpret_cast(&x); +} + +//------------------------------------------------------------------------ +// Filters, setup kernel, copying function. + +#define MAX_FILTER_SIZE 32 + +// Combined up/down filter buffers so that transfer can be done with one copy. +__device__ float + g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, + // written by setup kernel. +__device__ __constant__ float + c_fbuf[2 * MAX_FILTER_SIZE * + MAX_FILTER_SIZE]; // Filters in constant memory, read by main + // kernel. + +// Accessors to combined buffers to index up/down filters individually. +#define c_fu (c_fbuf) +#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) +#define g_fu (g_fbuf) +#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) + +// Set up filters into global memory buffer. +static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) { + for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; + idx += blockDim.x) { + int x, y; + fast_div_mod(x, y, idx); + + int fu_x = p.flip ? x : (p.fuShape.x - 1 - x); + int fu_y = p.flip ? y : (p.fuShape.y - 1 - y); + if (p.fuShape.y > 0) + g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) + ? 0.0f + : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y]; + else + g_fu[idx] = + (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x]; + + int fd_x = p.flip ? x : (p.fdShape.x - 1 - x); + int fd_y = p.flip ? y : (p.fdShape.y - 1 - y); + if (p.fdShape.y > 0) + g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) + ? 0.0f + : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y]; + else + g_fd[idx] = + (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x]; + } +} + +// Host function to copy filters written by setup kernel into constant buffer +// for main kernel. +static musaError_t copy_filters(musaStream_t stream) { + void *src = 0; + musaError_t err = musaGetSymbolAddress(&src, g_fbuf); + if (err) return err; + return musaMemcpyToSymbolAsync( + c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, + musaMemcpyDeviceToDevice, stream); +} + +//------------------------------------------------------------------------ +// Coordinate spaces: +// - Relative to input tensor: inX, inY, tileInX, tileInY +// - Relative to input tile: relInX, relInY, tileInW, tileInH +// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH +// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH +// - Relative to output tensor: outX, outY, tileOutX, tileOutY +// +// Relationships between coordinate spaces: +// - inX = tileInX + relInX +// - inY = tileInY + relInY +// - relUpX = relInX * up + phaseInX +// - relUpY = relInY * up + phaseInY +// - relUpX = relOutX * down +// - relUpY = relOutY * down +// - outX = tileOutX + relOutX +// - outY = tileOutY + relOutY + +extern __shared__ char + s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically + // inside the kernel, otherwise use the externally allocated + // shared memory buffer. + +template +static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { + // Check that we don't try to support non-existing filter modes. + static_assert(up == 1 || up == 2 || up == 4, + "only up=1, up=2, up=4 scales supported"); + static_assert(down == 1 || down == 2 || down == 4, + "only down=1, down=2, down=4 scales supported"); + static_assert(fuSize >= up, + "upsampling filter size must be at least upsampling factor"); + static_assert( + fdSize >= down, + "downsampling filter size must be at least downsampling factor"); + static_assert( + fuSize % up == 0, + "upsampling filter size must be divisible with upsampling factor"); + static_assert( + fdSize % down == 0, + "downsampling filter size must be divisible with downsampling factor"); + static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, + "filter size greater than MAX_FILTER_SIZE"); + static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || + filterMode == MODE_FUSD)), + "up=1 supported only for 1x1 full filters"); + static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || + filterMode == MODE_SUFD)), + "down=1 supported only for 1x1 full filters"); + static_assert( + !(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), + "full filters not supported for up=4"); + static_assert( + !(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), + "full filters not supported for down=4"); + + // Static definitions. + typedef typename InternalType::scalar_t scalar_t; + typedef typename InternalType::vec2_t vec2_t; + typedef typename InternalType::vec4_t vec4_t; + const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & + ~3; // Upsampled tile width, rounded up to multiple of 4. + const int tileUpH = + tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height. + const int tileInW = + CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width. + const int tileInH = + CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height. + const int tileUpH_up = + CEIL_DIV(tileUpH, up) * + up; // Upsampled tile height rounded up to a multiple of up. + const int tileInH_up = + CEIL_DIV(tileUpH_up + (fuSize - 1), + up); // For allocations only, to avoid shared memory read + // overruns with up=2 and up=4. + + // Merge 1x1 downsampling into last upsampling step for upf1 and ups2. + const bool downInline = + (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || + (up == 2 && filterMode == MODE_SUFD)); + + // Sizes of logical buffers. + const int szIn = tileInH_up * tileInW; + const int szUpX = tileInH_up * tileUpW; + const int szUpXY = downInline ? 0 : (tileUpH * tileUpW); + const int szDownX = tileUpH * tileOutW; + + // Sizes for shared memory arrays. + const int s_buf0_size_base = + (filterMode == MODE_SUSD) + ? MAX(szIn, szUpXY) + : (filterMode == MODE_FUSD) + ? MAX(szIn, szDownX) + : (filterMode == MODE_SUFD) + ? MAX(szIn, szUpXY) + : (filterMode == MODE_FUFD) ? szIn : -1; + const int s_buf1_size_base = + (filterMode == MODE_SUSD) + ? MAX(szUpX, szDownX) + : (filterMode == MODE_FUSD) + ? szUpXY + : (filterMode == MODE_SUFD) + ? szUpX + : (filterMode == MODE_FUFD) ? szUpXY : -1; + + // Ensure U128 alignment. + const int s_buf0_size = (s_buf0_size_base + 3) & ~3; + const int s_buf1_size = (s_buf1_size_base + 3) & ~3; + + // Check at compile time that we don't use too much shared memory. + static_assert( + (s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), + "shared memory overflow"); + + // Declare shared memory arrays. + scalar_t *s_buf0; + scalar_t *s_buf1; + if (sharedKB <= 48) { + // Allocate shared memory arrays here. + __shared__ scalar_t + s_buf0_st[(sharedKB > 48) + ? (1 << 24) + : (s_buf0_size + + s_buf1_size)]; // Prevent launching if this isn't + // optimized away when unused. + s_buf0 = s_buf0_st; + s_buf1 = s_buf0 + s_buf0_size; + } else { + // Use the dynamically allocated shared memory array. + s_buf0 = (scalar_t *)s_buf_raw; + s_buf1 = s_buf0 + s_buf0_size; + } + + // Pointers to the buffers. + scalar_t * + s_tileIn; // Input tile: [relInX * tileInH + relInY] + scalar_t *s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + + // relUpX] + scalar_t *s_tileUpXY; // After upsampling: [relUpY * tileUpW + + // relUpX] + scalar_t *s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + // + relOutX] + if (filterMode == MODE_SUSD) { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + s_tileDownX = s_buf1; + } else if (filterMode == MODE_FUSD) { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + s_tileDownX = s_buf0; + } else if (filterMode == MODE_SUFD) { + s_tileIn = s_buf0; + s_tileUpX = s_buf1; + s_tileUpXY = s_buf0; + } else if (filterMode == MODE_FUFD) { + s_tileIn = s_buf0; + s_tileUpXY = s_buf1; + } + + // Allow large grids in z direction via per-launch offset. + int channelIdx = blockIdx.z + p.blockZofs; + int batchIdx = channelIdx / p.yShape.z; + channelIdx -= batchIdx * p.yShape.z; + + // Offset to output feature map. In bytes. + index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + + batchIdx * get_stride(p.yStride.w); + + // Sign shift amount. + uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6; + +// Inner tile loop. +#pragma unroll 1 + for (int tileIdx = 0; + !enableXrep || + (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); + tileIdx++) { + // Locate output tile. + int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x; + int tileOutX = tileX * tileOutW; + int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH; + + // Locate input tile. + int tmpX = tileOutX * down - p.pad0.x; + int tmpY = tileOutY * down - p.pad0.y; + int tileInX = CEIL_DIV(tmpX, up); + int tileInY = CEIL_DIV(tmpY, up); + const int phaseInX = tileInX * up - tmpX; + const int phaseInY = tileInY * up - tmpY; + + // Extra sync if input and output buffers are the same and we are not on + // first tile. + if (enableXrep && tileIdx > 0 && + (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || + (filterMode == MODE_FUFD && downInline))) + __syncthreads(); + + // Load input tile & apply bias. Unrolled. + scalar_t b = + (scalar_t) * (const T *)((const char *)p.b + + (channelIdx * get_stride(p.bStride))); + index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + + batchIdx * get_stride(p.xStride.w); + int idx = threadIdx.x; + const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock); +#pragma unroll + for (int loop = 0; loop < loopCountIN; loop++) { + int relInX, relInY; + fast_div_mod(relInX, relInY, idx); + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + + if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y) + v = (scalar_t) * ((const T *)((const char *)p.x + + (inX * get_stride(p.xStride.x) + + inY * get_stride(p.xStride.y) + + mapOfsIn))) + + b; + + bool skip = (loop == loopCountIN - 1) && (idx >= tileInW * tileInH); + if (!skip) s_tileIn[idx] = v; + + idx += threadsPerBlock; + } + + if (filterMode == MODE_SUSD || + filterMode == MODE_SUFD) // Separable upsampling filter. + { + // Horizontal upsampling. + __syncthreads(); + if (up == 4) { + for (int idx = threadIdx.x * up; idx < tileUpW * tileInH; + idx += blockDim.x * up) { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); + scalar_t a = s_tileIn[src0]; + if (phaseInX == 0) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } else if (phaseInX == 1) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } else if (phaseInX == 2) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } else // (phaseInX == 3) + { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst + 0] = v.x; + s_tileUpX[dst + 1] = v.y; + s_tileUpX[dst + 2] = v.z; + s_tileUpX[dst + 3] = v.w; + } + } else if (up == 2) { + bool p0 = (phaseInX == 0); + for (int idx = threadIdx.x * up; idx < tileUpW * tileInH; + idx += blockDim.x * up) { + int relUpX0, relInY; + fast_div_mod(relUpX0, relInY, idx); + int relInX0 = relUpX0 / up; + int src0 = relInX0 + tileInW * relInY; + int dst = relInY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); + scalar_t a = s_tileIn[src0]; + if (p0) // (phaseInX == 0) + { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } else // (phaseInX == 1) + { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileIn[src0 + step + 1]; + } + } + s_tileUpX[dst + 0] = v.x; + s_tileUpX[dst + 1] = v.y; + } + } + + // Vertical upsampling & nonlinearity. + + __syncthreads(); + int groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + : 0; // Skip already written signs. + int sShapeMaxY = + MIN(p.sShape.y, + tileOutY * down + tileUpH); // Avoid out-of-tile sign writes. + if (up == 4) { + minY -= 3; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; + idx += blockDim.x) { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec4_t v = InternalType::zero_vec4(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 3]; + v.z += a * (scalar_t)c_fu[step * up + 2]; + v.w += a * (scalar_t)c_fu[step * up + 1]; + } + } else if (phaseInY == 1) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.z += a * (scalar_t)c_fu[step * up + 3]; + v.w += a * (scalar_t)c_fu[step * up + 2]; + } + } else if (phaseInY == 2) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 2]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + v.z += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.w += a * (scalar_t)c_fu[step * up + 3]; + } + } else // (phaseInY == 3) + { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 3]; + v.y += a * (scalar_t)c_fu[step * up + 2]; + v.z += a * (scalar_t)c_fu[step * up + 1]; + v.w += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = + signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + index_t si2 = si0 + p.sShape.x * 2; + index_t si3 = si0 + p.sShape.x * 3; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) { + if (!enableWriteSkip) { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2 << 0; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (fabsf(v.y) > p.clamp) { + sy = 2 << 8; + v.y = InternalType::clamp(v.y, p.clamp); + } + if (fabsf(v.z) > p.clamp) { + sz = 2 << 16; + v.z = InternalType::clamp(v.z, p.clamp); + } + if (fabsf(v.w) > p.clamp) { + sw = 2 << 24; + v.w = InternalType::clamp(v.w, p.clamp); + } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) { + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; +#ifdef MMCV_WITH_HIP + s |= __shfl_xor(s, 1); + s |= __shfl_xor(s, 2); +#else + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); +#endif + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { + p.s[si0] = (unsigned char)(s >> 0); + } + if ((uint32_t)(signY + 1) < sShapeMaxY) { + p.s[si1] = (unsigned char)(s >> 8); + } + if ((uint32_t)(signY + 2) < sShapeMaxY) { + p.s[si2] = (unsigned char)(s >> 16); + } + if ((uint32_t)(signY + 3) < sShapeMaxY) { + p.s[si3] = (unsigned char)(s >> 24); + } + } + } else { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + int sz = __float_as_uint(v.z) >> 31 << 16; + int sw = __float_as_uint(v.w) >> 31 << 24; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (sz) v.z *= p.slope; + if (sw) v.w *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2 << 0; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (fabsf(v.y) > p.clamp) { + sy = 2 << 8; + v.y = InternalType::clamp(v.y, p.clamp); + } + if (fabsf(v.z) > p.clamp) { + sz = 2 << 16; + v.z = InternalType::clamp(v.z, p.clamp); + } + if (fabsf(v.w) > p.clamp) { + sw = 2 << 24; + v.w = InternalType::clamp(v.w, p.clamp); + } + + // Combine signs. + uint32_t s = sx + sy + sw + sz; + s <<= (signX & 3) << 1; +#ifdef MMCV_WITH_HIP + s |= __shfl_xor(s, 1); + s |= __shfl_xor(s, 2); +#else + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); +#endif + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { + p.s[si0] = (unsigned char)(s >> 0); + } + if ((uint32_t)(signY + 1) < sShapeMaxY) { + p.s[si1] = (unsigned char)(s >> 8); + } + if ((uint32_t)(signY + 2) < sShapeMaxY) { + p.s[si2] = (unsigned char)(s >> 16); + } + if ((uint32_t)(signY + 3) < sShapeMaxY) { + p.s[si3] = (unsigned char)(s >> 24); + } + } else { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; + v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; + v.w = InternalType::clamp(v.w, p.clamp); + } + } + } else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) { + int ss = (signX & 3) << 1; + if ((uint32_t)(signY + 0) < p.sShape.y) { + int s = p.s[si0] >> ss; + if (s & 1) v.x *= p.slope; + if (s & 2) v.x = 0.f; + } + if ((uint32_t)(signY + 1) < p.sShape.y) { + int s = p.s[si1] >> ss; + if (s & 1) v.y *= p.slope; + if (s & 2) v.y = 0.f; + } + if ((uint32_t)(signY + 2) < p.sShape.y) { + int s = p.s[si2] >> ss; + if (s & 1) v.z *= p.slope; + if (s & 2) v.z = 0.f; + } + if ((uint32_t)(signY + 3) < p.sShape.y) { + int s = p.s[si3] >> ss; + if (s & 1) v.w *= p.slope; + if (s & 2) v.w = 0.f; + } + } + } else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; + v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; + v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[dst + 0 * tileUpW] = v.x; + if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y; + if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z; + if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w; + } + } else if (up == 2) { + minY -= 1; // Adjust according to block height. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; + idx += blockDim.x) { + int relUpX, relInY0; + fast_div_mod(relUpX, relInY0, idx); + int relUpY0 = relInY0 * up; + int src0 = relInY0 * tileUpW + relUpX; + int dst = relUpY0 * tileUpW + relUpX; + vec2_t v = InternalType::zero_vec2(); + + scalar_t a = s_tileUpX[src0]; + if (phaseInY == 0) { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + v.y += a * (scalar_t)c_fu[step * up + 1]; + } + } else // (phaseInY == 1) + { +#pragma unroll + for (int step = 0; step < fuSize / up; step++) { + v.x += a * (scalar_t)c_fu[step * up + 1]; + v.y += a * (scalar_t)c_fu[step * up + 0]; + a = s_tileUpX[src0 + (step + 1) * tileUpW]; + } + } + + int x = tileOutX * down + relUpX; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si0 = + signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + index_t si1 = si0 + p.sShape.x; + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) { + if (!enableWriteSkip) { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2 << 0; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (fabsf(v.y) > p.clamp) { + sy = 2 << 8; + v.y = InternalType::clamp(v.y, p.clamp); + } + + if ((uint32_t)signXb < p.swLimit && signY >= minY) { + // Combine signs. + int s = sx + sy; + s <<= signXo; +#ifdef MMCV_WITH_HIP + s |= __shfl_xor(s, 1); + s |= __shfl_xor(s, 2); +#else + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); +#endif + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { + p.s[si0] = (unsigned char)(s >> 0); + } + if ((uint32_t)(signY + 1) < sShapeMaxY) { + p.s[si1] = (unsigned char)(s >> 8); + } + } + } else { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && signY >= minY) { + int sx = __float_as_uint(v.x) >> 31 << 0; + int sy = __float_as_uint(v.y) >> 31 << 8; + if (sx) v.x *= p.slope; + if (sy) v.y *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2 << 0; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (fabsf(v.y) > p.clamp) { + sy = 2 << 8; + v.y = InternalType::clamp(v.y, p.clamp); + } + + // Combine signs. + int s = sx + sy; + s <<= signXo; +#ifdef MMCV_WITH_HIP + s |= __shfl_xor(s, 1); + s |= __shfl_xor(s, 2); +#else + s |= __shfl_xor_sync(groupMask, s, 1); + s |= __shfl_xor_sync(groupMask, s, 2); +#endif + + // Write signs. + if ((uint32_t)(signY + 0) < sShapeMaxY) { + p.s[si0] = (unsigned char)(s >> 0); + } + if ((uint32_t)(signY + 1) < sShapeMaxY) { + p.s[si1] = (unsigned char)(s >> 8); + } + } else { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + } + } + } else if (signRead) // Read signs and apply. + { + if ((uint32_t)signXb < p.swLimit) { + if ((uint32_t)(signY + 0) < p.sShape.y) { + int s = p.s[si0] >> signXo; + if (s & 1) v.x *= p.slope; + if (s & 2) v.x = 0.f; + } + if ((uint32_t)(signY + 1) < p.sShape.y) { + int s = p.s[si1] >> signXo; + if (s & 1) v.y *= p.slope; + if (s & 2) v.y = 0.f; + } + } + } else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + } + + if (!downInline) { + // Write into temporary buffer. + s_tileUpXY[dst] = v.x; + if (relUpY0 < tileUpH - 1) s_tileUpXY[dst + tileUpW] = v.y; + } else { + // Write directly into output buffer. + if ((uint32_t)x < p.yShape.x) { + int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down); + index_t ofs = x * get_stride(p.yStride.x) + + y * get_stride(p.yStride.y) + mapOfsOut; + if ((uint32_t)y + 0 < p.yShape.y) + *((T *)((char *)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]); + if ((uint32_t)y + 1 < ymax) + *((T *)((char *)p.y + ofs + get_stride(p.yStride.y))) = + (T)(v.y * (scalar_t)c_fd[0]); + } + } + } + } + } else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) { + // Full upsampling filter. + + if (up == 2) { + // 2 x 2-wide. + __syncthreads(); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y + : 0; // Skip already written signs. + for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; + idx += blockDim.x * 4) { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up); + int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up); + int src0 = relInX0 + tileInW * relInY0; + int tap0y = (relInY0 * up + phaseInY - relUpY0); + +#define X_LOOP(TAPY, PX) \ + for (int sx = 0; sx < fuSize / up; sx++) { \ + v.x += a * (scalar_t)c_fu[(sx * up + (((PX)-0) & (up - 1))) + \ + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.z += b * (scalar_t)c_fu[(sx * up + (((PX)-0) & (up - 1))) + \ + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + if ((PX) == 0) { \ + a = b; \ + b = s_tileIn[src0 + 2 + sx + sy * tileInW]; \ + } \ + v.y += a * (scalar_t)c_fu[(sx * up + (((PX)-1) & (up - 1))) + \ + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + v.w += b * (scalar_t)c_fu[(sx * up + (((PX)-1) & (up - 1))) + \ + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ + if ((PX) == 1) { \ + a = b; \ + b = s_tileIn[src0 + 2 + sx + sy * tileInW]; \ + } \ + } + + vec4_t v = InternalType::zero_vec4(); + if (tap0y == 0 && phaseInX == 0) +#pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { + scalar_t a = s_tileIn[src0 + sy * tileInW]; + scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; +#pragma unroll + X_LOOP(0, 0) + } + if (tap0y == 0 && phaseInX == 1) +#pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { + scalar_t a = s_tileIn[src0 + sy * tileInW]; + scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; +#pragma unroll + X_LOOP(0, 1) + } + if (tap0y == 1 && phaseInX == 0) +#pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { + scalar_t a = s_tileIn[src0 + sy * tileInW]; + scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; +#pragma unroll + X_LOOP(1, 0) + } + if (tap0y == 1 && phaseInX == 1) +#pragma unroll + for (int sy = 0; sy < fuSize / up; sy++) { + scalar_t a = s_tileIn[src0 + sy * tileInW]; + scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; +#pragma unroll + X_LOOP(1, 1) + } + +#undef X_LOOP + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = + signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + + v.x *= (scalar_t)((float)up * (float)up * p.gain); + v.y *= (scalar_t)((float)up * (float)up * p.gain); + v.z *= (scalar_t)((float)up * (float)up * p.gain); + v.w *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) { + if (!enableWriteSkip) { + // Determine and write signs. + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (sy) v.y *= p.slope; + if (fabsf(v.y) > p.clamp) { + sy = 2; + v.y = InternalType::clamp(v.y, p.clamp); + } + if (sz) v.z *= p.slope; + if (fabsf(v.z) > p.clamp) { + sz = 2; + v.z = InternalType::clamp(v.z, p.clamp); + } + if (sw) v.w *= p.slope; + if (fabsf(v.w) > p.clamp) { + sw = 2; + v.w = InternalType::clamp(v.w, p.clamp); + } + + if ((uint32_t)signXb < p.swLimit && + (uint32_t)signY < p.sShape.y && signY >= minY) { + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } + } else { + // Determine and write signs. + if ((uint32_t)signXb < p.swLimit && + (uint32_t)signY < p.sShape.y && signY >= minY) { + int sx = __float_as_uint(v.x) >> 31; + int sy = __float_as_uint(v.y) >> 31; + int sz = __float_as_uint(v.z) >> 31; + int sw = __float_as_uint(v.w) >> 31; + if (sx) v.x *= p.slope; + if (fabsf(v.x) > p.clamp) { + sx = 2; + v.x = InternalType::clamp(v.x, p.clamp); + } + if (sy) v.y *= p.slope; + if (fabsf(v.y) > p.clamp) { + sy = 2; + v.y = InternalType::clamp(v.y, p.clamp); + } + if (sz) v.z *= p.slope; + if (fabsf(v.z) > p.clamp) { + sz = 2; + v.z = InternalType::clamp(v.z, p.clamp); + } + if (sw) v.w *= p.slope; + if (fabsf(v.w) > p.clamp) { + sw = 2; + v.w = InternalType::clamp(v.w, p.clamp); + } + + p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); + } else { + // Just compute the values. + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; + v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; + v.w = InternalType::clamp(v.w, p.clamp); + } + } + } else if (signRead) // Read sign and apply. + { + if ((uint32_t)signY < p.sShape.y) { + int s = 0; + if ((uint32_t)signXb < p.swLimit) s = p.s[si]; + if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8; + s >>= (signX & 3) << 1; + if (s & 0x01) v.x *= p.slope; + if (s & 0x02) v.x = 0.f; + if (s & 0x04) v.y *= p.slope; + if (s & 0x08) v.y = 0.f; + if (s & 0x10) v.z *= p.slope; + if (s & 0x20) v.z = 0.f; + if (s & 0x40) v.w *= p.slope; + if (s & 0x80) v.w = 0.f; + } + } else // Forward pass with no sign write. + { + if (v.x < 0.f) v.x *= p.slope; + v.x = InternalType::clamp(v.x, p.clamp); + if (v.y < 0.f) v.y *= p.slope; + v.y = InternalType::clamp(v.y, p.clamp); + if (v.z < 0.f) v.z *= p.slope; + v.z = InternalType::clamp(v.z, p.clamp); + if (v.w < 0.f) v.w *= p.slope; + v.w = InternalType::clamp(v.w, p.clamp); + } + + s_tileUpXY[idx + 0] = v.x; + s_tileUpXY[idx + 1] = v.y; + s_tileUpXY[idx + 2] = v.z; + s_tileUpXY[idx + 3] = v.w; + } + } else if (up == 1) { + __syncthreads(); + uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3); + int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + : 0; // Skip already written signs. + for (int idx = threadIdx.x; idx < tileUpW * tileUpH; + idx += blockDim.x) { + int relUpX0, relUpY0; + fast_div_mod(relUpX0, relUpY0, idx); + scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter. + + int x = tileOutX * down + relUpX0; + int y = tileOutY * down + relUpY0; + int signX = x + p.sOfs.x; + int signY = y + p.sOfs.y; + int signZ = blockIdx.z + p.blockZofs; + int signXb = signX >> 2; + index_t si = + signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); + v *= (scalar_t)((float)up * (float)up * p.gain); + + if (signWrite) { + if (!enableWriteSkip) { + // Determine and write sign. + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } + if ((uint32_t)signXb < p.swLimit && + (uint32_t)signY < p.sShape.y && signY >= minY) { +#ifdef MMCV_WITH_HIP + s += __shfl_xor(s, 1); // Coalesce. + s += __shfl_xor(s, 2); // Coalesce. +#else + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. +#endif + p.s[si] = s; // Write. + } + } else { + // Determine and write sign. + if ((uint32_t)signXb < p.swLimit && + (uint32_t)signY < p.sShape.y && signY >= minY) { + uint32_t s = 0; + uint32_t signXbit = (1u << signXo); + if (v < 0.f) { + s = signXbit; + v *= p.slope; + } + if (fabsf(v) > p.clamp) { + s = signXbit * 2; + v = InternalType::clamp(v, p.clamp); + } +#ifdef MMCV_WITH_HIP + s += __shfl_xor(s, 1); // Coalesce. + s += __shfl_xor(s, 2); // Coalesce. +#else + s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. + s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. +#endif + p.s[si] = s; // Write. + } else { + // Just compute the value. + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + } + } else if (signRead) { + // Read sign and apply if within sign tensor bounds. + if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) { + int s = p.s[si]; + s >>= signXo; + if (s & 1) v *= p.slope; + if (s & 2) v = 0.f; + } + } else // Forward pass with no sign write. + { + if (v < 0.f) v *= p.slope; + v = InternalType::clamp(v, p.clamp); + } + + if (!downInline) // Write into temporary buffer. + s_tileUpXY[idx] = v; + else if ((uint32_t)x < p.yShape.x && + (uint32_t)y < + p.yShape.y) // Write directly into output buffer + *((T *)((char *)p.y + (x * get_stride(p.yStride.x) + + y * get_stride(p.yStride.y) + + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]); + } + } + } + + // Downsampling. + if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) { + // Horizontal downsampling. + __syncthreads(); + if (down == 4 && tileOutW % 4 == 0) { + // Calculate 4 pixels at a time. + for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; + idx += blockDim.x * 4) { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec4_t v = InternalType::zero_vec4(); +#pragma unroll + for (int step = 0; step < fdSize; step++) { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step]; + v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step]; + v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx + 0] = v.x; + s_tileDownX[idx + 1] = v.y; + s_tileDownX[idx + 2] = v.z; + s_tileDownX[idx + 3] = v.w; + } + } else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) { + // Calculate 2 pixels at a time. + for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; + idx += blockDim.x * 2) { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src0 = relUpY * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); +#pragma unroll + for (int step = 0; step < fdSize; step++) { + v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; + v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step]; + } + s_tileDownX[idx + 0] = v.x; + s_tileDownX[idx + 1] = v.y; + } + } else { + // Calculate 1 pixel at a time. + for (int idx = threadIdx.x; idx < tileOutW * tileUpH; + idx += blockDim.x) { + int relOutX0, relUpY; + fast_div_mod(relOutX0, relUpY, idx); + int relUpX0 = relOutX0 * down; + int src = relUpY * tileUpW + relUpX0; + scalar_t v = 0.f; +#pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileUpXY[src + step] * (scalar_t)c_fd[step]; + s_tileDownX[idx] = v; + } + } + + // Vertical downsampling & store output tile. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; + idx += blockDim.x) { + int relOutX, relOutY0; + fast_div_mod(relOutX, relOutY0, idx); + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileOutW + relOutX; + scalar_t v = 0; +#pragma unroll + for (int step = 0; step < fdSize; step++) + v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step]; + + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY0; + + if (outX < p.yShape.x & outY < p.yShape.y) + *((T *)((char *)p.y + (outX * get_stride(p.yStride.x) + + outY * get_stride(p.yStride.y) + + mapOfsOut))) = (T)v; + } + } else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) { + // Full downsampling filter. + if (down == 2) { + // 2-wide. + __syncthreads(); + for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; + idx += blockDim.x * 2) { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + int relUpX0 = relOutX0 * down; + int relUpY0 = relOutY0 * down; + int src0 = relUpY0 * tileUpW + relUpX0; + vec2_t v = InternalType::zero_vec2(); +#pragma unroll + for (int sy = 0; sy < fdSize; sy++) +#pragma unroll + for (int sx = 0; sx < fdSize; sx++) { + v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * + (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * + (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; + } + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outY < p.yShape.y) { + index_t ofs = outX * get_stride(p.yStride.x) + + outY * get_stride(p.yStride.y) + mapOfsOut; + if (outX + 0 < p.yShape.x) *((T *)((char *)p.y + ofs)) = (T)v.x; + if (outX + 1 < p.yShape.x) + *((T *)((char *)p.y + ofs + get_stride(p.yStride.x))) = + (T)v.y; + } + } + } else if (down == 1 && !downInline) { + // Thread per pixel. + __syncthreads(); + for (int idx = threadIdx.x; idx < tileOutW * tileOutH; + idx += blockDim.x) { + int relOutX0, relOutY0; + fast_div_mod(relOutX0, relOutY0, idx); + scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter. + + int outX = tileOutX + relOutX0; + int outY = tileOutY + relOutY0; + if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y) + *((T *)((char *)p.y + (outX * get_stride(p.yStride.x) + + outY * get_stride(p.yStride.y) + + mapOfsOut))) = (T)v; + } + } + } + + if (!enableXrep) break; + } +} + +//------------------------------------------------------------------------ +// Compute activation function and signs for upsampled data tensor, modifying +// data tensor in-place. Used for accelerating the generic variant. Sign tensor +// is known to be contiguous, and p.x and p.s have the same z, w dimensions. +// 64-bit indexing is always used. + +template +static __global__ void filtered_lrelu_act_kernel( + filtered_lrelu_act_kernel_params p) { + typedef typename InternalType::scalar_t scalar_t; + + // Indexing. + int32_t x = threadIdx.x + blockIdx.x * blockDim.x; + int32_t ymax = signWrite ? p.sShape.y : p.xShape.y; + int32_t qmax = + p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index. + + // Loop to accommodate oversized tensors. + for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z) + for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) { + // Extract z and w (channel, minibatch index). + int32_t w = q / p.xShape.z; + int32_t z = q - w * p.xShape.z; + + // Choose behavior based on sign read/write mode. + if (signWrite) { + // Process value if in p.x. + uint32_t s = 0; + if (x < p.xShape.x && y < p.xShape.y) { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + + w * p.xStride.w; + T *pv = ((T *)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + + // Gain, LReLU, clamp. + v *= p.gain; + if (v < 0.f) { + v *= p.slope; + s = 1; // Sign. + } + if (fabsf(v) > p.clamp) { + v = InternalType::clamp(v, p.clamp); + s = 2; // Clamp. + } + + *pv = (T)v; // Write value. + } + + // Coalesce into threads 0 and 16 of warp. + uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; + s <<= ((threadIdx.x & 15) << 1); // Shift into place. +#ifdef MMCV_WITH_HIP + s |= __shfl_xor(s, 1); // Distribute. + s |= __shfl_xor(s, 2); + s |= __shfl_xor(s, 4); + s |= __shfl_xor(s, 8); +#else + s |= __shfl_xor_sync(m, s, 1); // Distribute. + s |= __shfl_xor_sync(m, s, 2); + s |= __shfl_xor_sync(m, s, 4); + s |= __shfl_xor_sync(m, s, 8); +#endif + + // Write signs if leader and in p.s. + if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. + { + uint64_t is = + x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous. + ((uint32_t *)p.s)[is >> 4] = s; + } + } else if (signRead) { + // Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + + w * p.xStride.w; + T *pv = ((T *)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + + // Apply sign buffer offset. + uint32_t sx = x + p.sOfs.x; + uint32_t sy = y + p.sOfs.y; + + // Read and apply signs if we land inside valid region of sign buffer. + if (sx < p.sShape.x && sy < p.sShape.y) { + uint64_t is = + (sx >> 2) + (p.sShape.x >> 2) * + (sy + (uint64_t)p.sShape.y * q); // Contiguous. + unsigned char s = p.s[is]; + s >>= (sx & 3) << 1; // Shift into place. + if (s & 1) // Sign? + v *= p.slope; + if (s & 2) // Clamp? + v = 0.f; + } + + *pv = (T)v; // Write value. + } + } else { + // Forward pass with no sign write. Process value if in p.x. + if (x < p.xShape.x) // y is always in. + { + int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + + w * p.xStride.w; + T *pv = ((T *)p.x) + ix; + scalar_t v = (scalar_t)(*pv); + v *= p.gain; + if (v < 0.f) v *= p.slope; + if (fabsf(v) > p.clamp) v = InternalType::clamp(v, p.clamp); + *pv = (T)v; // Write value. + } + } + } +} + +template +void *choose_filtered_lrelu_act_kernel(void) { + return (void *)filtered_lrelu_act_kernel; +} + +//------------------------------------------------------------------------ +// MUSA kernel selection. + +template +filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel( + const filtered_lrelu_kernel_params &p, int sharedKB) { + filtered_lrelu_kernel_spec s = {0}; + + // Return the first matching kernel. +#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \ + if (sharedKB >= SH) \ + if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || \ + (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \ + if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || \ + (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \ + if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && \ + p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) { \ + static_assert((D * TW % 4) == 0, \ + "down * tileWidth must be divisible by 4"); \ + static_assert( \ + FU % U == 0, \ + "upscaling filter size must be multiple of upscaling factor"); \ + static_assert(FD % D == 0, \ + "downscaling filter size must be multiple of " \ + "downscaling factor"); \ + s.setup = (void *)setup_filters_kernel; \ + s.exec = (void *) \ + filtered_lrelu_kernel; \ + s.tileOut = make_int2(TW, TH); \ + s.numWarps = W; \ + s.xrep = XR; \ + s.dynamicSharedKB = (SH == 48) ? 0 : SH; \ + return s; \ + } + + // Launch parameters for various kernel specializations. + // Small filters must be listed before large filters, otherwise the kernel for + // larger filter will always match first. Kernels that use more shared memory + // must be listed before those that use less, for the same reason. + + CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 1, 1, /*mode*/ MODE_FUFD, + /*tw,th,warps,xrep,wskip*/ 64, 178, 32, 0, 0) // 1t-upf1-downf1 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 152, 95, 16, 0, 0) // 4t-ups2-downf1 + CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 8, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 56, 22, 16, 0, 0) // 4t-upf1-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 56, 29, 16, 11, 0) // 4t-ups2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 60, 28, 16, 0, 0) // 4t-upf2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 2, 8, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 56, 28, 16, 0, 0) // 4t-ups2-downf2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 16, /*down,fd*/ 2, 8, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 56, 31, 16, 11, 0) // 4t-ups4-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 16, /*down,fd*/ 2, 8, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 56, 36, 16, 0, 0) // 4t-ups4-downf2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 4, 16, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 16, 22, 16, 12, 0) // 4t-ups2-downs4 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 8, /*down,fd*/ 4, 16, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 29, 15, 16, 0, 0) // 4t-upf2-downs4 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 96, 150, 28, 0, 0) // 6t-ups2-downf1 + CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 12, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 32, 35, 24, 0, 0) // 6t-upf1-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 32, 46, 16, 10, 0) // 6t-ups2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 58, 28, 24, 8, 0) // 6t-upf2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 2, 12, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 52, 28, 16, 0, 0) // 6t-ups2-downf2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 24, /*down,fd*/ 2, 12, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 32, 51, 16, 5, 0) // 6t-ups4-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 24, /*down,fd*/ 2, 12, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 32, 56, 16, 6, 0) // 6t-ups4-downf2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 16, 18, 16, 12, 0) // 6t-ups2-downs4 + CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 12, /*down,fd*/ 4, 24, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 27, 13, 24, 0, 0) // 6t-upf2-downs4 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 1, 1, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 148, 89, 24, 0, 0) // 8t-ups2-downf1 + CASE(/*sharedKB*/ 48, /*up,fu*/ 1, 1, /*down,fd*/ 2, 16, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 32, 31, 16, 5, 0) // 8t-upf1-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 32, 41, 16, 9, 0) // 8t-ups2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 56, 26, 24, 0, 0) // 8t-upf2-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 2, 16, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 32, 40, 16, 0, 0) // 8t-ups2-downf2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 32, /*down,fd*/ 2, 16, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 32, 46, 24, 5, 0) // 8t-ups4-downs2 + CASE(/*sharedKB*/ 48, /*up,fu*/ 4, 32, /*down,fd*/ 2, 16, /*mode*/ MODE_SUFD, + /*tw,th,warps,xrep,wskip*/ 32, 50, 16, 0, 0) // 8t-ups4-downf2 + CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_SUSD, + /*tw,th,warps,xrep,wskip*/ 16, 13, 16, 10, 1) // 8t-ups2-downs4 + CASE(/*sharedKB*/ 96, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB + CASE(/*sharedKB*/ 48, /*up,fu*/ 2, 16, /*down,fd*/ 4, 32, /*mode*/ MODE_FUSD, + /*tw,th,warps,xrep,wskip*/ 25, 10, 24, 0, 0) // 8t-upf2-downs4 + +#undef CASE + return s; // No kernel found. +} + +//------------------------------------------------------------------------ + +#define BUILD_FILTERED_LRELU_OP 1 + +#ifndef MMCV_WITH_HIP +#ifdef __GNUC__ +#if __GNUC__ < 6 +#undef BUILD_FILTERED_LRELU_OP +#define BUILD_FILTERED_LRELU_OP 0 +#endif +#endif + +std::tuple filtered_lrelu_op( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, + torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1, + int sx, int sy, float gain, float slope, float clamp, bool flip_filters, + bool writeSigns) { + // Set MUSA device. + TORCH_CHECK(x.is_privateuseone(), "x must reside on MUSA device"); + const at::musa::OptionalMUSAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && + b.device() == x.device(), + "all input tensors must reside on the same device"); + TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, + "fu and fd must be float32"); + TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, + "x and b must be float16 or float32"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && + x.size(3) <= INT_MAX, + "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK( + (fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), + "fu and fd must be rank 1 or 2"); + TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, + "fu is too large"); + TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, + "fd is too large"); + TORCH_CHECK(fu.numel() > 0, "fu is empty"); + TORCH_CHECK(fd.numel() > 0, "fd is empty"); + TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), + "b must be a vector with the same number of channels as x"); + TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); + + // Figure out how much shared memory is available on the device. + int maxSharedBytes = 0; +#ifdef MMCV_WITH_HIP + musaDeviceGetAttribute(&maxSharedBytes, + hipDeviceAttributeSharedMemPerBlockOptin, + x.device().index()); +#else + AT_MUSA_CHECK(musaDeviceGetAttribute(&maxSharedBytes, + musaDevAttrMaxSharedMemoryPerBlockOptin, + x.device().index())); +#endif + int sharedKB = maxSharedBytes >> 10; + + // Populate enough launch parameters to check if a MUSA kernel exists. + filtered_lrelu_kernel_params p; + p.up = up; + p.down = down; + p.fuShape = + make_int2((int)fu.size(-1), + fu.dim() == 2 ? (int)fu.size(0) + : 0); // shape [n, 0] indicates separable filter. + p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); + filtered_lrelu_kernel_spec test_spec = + choose_filtered_lrelu_kernel(p, sharedKB); + if (!test_spec.exec) { + // No kernel found - return empty tensors and indicate missing kernel with + // return code of -1. + return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); + } + + // Input/output element size. + int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; + + // Input sizes. + int64_t xw = (int)x.size(3); + int64_t xh = (int)x.size(2); + int64_t fut_w = (int)fu.size(-1) - 1; + int64_t fut_h = (int)fu.size(0) - 1; + int64_t fdt_w = (int)fd.size(-1) - 1; + int64_t fdt_h = (int)fd.size(0) - 1; + + // Logical size of upsampled buffer. + int64_t cw = xw * up + (px0 + px1) - fut_w; + int64_t ch = xh * up + (py0 + py1) - fut_h; + TORCH_CHECK( + cw > fdt_w && ch > fdt_h, + "upsampled buffer must be at least the size of downsampling filter"); + TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); + + // Compute output size and allocate. + int64_t yw = (cw - fdt_w + (down - 1)) / down; + int64_t yh = (ch - fdt_h + (down - 1)) / down; + TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); + TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), + x.suggest_memory_format()); + + // Allocate sign tensor. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + int64_t sw_active = 0; // Active width of sign tensor. + if (writeSigns) { + sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. + int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. + int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, + // rounded up to multiple of 16. + TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); + s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, + x.options().dtype(torch::kUInt8), + at::MemoryFormat::Contiguous); + } else if (readSigns) + sw_active = s.size(3) << 2; + + // Validate sign tensor if in use. + if (readSigns || writeSigns) { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), + "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), + "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, + "signs is too large"); + } + + // Populate rest of MUSA kernel parameters. + p.x = x.data_ptr(); + p.y = y.data_ptr(); + p.b = b.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.fu = fu.data_ptr(); + p.fd = fd.data_ptr(); + p.pad0 = make_int2(px0, py0); + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.flip = (flip_filters) ? 1 : 0; + p.xShape = + make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.yShape = + make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.sShape = (readSigns || writeSigns) + ? make_int2((int)s.size(3), (int)s.size(2)) + : make_int2(0, 0); // Width is in bytes. Contiguous. + p.sOfs = make_int2(sx, sy); + p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. + + // x, y, b strides are in bytes. + p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), + sz * x.stride(1), sz * x.stride(0)); + p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), + sz * y.stride(1), sz * y.stride(0)); + p.bStride = sz * b.stride(0); + + // fu, fd strides are in elements. + p.fuStride = + make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); + p.fdStride = + make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); + + // Determine if indices don't fit in int32. Support negative strides although + // Torch currently never produces those. + bool index64b = false; + if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; + if (std::min(x.size(0) * p.xStride.w, 0ll) + + std::min(x.size(1) * p.xStride.z, 0ll) + + std::min(x.size(2) * p.xStride.y, 0ll) + + std::min(x.size(3) * p.xStride.x, 0ll) < + -INT_MAX) + index64b = true; + if (std::max(x.size(0) * p.xStride.w, 0ll) + + std::max(x.size(1) * p.xStride.z, 0ll) + + std::max(x.size(2) * p.xStride.y, 0ll) + + std::max(x.size(3) * p.xStride.x, 0ll) > + INT_MAX) + index64b = true; + if (std::min(y.size(0) * p.yStride.w, 0ll) + + std::min(y.size(1) * p.yStride.z, 0ll) + + std::min(y.size(2) * p.yStride.y, 0ll) + + std::min(y.size(3) * p.yStride.x, 0ll) < + -INT_MAX) + index64b = true; + if (std::max(y.size(0) * p.yStride.w, 0ll) + + std::max(y.size(1) * p.yStride.z, 0ll) + + std::max(y.size(2) * p.yStride.y, 0ll) + + std::max(y.size(3) * p.yStride.x, 0ll) > + INT_MAX) + index64b = true; + if (s.numel() > INT_MAX) index64b = true; + + // Choose MUSA kernel. + filtered_lrelu_kernel_spec spec = {0}; + AT_DISPATCH_FLOATING_TYPES( + x.scalar_type(), "filtered_lrelu_musa", [&] { + if constexpr (sizeof(scalar_t) <= + 4) // Exclude doubles. constexpr + // prevents template instantiation. + { + // Choose kernel based on index type, datatype and sign read/write + // modes. + if (!index64b && writeSigns && !readSigns) + spec = choose_filtered_lrelu_kernel( + p, sharedKB); + else if (!index64b && !writeSigns && readSigns) + spec = choose_filtered_lrelu_kernel( + p, sharedKB); + else if (!index64b && !writeSigns && !readSigns) + spec = + choose_filtered_lrelu_kernel( + p, sharedKB); + else if (index64b && writeSigns && !readSigns) + spec = choose_filtered_lrelu_kernel( + p, sharedKB); + else if (index64b && !writeSigns && readSigns) + spec = choose_filtered_lrelu_kernel( + p, sharedKB); + else if (index64b && !writeSigns && !readSigns) + spec = + choose_filtered_lrelu_kernel( + p, sharedKB); + } + }); + TORCH_CHECK( + spec.exec, + "internal error - MUSA kernel not found") // This should not happen + // because we tested earlier + // that kernel exists. + + // Launch MUSA kernel. + void *args[] = {&p}; + int bx = spec.numWarps * 32; + int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; + int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; + int gz = p.yShape.z * p.yShape.w; + + // Repeat multiple horizontal tiles in a CTA? + if (spec.xrep) { + p.tilesXrep = spec.xrep; + p.tilesXdim = gx; + + gx = (gx + p.tilesXrep - 1) / p.tilesXrep; + std::swap(gx, gy); + } else { + p.tilesXrep = 0; + p.tilesXdim = 0; + } +#ifdef MMCV_WITH_HIP + AT_MUSA_CHECK(hipLaunchKernel(spec.setup, 1, 1024, args, 0, + c10::musa::getCurrentMUSAStream())); +#else + // Launch filter setup kernel. + AT_MUSA_CHECK(musaLaunchKernel(spec.setup, 1, 1024, args, 0, + c10::musa::getCurrentMUSAStream())); +#endif + + // Copy kernels to constant memory. + if (writeSigns && !readSigns) + AT_MUSA_CHECK((copy_filters(c10::musa::getCurrentMUSAStream()))); + else if (!writeSigns && readSigns) + AT_MUSA_CHECK((copy_filters(c10::musa::getCurrentMUSAStream()))); + else if (!writeSigns && !readSigns) + AT_MUSA_CHECK((copy_filters(c10::musa::getCurrentMUSAStream()))); + + // Set cache and shared memory configurations for main kernel. + // FIXME:TODO FIX BUG + AT_MUSA_CHECK(musaFuncSetCacheConfig(spec.exec, musaFuncCachePreferShared)); + if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? +#ifdef MMCV_WITH_HIP + AT_MUSA_CHECK(hipFuncSetAttribute( + spec.exec, hipFuncAttributeMaxDynamicSharedMemorySize, + spec.dynamicSharedKB << 10)); +#else + AT_MUSA_CHECK(musaFuncSetAttribute( + spec.exec, musaFuncAttributeMaxDynamicSharedMemorySize, + spec.dynamicSharedKB << 10)); +#endif + // FIXME:TODO FIX BUG + AT_MUSA_CHECK( + musaFuncSetSharedMemConfig(spec.exec, musaSharedMemBankSizeFourByte)); + + // Launch main kernel. + const int maxSubGz = 65535; // MUSA maximum for block z dimension. + for (int zofs = 0; zofs < gz; + zofs += maxSubGz) // Do multiple launches if gz is too big. + { + p.blockZofs = zofs; + int subGz = std::min(maxSubGz, gz - zofs); +// FIXME:TODO FIX BUG +#ifdef MMCV_WITH_HIP + AT_MUSA_CHECK(hipLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, + spec.dynamicSharedKB << 10, + c10::musa::getCurrentMUSAStream())); +#else + AT_MUSA_CHECK(musaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, + spec.dynamicSharedKB << 10, + c10::musa::getCurrentMUSAStream())); +#endif + } + + // Done. + return std::make_tuple(y, so, 0); +} + +std::tuple filtered_lrelu_op_impl( + torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, + torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1, + int sx, int sy, float gain, float slope, float clamp, bool flip_filters, + bool writeSigns); + +REGISTER_DEVICE_IMPL(filtered_lrelu_op_impl, MUSA, filtered_lrelu_op); + +#else + +#pragma message( \ + "filtered_lrelu_op is not available. " \ + "Please update your compiler and musa version.") + +#endif +#undef BUILD_FILTERED_LRELU_OP + +//------------------------------------------------------------------------ + +torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx, + int sy, float gain, float slope, + float clamp, bool writeSigns) { + // Set MUSA device. + TORCH_CHECK(x.is_privateuseone(), "x must reside on MUSA device"); + const at::musa::OptionalMUSAGuard device_guard(device_of(x)); + + // Validate arguments. + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && + x.size(3) <= INT_MAX, + "x is too large"); + TORCH_CHECK(x.numel() > 0, "x is empty"); + TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || + x.dtype() == torch::kDouble, + "x must be float16, float32 or float64"); + + // Output signs if we don't have sign input. + torch::Tensor so; + torch::Tensor s = si; + bool readSigns = !!s.numel(); + if (writeSigns) { + int64_t sw = x.size(3); + sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. + s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, + x.options().dtype(torch::kUInt8), + at::MemoryFormat::Contiguous); + } + + // Validate sign tensor if in use. + if (readSigns || writeSigns) { + TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); + TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); + TORCH_CHECK(s.device() == x.device(), + "signs must reside on the same device as x"); + TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); + TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), + "signs must have same batch & channels as x"); + TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, + "signs tensor is too large"); + } + + // Initialize MUSA kernel parameters. + filtered_lrelu_act_kernel_params p; + p.x = x.data_ptr(); + p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; + p.gain = gain; + p.slope = slope; + p.clamp = clamp; + p.xShape = + make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.xStride = + make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); + p.sShape = (readSigns || writeSigns) + ? make_int2((int)s.size(3) << 2, (int)s.size(2)) + : make_int2(0, 0); // Width is in elements. Contiguous. + p.sOfs = make_int2(sx, sy); + + // Choose MUSA kernel. + void *func = 0; + AT_DISPATCH_FLOATING_TYPES( + x.scalar_type(), "filtered_lrelu_act_musa", [&] { + if (writeSigns) + func = choose_filtered_lrelu_act_kernel(); + else if (readSigns) + func = choose_filtered_lrelu_act_kernel(); + else + func = choose_filtered_lrelu_act_kernel(); + }); + TORCH_CHECK(func, "internal error - MUSA kernel not found"); + + // Launch MUSA kernel. + void *args[] = {&p}; + int bx = 128; // 4 warps per block. + + // Logical size of launch = writeSigns ? p.s : p.x + uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; + uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; + uint32_t gz = + p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. + gx = (gx - 1) / bx + 1; + + // Make sure grid y and z dimensions are within MUSA launch limits. Kernel + // loops internally to do the rest. + const uint32_t gmax = 65535; + gy = std::min(gy, gmax); + gz = std::min(gz, gmax); + + // Launch. +#ifdef MMCV_WITH_HIP + AT_MUSA_CHECK(hipLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, + c10::musa::getCurrentMUSAStream())); +#else + AT_MUSA_CHECK(musaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, + c10::musa::getCurrentMUSAStream())); +#endif + + return so; +} diff --git a/mmcv/ops/csrc/pytorch/musa/musabind.cpp b/mmcv/ops/csrc/pytorch/musa/musabind.cpp index ebbb692756..889574f957 100644 --- a/mmcv/ops/csrc/pytorch/musa/musabind.cpp +++ b/mmcv/ops/csrc/pytorch/musa/musabind.cpp @@ -540,6 +540,17 @@ torch::Tensor bias_act_op(const torch::Tensor &input, const torch::Tensor &bias, REGISTER_DEVICE_IMPL(bias_act_op_impl, MUSA, bias_act_op); +torch::Tensor filtered_lrelu_act_op_impl(torch::Tensor x, torch::Tensor si, + int sx, int sy, float gain, + float slope, float clamp, + bool writeSigns); + +torch::Tensor filtered_lrelu_act_op(torch::Tensor x, torch::Tensor si, int sx, + int sy, float gain, float slope, + float clamp, bool writeSigns); + +REGISTER_DEVICE_IMPL(filtered_lrelu_act_op_impl, MUSA, filtered_lrelu_act_op); + void GatherPointsForwardMUSAKernelLauncher(int b, int c, int n, int npoints, const Tensor points, const Tensor idx, Tensor out); @@ -869,6 +880,854 @@ Tensor nms_musa(Tensor boxes, Tensor scores, float iou_threshold, int offset) { Tensor nms_impl(Tensor boxes, Tensor scores, float iou_threshold, int offset); REGISTER_DEVICE_IMPL(nms_impl, MUSA, nms_musa); +void PointsInBoxesPartForwardMUSAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points); + +void PointsInBoxesAllForwardMUSAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points); + +void points_in_boxes_part_forward_musa(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + PointsInBoxesPartForwardMUSAKernelLauncher(batch_size, boxes_num, pts_num, + boxes, pts, box_idx_of_points); +}; + +void points_in_boxes_all_forward_musa(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + PointsInBoxesAllForwardMUSAKernelLauncher(batch_size, boxes_num, pts_num, + boxes, pts, box_idx_of_points); +}; + +void points_in_boxes_part_forward_impl(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points); + +void points_in_boxes_all_forward_impl(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points); +REGISTER_DEVICE_IMPL(points_in_boxes_part_forward_impl, MUSA, + points_in_boxes_part_forward_musa); +REGISTER_DEVICE_IMPL(points_in_boxes_all_forward_impl, MUSA, + points_in_boxes_all_forward_musa); + +void PSAMaskForwardMUSAKernelLauncher(const int psa_type, const Tensor input, + Tensor output, const int num_, + const int h_feature, const int w_feature, + const int h_mask, const int w_mask, + const int half_h_mask, + const int half_w_mask); + +void PSAMaskBackwardMUSAKernelLauncher( + const int psa_type, const Tensor grad_output, Tensor grad_input, + const int num_, const int h_feature, const int w_feature, const int h_mask, + const int w_mask, const int half_h_mask, const int half_w_mask); + +void psamask_forward_musa(const int psa_type, const Tensor input, Tensor output, + const int num_, const int h_feature, + const int w_feature, const int h_mask, + const int w_mask, const int half_h_mask, + const int half_w_mask) { + PSAMaskForwardMUSAKernelLauncher(psa_type, input, output, num_, h_feature, + w_feature, h_mask, w_mask, half_h_mask, + half_w_mask); +} + +void psamask_backward_musa(const int psa_type, const Tensor grad_output, + Tensor grad_input, const int num_, + const int h_feature, const int w_feature, + const int h_mask, const int w_mask, + const int half_h_mask, const int half_w_mask) { + PSAMaskBackwardMUSAKernelLauncher(psa_type, grad_output, grad_input, num_, + h_feature, w_feature, h_mask, w_mask, + half_h_mask, half_w_mask); +} + +void psamask_forward_impl(const int psa_type, const Tensor input, Tensor output, + const int num_, const int h_feature, + const int w_feature, const int h_mask, + const int w_mask, const int half_h_mask, + const int half_w_mask); + +void psamask_backward_impl(const int psa_type, const Tensor grad_output, + Tensor grad_input, const int num_, + const int h_feature, const int w_feature, + const int h_mask, const int w_mask, + const int half_h_mask, const int half_w_mask); +REGISTER_DEVICE_IMPL(psamask_forward_impl, MUSA, psamask_forward_musa); +REGISTER_DEVICE_IMPL(psamask_backward_impl, MUSA, psamask_backward_musa); + +void ROIAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned); + +void ROIAlignBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor argmax_y, Tensor argmax_x, + Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned); + +void roi_align_forward_musa(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned) { + ROIAlignForwardMUSAKernelLauncher( + input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width, + spatial_scale, sampling_ratio, pool_mode, aligned); +} + +void roi_align_backward_musa(Tensor grad_output, Tensor rois, Tensor argmax_y, + Tensor argmax_x, Tensor grad_input, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned) { + ROIAlignBackwardMUSAKernelLauncher( + grad_output, rois, argmax_y, argmax_x, grad_input, aligned_height, + aligned_width, spatial_scale, sampling_ratio, pool_mode, aligned); +} + +void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned); + +void roi_align_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax_y, + Tensor argmax_x, Tensor grad_input, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned); + +REGISTER_DEVICE_IMPL(roi_align_forward_impl, MUSA, roi_align_forward_musa); +REGISTER_DEVICE_IMPL(roi_align_backward_impl, MUSA, roi_align_backward_musa); + +void ROIAlignRotatedForwardMUSAKernelLauncher( + const at::Tensor input, const at::Tensor rois, const float spatial_scale, + const int sampling_ratio, const bool aligned, const bool clockwise, + const int channels, const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, at::Tensor output); + +void ROIAlignRotatedBackwardMUSAKernelLauncher( + const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale, + const int sampling_ratio, const bool aligned, const bool clockwise, + const int channels, const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, at::Tensor bottom_grad); + +void roi_align_rotated_forward_musa(Tensor input, Tensor rois, Tensor output, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned, bool clockwise) { + // Number of ROIs + int num_rois = rois.size(0); + int size_rois = rois.size(1); + + if (size_rois != 6) { + AT_ERROR("wrong roi size"); + } + + int num_channels = input.size(1); + int data_height = input.size(2); + int data_width = input.size(3); + ROIAlignRotatedForwardMUSAKernelLauncher( + input, rois, spatial_scale, sampling_ratio, aligned, clockwise, + num_channels, data_height, data_width, num_rois, aligned_height, + aligned_width, output); +} + +void roi_align_rotated_backward_musa(Tensor top_grad, Tensor rois, + Tensor bottom_grad, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, bool aligned, + bool clockwise) { + // Number of ROIs + int num_rois = rois.size(0); + int size_rois = rois.size(1); + if (size_rois != 6) { + AT_ERROR("wrong roi size"); + } + + int num_channels = bottom_grad.size(1); + int data_height = bottom_grad.size(2); + int data_width = bottom_grad.size(3); + ROIAlignRotatedBackwardMUSAKernelLauncher( + top_grad, rois, spatial_scale, sampling_ratio, aligned, clockwise, + num_channels, data_height, data_width, num_rois, aligned_height, + aligned_width, bottom_grad); +} + +void roi_align_rotated_forward_impl(Tensor input, Tensor rois, Tensor output, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + bool aligned, bool clockwise); + +void roi_align_rotated_backward_impl(Tensor top_grad, Tensor rois, + Tensor bottom_grad, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, bool aligned, + bool clockwise); +REGISTER_DEVICE_IMPL(roi_align_rotated_forward_impl, MUSA, + roi_align_rotated_forward_musa); +REGISTER_DEVICE_IMPL(roi_align_rotated_backward_impl, MUSA, + roi_align_rotated_backward_musa); + +void RiROIAlignRotatedForwardMUSAKernelLauncher( + const at::Tensor features, const at::Tensor rois, const float spatial_scale, + const int num_samples, const bool clockwise, const int channels, + const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, const int num_orientations, + at::Tensor output); + +void RiROIAlignRotatedBackwardMUSAKernelLauncher( + const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale, + const int num_samples, const bool clockwise, const int channels, + const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, const int num_orientations, + at::Tensor bottom_grad); + +void riroi_align_rotated_forward_musa(Tensor features, Tensor rois, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale, + int num_samples, int num_orientations, + bool clockwise) { + // Number of ROIs + int num_rois = rois.size(0); + int size_rois = rois.size(1); + if (size_rois != 6) { + AT_ERROR("wrong roi size"); + } + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(rois); + int num_channels = features.size(1) / num_orientations; + int data_height = features.size(2); + int data_width = features.size(3); + RiROIAlignRotatedForwardMUSAKernelLauncher( + features, rois, spatial_scale, num_samples, clockwise, num_channels, + data_height, data_width, num_rois, pooled_height, pooled_width, + num_orientations, output); +} + +void riroi_align_rotated_backward_musa(Tensor top_grad, Tensor rois, + Tensor bottom_grad, int pooled_height, + int pooled_width, float spatial_scale, + int num_samples, int num_orientations, + bool clockwise) { + // Number of ROIs + int num_rois = rois.size(0); + int size_rois = rois.size(1); + if (size_rois != 6) { + AT_ERROR("wrong roi size"); + } + CHECK_CONTIGUOUS(top_grad); + CHECK_CONTIGUOUS(rois); + int num_channels = bottom_grad.size(1) / num_orientations; + int data_height = bottom_grad.size(2); + int data_width = bottom_grad.size(3); + RiROIAlignRotatedBackwardMUSAKernelLauncher( + top_grad, rois, spatial_scale, num_samples, clockwise, num_channels, + data_height, data_width, num_rois, pooled_height, pooled_width, + num_orientations, bottom_grad); +} + +void riroi_align_rotated_forward_impl(Tensor features, Tensor rois, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale, + int num_samples, int num_orientations, + bool clockwise); + +void riroi_align_rotated_backward_impl(Tensor top_grad, Tensor rois, + Tensor bottom_grad, int pooled_height, + int pooled_width, float spatial_scale, + int num_samples, int num_orientations, + bool clockwise); + +REGISTER_DEVICE_IMPL(riroi_align_rotated_forward_impl, MUSA, + riroi_align_rotated_forward_musa); +REGISTER_DEVICE_IMPL(riroi_align_rotated_backward_impl, MUSA, + riroi_align_rotated_backward_musa); + +void RoiawarePool3dForwardMUSAKernelLauncher( + int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x, + int out_y, int out_z, const Tensor rois, const Tensor pts, + const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method); + +void RoiawarePool3dBackwardMUSAKernelLauncher( + int boxes_num, int out_x, int out_y, int out_z, int channels, + int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax, + const Tensor grad_out, Tensor grad_in, int pool_method); + +void roiaware_pool3d_forward_musa(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const Tensor rois, + const Tensor pts, const Tensor pts_feature, + Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method) { + RoiawarePool3dForwardMUSAKernelLauncher( + boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, out_z, + rois, pts, pts_feature, argmax, pts_idx_of_voxels, pooled_features, + pool_method); +}; + +void roiaware_pool3d_backward_musa(int boxes_num, int out_x, int out_y, + int out_z, int channels, + int max_pts_each_voxel, + const Tensor pts_idx_of_voxels, + const Tensor argmax, const Tensor grad_out, + Tensor grad_in, int pool_method) { + RoiawarePool3dBackwardMUSAKernelLauncher( + boxes_num, out_x, out_y, out_z, channels, max_pts_each_voxel, + pts_idx_of_voxels, argmax, grad_out, grad_in, pool_method); +}; + +void roiaware_pool3d_forward_impl(int boxes_num, int pts_num, int channels, + int max_pts_each_voxel, int out_x, int out_y, + int out_z, const Tensor rois, + const Tensor pts, const Tensor pts_feature, + Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method); + +void roiaware_pool3d_backward_impl(int boxes_num, int out_x, int out_y, + int out_z, int channels, + int max_pts_each_voxel, + const Tensor pts_idx_of_voxels, + const Tensor argmax, const Tensor grad_out, + Tensor grad_in, int pool_method); + +REGISTER_DEVICE_IMPL(roiaware_pool3d_forward_impl, MUSA, + roiaware_pool3d_forward_musa); +REGISTER_DEVICE_IMPL(roiaware_pool3d_backward_impl, MUSA, + roiaware_pool3d_backward_musa); + +void RoIPointPool3dForwardMUSAKernelLauncher( + int batch_size, int pts_num, int boxes_num, int feature_in_len, + int sampled_pts_num, const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, Tensor pooled_features, Tensor pooled_empty_flag); + +void roipoint_pool3d_forward_musa(int batch_size, int pts_num, int boxes_num, + int feature_in_len, int sampled_pts_num, + const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, + Tensor pooled_features, + Tensor pooled_empty_flag) { + RoIPointPool3dForwardMUSAKernelLauncher( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz, + boxes3d, pts_feature, pooled_features, pooled_empty_flag); +}; + +void roipoint_pool3d_forward_impl(int batch_size, int pts_num, int boxes_num, + int feature_in_len, int sampled_pts_num, + const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, + Tensor pooled_features, + Tensor pooled_empty_flag); +REGISTER_DEVICE_IMPL(roipoint_pool3d_forward_impl, MUSA, + roipoint_pool3d_forward_musa); + +void ROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output, + Tensor argmax, int pooled_height, + int pooled_width, float spatial_scale); + +void ROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor argmax, Tensor grad_input, + int pooled_height, int pooled_width, + float spatial_scale); + +void roi_pool_forward_musa(Tensor input, Tensor rois, Tensor output, + Tensor argmax, int pooled_height, int pooled_width, + float spatial_scale) { + ROIPoolForwardMUSAKernelLauncher(input, rois, output, argmax, pooled_height, + pooled_width, spatial_scale); +} + +void roi_pool_backward_musa(Tensor grad_output, Tensor rois, Tensor argmax, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale) { + ROIPoolBackwardMUSAKernelLauncher(grad_output, rois, argmax, grad_input, + pooled_height, pooled_width, spatial_scale); +} + +void roi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, + Tensor argmax, int pooled_height, int pooled_width, + float spatial_scale); +void roi_pool_backward_impl(Tensor grad_output, Tensor rois, Tensor argmax, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale); +REGISTER_DEVICE_IMPL(roi_pool_forward_impl, MUSA, roi_pool_forward_musa); +REGISTER_DEVICE_IMPL(roi_pool_backward_impl, MUSA, roi_pool_backward_musa); + +typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t; + +std::vector DynamicPointToVoxelForwardMUSAKernelLauncher( + const at::Tensor &feats, const at::Tensor &coors, + const reduce_t reduce_type); + +void DynamicPointToVoxelBackwardMUSAKernelLauncher( + at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats, + const at::Tensor &feats, const at::Tensor &reduced_feats, + const at::Tensor &coors_map, const at::Tensor &reduce_count, + const reduce_t reduce_type); + +std::vector dynamic_point_to_voxel_forward_musa( + const torch::Tensor &feats, const torch::Tensor &coors, + const reduce_t reduce_type) { + return DynamicPointToVoxelForwardMUSAKernelLauncher(feats, coors, + reduce_type); +}; + +void dynamic_point_to_voxel_backward_musa( + torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats, + const torch::Tensor &feats, const torch::Tensor &reduced_feats, + const torch::Tensor &coors_idx, const torch::Tensor &reduce_count, + const reduce_t reduce_type) { + DynamicPointToVoxelBackwardMUSAKernelLauncher(grad_feats, grad_reduced_feats, + feats, reduced_feats, coors_idx, + reduce_count, reduce_type); +}; + +std::vector dynamic_point_to_voxel_forward_impl( + const torch::Tensor &feats, const torch::Tensor &coors, + const reduce_t reduce_type); + +void dynamic_point_to_voxel_backward_impl( + torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats, + const torch::Tensor &feats, const torch::Tensor &reduced_feats, + const torch::Tensor &coors_idx, const torch::Tensor &reduce_count, + const reduce_t reduce_type); + +REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_forward_impl, MUSA, + dynamic_point_to_voxel_forward_musa); +REGISTER_DEVICE_IMPL(dynamic_point_to_voxel_backward_impl, MUSA, + dynamic_point_to_voxel_backward_musa); + +void SyncBNForwardMeanMUSAKernelLauncher(const Tensor input, Tensor mean); + +void SyncBNForwardVarMUSAKernelLauncher(const Tensor input, const Tensor mean, + Tensor var); + +void SyncBNForwardOutputMUSAKernelLauncher( + const Tensor input, const Tensor mean, const Tensor var, + Tensor running_mean, Tensor running_var, const Tensor weight, + const Tensor bias, Tensor norm, Tensor std, Tensor output, float eps, + float momentum, int group_size); + +void SyncBNBackwardParamMUSAKernelLauncher(const Tensor grad_output, + const Tensor norm, + Tensor grad_weight, + Tensor grad_bias); + +void SyncBNBackwardDataMUSAKernelLauncher(const Tensor grad_output, + const Tensor weight, + const Tensor grad_weight, + const Tensor grad_bias, + const Tensor norm, const Tensor std, + Tensor grad_input); + +void sync_bn_forward_mean_musa(const Tensor input, Tensor mean) { + SyncBNForwardMeanMUSAKernelLauncher(input, mean); +} + +void sync_bn_forward_var_musa(const Tensor input, const Tensor mean, + Tensor var) { + SyncBNForwardVarMUSAKernelLauncher(input, mean, var); +} + +void sync_bn_forward_output_musa(const Tensor input, const Tensor mean, + const Tensor var, Tensor running_mean, + Tensor running_var, const Tensor weight, + const Tensor bias, Tensor norm, Tensor std, + Tensor output, float eps, float momentum, + int group_size) { + SyncBNForwardOutputMUSAKernelLauncher(input, mean, var, running_mean, + running_var, weight, bias, norm, std, + output, eps, momentum, group_size); +} + +void sync_bn_backward_param_musa(const Tensor grad_output, const Tensor norm, + Tensor grad_weight, Tensor grad_bias) { + SyncBNBackwardParamMUSAKernelLauncher(grad_output, norm, grad_weight, + grad_bias); +} + +void sync_bn_backward_data_musa(const Tensor grad_output, const Tensor weight, + const Tensor grad_weight, + const Tensor grad_bias, const Tensor norm, + const Tensor std, Tensor grad_input) { + SyncBNBackwardDataMUSAKernelLauncher(grad_output, weight, grad_weight, + grad_bias, norm, std, grad_input); +} + +void sync_bn_forward_mean_impl(const Tensor input, Tensor mean); + +void sync_bn_forward_var_impl(const Tensor input, const Tensor mean, + Tensor var); + +void sync_bn_forward_output_impl(const Tensor input, const Tensor mean, + const Tensor var, Tensor running_mean, + Tensor running_var, const Tensor weight, + const Tensor bias, Tensor norm, Tensor std, + Tensor output, float eps, float momentum, + int group_size); + +void sync_bn_backward_param_impl(const Tensor grad_output, const Tensor norm, + Tensor grad_weight, Tensor grad_bias); + +void sync_bn_backward_data_impl(const Tensor grad_output, const Tensor weight, + const Tensor grad_weight, + const Tensor grad_bias, const Tensor norm, + const Tensor std, Tensor grad_input); + +REGISTER_DEVICE_IMPL(sync_bn_forward_mean_impl, MUSA, + sync_bn_forward_mean_musa); +REGISTER_DEVICE_IMPL(sync_bn_forward_var_impl, MUSA, sync_bn_forward_var_musa); +REGISTER_DEVICE_IMPL(sync_bn_forward_output_impl, MUSA, + sync_bn_forward_output_musa); +REGISTER_DEVICE_IMPL(sync_bn_backward_param_impl, MUSA, + sync_bn_backward_param_musa); +REGISTER_DEVICE_IMPL(sync_bn_backward_data_impl, MUSA, + sync_bn_backward_data_musa); + +void ThreeInterpolateForwardMUSAKernelLauncher(int b, int c, int m, int n, + const Tensor points, + const Tensor idx, + const Tensor weight, Tensor out); + +void ThreeInterpolateBackwardMUSAKernelLauncher(int b, int c, int n, int m, + const Tensor grad_out, + const Tensor idx, + const Tensor weight, + Tensor grad_points); + +void three_interpolate_forward_musa(int b, int c, int m, int n, + const Tensor points, const Tensor idx, + const Tensor weight, Tensor out) { + ThreeInterpolateForwardMUSAKernelLauncher(b, c, m, n, points, idx, weight, + out); +}; + +void three_interpolate_backward_musa(int b, int c, int n, int m, + const Tensor grad_out, const Tensor idx, + const Tensor weight, Tensor grad_points) { + ThreeInterpolateBackwardMUSAKernelLauncher(b, c, n, m, grad_out, idx, weight, + grad_points); +}; + +void three_interpolate_forward_impl(int b, int c, int m, int n, + const Tensor points, const Tensor idx, + const Tensor weight, Tensor out); + +void three_interpolate_backward_impl(int b, int c, int n, int m, + const Tensor grad_out, const Tensor idx, + const Tensor weight, Tensor grad_points); +REGISTER_DEVICE_IMPL(three_interpolate_forward_impl, MUSA, + three_interpolate_forward_musa); +REGISTER_DEVICE_IMPL(three_interpolate_backward_impl, MUSA, + three_interpolate_backward_musa); + +void ThreeNNForwardMUSAKernelLauncher(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, + Tensor idx); + +void three_nn_forward_musa(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx) { + ThreeNNForwardMUSAKernelLauncher(b, n, m, unknown, known, dist2, idx); +}; + +void three_nn_forward_impl(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, Tensor idx); +REGISTER_DEVICE_IMPL(three_nn_forward_impl, MUSA, three_nn_forward_musa); + +void TINShiftForwardMUSAKernelLauncher(Tensor input, Tensor shift, + Tensor output); + +void TINShiftBackwardMUSAKernelLauncher(Tensor grad_output, Tensor shift, + Tensor grad_input); + +void tin_shift_forward_musa(Tensor input, Tensor shift, Tensor output) { + TINShiftForwardMUSAKernelLauncher(input, shift, output); +} + +void tin_shift_backward_musa(Tensor grad_output, Tensor shift, + Tensor grad_input) { + TINShiftBackwardMUSAKernelLauncher(grad_output, shift, grad_input); +} + +void tin_shift_forward_impl(Tensor input, Tensor shift, Tensor output); +void tin_shift_backward_impl(Tensor grad_output, Tensor shift, + Tensor grad_input); +REGISTER_DEVICE_IMPL(tin_shift_forward_impl, MUSA, tin_shift_forward_musa); +REGISTER_DEVICE_IMPL(tin_shift_backward_impl, MUSA, tin_shift_backward_musa); + +#if ((!defined(MUSA_ARCH)) || (defined(MUSA_ARCH)) && (MUSA_ARCH > 21)) +torch::Tensor upfirdn2d_op(torch::Tensor input, torch::Tensor filter, int upx, + int upy, int downx, int downy, int padx0, int padx1, + int pady0, int pady1, bool flip, float gain); + +torch::Tensor upfirdn2d_op_impl(torch::Tensor input, torch::Tensor filter, + int upx, int upy, int downx, int downy, + int padx0, int padx1, int pady0, int pady1, + bool flip, float gain); +REGISTER_DEVICE_IMPL(upfirdn2d_op_impl, MUSA, upfirdn2d_op); +#endif + +int HardVoxelizeForwardMUSAKernelLauncher( + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim = 3); + +int NondeterministicHardVoxelizeForwardMUSAKernelLauncher( + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim = 3); + +void DynamicVoxelizeForwardMUSAKernelLauncher( + const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, const std::vector coors_range, + const int NDim = 3); + +int hard_voxelize_forward_musa(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, + at::Tensor &num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim) { + return HardVoxelizeForwardMUSAKernelLauncher( + points, voxels, coors, num_points_per_voxel, voxel_size, coors_range, + max_points, max_voxels, NDim); +}; + +int nondeterministic_hard_voxelize_forward_musa( + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim) { + return NondeterministicHardVoxelizeForwardMUSAKernelLauncher( + points, voxels, coors, num_points_per_voxel, voxel_size, coors_range, + max_points, max_voxels, NDim); +}; + +void dynamic_voxelize_forward_musa(const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim) { + DynamicVoxelizeForwardMUSAKernelLauncher(points, coors, voxel_size, + coors_range, NDim); +}; + +int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, + at::Tensor &num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim); + +int nondeterministic_hard_voxelize_forward_impl( + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim); + +void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, + const std::vector coors_range, + const int NDim); + +REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, MUSA, + hard_voxelize_forward_musa); +REGISTER_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl, MUSA, + nondeterministic_hard_voxelize_forward_musa); +REGISTER_DEVICE_IMPL(dynamic_voxelize_forward_impl, MUSA, + dynamic_voxelize_forward_musa); + +void RotatedFeatureAlignForwardMUSAKernelLauncher(const Tensor features, + const Tensor best_bboxes, + const float spatial_scale, + const int points, + Tensor output); + +void RotatedFeatureAlignBackwardMUSAKernelLauncher(const Tensor top_grad, + const Tensor best_bboxes, + const float spatial_scale, + const int points, + Tensor bottom_grad); + +void rotated_feature_align_forward_musa(const Tensor features, + const Tensor best_bboxes, + const float spatial_scale, + const int points, Tensor output) { + RotatedFeatureAlignForwardMUSAKernelLauncher(features, best_bboxes, + spatial_scale, points, output); +}; + +void rotated_feature_align_backward_musa(const Tensor top_grad, + const Tensor best_bboxes, + const float spatial_scale, + const int points, Tensor bottom_grad) { + RotatedFeatureAlignBackwardMUSAKernelLauncher( + top_grad, best_bboxes, spatial_scale, points, bottom_grad); +}; + +void rotated_feature_align_forward_impl(const Tensor features, + const Tensor best_bboxes, + const float spatial_scale, + const int points, Tensor output); + +void rotated_feature_align_backward_impl(const Tensor top_grad, + const Tensor best_bboxes, + const float spatial_scale, + const int points, Tensor bottom_grad); + +REGISTER_DEVICE_IMPL(rotated_feature_align_forward_impl, MUSA, + rotated_feature_align_forward_musa); +REGISTER_DEVICE_IMPL(rotated_feature_align_backward_impl, MUSA, + rotated_feature_align_backward_musa); + +void PointsInPolygonsForwardMUSAKernelLauncher(const at::Tensor points, + const at::Tensor polygons, + const int rows, const int cols, + at::Tensor output); + +void points_in_polygons_forward_musa(const Tensor points, const Tensor polygons, + Tensor output, const int rows, + const int cols) { + PointsInPolygonsForwardMUSAKernelLauncher(points, polygons, rows, cols, + output); +}; + +void points_in_polygons_forward_impl(const Tensor points, const Tensor polygons, + Tensor output, const int rows, + const int cols); + +REGISTER_DEVICE_IMPL(points_in_polygons_forward_impl, MUSA, + points_in_polygons_forward_musa); + +torch::Tensor IndiceMaxpoolForwardMUSAKernelLauncher(torch::Tensor features, + torch::Tensor indicePairs, + torch::Tensor indiceNum, + int64_t numAct); + +torch::Tensor indice_maxpool_forward_musa(torch::Tensor features, + torch::Tensor indicePairs, + torch::Tensor indiceNum, + int64_t numAct) { + return IndiceMaxpoolForwardMUSAKernelLauncher(features, indicePairs, + indiceNum, numAct); +}; + +torch::Tensor indice_maxpool_forward_impl(torch::Tensor features, + torch::Tensor indicePairs, + torch::Tensor indiceNum, + int64_t numAct); +REGISTER_DEVICE_IMPL(indice_maxpool_forward_impl, MUSA, + indice_maxpool_forward_musa); + +torch::Tensor IndiceMaxpoolBackwardMUSAKernelLauncher(torch::Tensor features, + torch::Tensor outFeatures, + torch::Tensor outGrad, + torch::Tensor indicePairs, + torch::Tensor indiceNum); + +torch::Tensor indice_maxpool_backward_musa(torch::Tensor features, + torch::Tensor outFeatures, + torch::Tensor outGrad, + torch::Tensor indicePairs, + torch::Tensor indiceNum) { + return IndiceMaxpoolBackwardMUSAKernelLauncher(features, outFeatures, outGrad, + indicePairs, indiceNum); +}; + +torch::Tensor indice_maxpool_backward_impl(torch::Tensor features, + torch::Tensor outFeatures, + torch::Tensor outGrad, + torch::Tensor indicePairs, + torch::Tensor indiceNum); + +REGISTER_DEVICE_IMPL(indice_maxpool_backward_impl, MUSA, + indice_maxpool_backward_musa) + +torch::Tensor IndiceConvForwardMUSAKernelLauncher( + torch::Tensor features, torch::Tensor filters, torch::Tensor indicePairs, + torch::Tensor indiceNum, int64_t numActOut, int64_t _inverse, + int64_t _subM); + +torch::Tensor indice_conv_forward_musa(torch::Tensor features, + torch::Tensor filters, + torch::Tensor indicePairs, + torch::Tensor indiceNum, + int64_t numActOut, int64_t _inverse, + int64_t _subM) { + return IndiceConvForwardMUSAKernelLauncher( + features, filters, indicePairs, indiceNum, numActOut, _inverse, _subM); +}; + +torch::Tensor indice_conv_forward_impl(torch::Tensor features, + torch::Tensor filters, + torch::Tensor indicePairs, + torch::Tensor indiceNum, + int64_t numActOut, int64_t _inverse, + int64_t _subM); + +REGISTER_DEVICE_IMPL(indice_conv_forward_impl, MUSA, indice_conv_forward_musa); + +std::vector IndiceConvBackwardMUSAKernelLauncher( + torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse, + int64_t _subM); + +std::vector indice_conv_backward_musa( + torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse, + int64_t _subM) { + return IndiceConvBackwardMUSAKernelLauncher( + features, filters, outGrad, indicePairs, indiceNum, _inverse, _subM); +}; + +std::vector indice_conv_backward_impl( + torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse, + int64_t _subM); + +REGISTER_DEVICE_IMPL(indice_conv_backward_impl, MUSA, + indice_conv_backward_musa); + +torch::Tensor FusedIndiceConvBatchnormMUSAKernelLauncher( + torch::Tensor features, torch::Tensor filters, torch::Tensor bias, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut, + int64_t _inverse, int64_t _subM); + +torch::Tensor fused_indice_conv_batchnorm_forward_musa( + torch::Tensor features, torch::Tensor filters, torch::Tensor bias, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut, + int64_t _inverse, int64_t _subM) { + return FusedIndiceConvBatchnormMUSAKernelLauncher(features, filters, bias, + indicePairs, indiceNum, + numActOut, _inverse, _subM); +}; + +torch::Tensor fused_indice_conv_batchnorm_forward_impl( + torch::Tensor features, torch::Tensor filters, torch::Tensor bias, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t numActOut, + int64_t _inverse, int64_t _subM); + +REGISTER_DEVICE_IMPL(fused_indice_conv_batchnorm_forward_impl, MUSA, + fused_indice_conv_batchnorm_forward_musa) + void MinAreaPolygonsMUSAKernelLauncher(const Tensor pointsets, Tensor polygons); void min_area_polygons_musa(const Tensor pointsets, Tensor polygons) { @@ -990,6 +1849,57 @@ REGISTER_DEVICE_IMPL(chamfer_distance_backward_impl, MUSA, chamfer_distance_backward_musa); #endif +void PrROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale); + +void PrROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale); + +void PrROIPoolCoorBackwardMUSAKernelLauncher( + Tensor output, Tensor grad_output, Tensor input, Tensor rois, + Tensor grad_rois, int pooled_height, int pooled_width, float spatial_scale); + +void prroi_pool_forward_musa(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale) { + PrROIPoolForwardMUSAKernelLauncher(input, rois, output, pooled_height, + pooled_width, spatial_scale); +} + +void prroi_pool_backward_musa(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale) { + PrROIPoolBackwardMUSAKernelLauncher(grad_output, rois, grad_input, + pooled_height, pooled_width, + spatial_scale); +} + +void prroi_pool_coor_backward_musa(Tensor output, Tensor grad_output, + Tensor input, Tensor rois, Tensor grad_rois, + int pooled_height, int pooled_width, + float spatial_scale) { + PrROIPoolCoorBackwardMUSAKernelLauncher(output, grad_output, input, rois, + grad_rois, pooled_height, + pooled_width, spatial_scale); +} + +void prroi_pool_forward_impl(Tensor input, Tensor rois, Tensor output, + int pooled_height, int pooled_width, + float spatial_scale); +void prroi_pool_backward_impl(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, float spatial_scale); +void prroi_pool_coor_backward_impl(Tensor output, Tensor grad_output, + Tensor input, Tensor rois, Tensor grad_rois, + int pooled_height, int pooled_width, + float spatial_scale); +REGISTER_DEVICE_IMPL(prroi_pool_forward_impl, MUSA, prroi_pool_forward_musa); +REGISTER_DEVICE_IMPL(prroi_pool_backward_impl, MUSA, prroi_pool_backward_musa); +REGISTER_DEVICE_IMPL(prroi_pool_coor_backward_impl, MUSA, + prroi_pool_coor_backward_musa); + void BezierAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output, int aligned_height, int aligned_width, diff --git a/mmcv/ops/csrc/pytorch/musa/points_in_boxes_musa.mu b/mmcv/ops/csrc/pytorch/musa/points_in_boxes_musa.mu new file mode 100644 index 0000000000..5330aeaac5 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/points_in_boxes_musa.mu @@ -0,0 +1,62 @@ +// Modified from +// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu +// Written by Shaoshuai Shi +// All Rights Reserved 2019. + +#include + +#include "points_in_boxes_musa_kernel.muh" +#include "pytorch_musa_helper.hpp" + +void PointsInBoxesPartForwardMUSAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is + // the bottom center, each box DO NOT overlaps params pts: (B, npoints, 3) [x, + // y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), default + // -1 + + c10::musa::MUSAGuard device_guard(boxes.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES( + boxes.scalar_type(), "points_in_boxes_part_forward_musa_kernel", [&] { + points_in_boxes_part_forward_musa_kernel + <<>>( + batch_size, boxes_num, pts_num, boxes.data_ptr(), + pts.data_ptr(), box_idx_of_points.data_ptr()); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void PointsInBoxesAllForwardMUSAKernelLauncher(int batch_size, int boxes_num, + int pts_num, const Tensor boxes, + const Tensor pts, + Tensor box_idx_of_points) { + // params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate, z is the bottom center, each box params pts: (B, npoints, 3) + // [x, y, z] in LiDAR coordinate params boxes_idx_of_points: (B, npoints), + // default -1 + + c10::musa::MUSAGuard device_guard(boxes.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), batch_size); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES( + boxes.scalar_type(), "points_in_boxes_all_forward_musa_kernel", [&] { + points_in_boxes_all_forward_musa_kernel + <<>>( + batch_size, boxes_num, pts_num, boxes.data_ptr(), + pts.data_ptr(), box_idx_of_points.data_ptr()); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/points_in_polygons_musa.mu b/mmcv/ops/csrc/pytorch/musa/points_in_polygons_musa.mu new file mode 100644 index 0000000000..307cb38ea3 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/points_in_polygons_musa.mu @@ -0,0 +1,28 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Modified from +// https://github.com/ming71/MUSA/blob/master/point_justify/points_justify_kernel.cu + +#include + +#include "points_in_polygons_musa_kernel.muh" +#include "pytorch_musa_helper.hpp" + +void PointsInPolygonsForwardMUSAKernelLauncher(const at::Tensor points, + const at::Tensor polygons, + const int rows, const int cols, + at::Tensor output) { + const int output_size = rows * cols; + c10::musa::MUSAGuard device_guard(points.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + points.scalar_type(), "points_in_polygons_forward_musa_kernel", ([&] { + const scalar_t *vertex1 = points.data_ptr(); + const scalar_t *vertex2 = polygons.data_ptr(); + scalar_t *inside_flag = output.data_ptr(); + + points_in_polygons_forward_musa_kernel + <<>>( + output_size, vertex1, vertex2, rows, cols, inside_flag); + })); + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/prroi_pool_musa.mu b/mmcv/ops/csrc/pytorch/musa/prroi_pool_musa.mu new file mode 100644 index 0000000000..fb71317762 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/prroi_pool_musa.mu @@ -0,0 +1,65 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "prroi_pool_musa_kernel.muh" +#include "pytorch_musa_helper.hpp" + +void PrROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, + Tensor output, int pooled_height, + int pooled_width, float spatial_scale) { + int output_size = output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + c10::musa::MUSAGuard device_guard(input.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + prroi_pool_forward_musa_kernel + <<>>( + output_size, input.data_ptr(), rois.data_ptr(), + output.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), channels, height, width); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void PrROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor grad_input, int pooled_height, + int pooled_width, + float spatial_scale) { + int output_size = grad_output.numel(); + int channels = grad_input.size(1); + int height = grad_input.size(2); + int width = grad_input.size(3); + + c10::musa::MUSAGuard device_guard(grad_output.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + prroi_pool_backward_musa_kernel + <<>>( + output_size, grad_output.data_ptr(), rois.data_ptr(), + grad_input.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), channels, height, width); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void PrROIPoolCoorBackwardMUSAKernelLauncher(Tensor output, Tensor grad_output, + Tensor input, Tensor rois, + Tensor grad_rois, + int pooled_height, + int pooled_width, + float spatial_scale) { + int output_size = grad_output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + c10::musa::MUSAGuard device_guard(grad_output.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + prroi_pool_coor_backward_musa_kernel + <<>>( + output_size, output.data_ptr(), grad_output.data_ptr(), + input.data_ptr(), rois.data_ptr(), + grad_rois.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), channels, height, width); + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/psamask_musa.mu b/mmcv/ops/csrc/pytorch/musa/psamask_musa.mu new file mode 100644 index 0000000000..d432954fac --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/psamask_musa.mu @@ -0,0 +1,60 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Modified from +// https://github.com/hszhao/semseg/blob/master/lib/psa/src + +#include + +#include "psamask_musa_kernel.muh" +#include "pytorch_musa_helper.hpp" + +void PSAMaskForwardMUSAKernelLauncher(const int psa_type, const Tensor input, + Tensor output, const int num_, + const int h_feature, const int w_feature, + const int h_mask, const int w_mask, + const int half_h_mask, + const int half_w_mask) { + int nthreads = num_ * h_feature * w_feature; + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + if (psa_type == 0) + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "psamask_collect_forward_musa", [&] { + psamask_collect_forward_musa<<>>( + nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask, + half_w_mask, input.data_ptr(), + output.data_ptr()); + }); + else + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "psamask_distribute_forward_musa", [&] { + psamask_distribute_forward_musa + <<>>( + nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask, + half_w_mask, input.data_ptr(), + output.data_ptr()); + }); +} + +void PSAMaskBackwardMUSAKernelLauncher( + const int psa_type, const Tensor grad_output, Tensor grad_input, + const int num_, const int h_feature, const int w_feature, const int h_mask, + const int w_mask, const int half_h_mask, const int half_w_mask) { + int nthreads = num_ * h_feature * w_feature; + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + if (psa_type == 0) + AT_DISPATCH_FLOATING_TYPES( + grad_input.scalar_type(), "psamask_collect_backward_musa", [&] { + psamask_collect_backward_musa<<>>( + nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask, + half_w_mask, grad_output.data_ptr(), + grad_input.data_ptr()); + }); + else + AT_DISPATCH_FLOATING_TYPES( + grad_input.scalar_type(), "psamask_distribute_backward_musa", [&] { + psamask_distribute_backward_musa + <<>>( + nthreads, h_feature, w_feature, h_mask, w_mask, half_h_mask, + half_w_mask, grad_output.data_ptr(), + grad_input.data_ptr()); + }); +} diff --git a/mmcv/ops/csrc/pytorch/musa/riroi_align_rotated_musa.mu b/mmcv/ops/csrc/pytorch/musa/riroi_align_rotated_musa.mu new file mode 100644 index 0000000000..575071e335 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/riroi_align_rotated_musa.mu @@ -0,0 +1,53 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_musa_helper.hpp" +#include "riroi_align_rotated_musa_kernel.muh" + +void RiROIAlignRotatedForwardMUSAKernelLauncher( + const at::Tensor features, const at::Tensor rois, const float spatial_scale, + const int num_samples, const bool clockwise, const int channels, + const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, const int num_orientations, + at::Tensor output) { + const int output_size = + num_rois * pooled_height * pooled_width * channels * num_orientations; + c10::musa::MUSAGuard device_guard(features.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + features.scalar_type(), "riroi_align_rotated_forward_musa_kernel", ([&] { + const scalar_t *bottom_data = features.data_ptr(); + const scalar_t *rois_data = rois.data_ptr(); + scalar_t *top_data = output.data_ptr(); + + riroi_align_rotated_forward_musa_kernel + <<>>( + output_size, bottom_data, rois_data, scalar_t(spatial_scale), + num_samples, clockwise, channels, height, width, pooled_height, + pooled_width, num_orientations, top_data); + })); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void RiROIAlignRotatedBackwardMUSAKernelLauncher( + const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale, + const int num_samples, const bool clockwise, const int channels, + const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, const int num_orientations, + at::Tensor bottom_grad) { + const int output_size = + num_rois * pooled_height * pooled_width * channels * num_orientations; + c10::musa::MUSAGuard device_guard(top_grad.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + top_grad.scalar_type(), "riroi_align_rotated_backward_musa_kernel", ([&] { + const scalar_t *top_diff = top_grad.data_ptr(); + const scalar_t *rois_data = rois.data_ptr(); + scalar_t *bottom_diff = bottom_grad.data_ptr(); + riroi_align_rotated_backward_musa_kernel + <<>>( + output_size, top_diff, rois_data, spatial_scale, num_samples, + clockwise, channels, height, width, pooled_height, pooled_width, + num_orientations, bottom_diff); + })); + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/roi_align_musa.mu b/mmcv/ops/csrc/pytorch/musa/roi_align_musa.mu new file mode 100644 index 0000000000..f44d8e7439 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/roi_align_musa.mu @@ -0,0 +1,58 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_musa_helper.hpp" +#include "roi_align_musa_kernel.muh" + +void ROIAlignForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output, + Tensor argmax_y, Tensor argmax_x, + int aligned_height, int aligned_width, + float spatial_scale, int sampling_ratio, + int pool_mode, bool aligned) { + int output_size = output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + c10::musa::MUSAGuard device_guard(input.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "roi_align_forward_musa_kernel", [&] { + roi_align_forward_musa_kernel + <<>>( + output_size, input.data_ptr(), + rois.data_ptr(), output.data_ptr(), + argmax_y.data_ptr(), argmax_x.data_ptr(), + aligned_height, aligned_width, + static_cast(spatial_scale), sampling_ratio, pool_mode, + aligned, channels, height, width); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void ROIAlignBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor argmax_y, Tensor argmax_x, + Tensor grad_input, int aligned_height, + int aligned_width, float spatial_scale, + int sampling_ratio, int pool_mode, + bool aligned) { + int output_size = grad_output.numel(); + int channels = grad_input.size(1); + int height = grad_input.size(2); + int width = grad_input.size(3); + + c10::musa::MUSAGuard device_guard(grad_output.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "roi_align_backward_musa_kernel", [&] { + roi_align_backward_musa_kernel + <<>>( + output_size, grad_output.data_ptr(), + rois.data_ptr(), argmax_y.data_ptr(), + argmax_x.data_ptr(), grad_input.data_ptr(), + aligned_height, aligned_width, + static_cast(spatial_scale), sampling_ratio, pool_mode, + aligned, channels, height, width); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/roi_align_rotated_musa.mu b/mmcv/ops/csrc/pytorch/musa/roi_align_rotated_musa.mu new file mode 100644 index 0000000000..793744e243 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/roi_align_rotated_musa.mu @@ -0,0 +1,45 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_musa_helper.hpp" +#include "roi_align_rotated_musa_kernel.muh" + +void ROIAlignRotatedForwardMUSAKernelLauncher( + const at::Tensor input, const at::Tensor rois, const float spatial_scale, + const int sampling_ratio, const bool aligned, const bool clockwise, + const int channels, const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, at::Tensor output) { + const int output_size = num_rois * pooled_height * pooled_width * channels; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "ROIAlignRotatedLaucherForward", ([&] { + const scalar_t *bottom_data = input.data_ptr(); + const scalar_t *rois_data = rois.data_ptr(); + scalar_t *top_data = output.data_ptr(); + + roi_align_rotated_forward_musa_kernel + <<>>( + output_size, bottom_data, rois_data, scalar_t(spatial_scale), + sampling_ratio, aligned, clockwise, channels, height, width, + pooled_height, pooled_width, top_data); + })); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void ROIAlignRotatedBackwardMUSAKernelLauncher( + const at::Tensor top_grad, const at::Tensor rois, const float spatial_scale, + const int sampling_ratio, const bool aligned, const bool clockwise, + const int channels, const int height, const int width, const int num_rois, + const int pooled_height, const int pooled_width, at::Tensor bottom_grad) { + const int output_size = num_rois * pooled_height * pooled_width * channels; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + top_grad.scalar_type(), "ROIAlignLaucherBackward", ([&] { + const scalar_t *top_diff = top_grad.data_ptr(); + const scalar_t *rois_data = rois.data_ptr(); + scalar_t *bottom_diff = bottom_grad.data_ptr(); + roi_align_rotated_backward_musa_kernel + <<>>( + output_size, top_diff, rois_data, spatial_scale, sampling_ratio, + aligned, clockwise, channels, height, width, pooled_height, + pooled_width, bottom_diff); + })); + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/roi_pool_musa.mu b/mmcv/ops/csrc/pytorch/musa/roi_pool_musa.mu new file mode 100644 index 0000000000..f6cb3b6999 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/roi_pool_musa.mu @@ -0,0 +1,50 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_musa_helper.hpp" +#include "roi_pool_musa_kernel.muh" + +void ROIPoolForwardMUSAKernelLauncher(Tensor input, Tensor rois, Tensor output, + Tensor argmax, int pooled_height, + int pooled_width, float spatial_scale) { + int output_size = output.numel(); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + c10::musa::MUSAGuard device_guard(input.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "roi_pool_forward_musa_kernel", [&] { + roi_pool_forward_musa_kernel + <<>>( + output_size, input.data_ptr(), + rois.data_ptr(), output.data_ptr(), + argmax.data_ptr(), pooled_height, pooled_width, + static_cast(spatial_scale), channels, height, width); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void ROIPoolBackwardMUSAKernelLauncher(Tensor grad_output, Tensor rois, + Tensor argmax, Tensor grad_input, + int pooled_height, int pooled_width, + float spatial_scale) { + int output_size = grad_output.numel(); + int channels = grad_input.size(1); + int height = grad_input.size(2); + int width = grad_input.size(3); + + c10::musa::MUSAGuard device_guard(grad_output.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "roi_pool_backward_musa_kernel", [&] { + roi_pool_backward_musa_kernel + <<>>( + output_size, grad_output.data_ptr(), + rois.data_ptr(), argmax.data_ptr(), + grad_input.data_ptr(), pooled_height, pooled_width, + channels, height, width); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/roiaware_pool3d_musa.mu b/mmcv/ops/csrc/pytorch/musa/roiaware_pool3d_musa.mu new file mode 100644 index 0000000000..55e283da37 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/roiaware_pool3d_musa.mu @@ -0,0 +1,118 @@ +// Modified from +// https://github.com/sshaoshuai/PCDet/blob/master/pcdet/ops/roiaware_pool3d/src/roiaware_pool3d_kernel.cu +// Written by Shaoshuai Shi +// All Rights Reserved 2019. + +#include + +#include "pytorch_musa_helper.hpp" +#include "roiaware_pool3d_musa_kernel.muh" + +void RoiawarePool3dForwardMUSAKernelLauncher( + int boxes_num, int pts_num, int channels, int max_pts_each_voxel, int out_x, + int out_y, int out_z, const Tensor rois, const Tensor pts, + const Tensor pts_feature, Tensor argmax, Tensor pts_idx_of_voxels, + Tensor pooled_features, int pool_method) { + // params rois: (N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR + // coordinate params pts: (npoints, 3) [x, y, z] in LiDAR coordinate params + // pts_feature: (npoints, C) params argmax: (N, out_x, out_y, out_z, C) params + // pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) params + // pooled_features: (N, out_x, out_y, out_z, C) params pool_method: 0: + // max_pool 1: avg_pool + + c10::musa::MUSAGuard device_guard(pts_feature.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + Tensor pts_mask = + -at::ones({boxes_num, pts_num}, pts_feature.options().dtype(at::kInt)); + + dim3 blocks_mask(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rois.scalar_type(), "generate_pts_mask_for_box3d", [&] { + generate_pts_mask_for_box3d + <<>>( + boxes_num, pts_num, out_x, out_y, out_z, + rois.data_ptr(), pts.data_ptr(), + pts_mask.data_ptr()); + }); + + AT_MUSA_CHECK(musaGetLastError()); + + // TODO: Merge the collect and pool functions, SS + + dim3 blocks_collect(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK)); + + AT_DISPATCH_INTEGRAL_TYPES( + pts_idx_of_voxels.scalar_type(), "collect_inside_pts_for_box3d", [&] { + collect_inside_pts_for_box3d + <<>>( + boxes_num, pts_num, max_pts_each_voxel, out_x, out_y, out_z, + pts_mask.data_ptr(), + pts_idx_of_voxels.data_ptr()); + }); + + AT_MUSA_CHECK(musaGetLastError()); + + dim3 blocks_pool(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK), + channels, boxes_num); + if (pool_method == 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + pts_feature.scalar_type(), "roiaware_maxpool3d", [&] { + roiaware_maxpool3d<<>>( + boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, + out_z, pts_feature.data_ptr(), + pts_idx_of_voxels.data_ptr(), + pooled_features.data_ptr(), argmax.data_ptr()); + }); + } else if (pool_method == 1) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + pts_feature.scalar_type(), "roiaware_avgpool3d", [&] { + roiaware_avgpool3d<<>>( + boxes_num, pts_num, channels, max_pts_each_voxel, out_x, out_y, + out_z, pts_feature.data_ptr(), + pts_idx_of_voxels.data_ptr(), + pooled_features.data_ptr()); + }); + } + + AT_MUSA_CHECK(musaGetLastError()); +} + +void RoiawarePool3dBackwardMUSAKernelLauncher( + int boxes_num, int out_x, int out_y, int out_z, int channels, + int max_pts_each_voxel, const Tensor pts_idx_of_voxels, const Tensor argmax, + const Tensor grad_out, Tensor grad_in, int pool_method) { + // params pts_idx_of_voxels: (N, out_x, out_y, out_z, max_pts_each_voxel) + // params argmax: (N, out_x, out_y, out_z, C) + // params grad_out: (N, out_x, out_y, out_z, C) + // params grad_in: (npoints, C), return value + // params pool_method: 0: max_pool, 1: avg_pool + + c10::musa::MUSAGuard device_guard(grad_out.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + dim3 blocks(GET_BLOCKS(out_x * out_y * out_z, THREADS_PER_BLOCK), channels, + boxes_num); + dim3 threads(THREADS_PER_BLOCK); + + if (pool_method == 0) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_in.scalar_type(), "roiaware_maxpool3d_backward", [&] { + roiaware_maxpool3d_backward<<>>( + boxes_num, channels, out_x, out_y, out_z, argmax.data_ptr(), + grad_out.data_ptr(), grad_in.data_ptr()); + }); + } else if (pool_method == 1) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_in.scalar_type(), "roiaware_avgpool3d_backward", [&] { + roiaware_avgpool3d_backward<<>>( + boxes_num, channels, out_x, out_y, out_z, max_pts_each_voxel, + pts_idx_of_voxels.data_ptr(), grad_out.data_ptr(), + grad_in.data_ptr()); + }); + } + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/roipoint_pool3d_musa.mu b/mmcv/ops/csrc/pytorch/musa/roipoint_pool3d_musa.mu new file mode 100644 index 0000000000..a4c11e7e62 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/roipoint_pool3d_musa.mu @@ -0,0 +1,60 @@ +/* +Modified from +https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/roipoint_pool3d/src/roipoint_pool3d_kernel.cu +Point cloud feature pooling +Written by Shaoshuai Shi +All Rights Reserved 2018. +*/ + +#include +#include + +#include "pytorch_musa_helper.hpp" +#include "roipoint_pool3d_musa_kernel.muh" + +void RoIPointPool3dForwardMUSAKernelLauncher( + int batch_size, int pts_num, int boxes_num, int feature_in_len, + int sampled_pts_num, const Tensor xyz, const Tensor boxes3d, + const Tensor pts_feature, Tensor pooled_features, + Tensor pooled_empty_flag) { + Tensor pts_assign = at::empty({batch_size, pts_num, boxes_num}, + boxes3d.options().dtype(at::kInt)); + + c10::musa::MUSAGuard device_guard(xyz.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(GET_BLOCKS(pts_num, THREADS_PER_BLOCK), boxes_num, batch_size); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + xyz.scalar_type(), "assign_pts_to_box3d", [&] { + assign_pts_to_box3d<<>>( + batch_size, pts_num, boxes_num, xyz.data_ptr(), + boxes3d.data_ptr(), pts_assign.data_ptr()); + }); + + Tensor pts_idx = at::empty({batch_size, boxes_num, sampled_pts_num}, + boxes3d.options().dtype(at::kInt)); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks2(GET_BLOCKS(boxes_num, THREADS_PER_BLOCK), batch_size); + + get_pooled_idx<<>>( + batch_size, pts_num, boxes_num, sampled_pts_num, + pts_assign.data_ptr(), pts_idx.data_ptr(), + pooled_empty_flag.data_ptr()); + + dim3 blocks_pool(GET_BLOCKS(sampled_pts_num, THREADS_PER_BLOCK), boxes_num, + batch_size); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + xyz.scalar_type(), "roipoint_pool3d_forward", [&] { + roipoint_pool3d_forward<<>>( + batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, + xyz.data_ptr(), pts_idx.data_ptr(), + pts_feature.data_ptr(), + pooled_features.data_ptr(), + pooled_empty_flag.data_ptr()); + }); +} diff --git a/mmcv/ops/csrc/pytorch/musa/rotated_feature_align_musa.mu b/mmcv/ops/csrc/pytorch/musa/rotated_feature_align_musa.mu new file mode 100644 index 0000000000..dd9ffe6c00 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/rotated_feature_align_musa.mu @@ -0,0 +1,53 @@ +// Copyright (c) OpenMMLab. All rights reserved. +// Modified from +// https://github.com/SJTU-Thinklab-Det/r3det-on-mmdetection/blob/master/mmdet/ops/fr/src/feature_refine_kernel.cu +#include "pytorch_musa_helper.hpp" +#include "rotated_feature_align_musa_kernel.muh" + +void RotatedFeatureAlignForwardMUSAKernelLauncher(const Tensor features, + const Tensor best_bboxes, + const float spatial_scale, + const int points, + Tensor output) { + c10::musa::MUSAGuard device_guard(features.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + const int output_size = features.numel(); + AT_DISPATCH_FLOATING_TYPES( + features.scalar_type(), "rotated_feature_align_forward_musa_kernel", + ([&] { + const scalar_t* bottom_data = features.data_ptr(); + const scalar_t* bboxes_data = best_bboxes.data_ptr(); + scalar_t* top_data = output.data_ptr(); + + rotated_feature_align_forward_kernel + <<>>( + output_size, points, bottom_data, bboxes_data, + scalar_t(spatial_scale), features.size(1), features.size(2), + features.size(3), top_data); + })); + AT_MUSA_CHECK(musaGetLastError()); +} + +void RotatedFeatureAlignBackwardMUSAKernelLauncher(const Tensor top_grad, + const Tensor best_bboxes, + const float spatial_scale, + const int points, + Tensor bottom_grad) { + c10::musa::MUSAGuard device_guard(top_grad.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + const int output_size = top_grad.numel(); + AT_DISPATCH_FLOATING_TYPES( + top_grad.scalar_type(), "rotated_feature_align_backward_musa_kernel", + ([&] { + const scalar_t* top_diff = top_grad.data_ptr(); + const scalar_t* bboxes_data = best_bboxes.data_ptr(); + scalar_t* bottom_diff = bottom_grad.data_ptr(); + + rotated_feature_align_backward_kernel + <<>>( + output_size, points, top_diff, bboxes_data, + scalar_t(spatial_scale), top_grad.size(1), top_grad.size(2), + top_grad.size(3), bottom_diff); + })); + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/scatter_points_musa.mu b/mmcv/ops/csrc/pytorch/musa/scatter_points_musa.mu new file mode 100644 index 0000000000..1edca61a46 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/scatter_points_musa.mu @@ -0,0 +1,132 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include +#include +#include + +#include "pytorch_musa_helper.hpp" +#include "scatter_points_musa_kernel.muh" + +std::vector DynamicPointToVoxelForwardMUSAKernelLauncher( + const at::Tensor &feats, const at::Tensor &coors, + const reduce_t reduce_type) { + const int num_input = feats.size(0); + const int num_feats = feats.size(1); + + if (num_input == 0) + return {feats.clone().detach(), coors.clone().detach(), + coors.new_empty({0}, torch::kInt32), + coors.new_empty({0}, torch::kInt32)}; + + at::Tensor out_coors; + at::Tensor coors_map; + at::Tensor reduce_count; + + auto coors_clean = coors.masked_fill(coors.lt(0).any(-1, true), -1); + + std::tie(out_coors, coors_map, reduce_count) = + at::unique_dim(coors_clean, 0, true, true, true); + + if (out_coors[0][0].lt(0).item()) { + // the first element of out_coors (-1,-1,-1) and should be removed + out_coors = out_coors.slice(0, 1); + reduce_count = reduce_count.slice(0, 1); + coors_map = coors_map - 1; + } + + coors_map = coors_map.to(torch::kInt32); + reduce_count = reduce_count.to(torch::kInt32); + + auto reduced_feats = + at::empty({out_coors.size(0), num_feats}, feats.options()); + + c10::musa::MUSAGuard device_guard(feats.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + AT_DISPATCH_FLOATING_TYPES( + feats.scalar_type(), "feats_reduce_kernel", ([&] { + if (reduce_type == reduce_t::MAX) + reduced_feats.fill_(-std::numeric_limits::infinity()); + else + reduced_feats.fill_(static_cast(0)); + + dim3 blocks(std::min( + at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim)); + dim3 threads(THREADS_PER_BLOCK); + feats_reduce_kernel<<>>( + feats.data_ptr(), coors_map.data_ptr(), + reduced_feats.data_ptr(), num_input, num_feats, + reduce_type); + if (reduce_type == reduce_t::MEAN) + reduced_feats /= reduce_count.unsqueeze(-1).to(reduced_feats.dtype()); + })); + + AT_MUSA_CHECK(musaGetLastError()); + + return {reduced_feats, out_coors, coors_map, reduce_count}; +} + +void DynamicPointToVoxelBackwardMUSAKernelLauncher( + at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats, + const at::Tensor &feats, const at::Tensor &reduced_feats, + const at::Tensor &coors_map, const at::Tensor &reduce_count, + const reduce_t reduce_type) { + const int num_input = feats.size(0); + const int num_reduced = reduced_feats.size(0); + const int num_feats = feats.size(1); + + grad_feats.fill_(0); + // copy voxel grad to points + + if (num_input == 0 || num_reduced == 0) return; + c10::musa::MUSAGuard device_guard(feats.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + if (reduce_type == reduce_t::MEAN || reduce_type == reduce_t::SUM) { + AT_DISPATCH_FLOATING_TYPES( + grad_reduced_feats.scalar_type(), "add_reduce_traceback_grad_kernel", + ([&] { + dim3 blocks(std::min( + at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim)); + dim3 threads(THREADS_PER_BLOCK); + add_reduce_traceback_grad_kernel<<>>( + grad_feats.data_ptr(), + grad_reduced_feats.data_ptr(), + coors_map.data_ptr(), reduce_count.data_ptr(), + num_input, num_feats, reduce_type); + })); + + AT_MUSA_CHECK(musaGetLastError()); + } else { + auto reduce_from = at::full({num_reduced, num_feats}, num_input, + coors_map.options().dtype(torch::kInt32)); + AT_DISPATCH_FLOATING_TYPES( + grad_reduced_feats.scalar_type(), + "max_reduce_traceback_scatter_idx_kernel", ([&] { + dim3 blocks(std::min( + at::musa::ATenCeilDiv(num_input, THREADS_PER_BLOCK), maxGridDim)); + dim3 threads(THREADS_PER_BLOCK); + max_reduce_traceback_scatter_idx_kernel<<>>( + feats.data_ptr(), reduced_feats.data_ptr(), + reduce_from.data_ptr(), coors_map.data_ptr(), + num_input, num_feats); + })); + + AT_MUSA_CHECK(musaGetLastError()); + + AT_DISPATCH_FLOATING_TYPES( + grad_reduced_feats.scalar_type(), + "max_reduce_traceback_scatter_idx_kernel", ([&] { + dim3 blocks( + std::min(at::musa::ATenCeilDiv(num_reduced, THREADS_PER_BLOCK), + maxGridDim)); + dim3 threads(THREADS_PER_BLOCK); + max_reduce_scatter_grad_kernel<<>>( + grad_feats.data_ptr(), + grad_reduced_feats.data_ptr(), + reduce_from.data_ptr(), num_reduced, num_feats); + })); + + AT_MUSA_CHECK(musaGetLastError()); + } +} diff --git a/mmcv/ops/csrc/pytorch/musa/sparse_maxpool.mu b/mmcv/ops/csrc/pytorch/musa/sparse_maxpool.mu new file mode 100644 index 0000000000..67a69c1761 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/sparse_maxpool.mu @@ -0,0 +1,486 @@ +// Copyright 2019 Yan Yan +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +// clang-format off +// TODO: make spconv_utils.h order agnostic +#include "../spconv_utils.h" +// clang-format on +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "pytorch_musa_helper.hpp" + +template +__global__ void maxPoolFwdBlockKernel(scalar_t *outFeatures, + const scalar_t *inFeatures, + const Index *indicesIn, + const Index *indicesOut, int numHot, + int numPlanes) { + scalar_t in, out; + int ILPStrideY[NumILP]; + Index idxo, idxi; +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) + ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y; + outFeatures += blockIdx.y * NumTLP; + inFeatures += blockIdx.y * NumTLP; + for (int ix = blockIdx.x * blockDim.x; ix < numHot; + ix += blockDim.x * gridDim.x) { + { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ++ilp) { + idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x; + idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x; + in = inFeatures[idxi]; + out = outFeatures[idxo]; + if (in > out) { + outFeatures[idxo] = in; + } + } + } + } +} + +template +__global__ void maxPoolFwdGenericBlockKernel(scalar_t *outFeatures, + const scalar_t *inFeatures, + const Index *indicesIn, + const Index *indicesOut, + int numHot, int numPlanes) { + int ILPStrideX[NumILP]; + Index RI[NumILP]; + Index RO[NumILP]; + scalar_t in, out; +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) + ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x; + for (int ix : tv::KernelLoopX(numHot)) { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) { + RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes; + RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes; + } + for (int iy : tv::KernelLoopY(numPlanes)) { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ++ilp) { + in = inFeatures[RI[ilp] + iy]; + out = outFeatures[RO[ilp] + iy]; + if (in > out) { + outFeatures[RO[ilp] + iy] = in; + } + } + } + } +} + +template +__global__ void maxPoolFwdVecBlockKernel(scalar_t *outFeatures, + const scalar_t *inFeatures, + const Index *indicesIn, + const Index *indicesOut, int numHot, + int numPlanes) { + int ILPStrideY[NumILP]; + constexpr int vecloadFactor = sizeof(VecType) / sizeof(scalar_t); + scalar_t bufi[vecloadFactor]; + scalar_t bufo[vecloadFactor]; + Index idxi, idxo; +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) + ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y; + outFeatures += blockIdx.y * NumTLP; + inFeatures += blockIdx.y * NumTLP; + for (int ix = blockIdx.x * blockDim.x * vecloadFactor; ix < numHot; + ix += blockDim.x * gridDim.x * vecloadFactor) { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ++ilp) { + idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x; + idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x; + reinterpret_cast(bufo)[0] = + reinterpret_cast(outFeatures)[idxo]; + reinterpret_cast(bufi)[0] = + reinterpret_cast(inFeatures)[idxi]; +#pragma unroll + for (int i = 0; i < vecloadFactor; i++) { + if (bufi[i] > bufo[i]) { + bufo[i] = bufi[i]; + } + } + reinterpret_cast(outFeatures)[idxo] = + reinterpret_cast(bufo)[0]; + } + } +} + +template +__global__ void maxPoolFwdGenericKernel(scalar_t *outFeatures, + const scalar_t *inFeatures, + const Index *indicesIn, + const Index *indicesOut, int numHot, + int numPlanes) { + int ILPStrideX[NumILP]; + Index RI[NumILP]; + Index RO[NumILP]; + scalar_t in, out; +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) + ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x; + for (int ix : tv::KernelLoopX(numHot)) { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) { + if (ix + ILPStrideX[ilp] < numHot) { + RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes; + RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes; + } + } + for (int iy : tv::KernelLoopY(numPlanes)) { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ++ilp) { + if (ix + ILPStrideX[ilp] < numHot) { + in = inFeatures[RI[ilp] + iy]; + out = outFeatures[RO[ilp] + iy]; + if (in > out) { + outFeatures[RO[ilp] + iy] = in; + } + } + } + } + } +} + +template +__global__ void maxPoolBwdBlockKernel(const scalar_t *outFeatures, + const scalar_t *inFeatures, + const scalar_t *fout, scalar_t *fin, + const Index *indicesIn, + const Index *indicesOut, int numHot, + int numPlanes) { + scalar_t in, out; + Index idxo, idxi; + int ILPStrideY[NumILP]; +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) + ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y; + outFeatures += blockIdx.y * NumTLP; + inFeatures += blockIdx.y * NumTLP; + fout += blockIdx.y * NumTLP; + fin += blockIdx.y * NumTLP; + for (int ix = blockIdx.x * blockDim.x; ix < numHot; + ix += blockDim.x * gridDim.x) { + { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ++ilp) { + idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x; + idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x; + in = inFeatures[idxi]; + out = outFeatures[idxo]; + if (in == out) { + fin[idxi] += fout[idxo]; + } + } + } + } +} + +template +__global__ void maxPoolBwdGenericBlockKernel( + const scalar_t *outFeatures, const scalar_t *inFeatures, + const scalar_t *fout, scalar_t *fin, const Index *indicesIn, + const Index *indicesOut, int numHot, int numPlanes) { + int ILPStrideX[NumILP]; + Index RI[NumILP]; + Index RO[NumILP]; + scalar_t in, out; +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) + ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x; + for (int ix : tv::KernelLoopX(numHot)) { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) { + RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes; + RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes; + } + for (int iy : tv::KernelLoopY(numPlanes)) { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ++ilp) { + in = inFeatures[RI[ilp] + iy]; + out = outFeatures[RO[ilp] + iy]; + if (in == out) { + fin[RI[ilp] + iy] += fout[RO[ilp] + iy]; + } + } + } + } +} + +template +__global__ void maxPoolBwdVecBlockKernel(const scalar_t *outFeatures, + const scalar_t *inFeatures, + const scalar_t *fout, scalar_t *fin, + const Index *indicesIn, + const Index *indicesOut, int numHot, + int numPlanes) { + int ILPStrideY[NumILP]; + constexpr int vecloadFactor = sizeof(VecType) / sizeof(scalar_t); + scalar_t bufi[vecloadFactor]; + scalar_t bufo[vecloadFactor]; + scalar_t bufdi[vecloadFactor]; + scalar_t bufdo[vecloadFactor]; + Index idxi, idxo; +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) + ILPStrideY[ilp] = threadIdx.y + ilp * blockDim.y; + outFeatures += blockIdx.y * NumTLP; + inFeatures += blockIdx.y * NumTLP; + for (int ix = blockIdx.x * blockDim.x * vecloadFactor; ix < numHot; + ix += blockDim.x * gridDim.x * vecloadFactor) { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ++ilp) { + idxi = indicesIn[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x; + idxo = indicesOut[ix + ILPStrideY[ilp]] * numPlanes + threadIdx.x; + reinterpret_cast(bufo)[0] = + reinterpret_cast(outFeatures)[idxo]; + reinterpret_cast(bufi)[0] = + reinterpret_cast(inFeatures)[idxi]; + reinterpret_cast(bufdo)[0] = + reinterpret_cast(fout)[idxo]; + reinterpret_cast(bufdi)[0] = + reinterpret_cast(fin)[idxi]; + +#pragma unroll + for (int i = 0; i < vecloadFactor; i++) { + if (bufi[i] == bufo[i]) { + bufdi[i] += bufdo[i]; + } + } + reinterpret_cast(fin)[idxi] = + reinterpret_cast(bufdi)[0]; + } + } +} + +template +__global__ void maxPoolBwdGenericKernel(const scalar_t *outFeatures, + const scalar_t *inFeatures, + const scalar_t *fout, scalar_t *fin, + const Index *indicesIn, + const Index *indicesOut, int numHot, + int numPlanes) { + int ILPStrideX[NumILP]; + Index RI[NumILP]; + Index RO[NumILP]; + scalar_t in, out; +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) + ILPStrideX[ilp] = ilp * gridDim.x * blockDim.x; + for (int ix : tv::KernelLoopX(numHot)) { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ilp++) { + if (ix + ILPStrideX[ilp] < numHot) { + RI[ilp] = indicesIn[ix + ILPStrideX[ilp]] * numPlanes; + RO[ilp] = indicesOut[ix + ILPStrideX[ilp]] * numPlanes; + } + } + for (int iy : tv::KernelLoopY(numPlanes)) { +#pragma unroll + for (int ilp = 0; ilp < NumILP; ++ilp) { + if (ix + ILPStrideX[ilp] < numHot) { + in = inFeatures[RI[ilp] + iy]; + out = outFeatures[RO[ilp] + iy]; + if (in == out) { + fin[RI[ilp] + iy] += fout[RO[ilp] + iy]; + } + } + } + } + } +} + +namespace functor { +template +struct SparseMaxPoolForwardFunctor { + using vecload_type_t = + std::conditional_t::value, int2, int4>; + using kernel_block_t = mp_list_c; + void operator()(const tv::TorchGPU &d, tv::TensorView outFeatures, + tv::TensorView inFeatures, + tv::TensorView indices, int size) { + if (size <= 0) return; + int numPlanes = inFeatures.dim(1); + bool notFound = true; + constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(scalar_t); + mp_for_each([=, &outFeatures, &inFeatures, &indices, + ¬Found](auto NumTLP) { + constexpr int NumILP = NumTLP / 4; + + int numHotBlock = (size / NumTLP) * NumTLP; + if (notFound) { + if (numPlanes % NumTLP == 0) { + if (numHotBlock >= NumTLP) { + maxPoolFwdVecBlockKernel + <<>>(outFeatures.data(), inFeatures.data(), + indices.subview(0).data(), + indices.subview(1).data(), numHotBlock, + numPlanes / vecloadFactor); + TV_CHECK_MUSA_ERR(); + } + + if (size > numHotBlock) { + maxPoolFwdGenericKernel + <<>>(outFeatures.data(), inFeatures.data(), + indices.subview(0).data() + numHotBlock, + indices.subview(1).data() + numHotBlock, + size - numHotBlock, numPlanes); + TV_CHECK_MUSA_ERR(); + } + notFound = false; + } + } + }); + + if (notFound) { + constexpr int NumTLP = 64; + constexpr int NumILP = NumTLP / 4; + int numHotBlock = (size / NumTLP) * NumTLP; + if (numHotBlock >= NumTLP) { + maxPoolFwdGenericBlockKernel + <<>>( + outFeatures.data(), inFeatures.data(), + indices.subview(0).data(), indices.subview(1).data(), + numHotBlock, numPlanes); + TV_CHECK_MUSA_ERR(); + } + + if (size > numHotBlock) { + maxPoolFwdGenericKernel + <<>>( + outFeatures.data(), inFeatures.data(), + indices.subview(0).data() + numHotBlock, + indices.subview(1).data() + numHotBlock, size - numHotBlock, + numPlanes); + TV_CHECK_MUSA_ERR(); + } + } + } +}; + +template +struct SparseMaxPoolBackwardFunctor { + using vecload_type_t = + std::conditional_t::value, int2, int4>; + using kernel_block_t = mp_list_c; + void operator()(const tv::TorchGPU &d, + tv::TensorView outFeatures, + tv::TensorView inFeatures, + tv::TensorView fout, + tv::TensorView fin, + tv::TensorView indices, int size) { + if (size <= 0) return; + int numPlanes = inFeatures.dim(1); + bool notFound = true; + constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(scalar_t); + mp_for_each([=, &outFeatures, &inFeatures, &fout, &fin, + &indices, ¬Found](auto NumTLP) { + constexpr int NumILP = NumTLP / 4; + + int numHotBlock = (size / NumTLP) * NumTLP; + if (notFound) { + if (numPlanes % NumTLP == 0) { + if (numHotBlock >= NumTLP) { + maxPoolBwdVecBlockKernel + <<>>(outFeatures.data(), inFeatures.data(), + fout.data(), fin.data(), + indices.subview(0).data(), + indices.subview(1).data(), numHotBlock, + numPlanes / vecloadFactor); + TV_CHECK_MUSA_ERR(); + } + + if (size > numHotBlock) { + maxPoolBwdGenericKernel + <<>>(outFeatures.data(), inFeatures.data(), + fout.data(), fin.data(), + indices.subview(0).data() + numHotBlock, + indices.subview(1).data() + numHotBlock, + size - numHotBlock, numPlanes); + TV_CHECK_MUSA_ERR(); + } + notFound = false; + } + } + }); + + if (notFound) { + constexpr int NumTLP = 64; + constexpr int NumILP = NumTLP / 4; + int numHotBlock = (size / NumTLP) * NumTLP; + if (numHotBlock >= NumTLP) { + maxPoolBwdGenericBlockKernel + <<>>( + outFeatures.data(), inFeatures.data(), fout.data(), fin.data(), + indices.subview(0).data(), indices.subview(1).data(), + numHotBlock, numPlanes); + TV_CHECK_MUSA_ERR(); + } + + if (size > numHotBlock) { + maxPoolBwdGenericKernel + <<>>( + outFeatures.data(), inFeatures.data(), fout.data(), fin.data(), + indices.subview(0).data() + numHotBlock, + indices.subview(1).data() + numHotBlock, size - numHotBlock, + numPlanes); + TV_CHECK_MUSA_ERR(); + } + } + } +}; + +} // namespace functor + +#define DECLARE_GPU_SPECS_T_INDEX(scalar_t, Index) \ + template struct functor::SparseMaxPoolForwardFunctor; \ + template struct functor::SparseMaxPoolBackwardFunctor; + +#define DECLARE_GPU_SPECS(scalar_t) DECLARE_GPU_SPECS_T_INDEX(scalar_t, int); + +DECLARE_GPU_SPECS(float); +DECLARE_GPU_SPECS(double); +DECLARE_GPU_SPECS(at::Half); + +#undef DECLARE_GPU_SPECS +#undef DECLARE_GPU_SPECS_T_INDEX diff --git a/mmcv/ops/csrc/pytorch/musa/sparse_pool_ops_musa.mu b/mmcv/ops/csrc/pytorch/musa/sparse_pool_ops_musa.mu new file mode 100644 index 0000000000..a4ce9b2d5c --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/sparse_pool_ops_musa.mu @@ -0,0 +1,91 @@ +#include +#include +// clang-format off +// TODO: make spconv_utils.h order agnostic +#include "../spconv_utils.h" +// clang-format on +#include + +#include "pytorch_musa_helper.hpp" + +torch::Tensor IndiceMaxpoolForwardMUSAKernelLauncher(torch::Tensor features, + torch::Tensor indicePairs, + torch::Tensor indiceNum, + int64_t numAct) { + c10::musa::MUSAGuard device_guard(features.device()); + auto device = features.device().type(); + auto kernelVolume = indicePairs.size(0); + auto numInPlanes = features.size(1); + auto indicePairNumCpu = indiceNum.to({torch::kCPU}); + auto options = + torch::TensorOptions().dtype(features.dtype()).device(features.device()); + torch::Tensor output = torch::zeros({numAct, numInPlanes}, options); + for (int i = 0; i < kernelVolume; ++i) { + auto nHot = indicePairNumCpu.data_ptr()[i]; + if (nHot <= 0) { + continue; + } + AT_DISPATCH_FLOATING_TYPES( + features.scalar_type(), "IndiceMaxpoolForwardKernel", [&] { + if (device == torch::kCPU) { + functor::SparseMaxPoolForwardFunctor + forwardFtor; + forwardFtor(tv::CPU(), tv::torch2tv(output), + tv::torch2tv(features), + tv::torch2tv(indicePairs).subview(i), nHot); + } else { + functor::SparseMaxPoolForwardFunctor + forwardFtor; + forwardFtor(tv::TorchGPU(), tv::torch2tv(output), + tv::torch2tv(features), + tv::torch2tv(indicePairs).subview(i), nHot); + TV_CHECK_MUSA_ERR(); + } + }); + } + return output; +} + +torch::Tensor IndiceMaxpoolBackwardMUSAKernelLauncher(torch::Tensor features, + torch::Tensor outFeatures, + torch::Tensor outGrad, + torch::Tensor indicePairs, + torch::Tensor indiceNum) { + c10::musa::MUSAGuard device_guard(features.device()); + auto device = features.device().type(); + auto numInPlanes = features.size(1); + auto indicePairNumCpu = indiceNum.to({torch::kCPU}); + auto options = + torch::TensorOptions().dtype(features.dtype()).device(features.device()); + torch::Tensor inputGrad = torch::zeros(features.sizes(), options); + auto kernelVolume = indicePairs.size(0); + for (int i = 0; i < kernelVolume; ++i) { + auto nHot = indicePairNumCpu.data_ptr()[i]; + if (nHot <= 0) { + continue; + } + AT_DISPATCH_FLOATING_TYPES( + features.scalar_type(), "IndiceMaxpoolBackwardKernel", [&] { + if (device == torch::kCPU) { + functor::SparseMaxPoolBackwardFunctor + backwardFtor; + backwardFtor(tv::CPU(), tv::torch2tv(outFeatures), + tv::torch2tv(features), + tv::torch2tv(outGrad), + tv::torch2tv(inputGrad), + tv::torch2tv(indicePairs).subview(i), nHot); + } else { + functor::SparseMaxPoolBackwardFunctor + backwardFtor; + backwardFtor(tv::TorchGPU(), + tv::torch2tv(outFeatures), + tv::torch2tv(features), + tv::torch2tv(outGrad), + tv::torch2tv(inputGrad), + tv::torch2tv(indicePairs).subview(i), nHot); + TV_CHECK_MUSA_ERR(); + } + }); + } + return inputGrad; +} diff --git a/mmcv/ops/csrc/pytorch/musa/sync_bn_musa.mu b/mmcv/ops/csrc/pytorch/musa/sync_bn_musa.mu new file mode 100644 index 0000000000..56327f4ed1 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/sync_bn_musa.mu @@ -0,0 +1,110 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_musa_helper.hpp" +#include "sync_bn_musa_kernel.muh" + +void SyncBNForwardMeanMUSAKernelLauncher(const Tensor input, Tensor mean) { + int num = input.size(0); + int channels = input.size(1); + int spatial = input.size(2); + + c10::musa::MUSAGuard device_guard(input.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] { + sync_bn_forward_mean_musa_kernel + <<>>( + input.data_ptr(), mean.data_ptr(), num, + channels, spatial); + }); + AT_MUSA_CHECK(musaGetLastError()); +} + +void SyncBNForwardVarMUSAKernelLauncher(const Tensor input, const Tensor mean, + Tensor var) { + int num = input.size(0); + int channels = input.size(1); + int spatial = input.size(2); + + c10::musa::MUSAGuard device_guard(input.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] { + sync_bn_forward_var_musa_kernel + <<>>( + input.data_ptr(), mean.data_ptr(), + var.data_ptr(), num, channels, spatial); + }); + AT_MUSA_CHECK(musaGetLastError()); +} + +void SyncBNForwardOutputMUSAKernelLauncher( + const Tensor input, const Tensor mean, const Tensor var, + Tensor running_mean, Tensor running_var, const Tensor weight, + const Tensor bias, Tensor norm, Tensor std, Tensor output, float eps, + float momentum, int group_size) { + int num = input.size(0); + int channels = input.size(1); + int spatial = input.size(2); + + c10::musa::MUSAGuard device_guard(input.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "sync_bn_forward_mean_musa_kernel", [&] { + sync_bn_forward_output_musa_kernel + <<>>( + input.data_ptr(), mean.data_ptr(), + var.data_ptr(), running_mean.data_ptr(), + running_var.data_ptr(), weight.data_ptr(), + bias.data_ptr(), norm.data_ptr(), + std.data_ptr(), output.data_ptr(), num, + channels, spatial, eps, momentum, group_size); + }); + AT_MUSA_CHECK(musaGetLastError()); +} + +void SyncBNBackwardParamMUSAKernelLauncher(const Tensor grad_output, + const Tensor norm, + Tensor grad_weight, + Tensor grad_bias) { + int num = grad_output.size(0); + int channels = grad_output.size(1); + int spatial = grad_output.size(2); + + c10::musa::MUSAGuard device_guard(grad_output.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + grad_output.scalar_type(), "sync_bn_backward_param_musa_kernel", [&] { + sync_bn_backward_param_musa_kernel + <<>>( + grad_output.data_ptr(), norm.data_ptr(), + grad_weight.data_ptr(), grad_bias.data_ptr(), num, + channels, spatial); + }); + AT_MUSA_CHECK(musaGetLastError()); +} + +void SyncBNBackwardDataMUSAKernelLauncher(const Tensor grad_output, + const Tensor weight, + const Tensor grad_weight, + const Tensor grad_bias, + const Tensor norm, const Tensor std, + Tensor grad_input) { + int output_size = grad_input.numel(); + int num = grad_input.size(0); + int channels = grad_input.size(1); + int spatial = grad_input.size(2); + + c10::musa::MUSAGuard device_guard(grad_input.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES( + grad_output.scalar_type(), "sync_bn_backward_data_musa_kernel", [&] { + sync_bn_backward_data_musa_kernel + <<>>( + output_size, grad_output.data_ptr(), + weight.data_ptr(), grad_weight.data_ptr(), + grad_bias.data_ptr(), norm.data_ptr(), + std.data_ptr(), grad_input.data_ptr(), num, + channels, spatial); + }); + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/three_interpolate_musa.mu b/mmcv/ops/csrc/pytorch/musa/three_interpolate_musa.mu new file mode 100644 index 0000000000..c48314cc39 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/three_interpolate_musa.mu @@ -0,0 +1,66 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu + +#include +#include +#include + +#include "pytorch_musa_helper.hpp" +#include "three_interpolate_musa_kernel.muh" + +void ThreeInterpolateForwardMUSAKernelLauncher(int b, int c, int m, int n, + const Tensor points, + const Tensor idx, + const Tensor weight, + Tensor out) { + // points: (B, C, M) + // idx: (B, N, 3) + // weight: (B, N, 3) + // output: + // out: (B, C, N) + + c10::musa::MUSAGuard device_guard(points.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + points.scalar_type(), "three_interpolate_forward_musa_kernel", [&] { + three_interpolate_forward_musa_kernel + <<>>( + b, c, m, n, points.data_ptr(), idx.data_ptr(), + weight.data_ptr(), out.data_ptr()); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void ThreeInterpolateBackwardMUSAKernelLauncher(int b, int c, int n, int m, + const Tensor grad_out, + const Tensor idx, + const Tensor weight, + Tensor grad_points) { + // grad_out: (B, C, N) + // weight: (B, N, 3) + // output: + // grad_points: (B, C, M) + + c10::musa::MUSAGuard device_guard(grad_out.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), c, b); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_out.scalar_type(), "three_interpolate_backward_musa_kernel", [&] { + three_interpolate_backward_musa_kernel + <<>>( + b, c, n, m, grad_out.data_ptr(), idx.data_ptr(), + weight.data_ptr(), grad_points.data_ptr()); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/three_nn_musa.mu b/mmcv/ops/csrc/pytorch/musa/three_nn_musa.mu new file mode 100644 index 0000000000..b69caa4039 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/three_nn_musa.mu @@ -0,0 +1,35 @@ +// Modified from +// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/interpolate_gpu.cu + +#include +#include +#include + +#include "pytorch_musa_helper.hpp" +#include "three_nn_musa_kernel.muh" + +void ThreeNNForwardMUSAKernelLauncher(int b, int n, int m, const Tensor unknown, + const Tensor known, Tensor dist2, + Tensor idx) { + // unknown: (B, N, 3) + // known: (B, M, 3) + // output: + // dist2: (B, N, 3) + // idx: (B, N, 3) + + c10::musa::MUSAGuard device_guard(unknown.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + // blockIdx.x(col), blockIdx.y(row) + dim3 blocks(GET_BLOCKS(n, THREADS_PER_BLOCK), b); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_FLOATING_TYPES( + unknown.scalar_type(), "three_nn_forward_musa_kernel", [&] { + three_nn_forward_musa_kernel<<>>( + b, n, m, unknown.data_ptr(), known.data_ptr(), + dist2.data_ptr(), idx.data_ptr()); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/tin_shift_musa.mu b/mmcv/ops/csrc/pytorch/musa/tin_shift_musa.mu new file mode 100644 index 0000000000..705c208d68 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/tin_shift_musa.mu @@ -0,0 +1,55 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "pytorch_musa_helper.hpp" +#include "pytorch_device_registry.hpp" +#include "tin_shift_musa_kernel.muh" + +void TINShiftForwardMUSAKernelLauncher(Tensor input, Tensor shift, + Tensor output) { + int output_size = output.numel(); + int batch_size = input.size(0); + int t_size = input.size(1); + int channels = input.size(2); + int hw_size = input.size(3); + int group_size = shift.size(1); + int group_channel = channels / group_size; + int num_kernels = batch_size * hw_size * channels; + + c10::musa::MUSAGuard device_guard(input.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "tin_shift_forward_musa_kernel", [&] { + tin_shift_forward_musa_kernel + <<>>( + output_size, input.data_ptr(), shift.data_ptr(), + output.data_ptr(), batch_size, channels, t_size, + hw_size, group_size, group_channel); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} + +void TINShiftBackwardMUSAKernelLauncher(Tensor grad_output, Tensor shift, + Tensor grad_input) { + int output_size = grad_output.numel(); + int batch_size = grad_output.size(0); + int t_size = grad_output.size(1); + int channels = grad_output.size(2); + int hw_size = grad_output.size(3); + int group_size = shift.size(1); + int group_channel = channels / group_size; + int num_kernels = batch_size * hw_size * channels; + + c10::musa::MUSAGuard device_guard(grad_output.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "tin_shift_backward_musa_kernel", [&] { + tin_shift_backward_musa_kernel + <<>>( + output_size, grad_output.data_ptr(), + shift.data_ptr(), grad_input.data_ptr(), + batch_size, channels, t_size, hw_size, group_size, + group_channel); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/musa/upfirdn2d_kernel.mu b/mmcv/ops/csrc/pytorch/musa/upfirdn2d_kernel.mu new file mode 100644 index 0000000000..9b9a2ffe80 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/upfirdn2d_kernel.mu @@ -0,0 +1,749 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. +#include +#include + +#include "pytorch_musa_helper.hpp" +#if MUSA_ARCH > 21 +struct upfirdn2d_kernel_params { + const void *x; + const float *f; + void *y; + + int2 up; + int2 down; + int2 pad0; + int flip; + float gain; + + int4 inSize; // [width, height, channel, batch] + int4 inStride; + int2 filterSize; // [width, height] + int2 filterStride; + int4 outSize; // [width, height, channel, batch] + int4 outStride; + int sizeMinor; + int sizeMajor; + + int loopMinor; + int loopMajor; + int loopX; + int launchMinor; + int launchMajor; +}; + +//------------------------------------------------------------------------ +// MUSA kernel specialization. + +struct upfirdn2d_kernel_spec { + void *kernel; + int tileOutW; + int tileOutH; + int loopMinor; + int loopX; +}; + +//------------------------------------------------------------------------ +// MUSA kernel selection. + +template +upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params &p); +//------------------------------------------------------------------------ + +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// NVIDIA CORPORATION and its licensors retain all intellectual property +// and proprietary rights in and to this software, related documentation +// and any modifications thereto. Any use, reproduction, disclosure or +// distribution of this software and related documentation without an express +// license agreement from NVIDIA CORPORATION is strictly prohibited. + +//------------------------------------------------------------------------ +// Helpers. + +template +struct InternalType; +template <> +struct InternalType { + typedef double scalar_t; +}; +template <> +struct InternalType { + typedef float scalar_t; +}; +template <> +struct InternalType { + typedef float scalar_t; +}; + +static __device__ __forceinline__ int floor_div(int a, int b) { + int t = 1 - a / b; + return (a + t * b) / b - t; +} + +//------------------------------------------------------------------------ +// Generic MUSA implementation for large filters. + +template +static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) { + typedef typename InternalType::scalar_t scalar_t; + + // Calculate thread index. + int minorBase = blockIdx.x * blockDim.x + threadIdx.x; + int outY = minorBase / p.launchMinor; + minorBase -= outY * p.launchMinor; + int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; + int majorBase = blockIdx.z * p.loopMajor; + if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) + return; + + // Setup Y receptive field. + int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; + int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); + int h = + min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; + int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; + if (p.flip) filterY = p.filterSize.y - 1 - filterY; + + // Loop over major, minor, and X. + for (int majorIdx = 0, major = majorBase; + majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) + for (int minorIdx = 0, minor = minorBase; + minorIdx < p.loopMinor & minor < p.sizeMinor; + minorIdx++, minor += p.launchMinor) { + int nc = major * p.sizeMinor + minor; + int n = nc / p.inSize.z; + int c = nc - n * p.inSize.z; + for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; + loopX++, outX += blockDim.y) { + // Setup X receptive field. + int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; + int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); + int w = + min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - + inX; + int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; + if (p.flip) filterX = p.filterSize.x - 1 - filterX; + + // Initialize pointers. + const T *xp = + &((const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y + + c * p.inStride.z + n * p.inStride.w]; + const float *fp = + &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; + int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; + int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; + + // Inner loop. + scalar_t v = 0; + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += (scalar_t)(*xp) * (scalar_t)(*fp); + xp += p.inStride.x; + fp += filterStepX; + } + xp += p.inStride.y - w * p.inStride.x; + fp += filterStepY - w * filterStepX; + } + + // Store result. + v *= p.gain; + ((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y + + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } +} + +//------------------------------------------------------------------------ +// Specialized MUSA implementation for small filters. + +template +static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) { + typedef typename InternalType::scalar_t scalar_t; + const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; + const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; + __shared__ volatile scalar_t sf[filterH][filterW]; + __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; + + // Calculate tile index. + int minorBase = blockIdx.x; + int tileOutY = minorBase / p.launchMinor; + minorBase -= tileOutY * p.launchMinor; + minorBase *= loopMinor; + tileOutY *= tileOutH; + int tileOutXBase = blockIdx.y * p.loopX * tileOutW; + int majorBase = blockIdx.z * p.loopMajor; + if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | + majorBase >= p.sizeMajor) + return; + + // Load filter (flipped). + for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; + tapIdx += blockDim.x) { + int fy = tapIdx / filterW; + int fx = tapIdx - fy * filterW; + scalar_t v = 0; + if (fx < p.filterSize.x & fy < p.filterSize.y) { + int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; + int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; + v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; + } + sf[fy][fx] = v; + } + + // Loop over major and X. + for (int majorIdx = 0, major = majorBase; + majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) { + int baseNC = major * p.sizeMinor + minorBase; + int n = baseNC / p.inSize.z; + int baseC = baseNC - n * p.inSize.z; + for (int loopX = 0, tileOutX = tileOutXBase; + loopX < p.loopX & tileOutX < p.outSize.x; + loopX++, tileOutX += tileOutW) { + // Load input pixels. + int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; + int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; + int tileInX = floor_div(tileMidX, upx); + int tileInY = floor_div(tileMidY, upy); + __syncthreads(); + for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; + inIdx += blockDim.x) { + int relC = inIdx; + int relInX = relC / loopMinor; + int relInY = relInX / tileInW; + relC -= relInX * loopMinor; + relInX -= relInY * tileInW; + int c = baseC + relC; + int inX = tileInX + relInX; + int inY = tileInY + relInY; + scalar_t v = 0; + if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & + c < p.inSize.z) + v = (scalar_t)( + (const T *)p.x)[inX * p.inStride.x + inY * p.inStride.y + + c * p.inStride.z + n * p.inStride.w]; + sx[relInY][relInX][relC] = v; + } + + // Loop over output pixels. + __syncthreads(); + for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; + outIdx += blockDim.x) { + int relC = outIdx; + int relOutX = relC / loopMinor; + int relOutY = relOutX / tileOutW; + relC -= relOutX * loopMinor; + relOutX -= relOutY * tileOutW; + int c = baseC + relC; + int outX = tileOutX + relOutX; + int outY = tileOutY + relOutY; + + // Setup receptive field. + int midX = tileMidX + relOutX * downx; + int midY = tileMidY + relOutY * downy; + int inX = floor_div(midX, upx); + int inY = floor_div(midY, upy); + int relInX = inX - tileInX; + int relInY = inY - tileInY; + int filterX = (inX + 1) * upx - midX - 1; // flipped + int filterY = (inY + 1) * upy - midY - 1; // flipped + + // Inner loop. + if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) { + scalar_t v = 0; +#pragma unroll + for (int y = 0; y < filterH / upy; y++) +#pragma unroll + for (int x = 0; x < filterW / upx; x++) + v += sx[relInY + y][relInX + x][relC] * + sf[filterY + y * upy][filterX + x * upx]; + v *= p.gain; + ((T *)p.y)[outX * p.outStride.x + outY * p.outStride.y + + c * p.outStride.z + n * p.outStride.w] = (T)v; + } + } + } + } +} + +//------------------------------------------------------------------------ +// MUSA kernel selection. + +template +upfirdn2d_kernel_spec choose_upfirdn2d_kernel( + const upfirdn2d_kernel_params &p) { + int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; + upfirdn2d_kernel_spec spec = {(void *)upfirdn2d_kernel_large, -1, -1, 1, + 4}; // contiguous + if (s == 1) + spec = {(void *)upfirdn2d_kernel_large, -1, -1, 4, 1}; // channels_last + + // No up/downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + if (s != 1 && fx <= 7 && fy <= 7) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 5 && fy <= 5) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 3 && fy <= 3) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s == 1 && fx <= 7 && fy <= 7) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 5 && fy <= 5) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 3 && fy <= 3) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + } + + // 2x upsampling. + if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + if (s != 1 && fx <= 8 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + if (s != 1 && fx <= 2 && fy <= 2) + spec = {(void *)upfirdn2d_kernel_small, + 64, 16, 1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s == 1 && fx <= 8 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + if (s == 1 && fx <= 2 && fy <= 2) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 8, 1}; + } + if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + } + if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + } + + // 2x downsampling. + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) { + // contiguous + if (s != 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 16, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 16, 1, 1}; + if (s != 1 && fx <= 8 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, 32, + 8, 1, 1}; + if (s != 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, 32, + 8, 1, 1}; + if (s != 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, 32, + 8, 1, 1}; + if (s != 1 && fx <= 2 && fy <= 2) + spec = {(void *)upfirdn2d_kernel_small, 32, + 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 1, 1}; + if (s == 1 && fx <= 16 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 16, 16, 1, 1}; + if (s == 1 && fx <= 8 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, 8, + 8, 8, 1}; + if (s == 1 && fx <= 6 && fy <= 6) + spec = {(void *)upfirdn2d_kernel_small, 8, + 8, 8, 1}; + if (s == 1 && fx <= 4 && fy <= 4) + spec = {(void *)upfirdn2d_kernel_small, 8, + 8, 8, 1}; + if (s == 1 && fx <= 2 && fy <= 2) + spec = {(void *)upfirdn2d_kernel_small, 8, + 8, 8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 64, 8, 1, 1}; + if (s != 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 64, 8, 1, 1}; + if (s != 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, 64, + 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 24 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 64, 1, 8, 1}; + if (s == 1 && fx <= 16 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 64, 1, 8, 1}; + if (s == 1 && fx <= 8 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, 64, + 1, 8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) { + // contiguous + if (s != 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, + 32, 16, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, + 32, 16, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, + 32, 16, 1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 24) + spec = {(void *)upfirdn2d_kernel_small, 1, + 64, 8, 1}; + if (s == 1 && fx <= 1 && fy <= 16) + spec = {(void *)upfirdn2d_kernel_small, 1, + 64, 8, 1}; + if (s == 1 && fx <= 1 && fy <= 8) + spec = {(void *)upfirdn2d_kernel_small, 1, + 64, 8, 1}; + } + + // 4x upsampling. + if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 48 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + if (s != 1 && fx <= 32 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, + 64, 32, 1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s == 1 && fx <= 32 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + } + if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + if (s == 1 && fx <= 32 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 128, 1, 16, 1}; + } + if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, + 32, 32, 1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + if (s == 1 && fx <= 1 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, + 1, 128, 16, 1}; + } + + // 4x downsampling (inefficient). + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) { + // contiguous + if (s != 1 && fx <= 48 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 32, 8, 1, 1}; + if (s != 1 && fx <= 32 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 32, 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 48 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 32, 1, 8, 1}; + if (s == 1 && fx <= 32 && fy <= 1) + spec = {(void *)upfirdn2d_kernel_small, + 32, 1, 8, 1}; + } + if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) { + // contiguous + if (s != 1 && fx <= 1 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, + 32, 8, 1, 1}; + if (s != 1 && fx <= 1 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, + 32, 8, 1, 1}; + // channels_last + if (s == 1 && fx <= 1 && fy <= 48) + spec = {(void *)upfirdn2d_kernel_small, 1, + 32, 8, 1}; + if (s == 1 && fx <= 1 && fy <= 32) + spec = {(void *)upfirdn2d_kernel_small, 1, + 32, 8, 1}; + } + return spec; +} + +//------------------------------------------------------------------------ +// Template specializations. + +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel( + const upfirdn2d_kernel_params &p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel( + const upfirdn2d_kernel_params &p); +template upfirdn2d_kernel_spec choose_upfirdn2d_kernel( + const upfirdn2d_kernel_params &p); + +//------------------------------------------------------------------------ + +//------------------------------------------------------------------------ + +torch::Tensor upfirdn2d_op(torch::Tensor x, torch::Tensor f, int upx, int upy, + int downx, int downy, int padx0, int padx1, + int pady0, int pady1, bool flip, float gain) { + // Validate arguments. + TORCH_CHECK(x.is_privateuseone(), "x must reside on MUSA device"); + TORCH_CHECK(f.device() == x.device(), + "f must reside on the same device as x"); + TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); + TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); + TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); + TORCH_CHECK(x.numel() > 0, "x has zero size"); + TORCH_CHECK(f.numel() > 0, "f has zero size"); + TORCH_CHECK(x.dim() == 4, "x must be rank 4"); + TORCH_CHECK(f.dim() == 2, "f must be rank 2"); + TORCH_CHECK((x.size(0) - 1) * x.stride(0) + (x.size(1) - 1) * x.stride(1) + + (x.size(2) - 1) * x.stride(2) + + (x.size(3) - 1) * x.stride(3) <= + INT_MAX, + "x memory footprint is too large"); + TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); + TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); + TORCH_CHECK(downx >= 1 && downy >= 1, + "downsampling factor must be at least 1"); + + // Create output tensor. + const at::musa::OptionalMUSAGuard device_guard(device_of(x)); + int outW = + ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; + int outH = + ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; + TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); + torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, + x.options(), x.suggest_memory_format()); + TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); + TORCH_CHECK((y.size(0) - 1) * y.stride(0) + (y.size(1) - 1) * y.stride(1) + + (y.size(2) - 1) * y.stride(2) + + (y.size(3) - 1) * y.stride(3) <= + INT_MAX, + "output memory footprint is too large"); + + // Initialize MUSA kernel parameters. + upfirdn2d_kernel_params p; + p.x = x.data_ptr(); + p.f = f.data_ptr(); + p.y = y.data_ptr(); + p.up = make_int2(upx, upy); + p.down = make_int2(downx, downy); + p.pad0 = make_int2(padx0, pady0); + p.flip = (flip) ? 1 : 0; + p.gain = gain; + p.inSize = + make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); + p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), + (int)x.stride(0)); + p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); + p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); + p.outSize = + make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); + p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), + (int)y.stride(0)); + p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; + p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; + + // Choose MUSA kernel. + upfirdn2d_kernel_spec spec; + AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "upfirdn2d_musa", [&] { + spec = choose_upfirdn2d_kernel(p); + }); + + // Set looping options. + p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; + p.loopMinor = spec.loopMinor; + p.loopX = spec.loopX; + p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; + p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; + + // Compute grid size. + dim3 blockSize, gridSize; + if (spec.tileOutW < 0) // large + { + blockSize = dim3(4, 32, 1); + gridSize = + dim3(((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, + (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, p.launchMajor); + } else // small + { + blockSize = dim3(256, 1, 1); + gridSize = + dim3(((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, + (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, p.launchMajor); + } + + // Launch MUSA kernel. + void *args[] = {&p}; +#ifdef MMCV_WITH_HIP + AT_MUSA_CHECK(hipLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, + c10::musa::getCurrentMUSAStream())); +#else + AT_MUSA_CHECK(musaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, + c10::musa::getCurrentMUSAStream())); +#endif + + return y; +} +#else +#warning "upfirdn2d is supported when MUSA_ARCH > 21" +#endif //MUSA_ARCH diff --git a/mmcv/ops/csrc/pytorch/musa/voxelization_musa.mu b/mmcv/ops/csrc/pytorch/musa/voxelization_musa.mu new file mode 100644 index 0000000000..b243871caa --- /dev/null +++ b/mmcv/ops/csrc/pytorch/musa/voxelization_musa.mu @@ -0,0 +1,286 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include +#include + +#include "pytorch_musa_helper.hpp" +#include "voxelization_musa_kernel.muh" + +int HardVoxelizeForwardMUSAKernelLauncher( + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim = 3) { + // current version tooks about 0.04s for one frame on cpu + // check device + + c10::musa::MUSAGuard device_guard(points.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + const int num_points = points.size(0); + const int num_features = points.size(1); + + const float voxel_x = voxel_size[0]; + const float voxel_y = voxel_size[1]; + const float voxel_z = voxel_size[2]; + const float coors_x_min = coors_range[0]; + const float coors_y_min = coors_range[1]; + const float coors_z_min = coors_range[2]; + const float coors_x_max = coors_range[3]; + const float coors_y_max = coors_range[4]; + const float coors_z_max = coors_range[5]; + + const int grid_x = round((coors_x_max - coors_x_min) / voxel_x); + const int grid_y = round((coors_y_max - coors_y_min) / voxel_y); + const int grid_z = round((coors_z_max - coors_z_min) / voxel_z); + + // map points to voxel coors + at::Tensor temp_coors = + at::zeros({num_points, NDim}, points.options().dtype(at::kInt)); + + dim3 grid(std::min(at::musa::ATenCeilDiv(num_points, 512), 4096)); + dim3 block(512); + + // 1. link point to corresponding voxel coors + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "hard_voxelize_kernel", ([&] { + dynamic_voxelize_kernel<<>>( + points.contiguous().data_ptr(), + temp_coors.contiguous().data_ptr(), voxel_x, voxel_y, voxel_z, + coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max, + coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, + NDim); + })); + + AT_MUSA_CHECK(musaGetLastError()); + + // 2. map point to the idx of the corresponding voxel, find duplicate coor + // create some temporary variables + auto point_to_pointidx = -at::ones( + { + num_points, + }, + points.options().dtype(at::kInt)); + auto point_to_voxelidx = -at::ones( + { + num_points, + }, + points.options().dtype(at::kInt)); + + dim3 map_grid(std::min(at::musa::ATenCeilDiv(num_points, 512), 4096)); + dim3 map_block(512); + + AT_DISPATCH_ALL_TYPES( + temp_coors.scalar_type(), "determin_duplicate", ([&] { + point_to_voxelidx_kernel<<>>( + temp_coors.contiguous().data_ptr(), + point_to_voxelidx.contiguous().data_ptr(), + point_to_pointidx.contiguous().data_ptr(), max_points, + max_voxels, num_points, NDim); + })); + + AT_MUSA_CHECK(musaGetLastError()); + + // 3. determine voxel num and voxel's coor index + // make the logic in the MUSA device could accelerate about 10 times + auto coor_to_voxelidx = -at::ones( + { + num_points, + }, + points.options().dtype(at::kInt)); + auto voxel_num = at::zeros( + { + 1, + }, + points.options().dtype(at::kInt)); // must be zero from the beginning + + AT_DISPATCH_ALL_TYPES(temp_coors.scalar_type(), "determin_duplicate", ([&] { + determin_voxel_num<<<1, 1, 0, stream>>>( + num_points_per_voxel.contiguous().data_ptr(), + point_to_voxelidx.contiguous().data_ptr(), + point_to_pointidx.contiguous().data_ptr(), + coor_to_voxelidx.contiguous().data_ptr(), + voxel_num.contiguous().data_ptr(), + max_points, max_voxels, num_points); + })); + + AT_MUSA_CHECK(musaGetLastError()); + + // 4. copy point features to voxels + // Step 4 & 5 could be parallel + auto pts_output_size = num_points * num_features; + dim3 cp_grid(std::min(at::musa::ATenCeilDiv(pts_output_size, 512), 4096)); + dim3 cp_block(512); + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "assign_point_to_voxel", ([&] { + assign_point_to_voxel<<>>( + pts_output_size, points.contiguous().data_ptr(), + point_to_voxelidx.contiguous().data_ptr(), + coor_to_voxelidx.contiguous().data_ptr(), + voxels.contiguous().data_ptr(), max_points, num_features, + num_points, NDim); + })); + // musaDeviceSynchronize(); + // AT_MUSA_CHECK(musaGetLastError()); + + // 5. copy coors of each voxels + auto coors_output_size = num_points * NDim; + dim3 coors_cp_grid( + std::min(at::musa::ATenCeilDiv(coors_output_size, 512), 4096)); + dim3 coors_cp_block(512); + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "assign_point_to_voxel", ([&] { + assign_voxel_coors + <<>>( + coors_output_size, temp_coors.contiguous().data_ptr(), + point_to_voxelidx.contiguous().data_ptr(), + coor_to_voxelidx.contiguous().data_ptr(), + coors.contiguous().data_ptr(), num_points, NDim); + })); + + AT_MUSA_CHECK(musaGetLastError()); + + auto voxel_num_cpu = voxel_num.to(at::kCPU); + int voxel_num_int = voxel_num_cpu.data_ptr()[0]; + + return voxel_num_int; +} + +int NondeterministicHardVoxelizeForwardMUSAKernelLauncher( + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim = 3) { + c10::musa::MUSAGuard device_guard(points.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + const int num_points = points.size(0); + const int num_features = points.size(1); + + if (num_points == 0) return 0; + + dim3 blocks( + std::min(at::musa::ATenCeilDiv(num_points, THREADS_PER_BLOCK), 4096)); + dim3 threads(THREADS_PER_BLOCK); + + const float voxel_x = voxel_size[0]; + const float voxel_y = voxel_size[1]; + const float voxel_z = voxel_size[2]; + const float coors_x_min = coors_range[0]; + const float coors_y_min = coors_range[1]; + const float coors_z_min = coors_range[2]; + const float coors_x_max = coors_range[3]; + const float coors_y_max = coors_range[4]; + const float coors_z_max = coors_range[5]; + + const int grid_x = round((coors_x_max - coors_x_min) / voxel_x); + const int grid_y = round((coors_y_max - coors_y_min) / voxel_y); + const int grid_z = round((coors_z_max - coors_z_min) / voxel_z); + + // map points to voxel coors + at::Tensor temp_coors = + at::zeros({num_points, NDim}, points.options().dtype(at::kInt)); + + // 1. link point to corresponding voxel coors + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "hard_voxelize_kernel", ([&] { + dynamic_voxelize_kernel<<>>( + points.contiguous().data_ptr(), + temp_coors.contiguous().data_ptr(), voxel_x, voxel_y, voxel_z, + coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max, + coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, + NDim); + })); + + at::Tensor coors_map; + at::Tensor reduce_count; + + auto coors_clean = temp_coors.masked_fill(temp_coors.lt(0).any(-1, true), -1); + + std::tie(temp_coors, coors_map, reduce_count) = + at::unique_dim(coors_clean, 0, true, true, false); + + if (temp_coors[0][0].lt(0).item()) { + // the first element of temp_coors is (-1,-1,-1) and should be removed + temp_coors = temp_coors.slice(0, 1); + coors_map = coors_map - 1; + } + + int num_coors = temp_coors.size(0); + temp_coors = temp_coors.to(at::kInt); + coors_map = coors_map.to(at::kInt); + + at::Tensor coors_count = at::zeros({1}, coors_map.options()); + at::Tensor coors_order = at::empty({num_coors}, coors_map.options()); + at::Tensor pts_id = at::zeros({num_points}, coors_map.options()); + reduce_count = at::zeros({num_coors}, coors_map.options()); + + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "get_assign_pos", ([&] { + nondeterministic_get_assign_pos<<>>( + num_points, coors_map.contiguous().data_ptr(), + pts_id.contiguous().data_ptr(), + coors_count.contiguous().data_ptr(), + reduce_count.contiguous().data_ptr(), + coors_order.contiguous().data_ptr()); + })); + + AT_DISPATCH_ALL_TYPES( + points.scalar_type(), "assign_point_to_voxel", ([&] { + nondeterministic_assign_point_voxel + <<>>( + num_points, points.contiguous().data_ptr(), + coors_map.contiguous().data_ptr(), + pts_id.contiguous().data_ptr(), + temp_coors.contiguous().data_ptr(), + reduce_count.contiguous().data_ptr(), + coors_order.contiguous().data_ptr(), + voxels.contiguous().data_ptr(), + coors.contiguous().data_ptr(), + num_points_per_voxel.contiguous().data_ptr(), + max_voxels, max_points, num_features, NDim); + })); + AT_MUSA_CHECK(musaGetLastError()); + return max_voxels < num_coors ? max_voxels : num_coors; +} + +void DynamicVoxelizeForwardMUSAKernelLauncher( + const at::Tensor &points, at::Tensor &coors, + const std::vector voxel_size, const std::vector coors_range, + const int NDim = 3) { + // current version tooks about 0.04s for one frame on cpu + // check device + + c10::musa::MUSAGuard device_guard(points.device()); + musaStream_t stream = c10::musa::getCurrentMUSAStream(); + + const int num_points = points.size(0); + const int num_features = points.size(1); + + const float voxel_x = voxel_size[0]; + const float voxel_y = voxel_size[1]; + const float voxel_z = voxel_size[2]; + const float coors_x_min = coors_range[0]; + const float coors_y_min = coors_range[1]; + const float coors_z_min = coors_range[2]; + const float coors_x_max = coors_range[3]; + const float coors_y_max = coors_range[4]; + const float coors_z_max = coors_range[5]; + + const int grid_x = round((coors_x_max - coors_x_min) / voxel_x); + const int grid_y = round((coors_y_max - coors_y_min) / voxel_y); + const int grid_z = round((coors_z_max - coors_z_min) / voxel_z); + + const int col_blocks = at::musa::ATenCeilDiv(num_points, THREADS_PER_BLOCK); + dim3 blocks(col_blocks); + dim3 threads(THREADS_PER_BLOCK); + + AT_DISPATCH_ALL_TYPES(points.scalar_type(), "dynamic_voxelize_kernel", [&] { + dynamic_voxelize_kernel<<>>( + points.contiguous().data_ptr(), + coors.contiguous().data_ptr(), voxel_x, voxel_y, voxel_z, + coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max, + coors_z_max, grid_x, grid_y, grid_z, num_points, num_features, NDim); + }); + + AT_MUSA_CHECK(musaGetLastError()); +} diff --git a/mmcv/ops/points_in_boxes.py b/mmcv/ops/points_in_boxes.py index e58a6e2a12..06be6c2a02 100644 --- a/mmcv/ops/points_in_boxes.py +++ b/mmcv/ops/points_in_boxes.py @@ -1,4 +1,5 @@ import torch +from mmengine.device import is_cuda_available, is_musa_available from torch import Tensor from ..utils import ext_loader @@ -10,7 +11,7 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor: - """Find the box in which each point is (CUDA). + """Find the box in which each point is (CUDA/MUSA). Args: points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate. @@ -38,7 +39,7 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor: # If manually put the tensor 'points' or 'boxes' on a device # which is not the current device, some temporary variables - # will be created on the current device in the cuda op, + # will be created on the current device in the cuda/musa op, # and the output will be incorrect. # Therefore, we force the current device to be the same # as the device of the tensors if it was not. @@ -48,8 +49,12 @@ def points_in_boxes_part(points: Tensor, boxes: Tensor) -> Tensor: assert points_device == boxes.get_device(), \ 'Points and boxes should be put on the same device' if points.device.type != 'npu': - if torch.cuda.current_device() != points_device: - torch.cuda.set_device(points_device) + if is_cuda_available(): + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + elif is_musa_available(): + if torch.musa.current_device() != points_device: + torch.musa.set_device(points_device) else: boxes[:, :, 2] += boxes[:, :, 5] / 2.0 @@ -99,7 +104,7 @@ def points_in_boxes_cpu(points: Tensor, boxes: Tensor) -> Tensor: def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor: - """Find all boxes in which each point is (CUDA). + """Find all boxes in which each point is (CUDA/MUSA). Args: points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate @@ -131,8 +136,12 @@ def points_in_boxes_all(points: Tensor, boxes: Tensor) -> Tensor: assert points_device == boxes.get_device(), \ 'Points and boxes should be put on the same device' if points.device.type != 'npu': - if torch.cuda.current_device() != points_device: - torch.cuda.set_device(points_device) + if is_cuda_available(): + if torch.cuda.current_device() != points_device: + torch.cuda.set_device(points_device) + elif is_musa_available(): + if torch.musa.current_device() != points_device: + torch.musa.set_device(points_device) ext_module.points_in_boxes_all_forward(boxes.contiguous(), points.contiguous(), diff --git a/mmcv/ops/sync_bn.py b/mmcv/ops/sync_bn.py index 2b14d30376..ef3c742aed 100644 --- a/mmcv/ops/sync_bn.py +++ b/mmcv/ops/sync_bn.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist import torch.nn.functional as F +from mmengine.device import is_cuda_available, is_musa_available from mmengine.registry import MODELS from torch.autograd import Function from torch.autograd.function import once_differentiable @@ -47,10 +48,20 @@ def forward(self, input: torch.Tensor, running_mean: torch.Tensor, self.group_size = group_size self.stats_mode = stats_mode - assert isinstance( - input, (torch.HalfTensor, torch.FloatTensor, - torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \ - f'only support Half or Float Tensor, but {input.type()}' + if is_cuda_available(): + assert isinstance( + input, (torch.HalfTensor, torch.FloatTensor, + torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \ + f'only support Half or Float Tensor, but {input.type()}' + elif is_musa_available(): + assert isinstance( + input, (torch.HalfTensor, torch.FloatTensor, + torch.musa.HalfTensor, torch.musa.FloatTensor)), \ + f'only support Half or Float Tensor, but {input.type()}' + else: + assert isinstance( + input, (torch.HalfTensor, torch.FloatTensor)), \ + f'only support Half or Float Tensor, but {input.type()}' output = torch.zeros_like(input) input3d = input.flatten(start_dim=2) output3d = output.view_as(input3d) diff --git a/mmcv/ops/upfirdn2d.py b/mmcv/ops/upfirdn2d.py index e015095033..6f251720b3 100644 --- a/mmcv/ops/upfirdn2d.py +++ b/mmcv/ops/upfirdn2d.py @@ -116,6 +116,13 @@ def upfirdn2d(input: torch.Tensor, padding=padding, flip_filter=flip_filter, gain=gain).apply(input, filter) + elif use_custom_op and input.device.type == 'musa': + return _upfirdn2d_musa( + up=up, + down=down, + padding=padding, + flip_filter=flip_filter, + gain=gain).apply(input, filter) return _upfirdn2d_ref( input, filter, @@ -303,6 +310,101 @@ def backward(ctx, dy): # pylint: disable=arguments-differ return Upfirdn2dCuda +_upfirdn2d_musa_cache: Dict = dict() + + +def _upfirdn2d_musa(up: int = 1, + down: int = 1, + padding: Union[int, List[int]] = 0, + flip_filter: bool = False, + gain: Union[float, int] = 1): + """Fast MUSA implementation of `upfirdn2d()` using custom ops. + + Args: + up (int): Integer upsampling factor. Can be a single int or a + list/tuple `[x, y]`. Defaults to 1. + down (int): Integer downsampling factor. Can be a single int + or a list/tuple `[x, y]`. Defaults to 1. + padding (int | tuple[int]): Padding with respect to the upsampled + image. Can be a single number or a list/tuple `[x, y]` or + `[x_before, x_after, y_before, y_after]`. Defaults to 0. + flip_filter (bool): False = convolution, True = correlation. + Defaults to False. + gain (int): Overall scaling factor for signal magnitude. + Defaults to 1. + + Returns: + torch.Tensor: Tensor of the shape `[batch_size, num_channels, + out_height, out_width]` + """ + # Parse arguments. + upx, upy = _parse_scaling(up) + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + + # Lookup from cache. + key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, + gain) + if key in _upfirdn2d_musa_cache: + return _upfirdn2d_musa_cache[key] + + # Forward op. + class Upfirdn2dMusa(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, f): # pylint: disable=arguments-differ + assert isinstance(x, torch.Tensor) and x.ndim == 4 + if f is None: + f = torch.ones([1, 1], dtype=torch.float32, device=x.device) + if f.ndim == 1 and f.shape[0] == 1: + f = f.square().unsqueeze( + 0) # Convert separable-1 into full-1x1. + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + y = x + if f.ndim == 2: + y = ext_module.upfirdn2d(y, f, upx, upy, downx, downy, padx0, + padx1, pady0, pady1, flip_filter, + gain) + else: + y = ext_module.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, + padx0, padx1, 0, 0, flip_filter, 1.0) + y = ext_module.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, + 0, 0, pady0, pady1, flip_filter, gain) + ctx.save_for_backward(f) + ctx.x_shape = x.shape + return y + + @staticmethod + def backward(ctx, dy): # pylint: disable=arguments-differ + f, = ctx.saved_tensors + _, _, ih, iw = ctx.x_shape + _, _, oh, ow = dy.shape + fw, fh = _get_filter_size(f) + p = [ + fw - padx0 - 1, + iw * upx - ow * downx + padx0 - upx + 1, + fh - pady0 - 1, + ih * upy - oh * downy + pady0 - upy + 1, + ] + dx = None + df = None + + if ctx.needs_input_grad[0]: + dx = _upfirdn2d_musa( + up=down, + down=up, + padding=p, + flip_filter=(not flip_filter), + gain=gain).apply(dy, f) + + assert not ctx.needs_input_grad[1] + return dx, df + + # Add to cache. + _upfirdn2d_musa_cache[key] = Upfirdn2dMusa + return Upfirdn2dMusa + + def filter2d(input: torch.Tensor, filter: torch.Tensor, padding: Union[int, List[int]] = 0, diff --git a/tests/test_ops/test_points_in_polygons.py b/tests/test_ops/test_points_in_polygons.py index d224d1593a..cdfbef5791 100644 --- a/tests/test_ops/test_points_in_polygons.py +++ b/tests/test_ops/test_points_in_polygons.py @@ -4,7 +4,7 @@ import torch from mmcv.ops import points_in_polygons -from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE @pytest.mark.parametrize('device', [ @@ -15,7 +15,11 @@ pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), ]) def test_points_in_polygons(device): points = np.array([[300., 300.], [400., 400.], [100., 100], [300, 250], diff --git a/tests/test_ops/test_prroi_pool.py b/tests/test_ops/test_prroi_pool.py index 0535dfbe21..b7fd52f95e 100644 --- a/tests/test_ops/test_prroi_pool.py +++ b/tests/test_ops/test_prroi_pool.py @@ -3,7 +3,7 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE _USING_PARROTS = True try: @@ -41,7 +41,11 @@ class TestPrRoiPool: pytest.param( 'cuda', marks=pytest.mark.skipif( - not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_roipool_gradcheck(self, device): from mmcv.ops import PrRoIPool @@ -92,7 +96,11 @@ def _test_roipool_allclose(self, device, dtype=torch.float): pytest.param( 'cuda', marks=pytest.mark.skipif( - not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_roipool_allclose_float(self, device): self._test_roipool_allclose(device, dtype=torch.float) diff --git a/tests/test_ops/test_psa_mask.py b/tests/test_ops/test_psa_mask.py index b0fd86e8f5..d692f4d782 100644 --- a/tests/test_ops/test_psa_mask.py +++ b/tests/test_ops/test_psa_mask.py @@ -4,7 +4,8 @@ import torch import torch.nn as nn -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE, + IS_NPU_AVAILABLE) class Loss(nn.Module): @@ -32,7 +33,11 @@ class TestPSAMask: pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_psa_mask_collect(self, device): from mmcv.ops import PSAMask @@ -84,7 +89,11 @@ def test_psa_mask_collect(self, device): pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_psa_mask_distribute(self, device): from mmcv.ops import PSAMask diff --git a/tests/test_ops/test_roi_align.py b/tests/test_ops/test_roi_align.py index dcd2103461..f9a51eb280 100644 --- a/tests/test_ops/test_roi_align.py +++ b/tests/test_ops/test_roi_align.py @@ -3,7 +3,8 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE, + IS_NPU_AVAILABLE) _USING_PARROTS = True try: @@ -107,7 +108,11 @@ def _test_roialign_allclose(device, dtype): pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), ]) def test_roialign_float(device, dtype): _test_roialign_allclose(device=device, dtype=dtype) diff --git a/tests/test_ops/test_roi_align_rotated.py b/tests/test_ops/test_roi_align_rotated.py index 0d5ca432df..ddfacea0b3 100644 --- a/tests/test_ops/test_roi_align_rotated.py +++ b/tests/test_ops/test_roi_align_rotated.py @@ -3,7 +3,7 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE _USING_PARROTS = True try: @@ -132,15 +132,19 @@ def _test_roialign_rotated_allclose(device, dtype): pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), ]) @pytest.mark.parametrize('dtype', [ torch.float, pytest.param( torch.double, marks=pytest.mark.skipif( - IS_MLU_AVAILABLE, - reason='MLU does not support for 64-bit floating point')), + IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE, + reason='MLU, MUSA does not support for 64-bit floating point')), torch.half ]) def test_roialign_rotated(device, dtype): diff --git a/tests/test_ops/test_roi_pool.py b/tests/test_ops/test_roi_pool.py index 5ab04bce2b..f6d38d5af0 100644 --- a/tests/test_ops/test_roi_pool.py +++ b/tests/test_ops/test_roi_pool.py @@ -5,7 +5,8 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE, + IS_NPU_AVAILABLE) _USING_PARROTS = True try: @@ -89,16 +90,20 @@ def _test_roipool_allclose(self, device, dtype=torch.float): pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), ]) @pytest.mark.parametrize('dtype', [ torch.float, pytest.param( torch.double, marks=pytest.mark.skipif( - IS_MLU_AVAILABLE or IS_NPU_AVAILABLE, - reason='MLU, NPU does not support for 64-bit floating point')), - torch.half + IS_MLU_AVAILABLE or IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE, + reason='MLU, NPU, MUSA ' + 'does not support for 64-bit floating point')), torch.half ]) def test_roipool_allclose(self, device, dtype): self._test_roipool_allclose(device, dtype) diff --git a/tests/test_ops/test_roiaware_pool3d.py b/tests/test_ops/test_roiaware_pool3d.py index 189db33cc0..338a1544c3 100644 --- a/tests/test_ops/test_roiaware_pool3d.py +++ b/tests/test_ops/test_roiaware_pool3d.py @@ -5,7 +5,8 @@ from mmcv.ops import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu, points_in_boxes_part) -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE, + IS_NPU_AVAILABLE) @pytest.mark.parametrize('dtype', [ @@ -13,7 +14,8 @@ pytest.param( torch.double, marks=pytest.mark.skipif( - IS_MLU_AVAILABLE, reason='MLU does not support for double')) + IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE, + reason='MLU, MUSA does not support for double')) ]) @pytest.mark.parametrize('device', [ pytest.param( @@ -23,7 +25,11 @@ pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_RoIAwarePool3d(device, dtype): roiaware_pool3d_max = RoIAwarePool3d( @@ -64,7 +70,11 @@ def test_RoIAwarePool3d(device, dtype): pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_points_in_boxes_part(device): boxes = torch.tensor( diff --git a/tests/test_ops/test_roipoint_pool3d.py b/tests/test_ops/test_roipoint_pool3d.py index f0ad5586a9..c3109a363a 100644 --- a/tests/test_ops/test_roipoint_pool3d.py +++ b/tests/test_ops/test_roipoint_pool3d.py @@ -3,7 +3,8 @@ import torch from mmcv.ops import RoIPointPool3d -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE, + IS_NPU_AVAILABLE) @pytest.mark.parametrize('device', [ @@ -18,15 +19,19 @@ pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), ]) @pytest.mark.parametrize('dtype', [ torch.float, torch.half, pytest.param( torch.double, marks=pytest.mark.skipif( - IS_MLU_AVAILABLE or IS_NPU_AVAILABLE, - reason='MLU and NPU does not support for double')) + IS_MLU_AVAILABLE or IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE, + reason='MLU, NPU, MUSA does not support for double')) ]) def test_roipoint(device, dtype): points = torch.tensor( diff --git a/tests/test_ops/test_rotated_feature_align.py b/tests/test_ops/test_rotated_feature_align.py index 23de07e8ef..6447f410ae 100644 --- a/tests/test_ops/test_rotated_feature_align.py +++ b/tests/test_ops/test_rotated_feature_align.py @@ -3,7 +3,8 @@ import torch from mmcv.ops import rotated_feature_align -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE, + IS_NPU_AVAILABLE) @pytest.mark.skipif( @@ -21,6 +22,10 @@ 'npu', marks=pytest.mark.skipif( not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), pytest.param( 'cpu', marks=pytest.mark.skipif( diff --git a/tests/test_ops/test_scatter_points.py b/tests/test_ops/test_scatter_points.py index b8b569481a..46ab4430e4 100644 --- a/tests/test_ops/test_scatter_points.py +++ b/tests/test_ops/test_scatter_points.py @@ -4,7 +4,7 @@ from torch.autograd import gradcheck from mmcv.ops import DynamicScatter -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE if torch.__version__ == 'parrots': pytest.skip('not supported in parrots now', allow_module_level=True) @@ -18,7 +18,11 @@ pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), ]) def test_dynamic_scatter(device): dsmean = DynamicScatter([0.32, 0.32, 6], diff --git a/tests/test_ops/test_spconv.py b/tests/test_ops/test_spconv.py index 17ca5678ed..6d958e0213 100644 --- a/tests/test_ops/test_spconv.py +++ b/tests/test_ops/test_spconv.py @@ -10,7 +10,7 @@ if torch.__version__ == 'parrots': pytest.skip('not supported in parrots now', allow_module_level=True) -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE def make_sparse_convmodule(in_channels, @@ -86,10 +86,17 @@ def make_sparse_convmodule(in_channels, pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), ]) def test_make_sparse_convmodule(device): - torch.cuda.empty_cache() + if IS_CUDA_AVAILABLE: + torch.cuda.empty_cache() + elif IS_MUSA_AVAILABLE: + torch.musa.empty_cache() voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315], [6.8162713, -2.480431, -1.3616394, 0.36], [11.643568, -4.744306, -1.3580885, 0.16], diff --git a/tests/test_ops/test_syncbn.py b/tests/test_ops/test_syncbn.py index d1c1605ad5..dd046b9e2f 100644 --- a/tests/test_ops/test_syncbn.py +++ b/tests/test_ops/test_syncbn.py @@ -8,6 +8,8 @@ import torch.distributed as dist import torch.nn as nn +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE + if platform.system() == 'Windows': import regex as re else: @@ -29,10 +31,24 @@ def dist_init(self): os.environ['WORLD_SIZE'] = str(world_size) os.environ['RANK'] = str(rank) - dist.init_process_group('nccl') - torch.cuda.set_device(local_rank) - - def _test_syncbn_train(self, size=1, half=False): + if IS_CUDA_AVAILABLE: + dist.init_process_group('nccl') + torch.cuda.set_device(local_rank) + elif IS_MUSA_AVAILABLE: + dist.init_process_group('mccl') + torch.musa.set_device(local_rank) + + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) + ]) + def _test_syncbn_train(self, size=1, half=False, device='cuda'): if 'SLURM_NTASKS' not in os.environ or int( os.environ['SLURM_NTASKS']) != 4: @@ -49,10 +65,13 @@ def _test_syncbn_train(self, size=1, half=False): rank = dist.get_rank() torch.manual_seed(9) - torch.cuda.manual_seed(9) + if IS_CUDA_AVAILABLE: + torch.cuda.manual_seed(9) + elif IS_MUSA_AVAILABLE: + torch.musa.manual_seed(9) - self.x = torch.rand(16, 3, 2, 3).cuda() - self.y_bp = torch.rand(16, 3, 2, 3).cuda() + self.x = torch.rand(16, 3, 2, 3).to(device) + self.y_bp = torch.rand(16, 3, 2, 3).to(device) if half: self.x = self.x.half() @@ -60,7 +79,10 @@ def _test_syncbn_train(self, size=1, half=False): dist.broadcast(self.x, src=0) dist.broadcast(self.y_bp, src=0) - torch.cuda.synchronize() + if IS_CUDA_AVAILABLE: + torch.cuda.synchronize() + elif IS_MUSA_AVAILABLE: + torch.musa.synchronize() if size == 1: groups = [None, None, None, None] groups[0] = dist.new_group([0]) @@ -75,13 +97,13 @@ def _test_syncbn_train(self, size=1, half=False): group = groups[rank] elif size == 4: group = dist.group.WORLD - syncbn = SyncBatchNorm(3, group=group).cuda() + syncbn = SyncBatchNorm(3, group=group).to(device) syncbn.weight.data[0] = 0.2 syncbn.weight.data[1] = 0.5 syncbn.weight.data[2] = 0.7 syncbn.train() - bn = nn.BatchNorm2d(3).cuda() + bn = nn.BatchNorm2d(3).to(device) bn.weight.data[0] = 0.2 bn.weight.data[1] = 0.5 bn.weight.data[2] = 0.7 @@ -143,7 +165,17 @@ def _test_syncbn_train(self, size=1, half=False): assert np.allclose(x_grad.data.cpu().numpy(), sx_grad.data.cpu().numpy(), 1e-2) - def _test_syncbn_empty_train(self, size=1, half=False): + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) + ]) + def _test_syncbn_empty_train(self, size=1, half=False, device='cuda'): if 'SLURM_NTASKS' not in os.environ or int( os.environ['SLURM_NTASKS']) != 4: @@ -160,10 +192,13 @@ def _test_syncbn_empty_train(self, size=1, half=False): rank = dist.get_rank() torch.manual_seed(9) - torch.cuda.manual_seed(9) + if IS_CUDA_AVAILABLE: + torch.cuda.manual_seed(9) + elif IS_MUSA_AVAILABLE: + torch.musa.manual_seed(9) - self.x = torch.rand(0, 3, 2, 3).cuda() - self.y_bp = torch.rand(0, 3, 2, 3).cuda() + self.x = torch.rand(0, 3, 2, 3).to(device) + self.y_bp = torch.rand(0, 3, 2, 3).to(device) if half: self.x = self.x.half() @@ -171,7 +206,10 @@ def _test_syncbn_empty_train(self, size=1, half=False): dist.broadcast(self.x, src=0) dist.broadcast(self.y_bp, src=0) - torch.cuda.synchronize() + if IS_CUDA_AVAILABLE: + torch.cuda.synchronize() + elif IS_MUSA_AVAILABLE: + torch.musa.synchronize() if size == 1: groups = [None, None, None, None] groups[0] = dist.new_group([0]) @@ -187,13 +225,13 @@ def _test_syncbn_empty_train(self, size=1, half=False): elif size == 4: group = dist.group.WORLD - syncbn = SyncBatchNorm(3, group=group, stats_mode='N').cuda() + syncbn = SyncBatchNorm(3, group=group, stats_mode='N').to(device) syncbn.weight.data[0] = 0.2 syncbn.weight.data[1] = 0.5 syncbn.weight.data[2] = 0.7 syncbn.train() - bn = nn.BatchNorm2d(3).cuda() + bn = nn.BatchNorm2d(3).to(device) bn.weight.data[0] = 0.2 bn.weight.data[1] = 0.5 bn.weight.data[2] = 0.7 diff --git a/tests/test_ops/test_three_interpolate.py b/tests/test_ops/test_three_interpolate.py index d27a795ecf..dd8e25c892 100644 --- a/tests/test_ops/test_three_interpolate.py +++ b/tests/test_ops/test_three_interpolate.py @@ -3,7 +3,7 @@ import torch from mmcv.ops import three_interpolate -from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MUSA_AVAILABLE, IS_NPU_AVAILABLE @pytest.mark.parametrize('dtype', [ @@ -11,8 +11,8 @@ pytest.param( torch.double, marks=pytest.mark.skipif( - IS_NPU_AVAILABLE, - reason='NPU does not support for 64-bit floating point')) + IS_NPU_AVAILABLE or IS_MUSA_AVAILABLE, + reason='NPU, MUSA does not support for 64-bit floating point')) ]) @pytest.mark.parametrize('device', [ pytest.param( @@ -22,9 +22,15 @@ pytest.param( 'npu', marks=pytest.mark.skipif( - not IS_NPU_AVAILABLE, reason='requires NPU support')) + not IS_NPU_AVAILABLE, reason='requires NPU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) def test_three_interpolate(dtype, device): + if IS_MUSA_AVAILABLE: + torch.musa.empty_cache() features = torch.tensor( [[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350], [3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236], diff --git a/tests/test_ops/test_three_nn.py b/tests/test_ops/test_three_nn.py index 456188b917..9348dd0d5b 100644 --- a/tests/test_ops/test_three_nn.py +++ b/tests/test_ops/test_three_nn.py @@ -3,7 +3,7 @@ import torch from mmcv.ops import three_nn -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE known = [[[-1.8373, 3.5605, -0.7867], [0.7615, 2.9420, 0.2314], [-0.6503, 3.6637, -1.0622], [-1.8373, 3.5605, -0.7867], @@ -48,7 +48,11 @@ pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')) ]) @pytest.mark.parametrize('dtype,rtol', [(torch.float, 1e-8), (torch.half, 1e-3)]) diff --git a/tests/test_ops/test_tin_shift.py b/tests/test_ops/test_tin_shift.py index c8ce14465c..3d82de73b5 100755 --- a/tests/test_ops/test_tin_shift.py +++ b/tests/test_ops/test_tin_shift.py @@ -5,7 +5,7 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE _USING_PARROTS = True try: @@ -209,15 +209,19 @@ def _test_tinshift_assert(device, dtype): pytest.param( 'mlu', marks=pytest.mark.skipif( - not IS_MLU_AVAILABLE, reason='requires MLU support')) + not IS_MLU_AVAILABLE, reason='requires MLU support')), + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), ]) @pytest.mark.parametrize('dtype', [ torch.float, pytest.param( torch.double, marks=pytest.mark.skipif( - IS_MLU_AVAILABLE, - reason='MLU does not support for 64-bit floating point')), + IS_MLU_AVAILABLE or IS_MUSA_AVAILABLE, + reason='MLU, MUSA does not support for 64-bit floating point')), torch.half ]) def test_tinshift(device, dtype): diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py index 78282a8ad0..06167c1d39 100644 --- a/tests/test_ops/test_voxelization.py +++ b/tests/test_ops/test_voxelization.py @@ -4,7 +4,8 @@ import torch from mmcv.ops import Voxelization -from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import (IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MUSA_AVAILABLE, + IS_NPU_AVAILABLE) def _get_voxel_points_indices(points, coors, voxel): @@ -215,3 +216,38 @@ def test_voxelization_npu(device_type): assert np.all(coors == expected_coors) assert np.all(voxels == expected_voxels) assert np.all(num_points_per_voxel == expected_num_points_per_voxel) + + +@pytest.mark.parametrize('device_type', [ + pytest.param( + 'musa', + marks=pytest.mark.skipif( + not IS_MUSA_AVAILABLE, reason='requires MUSA support')), +]) +def test_voxelization_musa(device_type): + voxel_size = [0.5, 0.5, 0.5] + point_cloud_range = [0, -40, -3, 70.4, 40, 1] + + voxel_dict = np.load( + 'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item() + expected_coors = voxel_dict['coors'] + expected_voxels = voxel_dict['voxels'] + expected_num_points_per_voxel = voxel_dict['num_points_per_voxel'] + points = voxel_dict['points'] + + points = torch.tensor(points) + max_num_points = 1000 + hard_voxelization = Voxelization(voxel_size, point_cloud_range, + max_num_points) + + device = torch.device(device_type) + + # test hard_voxelization on mlu + points = points.contiguous().to(device) + coors, voxels, num_points_per_voxel = hard_voxelization.forward(points) + coors = coors.cpu().detach().numpy() + voxels = voxels.cpu().detach().numpy() + num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy() + assert np.all(coors == expected_coors) + assert np.all(voxels == expected_voxels) + assert np.all(num_points_per_voxel == expected_num_points_per_voxel)