| """Base class for all Samplers. |
| Every Sampler subclass has to provide an __iter__ method, providing a way |
| to iterate over indices of dataset elements, and a __len__ method that |
| returns the length of the returned iterators. |
| def __init__(self, data_source): |
| raise NotImplementedError |
| raise NotImplementedError |
| class SequentialSampler(Sampler): |
| """Samples elements sequentially, always in the same order. |
| data_source (Dataset): dataset to sample from |
| def __init__(self, data_source): |
| self.num_samples = len(data_source) |
| return iter(range(self.num_samples)) |
| class RandomSampler(Sampler): |
| """Samples elements randomly, without replacement. |
| data_source (Dataset): dataset to sample from |
| def __init__(self, data_source): |
| self.num_samples = len(data_source) |
| return iter(torch.randperm(self.num_samples).long()) |