fix multinomial kernels to properly advance random states (#38046)

Summary:
Before, multinomial kernels did not advance random states enough, which lead to the same sequence being generated over and over with a shift of 4. This PR fixes that.
Fixes https://github.com/pytorch/pytorch/issues/37403
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38046

Differential Revision: D21516542

Pulled By: ngimel

fbshipit-source-id: 23248a8c3a5c44316c4c35cd71a8c3b5f76c90f2
diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu
index f57a77c..b5a2e71 100644
--- a/aten/src/ATen/native/cuda/MultinomialKernel.cu
+++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu
@@ -124,36 +124,33 @@
   // search due to divergence. It seems possible to compute multiple
   // values and limit divergence though later on.
 
-  // global index formula for 1D grid of 2D blocks
-  int idx = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x;
+  // global index formula for 2D grid of 1D blocks
+  int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x + threadIdx.x;
 
   curandStatePhilox4_32_10_t state;
   curand_init(seeds.first, idx, seeds.second, &state);
 
   // The block determines the distribution for which we generate a point
-  for (int64_t curDist = blockIdx.x;
+  for (int64_t curDist = blockIdx.y;
        curDist < distributions;
-       curDist += gridDim.x) {
-    for (int sampleBase = 0;
-         sampleBase < totalSamples; sampleBase += blockDim.y) {
-      // The warp determines the sample
-      int sample = sampleBase + threadIdx.y;
+       curDist += gridDim.y) {
+    for (int sample = blockIdx.x*blockDim.x + threadIdx.x;
+         sample < totalSamples; sample += blockDim.x*gridDim.x) {
 
-      // All threads participate in this
+      //we are losing 3 out of 4 generated numbers but it's ok
+      //this kernel is not very efficient anyway
       auto rand = curand_uniform4(&state);
       scalar_t r = static_cast<scalar_t>(rand.x);
 
-      if (threadIdx.x == 0 && sample < totalSamples) {
-        // Find the bucket that a uniform sample lies in
-        int choice = binarySearchForMultinomial<scalar_t>(
-            normDistPrefixSum + curDist * categories,
-            normDist + curDist * categories,
-            categories,
-            r);
+      // Find the bucket that a uniform sample lies in
+      int choice = binarySearchForMultinomial<scalar_t>(
+          normDistPrefixSum + curDist * categories,
+          normDist + curDist * categories,
+          categories,
+          r);
 
-        // Torch indices are 1-based
-        dest[curDist * totalSamples + sample] = choice;
-      }
+      dest[curDist * totalSamples + sample] = choice;
+
     }
   }
 }
@@ -180,17 +177,14 @@
 
   // The block and warp determines the distribution for which we
   // generate a point
-  for (int64_t curDistBase = blockIdx.x * blockDim.y;
-       curDistBase < distributions;
-       curDistBase += gridDim.x * blockDim.y) {
-    // The warp determines the distribution
-    int64_t curDist = curDistBase + threadIdx.y;
+  for (int64_t curDist = blockIdx.x * blockDim.y + threadIdx.y;
+       curDist < distributions;
+       curDist += gridDim.x * blockDim.y) {
 
-    // All threads must participate in this
     auto rand = curand_uniform4(&state);
     scalar_t r = static_cast<scalar_t>(rand.x);
 
-    if (threadIdx.x == 0 && curDist < distributions) {
+    if (threadIdx.x == 0) {
       // Find the bucket that a uniform sample lies in
       int choice = binarySearchForMultinomial<scalar_t>(
           normDistPrefixSum + curDist * categories,
@@ -415,26 +409,27 @@
       std::pair<uint64_t, uint64_t> rng_engine_inputs;
 
       if (with_replacement) {
+        // Binary search is warp divergent (so effectively we're running
+        // with just a single thread), but for better utilization,
+        // we need each block to have at least 4 warps.
+        dim3 block(128);
+
+        // Each block will generate a sample from one
+        // distribution concurrently.
+        int grid_y=std::min<int>(numDist, at::cuda::getCurrentDeviceProperties()->maxGridSize[1]);
+        dim3 grid((n_sample-1)/block.x+1, grid_y);
         {
           // See Note [Acquire lock when using random generators]
           std::lock_guard<std::mutex> lock(gen->mutex_);
 
-          // each thread will utilize one random, however, since we have to use
+          // each thread generates a single sample for (numdist/numblocks.y) distributions, however, since we have to use
           // curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]),
-          // offset is 4.
-          rng_engine_inputs = gen->philox_engine_inputs(4);
+          // offset is 4 times that.
+          auto offset = ((numDist-1)/grid.y+1)*4;
+          rng_engine_inputs = gen->philox_engine_inputs(offset);
         }
         // Sample with replacement
 
-        // Binary search is warp divergent (so effectively we're running
-        // with just a single thread), but for better utilization,
-        // we need each block to have at least 4 warps.
-        dim3 block(32, 4);
-
-        // Each warp in a block will generate a sample from one
-        // distribution concurrently.
-        dim3 grid(numDist < MAX_NUM_BLOCKS ? numDist : MAX_NUM_BLOCKS);
-
         sampleMultinomialWithReplacement
             <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
             rng_engine_inputs,
@@ -470,10 +465,11 @@
             // See Note [Acquire lock when using random generators]
             std::lock_guard<std::mutex> lock(gen->mutex_);
 
-            // each thread will utilize one random, however, since we have to use
+            // each thread will utilize distributions/(gridDim.x*blockDim.y) randoms, however, since we have to use
             // curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]),
-            // offset is 4.
-            rng_engine_inputs = gen->philox_engine_inputs(4);
+            // offset is 4 times that.
+            auto offset = ((numDist-1)/(grid.x*block.y)+1)*4;
+            rng_engine_inputs = gen->philox_engine_inputs(offset);
           }
 
           // The kernel can only draw one sample before we have to
diff --git a/test/test_torch.py b/test/test_torch.py
index 38570de..9827980 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -14226,6 +14226,24 @@
             self.assertEqual(prob_dist.dim(), 1, "wrong number of prob_dist dimensions")
             self.assertEqual(sample_indices.size(0), n_sample, "wrong number of samples")
 
+    @slowTest
+    @dtypes(torch.float)
+    def test_multinomial_rng_state_advance(self, device, dtype):
+        corpus_size = 100000
+        freqs = torch.ones(corpus_size, dtype=torch.float, device=device)
+        n_sample = 100
+        samples1 = torch.multinomial(freqs, n_sample, replacement=True)
+        samples2 = torch.multinomial(freqs, n_sample, replacement=True)
+        samples = torch.cat([samples1, samples2])
+        # expect no more than 1 repeating elements generated in 2 attempts
+        # the probability of at least element being repeated is surprisingly large, 18%
+        self.assertLessEqual(2 * n_sample - samples.unique().size(0), 2)
+        samples1 = torch.multinomial(freqs, n_sample, replacement=False)
+        samples2 = torch.multinomial(freqs, n_sample, replacement=False)
+        samples = torch.cat([samples1, samples2])
+        # expect no more than 1 repeating elements generated in 2 attempts
+        self.assertLessEqual(2 * n_sample - samples.unique().size(0), 1)
+
     def test_var_unbiased(self, device):
         tensor = torch.randn(100, device=device)
         self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True))