| import torch |
| from torch._six import int_classes as _int_classes |
| |
| |
| class Sampler(object): |
| r"""Base class for all Samplers. |
| |
| Every Sampler subclass has to provide an __iter__ method, providing a way |
| to iterate over indices of dataset elements, and a __len__ method that |
| returns the length of the returned iterators. |
| """ |
| |
| def __init__(self, data_source): |
| pass |
| |
| def __iter__(self): |
| raise NotImplementedError |
| |
| def __len__(self): |
| raise NotImplementedError |
| |
| |
| class SequentialSampler(Sampler): |
| r"""Samples elements sequentially, always in the same order. |
| |
| Arguments: |
| data_source (Dataset): dataset to sample from |
| """ |
| |
| def __init__(self, data_source): |
| self.data_source = data_source |
| |
| def __iter__(self): |
| return iter(range(len(self.data_source))) |
| |
| def __len__(self): |
| return len(self.data_source) |
| |
| |
| class RandomSampler(Sampler): |
| r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. |
| If with replacement, then user can specify ``num_samples`` to draw. |
| |
| Arguments: |
| data_source (Dataset): dataset to sample from |
| num_samples (int): number of samples to draw, default=len(dataset) |
| replacement (bool): samples are drawn with replacement if ``True``, default=False |
| """ |
| |
| def __init__(self, data_source, replacement=False, num_samples=None): |
| self.data_source = data_source |
| self.replacement = replacement |
| self.num_samples = num_samples |
| |
| if self.num_samples is not None and replacement is False: |
| raise ValueError("With replacement=False, num_samples should not be specified, " |
| "since a random permute will be performed.") |
| |
| if self.num_samples is None: |
| self.num_samples = len(self.data_source) |
| |
| if not isinstance(self.num_samples, int) or self.num_samples <= 0: |
| raise ValueError("num_samples should be a positive integeral " |
| "value, but got num_samples={}".format(self.num_samples)) |
| if not isinstance(self.replacement, bool): |
| raise ValueError("replacement should be a boolean value, but got " |
| "replacement={}".format(self.replacement)) |
| |
| def __iter__(self): |
| n = len(self.data_source) |
| if self.replacement: |
| return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) |
| return iter(torch.randperm(n).tolist()) |
| |
| def __len__(self): |
| return len(self.data_source) |
| |
| |
| class SubsetRandomSampler(Sampler): |
| r"""Samples elements randomly from a given list of indices, without replacement. |
| |
| Arguments: |
| indices (sequence): a sequence of indices |
| """ |
| |
| def __init__(self, indices): |
| self.indices = indices |
| |
| def __iter__(self): |
| return (self.indices[i] for i in torch.randperm(len(self.indices))) |
| |
| def __len__(self): |
| return len(self.indices) |
| |
| |
| class WeightedRandomSampler(Sampler): |
| r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights). |
| |
| Arguments: |
| weights (sequence) : a sequence of weights, not necessary summing up to one |
| num_samples (int): number of samples to draw |
| replacement (bool): if ``True``, samples are drawn with replacement. |
| If not, they are drawn without replacement, which means that when a |
| sample index is drawn for a row, it cannot be drawn again for that row. |
| """ |
| |
| def __init__(self, weights, num_samples, replacement=True): |
| if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \ |
| num_samples <= 0: |
| raise ValueError("num_samples should be a positive integeral " |
| "value, but got num_samples={}".format(num_samples)) |
| if not isinstance(replacement, bool): |
| raise ValueError("replacement should be a boolean value, but got " |
| "replacement={}".format(replacement)) |
| self.weights = torch.tensor(weights, dtype=torch.double) |
| self.num_samples = num_samples |
| self.replacement = replacement |
| |
| def __iter__(self): |
| return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist()) |
| |
| def __len__(self): |
| return self.num_samples |
| |
| |
| class BatchSampler(Sampler): |
| r"""Wraps another sampler to yield a mini-batch of indices. |
| |
| Args: |
| sampler (Sampler): Base sampler. |
| batch_size (int): Size of mini-batch. |
| drop_last (bool): If ``True``, the sampler will drop the last batch if |
| its size would be less than ``batch_size`` |
| |
| Example: |
| >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) |
| [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] |
| >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) |
| [[0, 1, 2], [3, 4, 5], [6, 7, 8]] |
| """ |
| |
| def __init__(self, sampler, batch_size, drop_last): |
| if not isinstance(sampler, Sampler): |
| raise ValueError("sampler should be an instance of " |
| "torch.utils.data.Sampler, but got sampler={}" |
| .format(sampler)) |
| if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \ |
| batch_size <= 0: |
| raise ValueError("batch_size should be a positive integeral value, " |
| "but got batch_size={}".format(batch_size)) |
| if not isinstance(drop_last, bool): |
| raise ValueError("drop_last should be a boolean value, but got " |
| "drop_last={}".format(drop_last)) |
| self.sampler = sampler |
| self.batch_size = batch_size |
| self.drop_last = drop_last |
| |
| def __iter__(self): |
| batch = [] |
| for idx in self.sampler: |
| batch.append(idx) |
| if len(batch) == self.batch_size: |
| yield batch |
| batch = [] |
| if len(batch) > 0 and not self.drop_last: |
| yield batch |
| |
| def __len__(self): |
| if self.drop_last: |
| return len(self.sampler) // self.batch_size |
| else: |
| return (len(self.sampler) + self.batch_size - 1) // self.batch_size |