Skip to content

Commit e5f021f

Browse files
committed
Merge branch 'huikang/incremental_dump' into 'main'
sok incremental dump See merge request dl/hugectr/hugectr!1524
2 parents af9c40c + bb71a4f commit e5f021f

18 files changed

+934
-10
lines changed

sparse_operation_kit/kit_src/variable/impl/det_variable.cu

+6
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ void DETVariable<KeyType, ValueType>::eXport(KeyType* keys, ValueType* values,
171171
CUDACHECK(cudaFree(d_values));
172172
}
173173

174+
template <typename KeyType, typename ValueType>
175+
void DETVariable<KeyType, ValueType>::eXport_if(KeyType* keys, ValueType* values, size_t* counter,
176+
uint64_t threshold, cudaStream_t stream) {
177+
throw std::runtime_error("SOK dynamic variable with DET backend don't support eXport_if");
178+
}
179+
174180
template <typename KeyType, typename ValueType>
175181
void DETVariable<KeyType, ValueType>::assign(const KeyType* keys, const ValueType* values,
176182
size_t num_keys, cudaStream_t stream) {

sparse_operation_kit/kit_src/variable/impl/det_variable.h

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class DETVariable : public VariableBase<KeyType, ValueType> {
3636
int64_t cols() override;
3737

3838
void eXport(KeyType *keys, ValueType *values, cudaStream_t stream = 0) override;
39+
void eXport_if(KeyType *keys, ValueType *values, size_t *counter, uint64_t threshold,
40+
cudaStream_t stream = 0) override;
3941
void assign(const KeyType *keys, const ValueType *values, size_t num_keys,
4042
cudaStream_t stream = 0) override;
4143

sparse_operation_kit/kit_src/variable/impl/hkv_variable.cu

+42-8
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ __global__ void generate_normal_kernel(curandState* state, T** result, bool* d_f
137137
}
138138
}
139139

140+
template <class K, class S>
141+
struct ExportIfPredFunctor {
142+
__forceinline__ __device__ bool operator()(const K& key, S& score, const K& pattern,
143+
const S& threshold) {
144+
return score > threshold;
145+
}
146+
};
147+
140148
static void set_curand_states(curandState** states, cudaStream_t stream = 0) {
141149
int device;
142150
CUDACHECK(cudaGetDevice(&device));
@@ -197,7 +205,6 @@ HKVVariable<KeyType, ValueType>::HKVVariable(int64_t dimension, int64_t initial_
197205
nv::merlin::EvictStrategy hkv_evict_strategy;
198206
parse_evict_strategy(evict_strategy, hkv_evict_strategy);
199207
hkv_table_option_.evict_strategy = hkv_evict_strategy;
200-
201208
hkv_table_->init(hkv_table_option_);
202209
}
203210

@@ -230,22 +237,49 @@ void HKVVariable<KeyType, ValueType>::eXport(KeyType* keys, ValueType* values,
230237
ValueType* d_values;
231238
CUDACHECK(cudaMallocManaged(&d_values, sizeof(ValueType) * num_keys * dim));
232239

233-
// KeyType* d_keys;
234-
// CUDACHECK(cudaMalloc(&d_keys, sizeof(KeyType) * num_keys));
235-
// ValueType* d_values;
236-
// CUDACHECK(cudaMalloc(&d_values, sizeof(ValueType) * num_keys * dim));
237240
hkv_table_->export_batch(hkv_table_option_.max_capacity, 0, d_keys, d_values, nullptr,
238241
stream); // Meta missing
239242
CUDACHECK(cudaStreamSynchronize(stream));
240243

241-
// clang-format off
242244
std::memcpy(keys, d_keys, sizeof(KeyType) * num_keys);
243245
std::memcpy(values, d_values, sizeof(ValueType) * num_keys * dim);
244-
//CUDACHECK(cudaMemcpy(keys, d_keys, sizeof(KeyType) * num_keys,cudaMemcpyDeviceToHost));
245-
//CUDACHECK(cudaMemcpy(values, d_values, sizeof(ValueType) * num_keys * dim,cudaMemcpyDeviceToHost));
246+
CUDACHECK(cudaFree(d_keys));
247+
CUDACHECK(cudaFree(d_values));
248+
}
249+
250+
template <typename KeyType, typename ValueType>
251+
void HKVVariable<KeyType, ValueType>::eXport_if(KeyType* keys, ValueType* values, size_t* counter,
252+
uint64_t threshold, cudaStream_t stream) {
253+
int64_t num_keys = rows();
254+
int64_t dim = cols();
255+
256+
// `keys` and `values` are pointers of host memory
257+
KeyType* d_keys;
258+
CUDACHECK(cudaMallocManaged(&d_keys, sizeof(KeyType) * num_keys));
259+
ValueType* d_values;
260+
CUDACHECK(cudaMallocManaged(&d_values, sizeof(ValueType) * num_keys * dim));
261+
262+
uint64_t* d_socre_type;
263+
CUDACHECK(cudaMallocManaged(&d_socre_type, sizeof(uint64_t) * num_keys));
264+
265+
uint64_t* d_dump_counter;
266+
CUDACHECK(cudaMallocManaged(&d_dump_counter, sizeof(uint64_t)));
267+
// useless HKV need a input , but do nothing in the ExportIfPredFunctor
268+
KeyType pattern = 100;
269+
270+
hkv_table_->template export_batch_if<ExportIfPredFunctor>(
271+
pattern, threshold, hkv_table_->capacity(), 0, d_dump_counter, d_keys, d_values, d_socre_type,
272+
stream);
273+
CUDACHECK(cudaStreamSynchronize(stream));
274+
// clang-format off
275+
std::memcpy(keys, d_keys, sizeof(KeyType) * (*d_dump_counter));
276+
std::memcpy(values, d_values, sizeof(ValueType) * (*d_dump_counter) * dim);
277+
counter[0] = (size_t)(*d_dump_counter);
246278
// clang-format on
247279
CUDACHECK(cudaFree(d_keys));
248280
CUDACHECK(cudaFree(d_values));
281+
CUDACHECK(cudaFree(d_socre_type));
282+
CUDACHECK(cudaFree(d_dump_counter));
249283
}
250284

251285
template <typename KeyType, typename ValueType>

sparse_operation_kit/kit_src/variable/impl/hkv_variable.h

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ class HKVVariable : public VariableBase<KeyType, ValueType> {
4040
int64_t cols() override;
4141

4242
void eXport(KeyType *keys, ValueType *values, cudaStream_t stream = 0) override;
43+
void eXport_if(KeyType *keys, ValueType *values, size_t *counter, uint64_t threshold,
44+
cudaStream_t stream = 0) override;
4345
void assign(const KeyType *keys, const ValueType *values, size_t num_keys,
4446
cudaStream_t stream = 0) override;
4547

sparse_operation_kit/kit_src/variable/impl/variable_base.h

+4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class VariableBase {
3232
virtual int64_t cols() = 0;
3333

3434
virtual void eXport(KeyType *keys, ValueType *values, cudaStream_t stream = 0) = 0;
35+
36+
virtual void eXport_if(KeyType *keys, ValueType *values, size_t *counter, uint64_t threshold,
37+
cudaStream_t stream = 0) = 0;
38+
3539
virtual void assign(const KeyType *keys, const ValueType *values, size_t num_keys,
3640
cudaStream_t stream = 0) = 0;
3741

sparse_operation_kit/kit_src/variable/kernels/dummy_var.cc

+6
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ void DummyVar<KeyType, ValueType>::Export(void* keys, void* values, cudaStream_t
6363
var_->eXport(static_cast<KeyType*>(keys), static_cast<ValueType*>(values), stream);
6464
}
6565

66+
template <typename KeyType, typename ValueType>
67+
void DummyVar<KeyType, ValueType>::ExportIf(void* keys, void* values,size_t* counter,uint64_t threshold, cudaStream_t stream) {
68+
check_var();
69+
var_->eXport_if(static_cast<KeyType*>(keys), static_cast<ValueType*>(values),counter,threshold, stream);
70+
}
71+
6672
template <typename KeyType, typename ValueType>
6773
void DummyVar<KeyType, ValueType>::Assign(const void* keys, const void* values, size_t num_keys,
6874
cudaStream_t stream) {

sparse_operation_kit/kit_src/variable/kernels/dummy_var.h

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class DummyVar : public ResourceBase {
5353
int64_t cols();
5454

5555
void Export(void *keys, void *values, cudaStream_t stream);
56+
void ExportIf(void *keys, void *values, size_t *counter, uint64_t threshold, cudaStream_t stream);
5657
void Assign(const void *keys, const void *values, size_t num_keys, cudaStream_t stream);
5758

5859
void SparseRead(const void *keys, void *values, size_t num_keys, cudaStream_t stream);

sparse_operation_kit/kit_src/variable/kernels/dummy_var_ops.cc

+69
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,75 @@ REGISTER_GPU_KERNELS(int32_t, int32_t, float, float);
132132
#endif
133133
#undef REGISTER_GPU_KERNELS
134134

135+
// -----------------------------------------------------------------------------------------------
136+
// DummyVarExportIf
137+
// -----------------------------------------------------------------------------------------------
138+
template <typename KeyType, typename ValueType>
139+
class DummyVarExportIfOp : public OpKernel {
140+
public:
141+
explicit DummyVarExportIfOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
142+
143+
void Compute(OpKernelContext* ctx) override {
144+
// Get DummyVar
145+
core::RefCountPtr<DummyVar<KeyType, ValueType>> var;
146+
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
147+
148+
tf_shared_lock ml(*var->mu());
149+
150+
// Get shape
151+
int64_t rows = var->rows();
152+
int64_t cols = var->cols();
153+
154+
const Tensor* threshold_tensor = nullptr;
155+
OP_REQUIRES_OK(ctx, ctx->input("threshold", &threshold_tensor));
156+
157+
AllocatorAttributes alloc_attr;
158+
alloc_attr.set_on_host(true);
159+
// temp buffer
160+
Tensor tmp_indices;
161+
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_INT64, {rows}, &tmp_indices, alloc_attr));
162+
163+
Tensor tmp_values;
164+
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ValueType>::v(), {rows*cols}, &tmp_values, alloc_attr));
165+
166+
Tensor counter;
167+
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_UINT64, {1}, &counter, alloc_attr));
168+
// Get cuda stream of tensorflow
169+
auto device_ctx = ctx->op_device_context();
170+
OP_REQUIRES(ctx, device_ctx != nullptr, errors::Aborted("No valid device context."));
171+
cudaStream_t stream = stream_executor::gpu::AsGpuStreamValue(device_ctx->stream());
172+
var->ExportIf(tmp_indices.data(), tmp_values.data(),(size_t*)counter.data(),((uint64_t*)threshold_tensor->data())[0], stream);
173+
// Allocate output
174+
Tensor* indices = nullptr;
175+
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {((size_t*)counter.data())[0]}, &indices));
176+
Tensor* values = nullptr;
177+
OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {((size_t*)counter.data())[0], cols}, &values));
178+
179+
std::memcpy(indices->data(), tmp_indices.data(), sizeof(KeyType) * ((size_t*)counter.data())[0]);
180+
std::memcpy(values->data(), tmp_values.data(), sizeof(ValueType) * ((size_t*)counter.data())[0] * cols);
181+
}
182+
};
183+
184+
#define REGISTER_GPU_KERNELS(key_type_tf, key_type, dtype_tf, dtype) \
185+
REGISTER_KERNEL_BUILDER(Name("DummyVarExportIf") \
186+
.Device(DEVICE_GPU) \
187+
.HostMemory("resource") \
188+
.HostMemory("threshold") \
189+
.HostMemory("indices") \
190+
.HostMemory("values") \
191+
.TypeConstraint<key_type_tf>("key_type") \
192+
.TypeConstraint<dtype_tf>("dtype"), \
193+
DummyVarExportIfOp<key_type, dtype>)
194+
#if TF_VERSION_MAJOR == 1
195+
REGISTER_GPU_KERNELS(int64, int64_t, float, float);
196+
REGISTER_GPU_KERNELS(int32, int32_t, float, float);
197+
#else
198+
REGISTER_GPU_KERNELS(int64_t, int64_t, float, float);
199+
REGISTER_GPU_KERNELS(int32_t, int32_t, float, float);
200+
#endif
201+
#undef REGISTER_GPU_KERNELS
202+
203+
135204
// -----------------------------------------------------------------------------------------------
136205
// DummyVarSparseRead
137206
// -----------------------------------------------------------------------------------------------

