allow RandomSampler to sample with replacement (#9911)
Summary:
fixes #7908
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9911
Reviewed By: yf225
Differential Revision: D9023223
Pulled By: weiyangfb
fbshipit-source-id: 68b199bef3940b7205d0fdad75e7c46e6fe65ba7
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index c6f15b9..020486c 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -437,6 +437,36 @@
self.assertEqual(len(input), 3)
self.assertEqual(input, self.data[offset:offset + 3])
+ def test_RandomSampler(self):
+
+ from collections import Counter
+ from torch.utils.data import RandomSampler
+
+ def sample_stat(sampler, num_samples):
+ counts = Counter(sampler)
+ count_repeated = sum(val > 1 for val in counts.values())
+ return (count_repeated, min(counts.keys()), max(counts.keys()))
+
+ # test sample with replacement
+ n = len(self.dataset) + 1 # ensure at least one sample is drawn more than once
+ sampler_with_replacement = RandomSampler(self.dataset, replacement=True, num_samples=n)
+ count_repeated, minval, maxval = sample_stat(sampler_with_replacement, n)
+ self.assertTrue(count_repeated > 0)
+ self.assertTrue(minval >= 0)
+ self.assertTrue(maxval < len(self.dataset))
+
+ # test sample without replacement
+ sampler_without_replacement = RandomSampler(self.dataset)
+ count_repeated, minval, maxval = sample_stat(sampler_without_replacement, len(self.dataset))
+ self.assertTrue(count_repeated == 0)
+ self.assertTrue(minval == 0)
+ self.assertTrue(maxval == len(self.dataset) - 1)
+
+ # raise error when replacement=False and num_samples is not None
+ self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=len(self.dataset)))
+
+ self.assertRaises(ValueError, lambda: RandomSampler(self.dataset, num_samples=0))
+
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
don't support multiprocessing with spawn start method")
def test_batch_sampler(self):
diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py
index 3738e48..37303f0 100644
--- a/torch/utils/data/sampler.py
+++ b/torch/utils/data/sampler.py
@@ -38,17 +38,39 @@
class RandomSampler(Sampler):
- r"""Samples elements randomly, without replacement.
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
+ If with replacement, then user can specify ``num_samples`` to draw.
Arguments:
data_source (Dataset): dataset to sample from
+ num_samples (int): number of samples to draw, default=len(dataset)
+ replacement (bool): samples are drawn with replacement if ``True``, default=False
"""
- def __init__(self, data_source):
+ def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
+ self.replacement = replacement
+ self.num_samples = num_samples
+
+ if self.num_samples is not None and replacement is False:
+ raise ValueError("With replacement=False, num_samples should not be specified, "
+ "since a random permute will be performed.")
+
+ if self.num_samples is None:
+ self.num_samples = len(self.data_source)
+
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
+ raise ValueError("num_samples should be a positive integeral "
+ "value, but got num_samples={}".format(self.num_samples))
+ if not isinstance(self.replacement, bool):
+ raise ValueError("replacement should be a boolean value, but got "
+ "replacement={}".format(self.replacement))
def __iter__(self):
- return iter(torch.randperm(len(self.data_source)).tolist())
+ n = len(self.data_source)
+ if self.replacement:
+ return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
+ return iter(torch.randperm(n).tolist())
def __len__(self):
return len(self.data_source)