@@ -279,6 +279,41 @@ namespace cuda {
279
279
280
280
template <typename RNG>
281
281
void random_from_to_kernel (TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
282
+ #ifdef FBCODE_CAFFE2
283
+ AT_DISPATCH_V2 (iter.dtype (), " random_from_to_kernel_cuda" , AT_WRAP ([&] {
284
+ if ((
285
+ std::is_same_v<scalar_t , int64_t > ||
286
+ std::is_same_v<scalar_t , double > ||
287
+ std::is_same_v<scalar_t , float > ||
288
+ std::is_same_v<scalar_t , at::BFloat16>) && range >= 1ULL << 32 )
289
+ {
290
+ // define lambda to mod with range and add base
291
+ auto random_func = [range, base] __device__ (uint64_t rand ) {
292
+ return transformation::uniform_int_from_to<scalar_t >(rand , range, base);
293
+ };
294
+ distribution_nullary_kernel<scalar_t , uint64_t , ulonglong2>(iter,
295
+ gen,
296
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
297
+ ulonglong2 ret;
298
+ uint4 rand_val = curand4 (state);
299
+ ret.x = (static_cast <uint64_t >(rand_val.x ) << 32 ) | rand_val.y ;
300
+ ret.y = (static_cast <uint64_t >(rand_val.z ) << 32 ) | rand_val.w ;
301
+ return ret;
302
+ },
303
+ random_func);
304
+ } else {
305
+ auto random_func = [range, base] __device__ (uint32_t rand ) {
306
+ return transformation::uniform_int_from_to<scalar_t >(rand , range, base);
307
+ };
308
+ distribution_nullary_kernel<scalar_t , uint32_t , uint4>(iter,
309
+ gen,
310
+ [] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 {
311
+ return curand4 (state);
312
+ },
313
+ random_func);
314
+ }
315
+ }), AT_EXPAND (AT_ALL_TYPES), kBool , kHalf , kBFloat16 , AT_EXPAND (AT_BAREBONES_UNSIGNED_TYPES));
316
+ #else
282
317
AT_DISPATCH_V2 (iter.dtype (), " random_from_to_kernel_cuda" , AT_WRAP ([&] {
283
318
if (range >= 1ULL << 28 ) // allow approx 5% skew in uniform int generation using %
284
319
{
@@ -308,6 +343,7 @@ void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t bas
308
343
random_func);
309
344
}
310
345
}), AT_EXPAND (AT_ALL_TYPES), kBool , kHalf , kBFloat16 , AT_EXPAND (AT_BAREBONES_UNSIGNED_TYPES));
346
+ #endif
311
347
}
312
348
313
349
// This is the special kernel to handle single specific case:
0 commit comments