| |
| class Dataset(object): |
| """An abstract class representing a Dataset. |
| |
| All other datasets should subclass it. All subclasses should override |
| ``__len__``, that provides the size of the dataset, and ``__getitem__``, |
| supporting integer indexing in range from 0 to len(self) exclusive. |
| """ |
| |
| def __getitem__(self, index): |
| raise NotImplementedError |
| |
| def __len__(self): |
| raise NotImplementedError |
| |
| |
| class TensorDataset(Dataset): |
| """Dataset wrapping data and target tensors. |
| |
| Each sample will be retrieved by indexing both tensors along the first |
| dimension. |
| |
| Arguments: |
| data_tensor (Tensor): contains sample data. |
| target_tensor (Tensor): contains sample targets (labels). |
| """ |
| |
| def __init__(self, data_tensor, target_tensor): |
| assert data_tensor.size(0) == target_tensor.size(0) |
| self.data_tensor = data_tensor |
| self.target_tensor = target_tensor |
| |
| def __getitem__(self, index): |
| return self.data_tensor[index], self.target_tensor[index] |
| |
| def __len__(self): |
| return self.data_tensor.size(0) |