|  | import bisect | 
|  | import warnings | 
|  |  | 
|  | from torch._utils import _accumulate | 
|  | from torch import randperm | 
|  |  | 
|  |  | 
|  | class Dataset(object): | 
|  | """An abstract class representing a Dataset. | 
|  |  | 
|  | All other datasets should subclass it. All subclasses should override | 
|  | ``__len__``, that provides the size of the dataset, and ``__getitem__``, | 
|  | supporting integer indexing in range from 0 to len(self) exclusive. | 
|  | """ | 
|  |  | 
|  | def __getitem__(self, index): | 
|  | raise NotImplementedError | 
|  |  | 
|  | def __len__(self): | 
|  | raise NotImplementedError | 
|  |  | 
|  | def __add__(self, other): | 
|  | return ConcatDataset([self, other]) | 
|  |  | 
|  |  | 
|  | class TensorDataset(Dataset): | 
|  | """Dataset wrapping data and target tensors. | 
|  |  | 
|  | Each sample will be retrieved by indexing both tensors along the first | 
|  | dimension. | 
|  |  | 
|  | Arguments: | 
|  | data_tensor (Tensor): contains sample data. | 
|  | target_tensor (Tensor): contains sample targets (labels). | 
|  | """ | 
|  |  | 
|  | def __init__(self, data_tensor, target_tensor): | 
|  | assert data_tensor.size(0) == target_tensor.size(0) | 
|  | self.data_tensor = data_tensor | 
|  | self.target_tensor = target_tensor | 
|  |  | 
|  | def __getitem__(self, index): | 
|  | return self.data_tensor[index], self.target_tensor[index] | 
|  |  | 
|  | def __len__(self): | 
|  | return self.data_tensor.size(0) | 
|  |  | 
|  |  | 
|  | class ConcatDataset(Dataset): | 
|  | """ | 
|  | Dataset to concatenate multiple datasets. | 
|  | Purpose: useful to assemble different existing datasets, possibly | 
|  | large-scale datasets as the concatenation operation is done in an | 
|  | on-the-fly manner. | 
|  |  | 
|  | Arguments: | 
|  | datasets (iterable): List of datasets to be concatenated | 
|  | """ | 
|  |  | 
|  | @staticmethod | 
|  | def cumsum(sequence): | 
|  | r, s = [], 0 | 
|  | for e in sequence: | 
|  | l = len(e) | 
|  | r.append(l + s) | 
|  | s += l | 
|  | return r | 
|  |  | 
|  | def __init__(self, datasets): | 
|  | super(ConcatDataset, self).__init__() | 
|  | assert len(datasets) > 0, 'datasets should not be an empty iterable' | 
|  | self.datasets = list(datasets) | 
|  | self.cumulative_sizes = self.cumsum(self.datasets) | 
|  |  | 
|  | def __len__(self): | 
|  | return self.cumulative_sizes[-1] | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) | 
|  | if dataset_idx == 0: | 
|  | sample_idx = idx | 
|  | else: | 
|  | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] | 
|  | return self.datasets[dataset_idx][sample_idx] | 
|  |  | 
|  | @property | 
|  | def cummulative_sizes(self): | 
|  | warnings.warn("cummulative_sizes attribute is renamed to " | 
|  | "cumulative_sizes", DeprecationWarning, stacklevel=2) | 
|  | return self.cumulative_sizes | 
|  |  | 
|  |  | 
|  | class Subset(Dataset): | 
|  | def __init__(self, dataset, indices): | 
|  | self.dataset = dataset | 
|  | self.indices = indices | 
|  |  | 
|  | def __getitem__(self, idx): | 
|  | return self.dataset[self.indices[idx]] | 
|  |  | 
|  | def __len__(self): | 
|  | return len(self.indices) | 
|  |  | 
|  |  | 
|  | def random_split(dataset, lengths): | 
|  | """ | 
|  | Randomly split a dataset into non-overlapping new datasets of given lengths | 
|  | ds | 
|  |  | 
|  | Arguments: | 
|  | dataset (Dataset): Dataset to be split | 
|  | lengths (iterable): lengths of splits to be produced | 
|  | """ | 
|  | if sum(lengths) != len(dataset): | 
|  | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") | 
|  |  | 
|  | indices = randperm(sum(lengths)) | 
|  | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] |