Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mmcv support musa, split pr 4 #3260

Merged
merged 3 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions mmcv/ops/csrc/common/musa/points_in_boxes_musa_kernel.muh
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef POINT_IN_BOXES_MUSA_KERNEL_MUH
#define POINT_IN_BOXES_MUSA_KERNEL_MUH

#include "pytorch_musa_helper.hpp"

template <typename T>
__device__ inline void lidar_to_local_coords(T shift_x, T shift_y, T rz,
T &local_x, T &local_y) {
T cosa = cos(-rz), sina = sin(-rz);
local_x = shift_x * cosa + shift_y * (-sina);
local_y = shift_x * sina + shift_y * cosa;
}

template <typename T>
__device__ inline int check_pt_in_box3d(const T *pt, const T *box3d, T &local_x,
T &local_y) {
// param pt: (x, y, z)
// param box3d: (cx, cy, cz, x_size, y_size, z_size, rz) in LiDAR coordinate,
// cz in the bottom center
T x = pt[0], y = pt[1], z = pt[2];
T cx = box3d[0], cy = box3d[1], cz = box3d[2];
T x_size = box3d[3], y_size = box3d[4], z_size = box3d[5], rz = box3d[6];
cz += z_size /
2.0; // shift to the center since cz in box3d is the bottom center

if (fabsf(z - cz) > z_size / 2.0) return 0;
lidar_to_local_coords(x - cx, y - cy, rz, local_x, local_y);
float in_flag = (local_x > -x_size / 2.0) & (local_x < x_size / 2.0) &
(local_y > -y_size / 2.0) & (local_y < y_size / 2.0);
return in_flag;
}

template <typename T>
__global__ void points_in_boxes_part_forward_musa_kernel(
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
// (B, npoints), default -1

int bs_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
if (bs_idx >= batch_size) return;

boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num + pt_idx;

T local_x = 0, local_y = 0;
int cur_in_flag = 0;
for (int k = 0; k < boxes_num; k++) {
cur_in_flag = check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[0] = k;
break;
}
}
}
}

template <typename T>
__global__ void points_in_boxes_all_forward_musa_kernel(
int batch_size, int boxes_num, int pts_num, const T *boxes, const T *pts,
int *box_idx_of_points) {
// params boxes: (B, N, 7) [x, y, z, x_size, y_size, z_size, rz] in LiDAR
// coordinate, z is the bottom center, each box DO NOT overlaps params pts:
// (B, npoints, 3) [x, y, z] in LiDAR coordinate params boxes_idx_of_points:
// (B, npoints), default -1

int bs_idx = blockIdx.y;
MUSA_1D_KERNEL_LOOP(pt_idx, pts_num) {
if (bs_idx >= batch_size) return;

boxes += bs_idx * boxes_num * 7;
pts += bs_idx * pts_num * 3 + pt_idx * 3;
box_idx_of_points += bs_idx * pts_num * boxes_num + pt_idx * boxes_num;

T local_x = 0, local_y = 0;
for (int k = 0; k < boxes_num; k++) {
const int cur_in_flag =
check_pt_in_box3d(pts, boxes + k * 7, local_x, local_y);
if (cur_in_flag) {
box_idx_of_points[k] = 1;
}
}
}
}

#endif // POINT_IN_BOXES_MUSA_KERNEL_MUH
75 changes: 75 additions & 0 deletions mmcv/ops/csrc/common/musa/points_in_polygons_musa_kernel.muh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef POINTS_IN_POLYGONS_MUSA_KERNEL_MUH
#define POINTS_IN_POLYGONS_MUSA_KERNEL_MUH

#include "pytorch_musa_helper.hpp"

struct point {
float x, y;
};

template <typename scalar_t>
__global__ void points_in_polygons_forward_musa_kernel(
const int nthreads, const scalar_t *vertex1, const scalar_t *vertex2,
const int rows, const int cols, scalar_t *inside_flag) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
int row = index / cols;
int col = index % cols;

const scalar_t *offset_vertex1 = vertex1 + row * 2;
const scalar_t *offset_vertex2 = vertex2 + col * 8;

point point_[1];
point polygon[4];

point_[0].x = offset_vertex1[0];
point_[0].y = offset_vertex1[1];

polygon[0].x = offset_vertex2[0];
polygon[0].y = offset_vertex2[1];
polygon[1].x = offset_vertex2[2];
polygon[1].y = offset_vertex2[3];
polygon[2].x = offset_vertex2[4];
polygon[2].y = offset_vertex2[5];
polygon[3].x = offset_vertex2[6];
polygon[3].y = offset_vertex2[7];

int nCross = 0;
int i, j;
float sx, sy, tx, ty, px, py, x;
for (i = 0, j = 3; i < 4; j = i, i++) {
sx = polygon[i].x;
sy = polygon[i].y;
tx = polygon[j].x;
ty = polygon[j].y;

px = point_[0].x;
py = point_[0].y;

if (py < min(sy, ty)) continue;
if (py > max(sy, ty)) continue;

if ((sx == px && sy == py) || (tx == px && ty == py)) {
break;
} else {
if ((sy < py && ty >= py) || (sy >= py && ty < py)) {
x = sx + (py - sy) * (tx - sx) / (ty - sy);
if (x == px) {
break;
}
if (x > px) {
nCross++;
}
}
}
}
if (nCross % 2 == 1) {
inside_flag[index] = 1.0;
} else {
inside_flag[index] = 0.0;
}
return;
}
}

#endif // POINTS_IN_POLYGONS_MUSA_KERNEL_MUH
Loading