| import torch |
| |
| |
| class Sampler(object): |
| """Base class for all Samplers. |
| |
| Every Sampler subclass has to provide an __iter__ method, providing a way |
| to iterate over indices of dataset elements, and a __len__ method that |
| returns the length of the returned iterators. |
| """ |
| |
| def __init__(self, data_source): |
| pass |
| |
| def __iter__(self): |
| raise NotImplementedError |
| |
| def __len__(self): |
| raise NotImplementedError |
| |
| |
| class SequentialSampler(Sampler): |
| """Samples elements sequentially, always in the same order. |
| |
| Arguments: |
| data_source (Dataset): dataset to sample from |
| """ |
| |
| def __init__(self, data_source): |
| self.num_samples = len(data_source) |
| |
| def __iter__(self): |
| return iter(range(self.num_samples)) |
| |
| def __len__(self): |
| return self.num_samples |
| |
| |
| class RandomSampler(Sampler): |
| """Samples elements randomly, without replacement. |
| |
| Arguments: |
| data_source (Dataset): dataset to sample from |
| """ |
| |
| def __init__(self, data_source): |
| self.num_samples = len(data_source) |
| |
| def __iter__(self): |
| return iter(torch.randperm(self.num_samples).long()) |
| |
| def __len__(self): |
| return self.num_samples |
| |
| |
| class SubsetRandomSampler(Sampler): |
| """Samples elements randomly from a given list of indices, without replacement. |
| |
| Arguments: |
| indices (list): a list of indices |
| """ |
| |
| def __init__(self, indices): |
| self.indices = indices |
| |
| def __iter__(self): |
| return (self.indices[i] for i in torch.randperm(len(self.indices))) |
| |
| def __len__(self): |
| return len(self.indices) |
| |
| |
| class WeightedRandomSampler(Sampler): |
| """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). |
| Arguments: |
| weights (list) : a list of weights, not necessary summing up to one |
| num_samples (int): number of samples to draw |
| """ |
| |
| def __init__(self, weights, num_samples, replacement=True): |
| self.weights = torch.DoubleTensor(weights) |
| self.num_samples = num_samples |
| self.replacement = replacement |
| |
| def __iter__(self): |
| return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) |
| |
| def __len__(self): |
| return self.num_samples |