@@ -206,7 +206,7 @@ void weighted_sparse_forward_per_gpu(
206
206
const core23::Tensor &sp_weights_all_gather_recv_buffer, ILookup *emb_storage,
207
207
std::vector<core23::Tensor> &emb_vec_model_buffer, int64_t *num_model_key,
208
208
int64_t *num_model_offsets, core23::Tensor &ret_model_key, core23::Tensor &ret_model_offset,
209
- core23::Tensor &ret_sp_weight) {
209
+ core23::Tensor &ret_sp_weight, bool use_filter ) {
210
210
HugeCTR::CudaDeviceContext context (core->get_device_id ());
211
211
212
212
int tensor_device_id = core->get_device_id ();
@@ -369,14 +369,88 @@ void weighted_sparse_forward_per_gpu(
369
369
*num_model_offsets = model_offsets.num_elements ();
370
370
}
371
371
372
+ template <typename offset_t >
373
+ __global__ void cal_lookup_idx (size_t lookup_num, offset_t *bucket_after_filter, size_t batch_size,
374
+ offset_t *lookup_offset, size_t bucket_num) {
375
+ int32_t i = blockIdx .x * blockDim .x + threadIdx .x ;
376
+ int32_t step = blockDim .x * gridDim .x ;
377
+ for (; i < (lookup_num); i += step) {
378
+ lookup_offset[i] = bucket_after_filter[i * batch_size];
379
+ }
380
+ }
381
+
382
+ template <typename offset_t >
383
+ __global__ void count_ratio_filter (size_t bucket_num, char *filterd, const offset_t *bucket_range,
384
+ offset_t *bucket_after_filter) {
385
+ int32_t i = blockIdx .x * blockDim .x + threadIdx .x ;
386
+ int32_t step = blockDim .x * gridDim .x ;
387
+ for (; i < (bucket_num); i += step) {
388
+ offset_t start = bucket_range[i];
389
+ offset_t end = bucket_range[i + 1 ];
390
+ bucket_after_filter[i + 1 ] = 0 ;
391
+ for (offset_t idx = start; idx < end; idx++) {
392
+ if (filterd[idx] == 1 ) {
393
+ bucket_after_filter[i + 1 ]++;
394
+ }
395
+ }
396
+ if (i == 0 ) {
397
+ bucket_after_filter[i] = 0 ;
398
+ }
399
+ }
400
+ }
401
+
402
+ void filter (std::shared_ptr<CoreResourceManager> core,
403
+ const UniformModelParallelEmbeddingMeta &meta, const core23::Tensor &filterd,
404
+ core23::Tensor &bucket_range, core23::Tensor &bucket_after_filter,
405
+ core23::TensorParams ¶ms, EmbeddingInput &emb_input, core23::Tensor &lookup_offset,
406
+ core23::Tensor &temp_scan_storage, core23::Tensor &temp_select_storage,
407
+ size_t temp_scan_bytes, size_t temp_select_bytes, core23::Tensor &keys_after_filter) {
408
+ auto stream = core->get_local_gpu ()->get_stream ();
409
+ // bucket_range length = bucket_num+1 , so here we minus 1.
410
+ int bucket_num = bucket_range.num_elements () - 1 ;
411
+ const int block_size = 256 ;
412
+ const int grid_size =
413
+ core->get_kernel_param ().num_sms * core->get_kernel_param ().max_thread_per_block / block_size;
414
+
415
+ DISPATCH_INTEGRAL_FUNCTION_CORE23 (bucket_range.data_type ().type (), offset_t , [&] {
416
+ DISPATCH_INTEGRAL_FUNCTION_CORE23 (keys_after_filter.data_type ().type (), key_t , [&] {
417
+ offset_t *bucket_after_filter_ptr = bucket_after_filter.data <offset_t >();
418
+ const offset_t *bucket_range_ptr = bucket_range.data <offset_t >();
419
+ char *filterd_ptr = filterd.data <char >();
420
+ count_ratio_filter<<<grid_size, block_size, 0 , stream>>> (
421
+ bucket_num, filterd_ptr, bucket_range_ptr, bucket_after_filter_ptr);
422
+ cub::DeviceScan::InclusiveSum (
423
+ temp_scan_storage.data (), temp_scan_bytes, bucket_after_filter.data <offset_t >(),
424
+ bucket_after_filter.data <offset_t >(), bucket_after_filter.num_elements (), stream);
425
+
426
+ key_t *keys_ptr = emb_input.keys .data <key_t >();
427
+
428
+ cub::DeviceSelect::Flagged (temp_select_storage.data (), temp_select_bytes, keys_ptr,
429
+ filterd_ptr, keys_after_filter.data <key_t >(),
430
+ emb_input.num_keys .data <uint64_t >(), emb_input.h_num_keys , stream);
431
+
432
+ size_t batch_size = (bucket_num) / meta.num_lookup_ ;
433
+
434
+ cal_lookup_idx<<<1 , block_size, 0 , stream>>> (meta.num_lookup_ + 1 ,
435
+ bucket_after_filter.data <offset_t >(), batch_size,
436
+ lookup_offset.data <offset_t >(), bucket_num);
437
+ HCTR_LIB_THROW (cudaStreamSynchronize (stream));
438
+ emb_input.h_num_keys = static_cast <size_t >(emb_input.num_keys .data <uint64_t >()[0 ]);
439
+ emb_input.keys = keys_after_filter;
440
+ emb_input.bucket_range = bucket_after_filter;
441
+ });
442
+ });
443
+ }
444
+
372
445
void sparse_forward_per_gpu (std::shared_ptr<CoreResourceManager> core,
373
446
const EmbeddingCollectionParam &ebc_param,
374
447
const UniformModelParallelEmbeddingMeta &meta,
375
448
const core23::Tensor &key_all_gather_recv_buffer,
376
449
const core23::Tensor &row_lengths_all_gather_recv_buffer,
377
450
ILookup *emb_storage, std::vector<core23::Tensor> &emb_vec_model_buffer,
378
451
int64_t *num_model_key, int64_t *num_model_offsets,
379
- core23::Tensor *ret_model_key, core23::Tensor *ret_model_offset) {
452
+ core23::Tensor *ret_model_key, core23::Tensor *ret_model_offset,
453
+ bool use_filter) {
380
454
/*
381
455
There are some steps in this function:
382
456
1.reorder key to feature major
@@ -500,8 +574,56 @@ void sparse_forward_per_gpu(std::shared_ptr<CoreResourceManager> core,
500
574
compress_offset_.compute (embedding_input.bucket_range , batch_size, &num_key_per_lookup_offset);
501
575
HCTR_LIB_THROW (cudaStreamSynchronize (stream));
502
576
577
+ if (use_filter) {
578
+ core23::Tensor bucket_range_after_filter;
579
+ core23::Tensor keys_after_filter;
580
+ core23::Tensor filtered;
581
+
582
+ filtered = core23::Tensor (
583
+ params.shape ({(int64_t )embedding_input.h_num_keys }).data_type (core23::ScalarType::Char));
584
+ bucket_range_after_filter =
585
+ core23::Tensor (params.shape ({embedding_input.bucket_range .num_elements ()})
586
+ .data_type (embedding_input.bucket_range .data_type ().type ()));
587
+ keys_after_filter = core23::Tensor (params.shape ({(int64_t )embedding_input.h_num_keys + 1 })
588
+ .data_type (embedding_input.keys .data_type ().type ()));
589
+
590
+ core23::Tensor temp_scan_storage;
591
+ core23::Tensor temp_select_storage;
592
+
593
+ size_t temp_scan_bytes = 0 ;
594
+ size_t temp_select_bytes = 0 ;
595
+
596
+ DISPATCH_INTEGRAL_FUNCTION_CORE23 (
597
+ embedding_input.bucket_range .data_type ().type (), offset_t , [&] {
598
+ DISPATCH_INTEGRAL_FUNCTION_CORE23 (embedding_input.keys .data_type ().type (), key_t , [&] {
599
+ cub::DeviceScan::InclusiveSum (nullptr , temp_scan_bytes, (offset_t *)nullptr ,
600
+ (offset_t *)nullptr ,
601
+ bucket_range_after_filter.num_elements ());
602
+
603
+ temp_scan_storage = core23::Tensor (params.shape ({static_cast <int64_t >(temp_scan_bytes)})
604
+ .data_type (core23::ScalarType::Char));
605
+
606
+ cub::DeviceSelect::Flagged (nullptr , temp_select_bytes, (key_t *)nullptr ,
607
+ (char *)nullptr , (key_t *)nullptr , (uint64_t *)nullptr ,
608
+ embedding_input.h_num_keys );
609
+
610
+ temp_select_storage =
611
+ core23::Tensor (params.shape ({static_cast <int64_t >(temp_select_bytes)})
612
+ .data_type (core23::ScalarType::Char));
613
+ });
614
+ });
615
+
616
+ emb_storage->ratio_filter (embedding_input.keys , embedding_input.h_num_keys ,
617
+ num_key_per_lookup_offset, meta.num_local_lookup_ + 1 ,
618
+ meta.d_local_table_id_list_ , filtered);
619
+
620
+ filter (core, meta, filtered, embedding_input.bucket_range , bucket_range_after_filter, params,
621
+ embedding_input, num_key_per_lookup_offset, temp_scan_storage, temp_select_storage,
622
+ temp_scan_bytes, temp_select_bytes, keys_after_filter);
623
+ }
503
624
core23::Tensor embedding_vec = core23::init_tensor_list<float >(
504
625
key_all_gather_recv_buffer.num_elements (), params.device ().index ());
626
+
505
627
emb_storage->lookup (embedding_input.keys , embedding_input.h_num_keys , num_key_per_lookup_offset,
506
628
meta.num_local_lookup_ + 1 , meta.d_local_table_id_list_ , embedding_vec);
507
629
0 commit comments