Skip to content

Commit d81ac0e

Browse files
authored
Merge pull request #455 from ccccjunkang/main
low frequency filter
2 parents 2e8e855 + 0182e7c commit d81ac0e

15 files changed

+435
-17
lines changed

HugeCTR/embedding/all2all_embedding_collection.cu

+124-2
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ void weighted_sparse_forward_per_gpu(
206206
const core23::Tensor &sp_weights_all_gather_recv_buffer, ILookup *emb_storage,
207207
std::vector<core23::Tensor> &emb_vec_model_buffer, int64_t *num_model_key,
208208
int64_t *num_model_offsets, core23::Tensor &ret_model_key, core23::Tensor &ret_model_offset,
209-
core23::Tensor &ret_sp_weight) {
209+
core23::Tensor &ret_sp_weight, bool use_filter) {
210210
HugeCTR::CudaDeviceContext context(core->get_device_id());
211211

212212
int tensor_device_id = core->get_device_id();
@@ -369,14 +369,88 @@ void weighted_sparse_forward_per_gpu(
369369
*num_model_offsets = model_offsets.num_elements();
370370
}
371371

372+
template <typename offset_t>
373+
__global__ void cal_lookup_idx(size_t lookup_num, offset_t *bucket_after_filter, size_t batch_size,
374+
offset_t *lookup_offset, size_t bucket_num) {
375+
int32_t i = blockIdx.x * blockDim.x + threadIdx.x;
376+
int32_t step = blockDim.x * gridDim.x;
377+
for (; i < (lookup_num); i += step) {
378+
lookup_offset[i] = bucket_after_filter[i * batch_size];
379+
}
380+
}
381+
382+
template <typename offset_t>
383+
__global__ void count_ratio_filter(size_t bucket_num, char *filterd, const offset_t *bucket_range,
384+
offset_t *bucket_after_filter) {
385+
int32_t i = blockIdx.x * blockDim.x + threadIdx.x;
386+
int32_t step = blockDim.x * gridDim.x;
387+
for (; i < (bucket_num); i += step) {
388+
offset_t start = bucket_range[i];
389+
offset_t end = bucket_range[i + 1];
390+
bucket_after_filter[i + 1] = 0;
391+
for (offset_t idx = start; idx < end; idx++) {
392+
if (filterd[idx] == 1) {
393+
bucket_after_filter[i + 1]++;
394+
}
395+
}
396+
if (i == 0) {
397+
bucket_after_filter[i] = 0;
398+
}
399+
}
400+
}
401+
402+
void filter(std::shared_ptr<CoreResourceManager> core,
403+
const UniformModelParallelEmbeddingMeta &meta, const core23::Tensor &filterd,
404+
core23::Tensor &bucket_range, core23::Tensor &bucket_after_filter,
405+
core23::TensorParams &params, EmbeddingInput &emb_input, core23::Tensor &lookup_offset,
406+
core23::Tensor &temp_scan_storage, core23::Tensor &temp_select_storage,
407+
size_t temp_scan_bytes, size_t temp_select_bytes, core23::Tensor &keys_after_filter) {
408+
auto stream = core->get_local_gpu()->get_stream();
409+
// bucket_range length = bucket_num+1 , so here we minus 1.
410+
int bucket_num = bucket_range.num_elements() - 1;
411+
const int block_size = 256;
412+
const int grid_size =
413+
core->get_kernel_param().num_sms * core->get_kernel_param().max_thread_per_block / block_size;
414+
415+
DISPATCH_INTEGRAL_FUNCTION_CORE23(bucket_range.data_type().type(), offset_t, [&] {
416+
DISPATCH_INTEGRAL_FUNCTION_CORE23(keys_after_filter.data_type().type(), key_t, [&] {
417+
offset_t *bucket_after_filter_ptr = bucket_after_filter.data<offset_t>();
418+
const offset_t *bucket_range_ptr = bucket_range.data<offset_t>();
419+
char *filterd_ptr = filterd.data<char>();
420+
count_ratio_filter<<<grid_size, block_size, 0, stream>>>(
421+
bucket_num, filterd_ptr, bucket_range_ptr, bucket_after_filter_ptr);
422+
cub::DeviceScan::InclusiveSum(
423+
temp_scan_storage.data(), temp_scan_bytes, bucket_after_filter.data<offset_t>(),
424+
bucket_after_filter.data<offset_t>(), bucket_after_filter.num_elements(), stream);
425+
426+
key_t *keys_ptr = emb_input.keys.data<key_t>();
427+
428+
cub::DeviceSelect::Flagged(temp_select_storage.data(), temp_select_bytes, keys_ptr,
429+
filterd_ptr, keys_after_filter.data<key_t>(),
430+
emb_input.num_keys.data<uint64_t>(), emb_input.h_num_keys, stream);
431+
432+
size_t batch_size = (bucket_num) / meta.num_lookup_;
433+
434+
cal_lookup_idx<<<1, block_size, 0, stream>>>(meta.num_lookup_ + 1,
435+
bucket_after_filter.data<offset_t>(), batch_size,
436+
lookup_offset.data<offset_t>(), bucket_num);
437+
HCTR_LIB_THROW(cudaStreamSynchronize(stream));
438+
emb_input.h_num_keys = static_cast<size_t>(emb_input.num_keys.data<uint64_t>()[0]);
439+
emb_input.keys = keys_after_filter;
440+
emb_input.bucket_range = bucket_after_filter;
441+
});
442+
});
443+
}
444+
372445
void sparse_forward_per_gpu(std::shared_ptr<CoreResourceManager> core,
373446
const EmbeddingCollectionParam &ebc_param,
374447
const UniformModelParallelEmbeddingMeta &meta,
375448
const core23::Tensor &key_all_gather_recv_buffer,
376449
const core23::Tensor &row_lengths_all_gather_recv_buffer,
377450
ILookup *emb_storage, std::vector<core23::Tensor> &emb_vec_model_buffer,
378451
int64_t *num_model_key, int64_t *num_model_offsets,
379-
core23::Tensor *ret_model_key, core23::Tensor *ret_model_offset) {
452+
core23::Tensor *ret_model_key, core23::Tensor *ret_model_offset,
453+
bool use_filter) {
380454
/*
381455
There are some steps in this function:
382456
1.reorder key to feature major
@@ -500,8 +574,56 @@ void sparse_forward_per_gpu(std::shared_ptr<CoreResourceManager> core,
500574
compress_offset_.compute(embedding_input.bucket_range, batch_size, &num_key_per_lookup_offset);
501575
HCTR_LIB_THROW(cudaStreamSynchronize(stream));
502576

577+
if (use_filter) {
578+
core23::Tensor bucket_range_after_filter;
579+
core23::Tensor keys_after_filter;
580+
core23::Tensor filtered;
581+
582+
filtered = core23::Tensor(
583+
params.shape({(int64_t)embedding_input.h_num_keys}).data_type(core23::ScalarType::Char));
584+
bucket_range_after_filter =
585+
core23::Tensor(params.shape({embedding_input.bucket_range.num_elements()})
586+
.data_type(embedding_input.bucket_range.data_type().type()));
587+
keys_after_filter = core23::Tensor(params.shape({(int64_t)embedding_input.h_num_keys + 1})
588+
.data_type(embedding_input.keys.data_type().type()));
589+
590+
core23::Tensor temp_scan_storage;
591+
core23::Tensor temp_select_storage;
592+
593+
size_t temp_scan_bytes = 0;
594+
size_t temp_select_bytes = 0;
595+
596+
DISPATCH_INTEGRAL_FUNCTION_CORE23(
597+
embedding_input.bucket_range.data_type().type(), offset_t, [&] {
598+
DISPATCH_INTEGRAL_FUNCTION_CORE23(embedding_input.keys.data_type().type(), key_t, [&] {
599+
cub::DeviceScan::InclusiveSum(nullptr, temp_scan_bytes, (offset_t *)nullptr,
600+
(offset_t *)nullptr,
601+
bucket_range_after_filter.num_elements());
602+
603+
temp_scan_storage = core23::Tensor(params.shape({static_cast<int64_t>(temp_scan_bytes)})
604+
.data_type(core23::ScalarType::Char));
605+
606+
cub::DeviceSelect::Flagged(nullptr, temp_select_bytes, (key_t *)nullptr,
607+
(char *)nullptr, (key_t *)nullptr, (uint64_t *)nullptr,
608+
embedding_input.h_num_keys);
609+
610+
temp_select_storage =
611+
core23::Tensor(params.shape({static_cast<int64_t>(temp_select_bytes)})
612+
.data_type(core23::ScalarType::Char));
613+
});
614+
});
615+
616+
emb_storage->ratio_filter(embedding_input.keys, embedding_input.h_num_keys,
617+
num_key_per_lookup_offset, meta.num_local_lookup_ + 1,
618+
meta.d_local_table_id_list_, filtered);
619+
620+
filter(core, meta, filtered, embedding_input.bucket_range, bucket_range_after_filter, params,
621+
embedding_input, num_key_per_lookup_offset, temp_scan_storage, temp_select_storage,
622+
temp_scan_bytes, temp_select_bytes, keys_after_filter);
623+
}
503624
core23::Tensor embedding_vec = core23::init_tensor_list<float>(
504625
key_all_gather_recv_buffer.num_elements(), params.device().index());
626+
505627
emb_storage->lookup(embedding_input.keys, embedding_input.h_num_keys, num_key_per_lookup_offset,
506628
meta.num_local_lookup_ + 1, meta.d_local_table_id_list_, embedding_vec);
507629

HugeCTR/embedding/all2all_embedding_collection.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ void weighted_sparse_forward_per_gpu(
5757
const core23::Tensor &sp_weights_all_gather_recv_buffer, ILookup *emb_storage,
5858
std::vector<core23::Tensor> &emb_vec_model_buffer, int64_t *num_model_key,
5959
int64_t *num_model_offsets, core23::Tensor &ret_model_key, core23::Tensor &ret_model_offset,
60-
core23::Tensor &ret_sp_weight);
60+
core23::Tensor &ret_sp_weight, bool use_filter);
6161

6262
void weighted_copy_model_keys_and_offsets(
6363
std::shared_ptr<CoreResourceManager> core, const core23::Tensor &model_key,
@@ -71,7 +71,8 @@ void sparse_forward_per_gpu(std::shared_ptr<CoreResourceManager> core,
7171
const core23::Tensor &row_lengths_all_gather_recv_buffer,
7272
ILookup *emb_storage, std::vector<core23::Tensor> &emb_vec_model_buffer,
7373
int64_t *num_model_key, int64_t *num_model_offsets,
74-
core23::Tensor *ret_model_key, core23::Tensor *ret_model_offset);
74+
core23::Tensor *ret_model_key, core23::Tensor *ret_model_offset,
75+
bool use_filter);
7576

7677
void copy_model_keys_and_offsets(std::shared_ptr<CoreResourceManager> core,
7778
const core23::Tensor &model_key,

HugeCTR/embedding/embedding_table.hpp

+7-3
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@ class ILookup {
2323
public:
2424
virtual ~ILookup() = default;
2525

26-
virtual void lookup(const core23::Tensor &keys, size_t num_keys,
27-
const core23::Tensor &num_keys_per_table_offset, size_t num_table_offset,
28-
const core23::Tensor &table_id_list, core23::Tensor &embedding_vec) = 0;
26+
virtual void lookup(const core23::Tensor& keys, size_t num_keys,
27+
const core23::Tensor& num_keys_per_table_offset, size_t num_table_offset,
28+
const core23::Tensor& table_id_list, core23::Tensor& embedding_vec) = 0;
29+
30+
virtual void ratio_filter(const core23::Tensor& keys, size_t num_keys,
31+
const core23::Tensor& id_space_offset, size_t num_id_space_offset,
32+
const core23::Tensor& id_space, core23::Tensor& filtered){};
2933
};
3034

3135
} // namespace embedding

sparse_operation_kit/kit_src/lookup/impl/embedding_collection_adapter.cu

+50
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,56 @@ void DummyVarAdapter<KeyType, OffsetType, DType>::lookup(
380380
}
381381
}
382382

383+
template <typename KeyType, typename OffsetType, typename DType>
384+
void DummyVarAdapter<KeyType, OffsetType, DType>::ratio_filter(
385+
const core23::Tensor& keys, size_t num_keys, const core23::Tensor& id_space_offset,
386+
size_t num_id_space_offset, const core23::Tensor& id_space, core23::Tensor& filtered) {
387+
// clang-format off
388+
id_space_offset_.clear();
389+
id_space_.clear();
390+
id_space_offset_.resize(num_id_space_offset);
391+
CUDACHECK(cudaMemcpyAsync(id_space_offset_.data(),
392+
id_space_offset.data<OffsetType>(),
393+
sizeof(OffsetType) * (num_id_space_offset),
394+
cudaMemcpyDeviceToHost, stream_));
395+
id_space_.resize(num_id_space_offset - 1);
396+
CUDACHECK(cudaMemcpyAsync(id_space_.data(),
397+
id_space.data<int>(),
398+
sizeof(int) * (num_id_space_offset - 1),
399+
cudaMemcpyDeviceToHost, stream_));
400+
// clang-format on
401+
CUDACHECK(cudaStreamSynchronize(stream_));
402+
const KeyType* input = keys.data<KeyType>();
403+
bool* output_filtered = filtered.data<bool>();
404+
int start_index = 0;
405+
size_t num = 0;
406+
bool is_lookup = false;
407+
408+
for (int i = 0; i < num_id_space_offset - 1; ++i) {
409+
if (i == num_id_space_offset - 2) {
410+
num += id_space_offset_[i + 1] - id_space_offset_[i];
411+
is_lookup = true;
412+
} else {
413+
if (same_table_[i + 1] != same_table_[i]) {
414+
num += id_space_offset_[i + 1] - id_space_offset_[i];
415+
is_lookup = true;
416+
} else {
417+
num += id_space_offset_[i + 1] - id_space_offset_[i];
418+
}
419+
}
420+
if (num != 0 && is_lookup) {
421+
auto var = vars_[id_space_[start_index]];
422+
var->ratio_filter(input, output_filtered, num, stream_);
423+
CUDACHECK(cudaStreamSynchronize(stream_));
424+
input += num;
425+
output_filtered += num;
426+
num = 0;
427+
is_lookup = false;
428+
start_index = i + 1;
429+
}
430+
}
431+
}
432+
383433
template class DummyVarAdapter<int32_t, int32_t, float>;
384434
template class DummyVarAdapter<int32_t, int64_t, float>;
385435
// template class DummyVarAdapter<int32_t, __half>;

sparse_operation_kit/kit_src/lookup/impl/embedding_collection_adapter.h

+4
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ class DummyVarAdapter : public ::embedding::ILookup {
8888
size_t num_id_space_offset, const core23::Tensor& id_space,
8989
core23::Tensor& embedding_vec) override;
9090

91+
void ratio_filter(const core23::Tensor& keys, size_t num_keys,
92+
const core23::Tensor& id_space_offset, size_t num_id_space_offset,
93+
const core23::Tensor& id_space, core23::Tensor& filtered) override;
94+
9195
private:
9296
std::shared_ptr<sok::CoreResourceManager> tf_backend_;
9397
int sm_count_;

sparse_operation_kit/kit_src/lookup/kernels/embedding_collection.cc

+4-2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class EmbeddingCollectionBase : public OpKernel {
7272
int global_gpu_id_;
7373
int num_local_lookups_;
7474
bool use_sp_weight_;
75+
bool use_filter_;
7576
HugeCTR::core23::KernelParams kernel_params_;
7677

7778
std::unique_ptr<sok::EmbeddingCollectionParam> ebc_param_;
@@ -143,6 +144,7 @@ class EmbeddingCollectionBase : public OpKernel {
143144
OP_REQUIRES_OK(ctx, ctx->GetAttr("id_in_local_rank", &id_in_local_rank_));
144145
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_gpus", &num_gpus_));
145146
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_sp_weight", &use_sp_weight_));
147+
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_filter", &use_filter_));
146148

147149
// check rank/num_ranks/id_in_local_rank/num_gpus
148150
OP_REQUIRES(ctx, rank_ >= 0 && rank_ < num_ranks_, errors::InvalidArgument("Invalid rank."));
@@ -477,13 +479,13 @@ class LookupFowardBase : public EmbeddingCollectionBase<KeyType, OffsetType, DTy
477479
tf_backend, *this->meta_, this->global_gpu_id_, key_recv_buffer_tensor,
478480
row_length_recv_buffer_tensor, sp_weight_recv_buffer_tensor, &adapter_,
479481
emb_vec_model_buffer, &num_model_key, &num_model_offsets, ret_model_key, ret_model_offset,
480-
ret_sp_weight);
482+
ret_sp_weight,this->use_filter_);
481483

482484
} else {
483485
::embedding::tf::model_forward::sparse_forward_per_gpu(
484486
tf_backend, *this->ebc_param_, *this->meta_, key_recv_buffer_tensor,
485487
row_length_recv_buffer_tensor, &adapter_, emb_vec_model_buffer, &num_model_key,
486-
&num_model_offsets, &ret_model_key, &ret_model_offset);
488+
&num_model_offsets, &ret_model_key, &ret_model_offset,this->use_filter_);
487489
}
488490

489491
// Prepare model_key & model_offsets

sparse_operation_kit/kit_src/lookup/ops/embedding_collection.cc

+9
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ REGISTER_OP("PreprocessingForward")
5252
.Attr("id_in_local_rank: int")
5353
.Attr("num_gpus: int")
5454
.Attr("use_sp_weight: bool")
55+
.Attr("use_filter: bool")
5556
.Attr("Tindices: {int32, int64} = DT_INT64")
5657
.Attr("Toffsets: {int32, int64} = DT_INT64")
5758
.Attr("dtype: {float32, float16} = DT_FLOAT")
@@ -80,6 +81,7 @@ REGISTER_OP("PreprocessingForwardWithWeight")
8081
.Attr("id_in_local_rank: int")
8182
.Attr("num_gpus: int")
8283
.Attr("use_sp_weight: bool")
84+
.Attr("use_filter: bool")
8385
.Attr("Tindices: {int32, int64} = DT_INT64")
8486
.Attr("Toffsets: {int32, int64} = DT_INT64")
8587
.Attr("dtype: {float32, float16} = DT_FLOAT")
@@ -112,6 +114,7 @@ REGISTER_OP("LookupForward")
112114
.Attr("id_in_local_rank: int")
113115
.Attr("num_gpus: int")
114116
.Attr("use_sp_weight: bool")
117+
.Attr("use_filter: bool")
115118
.Attr("Tindices: {int32, int64} = DT_INT64")
116119
.Attr("Toffsets: {int32, int64} = DT_INT64")
117120
.Attr("dtype: {float32, float16} = DT_FLOAT")
@@ -165,6 +168,7 @@ REGISTER_OP("LookupForwardVariable")
165168
.Attr("id_in_local_rank: int")
166169
.Attr("num_gpus: int")
167170
.Attr("use_sp_weight: bool")
171+
.Attr("use_filter: bool")
168172
.Attr("Tindices: {int32, int64} = DT_INT64")
169173
.Attr("Toffsets: {int32, int64} = DT_INT64")
170174
.Attr("dtype: {float32, float16} = DT_FLOAT")
@@ -218,6 +222,7 @@ REGISTER_OP("LookupForwardDynamic")
218222
.Attr("id_in_local_rank: int")
219223
.Attr("num_gpus: int")
220224
.Attr("use_sp_weight: bool")
225+
.Attr("use_filter: bool")
221226
.Attr("Tindices: {int32, int64} = DT_INT64")
222227
.Attr("Toffsets: {int32, int64} = DT_INT64")
223228
.Attr("dtype: {float32, float16} = DT_FLOAT")
@@ -273,6 +278,7 @@ REGISTER_OP("LookupForwardEmbeddingVarGPU")
273278
.Attr("id_in_local_rank: int")
274279
.Attr("num_gpus: int")
275280
.Attr("use_sp_weight: bool")
281+
.Attr("use_filter: bool")
276282
.Attr("Tindices: {int32, int64} = DT_INT64")
277283
.Attr("Toffsets: {int32, int64} = DT_INT64")
278284
.Attr("dtype: {float32, float16} = DT_FLOAT")
@@ -324,6 +330,7 @@ REGISTER_OP("LookupBackward")
324330
.Attr("id_in_local_rank: int")
325331
.Attr("num_gpus: int")
326332
.Attr("use_sp_weight: bool")
333+
.Attr("use_filter: bool")
327334
.Attr("Tindices: {int32, int64} = DT_INT64")
328335
.Attr("Toffsets: {int32, int64} = DT_INT64")
329336
.Attr("dtype: {float32, float16} = DT_FLOAT")
@@ -362,6 +369,7 @@ REGISTER_OP("PostprocessingForward")
362369
.Attr("id_in_local_rank: int")
363370
.Attr("num_gpus: int")
364371
.Attr("use_sp_weight: bool")
372+
.Attr("use_filter: bool")
365373
.Attr("Tindices: {int32, int64} = DT_INT64")
366374
.Attr("Toffsets: {int32, int64} = DT_INT64")
367375
.Attr("dtype: {float32, float16} = DT_FLOAT")
@@ -403,6 +411,7 @@ REGISTER_OP("PostprocessingBackward")
403411
.Attr("id_in_local_rank: int")
404412
.Attr("num_gpus: int")
405413
.Attr("use_sp_weight: bool")
414+
.Attr("use_filter: bool")
406415
.Attr("Tindices: {int32, int64} = DT_INT64")
407416
.Attr("Toffsets: {int32, int64} = DT_INT64")
408417
.Attr("dtype: {float32, float16} = DT_FLOAT")

sparse_operation_kit/kit_src/variable/impl/det_variable.cu

+6
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ void DETVariable<KeyType, ValueType>::scatter_update(const KeyType* keys, const
248248
map_->scatter_update(keys, values, num_keys, stream);
249249
}
250250

251+
template <typename KeyType, typename ValueType>
252+
void DETVariable<KeyType, ValueType>::ratio_filter(const KeyType* keys, bool* filtered,
253+
size_t num_keys, cudaStream_t stream) {
254+
throw std::runtime_error("SOK dynamic variable with DET backend don't support ratio_filter!");
255+
}
256+
251257
template class DETVariable<int32_t, float>;
252258
template class DETVariable<int64_t, float>;
253259

sparse_operation_kit/kit_src/variable/impl/det_variable.h

+3
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@ class DETVariable : public VariableBase<KeyType, ValueType> {
5656
cudaStream_t stream = 0) override;
5757
void scatter_update(const KeyType *keys, const ValueType *values, size_t num_keys,
5858
cudaStream_t stream = 0) override;
59+
void ratio_filter(const KeyType *keys, bool *filtered, size_t num_keys,
60+
cudaStream_t stream = 0) override;
5961

6062
private:
6163
std::unique_ptr<cuco::dynamic_map<KeyType, ValueType, cuco::initializer>> map_;
6264

65+
float filter_ratio_;
6366
size_t dimension_;
6467
size_t initial_capacity_;
6568
std::string initializer_;

0 commit comments

Comments
 (0)