| import bisect |
| import warnings |
| |
| from torch._utils import _accumulate |
| from torch import randperm |
| |
| |
| class Dataset(object): |
| """An abstract class representing a Dataset. |
| |
| All other datasets should subclass it. All subclasses should override |
| ``__len__``, that provides the size of the dataset, and ``__getitem__``, |
| supporting integer indexing in range from 0 to len(self) exclusive. |
| """ |
| |
| def __getitem__(self, index): |
| raise NotImplementedError |
| |
| def __len__(self): |
| raise NotImplementedError |
| |
| def __add__(self, other): |
| return ConcatDataset([self, other]) |
| |
| |
| class TensorDataset(Dataset): |
| """Dataset wrapping tensors. |
| |
| Each sample will be retrieved by indexing tensors along the first dimension. |
| |
| Arguments: |
| *tensors (Tensor): tensors that have the same size of the first dimension. |
| """ |
| |
| def __init__(self, *tensors): |
| assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) |
| self.tensors = tensors |
| |
| def __getitem__(self, index): |
| return tuple(tensor[index] for tensor in self.tensors) |
| |
| def __len__(self): |
| return self.tensors[0].size(0) |
| |
| |
| class ConcatDataset(Dataset): |
| """ |
| Dataset to concatenate multiple datasets. |
| Purpose: useful to assemble different existing datasets, possibly |
| large-scale datasets as the concatenation operation is done in an |
| on-the-fly manner. |
| |
| Arguments: |
| datasets (sequence): List of datasets to be concatenated |
| """ |
| |
| @staticmethod |
| def cumsum(sequence): |
| r, s = [], 0 |
| for e in sequence: |
| l = len(e) |
| r.append(l + s) |
| s += l |
| return r |
| |
| def __init__(self, datasets): |
| super(ConcatDataset, self).__init__() |
| assert len(datasets) > 0, 'datasets should not be an empty iterable' |
| self.datasets = list(datasets) |
| self.cumulative_sizes = self.cumsum(self.datasets) |
| |
| def __len__(self): |
| return self.cumulative_sizes[-1] |
| |
| def __getitem__(self, idx): |
| if idx < 0: |
| if -idx > len(self): |
| raise ValueError("absolute value of index should not exceed dataset length") |
| idx = len(self) + idx |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
| if dataset_idx == 0: |
| sample_idx = idx |
| else: |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
| return self.datasets[dataset_idx][sample_idx] |
| |
| @property |
| def cummulative_sizes(self): |
| warnings.warn("cummulative_sizes attribute is renamed to " |
| "cumulative_sizes", DeprecationWarning, stacklevel=2) |
| return self.cumulative_sizes |
| |
| |
| class Subset(Dataset): |
| """ |
| Subset of a dataset at specified indices. |
| |
| Arguments: |
| dataset (Dataset): The whole Dataset |
| indices (sequence): Indices in the whole set selected for subset |
| """ |
| def __init__(self, dataset, indices): |
| self.dataset = dataset |
| self.indices = indices |
| |
| def __getitem__(self, idx): |
| return self.dataset[self.indices[idx]] |
| |
| def __len__(self): |
| return len(self.indices) |
| |
| |
| def random_split(dataset, lengths): |
| """ |
| Randomly split a dataset into non-overlapping new datasets of given lengths. |
| |
| Arguments: |
| dataset (Dataset): Dataset to be split |
| lengths (sequence): lengths of splits to be produced |
| """ |
| if sum(lengths) != len(dataset): |
| raise ValueError("Sum of input lengths does not equal the length of the input dataset!") |
| |
| indices = randperm(sum(lengths)).tolist() |
| return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] |