blob: 53e3ff5c0143cb798652d9ca62dcbaa2502b71cf [file] [log] [blame]
import torch
class Sampler(object):
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.num_samples = len(data_source)
def __iter__(self):
return iter(range(self.num_samples))
def __len__(self):
return self.num_samples
class RandomSampler(Sampler):
def __init__(self, data_source):
self.num_samples = len(data_source)
def __iter__(self):
return iter(torch.randperm(self.num_samples).long())
def __len__(self):
return self.num_samples