blob: 69b4c36eb699aab786129257f1aa78c88aa81724 [file] [log] [blame]
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class TensorDataset(Dataset):
def __init__(self, data_tensor, target_tensor):
assert data_tensor.size(0) == target_tensor.size(0)
self.data_tensor = data_tensor
self.target_tensor = target_tensor
if self.data_tensor.dim() == 1:
self.data_tensor = self.data_tensor.view(-1, 1)
if self.target_tensor.dim() == 1:
self.target_tensor = self.target_tensor.view(-1, 1)
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
def __len__(self):
return self.data_tensor.size(0)