| |
| class Dataset(object): |
| |
| def __getitem__(self, index): |
| raise NotImplementedError |
| |
| def __len__(self): |
| raise NotImplementedError |
| |
| |
| class TensorDataset(Dataset): |
| |
| def __init__(self, data_tensor, target_tensor): |
| assert data_tensor.size(0) == target_tensor.size(0) |
| self.data_tensor = data_tensor |
| self.target_tensor = target_tensor |
| if self.data_tensor.dim() == 1: |
| self.data_tensor = self.data_tensor.view(-1, 1) |
| if self.target_tensor.dim() == 1: |
| self.target_tensor = self.target_tensor.view(-1, 1) |
| |
| def __getitem__(self, index): |
| return self.data_tensor[index], self.target_tensor[index] |
| |
| def __len__(self): |
| return self.data_tensor.size(0) |
| |