@@ -137,6 +137,14 @@ __global__ void generate_normal_kernel(curandState* state, T** result, bool* d_f
137
137
}
138
138
}
139
139
140
+ template <class K , class S >
141
+ struct ExportIfPredFunctor {
142
+ __forceinline__ __device__ bool operator ()(const K& key, S& score, const K& pattern,
143
+ const S& threshold) {
144
+ return score > threshold;
145
+ }
146
+ };
147
+
140
148
static void set_curand_states (curandState** states, cudaStream_t stream = 0 ) {
141
149
int device;
142
150
CUDACHECK (cudaGetDevice (&device));
@@ -197,7 +205,6 @@ HKVVariable<KeyType, ValueType>::HKVVariable(int64_t dimension, int64_t initial_
197
205
nv::merlin::EvictStrategy hkv_evict_strategy;
198
206
parse_evict_strategy (evict_strategy, hkv_evict_strategy);
199
207
hkv_table_option_.evict_strategy = hkv_evict_strategy;
200
-
201
208
hkv_table_->init (hkv_table_option_);
202
209
}
203
210
@@ -230,22 +237,49 @@ void HKVVariable<KeyType, ValueType>::eXport(KeyType* keys, ValueType* values,
230
237
ValueType* d_values;
231
238
CUDACHECK (cudaMallocManaged (&d_values, sizeof (ValueType) * num_keys * dim));
232
239
233
- // KeyType* d_keys;
234
- // CUDACHECK(cudaMalloc(&d_keys, sizeof(KeyType) * num_keys));
235
- // ValueType* d_values;
236
- // CUDACHECK(cudaMalloc(&d_values, sizeof(ValueType) * num_keys * dim));
237
240
hkv_table_->export_batch (hkv_table_option_.max_capacity , 0 , d_keys, d_values, nullptr ,
238
241
stream); // Meta missing
239
242
CUDACHECK (cudaStreamSynchronize (stream));
240
243
241
- // clang-format off
242
244
std::memcpy (keys, d_keys, sizeof (KeyType) * num_keys);
243
245
std::memcpy (values, d_values, sizeof (ValueType) * num_keys * dim);
244
- // CUDACHECK(cudaMemcpy(keys, d_keys, sizeof(KeyType) * num_keys,cudaMemcpyDeviceToHost));
245
- // CUDACHECK(cudaMemcpy(values, d_values, sizeof(ValueType) * num_keys * dim,cudaMemcpyDeviceToHost));
246
+ CUDACHECK (cudaFree (d_keys));
247
+ CUDACHECK (cudaFree (d_values));
248
+ }
249
+
250
+ template <typename KeyType, typename ValueType>
251
+ void HKVVariable<KeyType, ValueType>::eXport_if(KeyType* keys, ValueType* values, size_t * counter,
252
+ uint64_t threshold, cudaStream_t stream) {
253
+ int64_t num_keys = rows ();
254
+ int64_t dim = cols ();
255
+
256
+ // `keys` and `values` are pointers of host memory
257
+ KeyType* d_keys;
258
+ CUDACHECK (cudaMallocManaged (&d_keys, sizeof (KeyType) * num_keys));
259
+ ValueType* d_values;
260
+ CUDACHECK (cudaMallocManaged (&d_values, sizeof (ValueType) * num_keys * dim));
261
+
262
+ uint64_t * d_socre_type;
263
+ CUDACHECK (cudaMallocManaged (&d_socre_type, sizeof (uint64_t ) * num_keys));
264
+
265
+ uint64_t * d_dump_counter;
266
+ CUDACHECK (cudaMallocManaged (&d_dump_counter, sizeof (uint64_t )));
267
+ // useless HKV need a input , but do nothing in the ExportIfPredFunctor
268
+ KeyType pattern = 100 ;
269
+
270
+ hkv_table_->template export_batch_if <ExportIfPredFunctor>(
271
+ pattern, threshold, hkv_table_->capacity (), 0 , d_dump_counter, d_keys, d_values, d_socre_type,
272
+ stream);
273
+ CUDACHECK (cudaStreamSynchronize (stream));
274
+ // clang-format off
275
+ std::memcpy (keys, d_keys, sizeof (KeyType) * (*d_dump_counter));
276
+ std::memcpy (values, d_values, sizeof (ValueType) * (*d_dump_counter) * dim);
277
+ counter[0 ] = (size_t )(*d_dump_counter);
246
278
// clang-format on
247
279
CUDACHECK (cudaFree (d_keys));
248
280
CUDACHECK (cudaFree (d_values));
281
+ CUDACHECK (cudaFree (d_socre_type));
282
+ CUDACHECK (cudaFree (d_dump_counter));
249
283
}
250
284
251
285
template <typename KeyType, typename ValueType>
0 commit comments