fix multinomial kernels to properly advance random states (#38046)
Summary:
Before, multinomial kernels did not advance random states enough, which lead to the same sequence being generated over and over with a shift of 4. This PR fixes that.
Fixes https://github.com/pytorch/pytorch/issues/37403
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38046
Differential Revision: D21516542
Pulled By: ngimel
fbshipit-source-id: 23248a8c3a5c44316c4c35cd71a8c3b5f76c90f2
diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu
index f57a77c..b5a2e71 100644
--- a/aten/src/ATen/native/cuda/MultinomialKernel.cu
+++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu
@@ -124,36 +124,33 @@
// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on.
- // global index formula for 1D grid of 2D blocks
- int idx = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;
+ // global index formula for 2D grid of 1D blocks
+ int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seeds.first, idx, seeds.second, &state);
// The block determines the distribution for which we generate a point
- for (int64_t curDist = blockIdx.x;
+ for (int64_t curDist = blockIdx.y;
curDist < distributions;
- curDist += gridDim.x) {
- for (int sampleBase = 0;
- sampleBase < totalSamples; sampleBase += blockDim.y) {
- // The warp determines the sample
- int sample = sampleBase + threadIdx.y;
+ curDist += gridDim.y) {
+ for (int sample = blockIdx.x*blockDim.x + threadIdx.x;
+ sample < totalSamples; sample += blockDim.x*gridDim.x) {
- // All threads participate in this
+ //we are losing 3 out of 4 generated numbers but it's ok
+ //this kernel is not very efficient anyway
auto rand = curand_uniform4(&state);
scalar_t r = static_cast<scalar_t>(rand.x);
- if (threadIdx.x == 0 && sample < totalSamples) {
- // Find the bucket that a uniform sample lies in
- int choice = binarySearchForMultinomial<scalar_t>(
- normDistPrefixSum + curDist * categories,
- normDist + curDist * categories,
- categories,
- r);
+ // Find the bucket that a uniform sample lies in
+ int choice = binarySearchForMultinomial<scalar_t>(
+ normDistPrefixSum + curDist * categories,
+ normDist + curDist * categories,
+ categories,
+ r);
- // Torch indices are 1-based
- dest[curDist * totalSamples + sample] = choice;
- }
+ dest[curDist * totalSamples + sample] = choice;
+
}
}
}
@@ -180,17 +177,14 @@
// The block and warp determines the distribution for which we
// generate a point
- for (int64_t curDistBase = blockIdx.x * blockDim.y;
- curDistBase < distributions;
- curDistBase += gridDim.x * blockDim.y) {
- // The warp determines the distribution
- int64_t curDist = curDistBase + threadIdx.y;
+ for (int64_t curDist = blockIdx.x * blockDim.y + threadIdx.y;
+ curDist < distributions;
+ curDist += gridDim.x * blockDim.y) {
- // All threads must participate in this
auto rand = curand_uniform4(&state);
scalar_t r = static_cast<scalar_t>(rand.x);
- if (threadIdx.x == 0 && curDist < distributions) {
+ if (threadIdx.x == 0) {
// Find the bucket that a uniform sample lies in
int choice = binarySearchForMultinomial<scalar_t>(
normDistPrefixSum + curDist * categories,
@@ -415,26 +409,27 @@
std::pair<uint64_t, uint64_t> rng_engine_inputs;
if (with_replacement) {
+ // Binary search is warp divergent (so effectively we're running
+ // with just a single thread), but for better utilization,
+ // we need each block to have at least 4 warps.
+ dim3 block(128);
+
+ // Each block will generate a sample from one
+ // distribution concurrently.
+ int grid_y=std::min<int>(numDist, at::cuda::getCurrentDeviceProperties()->maxGridSize[1]);
+ dim3 grid((n_sample-1)/block.x+1, grid_y);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
- // each thread will utilize one random, however, since we have to use
+ // each thread generates a single sample for (numdist/numblocks.y) distributions, however, since we have to use
// curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]),
- // offset is 4.
- rng_engine_inputs = gen->philox_engine_inputs(4);
+ // offset is 4 times that.
+ auto offset = ((numDist-1)/grid.y+1)*4;
+ rng_engine_inputs = gen->philox_engine_inputs(offset);
}
// Sample with replacement
- // Binary search is warp divergent (so effectively we're running
- // with just a single thread), but for better utilization,
- // we need each block to have at least 4 warps.
- dim3 block(32, 4);
-
- // Each warp in a block will generate a sample from one
- // distribution concurrently.
- dim3 grid(numDist < MAX_NUM_BLOCKS ? numDist : MAX_NUM_BLOCKS);
-
sampleMultinomialWithReplacement
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
rng_engine_inputs,
@@ -470,10 +465,11 @@
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
- // each thread will utilize one random, however, since we have to use
+ // each thread will utilize distributions/(gridDim.x*blockDim.y) randoms, however, since we have to use
// curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]),
- // offset is 4.
- rng_engine_inputs = gen->philox_engine_inputs(4);
+ // offset is 4 times that.
+ auto offset = ((numDist-1)/(grid.x*block.y)+1)*4;
+ rng_engine_inputs = gen->philox_engine_inputs(offset);
}
// The kernel can only draw one sample before we have to
diff --git a/test/test_torch.py b/test/test_torch.py
index 38570de..9827980 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -14226,6 +14226,24 @@
self.assertEqual(prob_dist.dim(), 1, "wrong number of prob_dist dimensions")
self.assertEqual(sample_indices.size(0), n_sample, "wrong number of samples")
+ @slowTest
+ @dtypes(torch.float)
+ def test_multinomial_rng_state_advance(self, device, dtype):
+ corpus_size = 100000
+ freqs = torch.ones(corpus_size, dtype=torch.float, device=device)
+ n_sample = 100
+ samples1 = torch.multinomial(freqs, n_sample, replacement=True)
+ samples2 = torch.multinomial(freqs, n_sample, replacement=True)
+ samples = torch.cat([samples1, samples2])
+ # expect no more than 1 repeating elements generated in 2 attempts
+ # the probability of at least element being repeated is surprisingly large, 18%
+ self.assertLessEqual(2 * n_sample - samples.unique().size(0), 2)
+ samples1 = torch.multinomial(freqs, n_sample, replacement=False)
+ samples2 = torch.multinomial(freqs, n_sample, replacement=False)
+ samples = torch.cat([samples1, samples2])
+ # expect no more than 1 repeating elements generated in 2 attempts
+ self.assertLessEqual(2 * n_sample - samples.unique().size(0), 1)
+
def test_var_unbiased(self, device):
tensor = torch.randn(100, device=device)
self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True))