| import torch |
| |
| |
| class Sampler(object): |
| |
| def __init__(self, data_source): |
| pass |
| |
| def __iter__(self): |
| raise NotImplementedError |
| |
| def __len__(self): |
| raise NotImplementedError |
| |
| |
| class SequentialSampler(Sampler): |
| |
| def __init__(self, data_source): |
| self.num_samples = len(data_source) |
| |
| def __iter__(self): |
| return iter(range(self.num_samples)) |
| |
| def __len__(self): |
| return self.num_samples |
| |
| |
| class RandomSampler(Sampler): |
| |
| def __init__(self, data_source): |
| self.num_samples = len(data_source) |
| |
| def __iter__(self): |
| return iter(torch.randperm(self.num_samples).long()) |
| |
| def __len__(self): |
| return self.num_samples |