Skip to content

Commit 0b17c09

Browse files
ngimelpytorchmergebot
authored andcommitted
restore rng generation for fbcode (pytorch#144819)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#144819 Approved by: https://github.com/malfet, https://github.com/kit1980
1 parent 49bdc41 commit 0b17c09

File tree

4 files changed

+49
-0
lines changed

4 files changed

+49
-0
lines changed

aten/src/ATen/core/DistributionsHelper.h

+8
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,15 @@ struct uniform_int_from_to_distribution {
4040

4141
template <typename RNG>
4242
C10_HOST_DEVICE inline T operator()(RNG generator) {
43+
#ifdef FBCODE_CAFFE2
44+
if ((
45+
std::is_same_v<T, int64_t> ||
46+
std::is_same_v<T, double> ||
47+
std::is_same_v<T, float> ||
48+
std::is_same_v<T, at::BFloat16>) && range_ >= 1ULL << 32)
49+
#else
4350
if (range_ >= 1ULL << 28) // allow approx 5% skew in uniform int generation using %
51+
#endif
4452
{
4553
return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_);
4654
} else {

aten/src/ATen/native/cuda/DistributionTemplates.h

+36
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,41 @@ namespace cuda {
279279

280280
template<typename RNG>
281281
void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
282+
#ifdef FBCODE_CAFFE2
283+
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
284+
if ((
285+
std::is_same_v<scalar_t, int64_t> ||
286+
std::is_same_v<scalar_t, double> ||
287+
std::is_same_v<scalar_t, float> ||
288+
std::is_same_v<scalar_t, at::BFloat16>) && range >= 1ULL << 32)
289+
{
290+
// define lambda to mod with range and add base
291+
auto random_func = [range, base] __device__ (uint64_t rand) {
292+
return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
293+
};
294+
distribution_nullary_kernel<scalar_t, uint64_t, ulonglong2>(iter,
295+
gen,
296+
[] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
297+
ulonglong2 ret;
298+
uint4 rand_val = curand4(state);
299+
ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
300+
ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
301+
return ret;
302+
},
303+
random_func);
304+
} else {
305+
auto random_func = [range, base] __device__ (uint32_t rand) {
306+
return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
307+
};
308+
distribution_nullary_kernel<scalar_t, uint32_t, uint4>(iter,
309+
gen,
310+
[] __device__ (curandStatePhilox4_32_10_t* state) -> uint4 {
311+
return curand4(state);
312+
},
313+
random_func);
314+
}
315+
}), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
316+
#else
282317
AT_DISPATCH_V2(iter.dtype(), "random_from_to_kernel_cuda", AT_WRAP([&] {
283318
if (range >= 1ULL << 28) // allow approx 5% skew in uniform int generation using %
284319
{
@@ -308,6 +343,7 @@ void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t bas
308343
random_func);
309344
}
310345
}), AT_EXPAND(AT_ALL_TYPES), kBool, kHalf, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
346+
#endif
311347
}
312348

313349
// This is the special kernel to handle single specific case:

aten/src/ATen/test/rng_test.h

+4
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,13 @@ void test_random_from_to(const at::Device& device) {
137137
range = static_cast<uint64_t>(max_to) - static_cast<uint64_t>(from) + 1;
138138
from_case_covered = true;
139139
}
140+
#ifdef FBCODE_CAFFE2
141+
if (range < (1ULL << 32)) {
142+
#else
140143
// this is leaking details of implementation into test
141144
// we are starting to use random64() at 2^28 to minimize skew due to %
142145
if (range < (1ULL << 28)) {
146+
#endif
143147
exp = static_cast<T>(static_cast<int64_t>((static_cast<uint32_t>(val) % range + from)));
144148
} else {
145149
exp = static_cast<T>(static_cast<int64_t>((val % range + from)));

test/test_tensor_creation_ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -3502,6 +3502,7 @@ def seed(generator):
35023502
self.assertTrue((res1 >= 0).all().item())
35033503

35043504

3505+
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "For fb compatibility random not changed in fbcode")
35053506
def test_randint_distribution(self, device):
35063507
size = 1_000_000
35073508
n_max = int(0.75 * 2 ** 32)

0 commit comments

Comments
 (0)