Skip to content

Commit 25f2aaa

Browse files
Jeet Kanjanifacebook-github-bot
Jeet Kanjani
authored andcommitted
nested dispatching of segment_csr on cpu/gpu (#3881)
Summary: Pull Request resolved: #3881 X-link: facebookresearch/FBGEMM#972 Updating the segment_sum_csr function (both CPU and GPU) in predictor to accept arbitrary input types (int32_t or int64_t) for the offset. In cases where the offsets overflow 31 bits, casting to int32 will result in negative numbers, causing unintended behavior from the function. Reviewed By: zhaozhul, YazhiGao Differential Revision: D71663741 fbshipit-source-id: 48537b2be23df0e41d3d99faf5d375b37abcabed
1 parent c407f65 commit 25f2aaa

File tree

2 files changed

+49
-40
lines changed

2 files changed

+49
-40
lines changed

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

+20-16
Original file line numberDiff line numberDiff line change
@@ -2288,18 +2288,17 @@ std::tuple<Tensor, Tensor> generic_histogram_binning_calibration_by_feature_cpu(
22882288

22892289
return std::make_tuple(calibrated_prediction, bin_ids);
22902290
}
2291-
2292-
template <typename scalar_t>
2291+
template <typename value_t, typename index_t>
22932292
void _segment_sum_csr_cpu_kernel(
22942293
const int num_segments,
22952294
const int batch_size,
2296-
const int* const csr_seg_data,
2297-
const scalar_t* const values_data,
2298-
scalar_t* const output_data) {
2295+
const index_t* const csr_seg_data,
2296+
const value_t* const values_data,
2297+
value_t* const output_data) {
22992298
for (const auto i : c10::irange(num_segments)) {
2300-
const int seg_start = csr_seg_data[i] * batch_size;
2301-
const int seg_end = csr_seg_data[i + 1] * batch_size;
2302-
scalar_t v = 0;
2299+
const index_t seg_start = csr_seg_data[i] * batch_size;
2300+
const index_t seg_end = csr_seg_data[i + 1] * batch_size;
2301+
value_t v = 0;
23032302
for (const auto j : c10::irange(seg_start, seg_end)) {
23042303
v += values_data[j];
23052304
}
@@ -2315,14 +2314,19 @@ Tensor segment_sum_csr_cpu(
23152314
TENSOR_ON_CPU(values);
23162315

23172316
auto output = at::empty(csr_seg.numel() - 1, values.options());
2318-
FBGEMM_DISPATCH_ALL_TYPES(values.scalar_type(), "_segment_sum_csr_cpu", [&] {
2319-
_segment_sum_csr_cpu_kernel<scalar_t>(
2320-
csr_seg.numel() - 1,
2321-
batch_size,
2322-
csr_seg.data_ptr<int>(),
2323-
values.data_ptr<scalar_t>(),
2324-
output.data_ptr<scalar_t>());
2325-
});
2317+
FBGEMM_DISPATCH_ALL_TYPES(
2318+
values.scalar_type(), "_segment_sum_csr_cpu_1", [&] {
2319+
using value_t = scalar_t;
2320+
AT_DISPATCH_INDEX_TYPES(
2321+
csr_seg.scalar_type(), "_segment_sum_csr_cpu_2", [&] {
2322+
_segment_sum_csr_cpu_kernel<value_t, index_t>(
2323+
csr_seg.numel() - 1,
2324+
batch_size,
2325+
csr_seg.data_ptr<index_t>(),
2326+
values.data_ptr<value_t>(),
2327+
output.data_ptr<value_t>());
2328+
});
2329+
});
23262330
return output;
23272331
}
23282332

fbgemm_gpu/src/sparse_ops/sparse_segment_sum_csr.cu

+29-24
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,27 @@ namespace fbgemm_gpu {
1414

1515
// Kernel for calculating the segmented sum for sparse matrix with CSR format.
1616
// See https://moderngpu.github.io/segreduce.html
17-
template <typename scalar_t>
17+
template <typename values_t, typename index_t>
1818
__global__ __launch_bounds__(kMaxThreads) void _segment_sum_csr_cuda_kernel(
1919
int num_segments,
2020
int batch_size,
21-
const int* csr_seg_data,
22-
const scalar_t* values_data,
23-
scalar_t* output_data) {
24-
typedef FBGEMM_GPU_CUB_NS_PREFIX cub::BlockReduce<scalar_t, 256> BlockReduce;
21+
const index_t* csr_seg_data,
22+
const values_t* values_data,
23+
values_t* output_data) {
24+
typedef FBGEMM_GPU_CUB_NS_PREFIX cub::BlockReduce<values_t, 256> BlockReduce;
2525

2626
__shared__ typename BlockReduce::TempStorage temp_storage;
27-
int seg_start = csr_seg_data[blockIdx.x] * batch_size;
28-
int seg_end = csr_seg_data[blockIdx.x + 1] * batch_size;
29-
scalar_t sum = 0;
27+
index_t seg_start = csr_seg_data[blockIdx.x] * batch_size;
28+
index_t seg_end = csr_seg_data[blockIdx.x + 1] * batch_size;
29+
values_t sum = 0;
3030

31-
for (auto i = seg_start; i < seg_end; i += blockDim.x) {
32-
scalar_t thread_data;
31+
for (index_t i = seg_start; i < seg_end; i += blockDim.x) {
32+
values_t thread_data;
3333
if (threadIdx.x < seg_end - i) {
3434
thread_data = values_data[i + threadIdx.x];
3535
}
3636

37-
scalar_t aggregate =
37+
values_t aggregate =
3838
BlockReduce(temp_storage).Sum(thread_data, seg_end - i);
3939

4040
__syncthreads();
@@ -68,19 +68,24 @@ DLL_PUBLIC Tensor segment_sum_csr_cuda(
6868
constexpr uint32_t threads_per_block = 256;
6969
const uint32_t num_blocks = csr_seg.numel() - 1;
7070

71-
FBGEMM_DISPATCH_ALL_TYPES(values.scalar_type(), "_segment_sum_csr_cuda", [&] {
72-
_segment_sum_csr_cuda_kernel<scalar_t>
73-
<<<num_blocks,
74-
threads_per_block,
75-
0,
76-
at::cuda::getCurrentCUDAStream()>>>(
77-
csr_seg.numel() - 1,
78-
batch_size,
79-
csr_seg.data_ptr<int>(),
80-
values.data_ptr<scalar_t>(),
81-
output.data_ptr<scalar_t>());
82-
C10_CUDA_KERNEL_LAUNCH_CHECK();
83-
});
71+
FBGEMM_DISPATCH_ALL_TYPES(
72+
values.scalar_type(), "_segment_sum_csr_cuda_1", [&] {
73+
using values_t = scalar_t;
74+
AT_DISPATCH_INDEX_TYPES(
75+
csr_seg.scalar_type(), "_segment_sum_csr_cuda_2", [&] {
76+
_segment_sum_csr_cuda_kernel<values_t, index_t>
77+
<<<num_blocks,
78+
threads_per_block,
79+
0,
80+
at::cuda::getCurrentCUDAStream()>>>(
81+
csr_seg.numel() - 1,
82+
batch_size,
83+
csr_seg.data_ptr<index_t>(),
84+
values.data_ptr<values_t>(),
85+
output.data_ptr<values_t>());
86+
C10_CUDA_KERNEL_LAUNCH_CHECK();
87+
});
88+
});
8489
8590
return output;
8691
}

0 commit comments

Comments
 (0)