sparse_operation_kit/kit_src/variable/ops/dummy_var_ops.cc

+10
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ REGISTER_OP("DummyVarExport")
4141
.Attr("dtype: {float32} = DT_FLOAT")
4242
.SetShapeFn([](InferenceContext* c) { return sok_tsl_status(); });
4343

44+
REGISTER_OP("DummyVarExportIf")
45+
.Input("resource: resource")
46+
.Input("threshold: uint64")
47+
.Output("indices: key_type")
48+
.Output("values: dtype")
49+
.Attr("key_type: {int32, int64} = DT_INT64")
50+
.Attr("dtype: {float32} = DT_FLOAT")
51+
.SetShapeFn([](InferenceContext* c) { return sok_tsl_status(); });
52+
53+
4454
REGISTER_OP("DummyVarSparseRead")
4555
.Input("resource: resource")
4656
.Input("indices: key_type")

sparse_operation_kit/sparse_operation_kit/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
from sparse_operation_kit.lookup import lookup_sparse
6969
from sparse_operation_kit.lookup import all2all_dense_embedding
7070

71-
from sparse_operation_kit.dump_load import dump, load
71+
from sparse_operation_kit.dump_load import dump, load, incremental_model_dump
7272

7373

7474
# a specific code path for dl framework tf2.11.0

0 commit comments

Comments
 (0)