allow RandomSampler to sample with replacement (#9911)

Summary:
fixes #7908
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9911

Reviewed By: yf225

Differential Revision: D9023223

Pulled By: weiyangfb

fbshipit-source-id: 68b199bef3940b7205d0fdad75e7c46e6fe65ba7
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index c6f15b9..020486c 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -437,6 +437,36 @@
                 self.assertEqual(len(input), 3)
                 self.assertEqual(input, self.data[offset:offset + 3])
 
+    def test_RandomSampler(self):
+
+        from collections import Counter
+        from torch.utils.data import RandomSampler
+
+        def sample_stat(sampler, num_samples):
+            counts = Counter(sampler)
+            count_repeated = sum(val > 1 for val in counts.values())
+            return (count_repeated, min(counts.keys()), max(counts.keys()))
+
+        # test sample with replacement
+        n = len(self.dataset) + 1  # ensure at least one sample is drawn more than once
+        sampler_with_replacement = RandomSampler(self.dataset, replacement=True, num_samples=n)
+        count_repeated, minval, maxval = sample_stat(sampler_with_replacement, n)
+        self.assertTrue(count_repeated > 0)
+        self.assertTrue(minval >= 0)
+        self.assertTrue(maxval < len(self.dataset))
+
+        # test sample without replacement
+        sampler_without_replacement = RandomSampler(self.dataset)
+        count_repeated, minval, maxval = sample_stat(sampler_without_replacement, len(self.dataset))
+        self.assertTrue(count_repeated == 0)
+        self.assertTrue(minval == 0)
+        self.assertTrue(maxval == len(self.dataset) - 1)
+
+        # raise error when replacement=False and num_samples is not None
+        self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=len(self.dataset)))
+
+        self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=0))
+
     @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
                      don't support multiprocessing with spawn start method")
     def test_batch_sampler(self):
diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py
index 3738e48..37303f0 100644
--- a/torch/utils/data/sampler.py
+++ b/torch/utils/data/sampler.py
@@ -38,17 +38,39 @@
 
 
 class RandomSampler(Sampler):
-    r"""Samples elements randomly, without replacement.
+    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
+    If with replacement, then user can specify ``num_samples`` to draw.
 
     Arguments:
         data_source (Dataset): dataset to sample from
+        num_samples (int): number of samples to draw, default=len(dataset)
+        replacement (bool): samples are drawn with replacement if ``True``, default=False
     """
 
-    def __init__(self, data_source):
+    def __init__(self, data_source, replacement=False, num_samples=None):
         self.data_source = data_source
+        self.replacement = replacement
+        self.num_samples = num_samples
+
+        if self.num_samples is not None and replacement is False:
+            raise ValueError("With replacement=False, num_samples should not be specified, "
+                             "since a random permute will be performed.")
+
+        if self.num_samples is None:
+            self.num_samples = len(self.data_source)
+
+        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
+            raise ValueError("num_samples should be a positive integeral "
+                             "value, but got num_samples={}".format(self.num_samples))
+        if not isinstance(self.replacement, bool):
+            raise ValueError("replacement should be a boolean value, but got "
+                             "replacement={}".format(self.replacement))
 
     def __iter__(self):
-        return iter(torch.randperm(len(self.data_source)).tolist())
+        n = len(self.data_source)
+        if self.replacement:
+            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
+        return iter(torch.randperm(n).tolist())
 
     def __len__(self):
         return len(self.data_source)