Multinomial raise error (#12490)

Summary:
Fixes #12260 #2896

```
torch.multinomial(torch.FloatTensor([0, 1, 0, 0]), 3, replacement=False)
```
The old behavior is that we return `0` after we run out of postive categories. Now we raise an error based on discussion in the issue thread.

- Add testcase for cpu & cuda case, in cuda case `n_samples=1` is a simple special case, so we test against `n_sample=2` instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12490

Differential Revision: D10278794

Pulled By: ailzhang

fbshipit-source-id: d04de7a60f60d0c0d648b975db3f3961fcf42db1
diff --git a/aten/src/TH/generic/THTensorRandom.cpp b/aten/src/TH/generic/THTensorRandom.cpp
index 84ead95..c5ff9d1 100644
--- a/aten/src/TH/generic/THTensorRandom.cpp
+++ b/aten/src/TH/generic/THTensorRandom.cpp
@@ -285,6 +285,7 @@
     /* Get normalized cumulative distribution from prob distribution */
     double sum = 0;
     double val;
+    int n_zeros = 0;
     for (j=0; j<n_categories; j++)
     {
       val = THStorage_(get)( \
@@ -300,6 +301,9 @@
                             2,
                             "invalid multinomial distribution (encountering probability entry = infinity or NaN)");
       sum += val;
+      if (val == 0) {
+        n_zeros += 1;
+      }
       THDoubleStorage_set(
         THTensor_getStoragePtr(cum_dist), \
         cum_dist->storage_offset()+j*cum_dist->stride(0), \
@@ -310,6 +314,10 @@
                           THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),
                           2,
                           "invalid multinomial distribution (sum of probabilities <= 0)");
+    THArgCheckWithCleanup((with_replacement || (n_categories - n_zeros >= n_sample)),
+                          THCleanup(THDoubleTensor_free(cum_dist); if (start_dim == 1) THTensor_(squeeze1d)(prob_dist, prob_dist, 0);),
+                          2,
+                          "invalid multinomial distribution (with replacement=False, not enough non-negative category to sample)");
     /* normalize cumulative probability distribution so that last val is 1
     i.e. doesn't assume original prob_dist row sums to one */
     if ( (sum > 0) || ( ( sum < 1.00001) && (sum > 0.99999) ) )
diff --git a/aten/src/THC/THCTensorRandom.cuh b/aten/src/THC/THCTensorRandom.cuh
index 13933ca..cfe8510 100644
--- a/aten/src/THC/THCTensorRandom.cuh
+++ b/aten/src/THC/THCTensorRandom.cuh
@@ -7,7 +7,7 @@
 
 #include <curand_kernel.h>
 
-#define MAX_NUM_BLOCKS 200 
+#define MAX_NUM_BLOCKS 200
 #define BLOCK_SIZE 256
 /* Separate kernel because curand_log_normal gets extra parameters. */
 
@@ -126,6 +126,8 @@
                                           T val) {
   int start = 0;
   int end = size;
+  // dist[size - 1] = 0 => all zero prob dist
+  assert(THCNumerics<T>::gt(dist[size - 1], 0));
 
   while (end - start > 0) {
     int mid = start + (end - start) / 2;
diff --git a/test/test_cuda.py b/test/test_cuda.py
index cf21208..0ec563a 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1594,17 +1594,6 @@
         r = torch.multinomial(p, 1)
         self.assertNotEqual(r.min().item(), 0)
 
-        # multinomial without repeat but with less nonzero
-        # elements than draws
-        # the intention currently is to return 0 for those
-        # and match CPU behaviour, see issue #9062
-        p = torch.zeros(1, 5, device="cuda")
-        p[:, 1] = 1
-        r = torch.multinomial(p, 2, replacement=False)
-        expected = torch.zeros(1, 2, device="cuda", dtype=torch.long)
-        expected[:, 0] = 1
-        self.assertEqual(r, expected)
-
     @staticmethod
     def mute():
         os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno())
@@ -1621,7 +1610,7 @@
     def _test_multinomial_invalid_probs_cuda(probs):
         try:
             with torch.random.fork_rng(devices=[0]):
-                torch.multinomial(probs.to('cuda'), 1)
+                torch.multinomial(probs.to('cuda'), 2)
                 torch.cuda.synchronize()
             return False  # Should not be reached
         except RuntimeError as e:
@@ -1635,10 +1624,11 @@
                      but we need it for creating another process with CUDA")
     def test_multinomial_invalid_probs_cuda(self):
         test_method = TestCuda._test_multinomial_invalid_probs_cuda
-        self._spawn_method(test_method, torch.Tensor([0, -1]))
-        self._spawn_method(test_method, torch.Tensor([0, inf]))
-        self._spawn_method(test_method, torch.Tensor([0, -inf]))
-        self._spawn_method(test_method, torch.Tensor([0, nan]))
+        self._spawn_method(test_method, torch.Tensor([1, -1, 1]))
+        self._spawn_method(test_method, torch.Tensor([1, inf, 1]))
+        self._spawn_method(test_method, torch.Tensor([1, -inf, 1]))
+        self._spawn_method(test_method, torch.Tensor([1, 1, nan]))
+        self._spawn_method(test_method, torch.Tensor([0, 1, 0]))
 
     @skipIfRocm
     def test_broadcast(self):
diff --git a/test/test_torch.py b/test/test_torch.py
index 8479929..e00b0d0 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -2851,7 +2851,8 @@
     @staticmethod
     def _test_multinomial_invalid_probs(probs):
         try:
-            torch.multinomial(probs.to('cpu'), 1)
+            # n_sample = 1 is a special case, test n_sample=2 which is more general
+            torch.multinomial(probs.to('cpu'), 2)
             return False  # Should not be reached
         except RuntimeError as e:
             return 'invalid multinomial distribution' in str(e)
@@ -2864,10 +2865,11 @@
                      but we need it for for testing failure case for CPU RNG on Windows")
     def test_multinomial_invalid_probs(self):
         test_method = TestTorch._test_multinomial_invalid_probs
-        self._spawn_method(test_method, torch.Tensor([0, -1]))
-        self._spawn_method(test_method, torch.Tensor([0, inf]))
-        self._spawn_method(test_method, torch.Tensor([0, -inf]))
-        self._spawn_method(test_method, torch.Tensor([0, nan]))
+        self._spawn_method(test_method, torch.Tensor([1, -1, 1]))
+        self._spawn_method(test_method, torch.Tensor([1, inf, 1]))
+        self._spawn_method(test_method, torch.Tensor([1, -inf, 1]))
+        self._spawn_method(test_method, torch.Tensor([1, 1, nan]))
+        self._spawn_method(test_method, torch.Tensor([0, 1, 0]))
 
     @suppress_warnings
     def test_range(self):