| import bisect |
| import functools |
| import warnings |
| from typing import ( |
| Callable, |
| Dict, |
| Generic, |
| Iterable, |
| Iterator, |
| List, |
| Optional, |
| Sequence, |
| Tuple, |
| TypeVar, |
| ) |
| |
| # No 'default_generator' in torch/__init__.pyi |
| from torch import default_generator, randperm |
| from torch._utils import _accumulate |
| from torch.utils.data._typing import _DataPipeMeta |
| |
| from ... import Generator, Tensor |
| |
| T_co = TypeVar('T_co', covariant=True) |
| T = TypeVar('T') |
| |
| UNTRACABLE_DATAFRAME_PIPES = ['batch', # As it returns DataChunks |
| 'groupby', # As it returns DataChunks |
| '_dataframes_as_tuples', # As it unpacks DF |
| 'trace_as_dataframe', # As it used to mark DF for tracing |
| ] |
| |
| class DataChunk(list, Generic[T]): |
| def __init__(self, items): |
| super().__init__(items) |
| self.items = items |
| |
| def as_str(self, indent=''): |
| res = indent + "[" + ", ".join(str(i) for i in iter(self)) + "]" |
| return res |
| |
| def __iter__(self) -> Iterator[T]: |
| for i in super().__iter__(): |
| yield i |
| |
| def raw_iterator(self) -> T: |
| for i in self.items: |
| yield i |
| |
| |
| class Dataset(Generic[T_co]): |
| r"""An abstract class representing a :class:`Dataset`. |
| |
| All datasets that represent a map from keys to data samples should subclass |
| it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a |
| data sample for a given key. Subclasses could also optionally overwrite |
| :meth:`__len__`, which is expected to return the size of the dataset by many |
| :class:`~torch.utils.data.Sampler` implementations and the default options |
| of :class:`~torch.utils.data.DataLoader`. |
| |
| .. note:: |
| :class:`~torch.utils.data.DataLoader` by default constructs a index |
| sampler that yields integral indices. To make it work with a map-style |
| dataset with non-integral indices/keys, a custom sampler must be provided. |
| """ |
| functions: Dict[str, Callable] = {} |
| |
| def __getitem__(self, index) -> T_co: |
| raise NotImplementedError |
| |
| def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': |
| return ConcatDataset([self, other]) |
| |
| # No `def __len__(self)` default? |
| # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] |
| # in pytorch/torch/utils/data/sampler.py |
| |
| def __getattr__(self, attribute_name): |
| if attribute_name in Dataset.functions: |
| function = functools.partial(Dataset.functions[attribute_name], self) |
| return function |
| else: |
| raise AttributeError("'{0}' object has no attribute '{1}".format(self.__class__.__name__, attribute_name)) |
| |
| @classmethod |
| def register_function(cls, function_name, function): |
| cls.functions[function_name] = function |
| |
| @classmethod |
| def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False): |
| if function_name in cls.functions: |
| raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name)) |
| |
| def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs): |
| result_pipe = cls(source_dp, *args, **kwargs) |
| if isinstance(result_pipe, Dataset): |
| if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe): |
| if function_name not in UNTRACABLE_DATAFRAME_PIPES: |
| result_pipe = result_pipe.trace_as_dataframe() |
| |
| return result_pipe |
| |
| function = functools.partial(class_function, cls_to_register, enable_df_api_tracing) |
| cls.functions[function_name] = function |
| |
| |
| class IterableDataset(Dataset[T_co], metaclass=_DataPipeMeta): |
| r"""An iterable Dataset. |
| |
| All datasets that represent an iterable of data samples should subclass it. |
| Such form of datasets is particularly useful when data come from a stream. |
| |
| All subclasses should overwrite :meth:`__iter__`, which would return an |
| iterator of samples in this dataset. |
| |
| When a subclass is used with :class:`~torch.utils.data.DataLoader`, each |
| item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` |
| iterator. When :attr:`num_workers > 0`, each worker process will have a |
| different copy of the dataset object, so it is often desired to configure |
| each copy independently to avoid having duplicate data returned from the |
| workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker |
| process, returns information about the worker. It can be used in either the |
| dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's |
| :attr:`worker_init_fn` option to modify each copy's behavior. |
| |
| Example 1: splitting workload across all workers in :meth:`__iter__`:: |
| |
| >>> class MyIterableDataset(torch.utils.data.IterableDataset): |
| ... def __init__(self, start, end): |
| ... super(MyIterableDataset).__init__() |
| ... assert end > start, "this example code only works with end >= start" |
| ... self.start = start |
| ... self.end = end |
| ... |
| ... def __iter__(self): |
| ... worker_info = torch.utils.data.get_worker_info() |
| ... if worker_info is None: # single-process data loading, return the full iterator |
| ... iter_start = self.start |
| ... iter_end = self.end |
| ... else: # in a worker process |
| ... # split workload |
| ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) |
| ... worker_id = worker_info.id |
| ... iter_start = self.start + worker_id * per_worker |
| ... iter_end = min(iter_start + per_worker, self.end) |
| ... return iter(range(iter_start, iter_end)) |
| ... |
| >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. |
| >>> ds = MyIterableDataset(start=3, end=7) |
| |
| >>> # Single-process loading |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) |
| [3, 4, 5, 6] |
| |
| >>> # Mult-process loading with two worker processes |
| >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) |
| [3, 5, 4, 6] |
| |
| >>> # With even more workers |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20))) |
| [3, 4, 5, 6] |
| |
| Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: |
| |
| >>> class MyIterableDataset(torch.utils.data.IterableDataset): |
| ... def __init__(self, start, end): |
| ... super(MyIterableDataset).__init__() |
| ... assert end > start, "this example code only works with end >= start" |
| ... self.start = start |
| ... self.end = end |
| ... |
| ... def __iter__(self): |
| ... return iter(range(self.start, self.end)) |
| ... |
| >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. |
| >>> ds = MyIterableDataset(start=3, end=7) |
| |
| >>> # Single-process loading |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) |
| [3, 4, 5, 6] |
| >>> |
| >>> # Directly doing multi-process loading yields duplicate data |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) |
| [3, 3, 4, 4, 5, 5, 6, 6] |
| |
| >>> # Define a `worker_init_fn` that configures each dataset copy differently |
| >>> def worker_init_fn(worker_id): |
| ... worker_info = torch.utils.data.get_worker_info() |
| ... dataset = worker_info.dataset # the dataset copy in this worker process |
| ... overall_start = dataset.start |
| ... overall_end = dataset.end |
| ... # configure the dataset to only process the split workload |
| ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) |
| ... worker_id = worker_info.id |
| ... dataset.start = overall_start + worker_id * per_worker |
| ... dataset.end = min(dataset.start + per_worker, overall_end) |
| ... |
| |
| >>> # Mult-process loading with the custom `worker_init_fn` |
| >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) |
| [3, 5, 4, 6] |
| |
| >>> # With even more workers |
| >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn))) |
| [3, 4, 5, 6] |
| """ |
| functions: Dict[str, Callable] = {} |
| reduce_ex_hook: Optional[Callable] = None |
| getstate_hook: Optional[Callable] = None |
| |
| def __iter__(self) -> Iterator[T_co]: |
| raise NotImplementedError |
| |
| def __add__(self, other: Dataset[T_co]): |
| return ChainDataset([self, other]) |
| |
| # No `def __len__(self)` default? Subclasses raise `TypeError` when needed. |
| # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] |
| |
| def __getattr__(self, attribute_name): |
| if attribute_name in IterableDataset.functions: |
| function = functools.partial(IterableDataset.functions[attribute_name], self) |
| return function |
| else: |
| raise AttributeError("'{0}' object has no attribute '{1}".format(self.__class__.__name__, attribute_name)) |
| |
| def __getstate__(self): |
| if IterableDataset.getstate_hook is not None: |
| return IterableDataset.getstate_hook(self) |
| return self.__dict__ |
| |
| def __reduce_ex__(self, *args, **kwargs): |
| if IterableDataset.reduce_ex_hook is not None: |
| try: |
| return IterableDataset.reduce_ex_hook(self) |
| except NotImplementedError: |
| pass |
| return super().__reduce_ex__(*args, **kwargs) |
| |
| @classmethod |
| def set_getstate_hook(cls, hook_fn): |
| if IterableDataset.getstate_hook is not None and hook_fn is not None: |
| raise Exception("Attempt to override existing getstate_hook") |
| IterableDataset.getstate_hook = hook_fn |
| |
| @classmethod |
| def set_reduce_ex_hook(cls, hook_fn): |
| if IterableDataset.reduce_ex_hook is not None and hook_fn is not None: |
| raise Exception("Attempt to override existing reduce_ex_hook") |
| IterableDataset.reduce_ex_hook = hook_fn |
| |
| class DFIterDataPipe(IterableDataset): |
| def _is_dfpipe(self): |
| return True |
| |
| class TensorDataset(Dataset[Tuple[Tensor, ...]]): |
| r"""Dataset wrapping tensors. |
| |
| Each sample will be retrieved by indexing tensors along the first dimension. |
| |
| Args: |
| *tensors (Tensor): tensors that have the same size of the first dimension. |
| """ |
| tensors: Tuple[Tensor, ...] |
| |
| def __init__(self, *tensors: Tensor) -> None: |
| assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors" |
| self.tensors = tensors |
| |
| def __getitem__(self, index): |
| return tuple(tensor[index] for tensor in self.tensors) |
| |
| def __len__(self): |
| return self.tensors[0].size(0) |
| |
| |
| class ConcatDataset(Dataset[T_co]): |
| r"""Dataset as a concatenation of multiple datasets. |
| |
| This class is useful to assemble different existing datasets. |
| |
| Args: |
| datasets (sequence): List of datasets to be concatenated |
| """ |
| datasets: List[Dataset[T_co]] |
| cumulative_sizes: List[int] |
| |
| @staticmethod |
| def cumsum(sequence): |
| r, s = [], 0 |
| for e in sequence: |
| l = len(e) |
| r.append(l + s) |
| s += l |
| return r |
| |
| def __init__(self, datasets: Iterable[Dataset]) -> None: |
| super(ConcatDataset, self).__init__() |
| self.datasets = list(datasets) |
| assert len(self.datasets) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type] |
| for d in self.datasets: |
| assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" |
| self.cumulative_sizes = self.cumsum(self.datasets) |
| |
| def __len__(self): |
| return self.cumulative_sizes[-1] |
| |
| def __getitem__(self, idx): |
| if idx < 0: |
| if -idx > len(self): |
| raise ValueError("absolute value of index should not exceed dataset length") |
| idx = len(self) + idx |
| dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
| if dataset_idx == 0: |
| sample_idx = idx |
| else: |
| sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
| return self.datasets[dataset_idx][sample_idx] |
| |
| @property |
| def cummulative_sizes(self): |
| warnings.warn("cummulative_sizes attribute is renamed to " |
| "cumulative_sizes", DeprecationWarning, stacklevel=2) |
| return self.cumulative_sizes |
| |
| |
| class ChainDataset(IterableDataset): |
| r"""Dataset for chaining multiple :class:`IterableDataset` s. |
| |
| This class is useful to assemble different existing dataset streams. The |
| chaining operation is done on-the-fly, so concatenating large-scale |
| datasets with this class will be efficient. |
| |
| Args: |
| datasets (iterable of IterableDataset): datasets to be chained together |
| """ |
| def __init__(self, datasets: Iterable[Dataset]) -> None: |
| super(ChainDataset, self).__init__() |
| self.datasets = datasets |
| |
| def __iter__(self): |
| for d in self.datasets: |
| assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" |
| for x in d: |
| yield x |
| |
| def __len__(self): |
| total = 0 |
| for d in self.datasets: |
| assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" |
| total += len(d) |
| return total |
| |
| |
| class Subset(Dataset[T_co]): |
| r""" |
| Subset of a dataset at specified indices. |
| |
| Args: |
| dataset (Dataset): The whole Dataset |
| indices (sequence): Indices in the whole set selected for subset |
| """ |
| dataset: Dataset[T_co] |
| indices: Sequence[int] |
| |
| def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: |
| self.dataset = dataset |
| self.indices = indices |
| |
| def __getitem__(self, idx): |
| if isinstance(idx, list): |
| return self.dataset[[self.indices[i] for i in idx]] |
| return self.dataset[self.indices[idx]] |
| |
| def __len__(self): |
| return len(self.indices) |
| |
| |
| def random_split(dataset: Dataset[T], lengths: Sequence[int], |
| generator: Optional[Generator] = default_generator) -> List[Subset[T]]: |
| r""" |
| Randomly split a dataset into non-overlapping new datasets of given lengths. |
| Optionally fix the generator for reproducible results, e.g.: |
| |
| >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42)) |
| |
| Args: |
| dataset (Dataset): Dataset to be split |
| lengths (sequence): lengths of splits to be produced |
| generator (Generator): Generator used for the random permutation. |
| """ |
| # Cannot verify that dataset is Sized |
| if sum(lengths) != len(dataset): |
| raise ValueError("Sum of input lengths does not equal the length of the input dataset!") |
| |
| indices = randperm(sum(lengths), generator=generator).tolist() |
| return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)] |