|  | import functools | 
|  | import pickle | 
|  | from typing import Callable, Dict, Iterable, Iterator, List, Optional, TypeVar | 
|  |  | 
|  | from torch.utils._import_utils import import_dill | 
|  | from torch.utils.data.datapipes._hook_iterator import _SnapshotState | 
|  | from torch.utils.data.datapipes._typing import _DataPipeMeta, _IterDataPipeMeta | 
|  | from torch.utils.data.datapipes.utils.common import ( | 
|  | _deprecation_warning, | 
|  | _iter_deprecated_functional_names, | 
|  | _map_deprecated_functional_names, | 
|  | ) | 
|  | from torch.utils.data.dataset import Dataset, IterableDataset | 
|  |  | 
|  |  | 
|  | dill = import_dill() | 
|  | HAS_DILL = dill is not None | 
|  |  | 
|  | __all__ = [ | 
|  | "DataChunk", | 
|  | "DFIterDataPipe", | 
|  | "IterDataPipe", | 
|  | "MapDataPipe", | 
|  | ] | 
|  |  | 
|  |  | 
|  | _T = TypeVar("_T") | 
|  | _T_co = TypeVar("_T_co", covariant=True) | 
|  |  | 
|  | UNTRACABLE_DATAFRAME_PIPES = [ | 
|  | "batch",  # As it returns DataChunks | 
|  | "groupby",  # As it returns DataChunks | 
|  | "_dataframes_as_tuples",  # As it unpacks DF | 
|  | "trace_as_dataframe",  # As it used to mark DF for tracing | 
|  | ] | 
|  |  | 
|  |  | 
|  | class DataChunk(List[_T]): | 
|  | def __init__(self, items: Iterable[_T]) -> None: | 
|  | items = list(items) | 
|  | super().__init__(items) | 
|  | self.items = items | 
|  |  | 
|  | def as_str(self, indent: str = "") -> str: | 
|  | return indent + "[" + ", ".join(str(i) for i in iter(self)) + "]" | 
|  |  | 
|  | def __iter__(self) -> Iterator[_T]: | 
|  | yield from super().__iter__() | 
|  |  | 
|  | def raw_iterator(self) -> Iterator[_T]: | 
|  | yield from self.items | 
|  |  | 
|  |  | 
|  | class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): | 
|  | r""" | 
|  | Iterable-style DataPipe. | 
|  |  | 
|  | All DataPipes that represent an iterable of data samples should subclass this. | 
|  | This style of DataPipes is particularly useful when data come from a stream, or | 
|  | when the number of samples is too large to fit them all in memory. ``IterDataPipe`` is lazily initialized and its | 
|  | elements are computed only when ``next()`` is called on the iterator of an ``IterDataPipe``. | 
|  |  | 
|  | All subclasses should overwrite :meth:`__iter__`, which would return an | 
|  | iterator of samples in this DataPipe. Calling ``__iter__`` of an ``IterDataPipe`` automatically invokes its | 
|  | method ``reset()``, which by default performs no operation. When writing a custom ``IterDataPipe``, users should | 
|  | override ``reset()`` if necessary. The common usages include resetting buffers, pointers, | 
|  | and various state variables within the custom ``IterDataPipe``. | 
|  |  | 
|  | Note: | 
|  | Only `one` iterator can be valid for each ``IterDataPipe`` at a time, | 
|  | and the creation a second iterator will invalidate the first one. This constraint is necessary because | 
|  | some ``IterDataPipe`` have internal buffers, whose states can become invalid if there are multiple iterators. | 
|  | The code example below presents details on how this constraint looks in practice. | 
|  | If you have any feedback related to this constraint, please see `GitHub IterDataPipe Single Iterator Issue`_. | 
|  |  | 
|  | These DataPipes can be invoked in two ways, using the class constructor or applying their | 
|  | functional form onto an existing ``IterDataPipe`` (recommended, available to most but not all DataPipes). | 
|  | You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple | 
|  | operations in succession. | 
|  |  | 
|  | .. _GitHub IterDataPipe Single Iterator Issue: | 
|  | https://github.com/pytorch/data/issues/45 | 
|  |  | 
|  | Note: | 
|  | When a subclass is used with :class:`~torch.utils.data.DataLoader`, each | 
|  | item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader` | 
|  | iterator. When :attr:`num_workers > 0`, each worker process will have a | 
|  | different copy of the DataPipe object, so it is often desired to configure | 
|  | each copy independently to avoid having duplicate data returned from the | 
|  | workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker | 
|  | process, returns information about the worker. It can be used in either the | 
|  | dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's | 
|  | :attr:`worker_init_fn` option to modify each copy's behavior. | 
|  |  | 
|  | Examples: | 
|  | General Usage: | 
|  | >>> # xdoctest: +SKIP | 
|  | >>> from torchdata.datapipes.iter import IterableWrapper, Mapper | 
|  | >>> dp = IterableWrapper(range(10)) | 
|  | >>> map_dp_1 = Mapper(dp, lambda x: x + 1)  # Using class constructor | 
|  | >>> map_dp_2 = dp.map(lambda x: x + 1)  # Using functional form (recommended) | 
|  | >>> list(map_dp_1) | 
|  | [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | 
|  | >>> list(map_dp_2) | 
|  | [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | 
|  | >>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0) | 
|  | >>> list(filter_dp) | 
|  | [2, 4, 6, 8, 10] | 
|  | Single Iterator Constraint Example: | 
|  | >>> from torchdata.datapipes.iter import IterableWrapper, Mapper | 
|  | >>> source_dp = IterableWrapper(range(10)) | 
|  | >>> it1 = iter(source_dp) | 
|  | >>> list(it1) | 
|  | [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | 
|  | >>> it1 = iter(source_dp) | 
|  | >>> it2 = iter(source_dp)  # The creation of a new iterator invalidates `it1` | 
|  | >>> next(it2) | 
|  | 0 | 
|  | >>> next(it1)  # Further usage of `it1` will raise a `RunTimeError` | 
|  | """ | 
|  |  | 
|  | functions: Dict[str, Callable] = {} | 
|  | reduce_ex_hook: Optional[Callable] = None | 
|  | getstate_hook: Optional[Callable] = None | 
|  | str_hook: Optional[Callable] = None | 
|  | repr_hook: Optional[Callable] = None | 
|  | _valid_iterator_id: Optional[int] = None | 
|  | _number_of_samples_yielded: int = 0 | 
|  | _snapshot_state: _SnapshotState = _SnapshotState.NotStarted | 
|  | _fast_forward_iterator: Optional[Iterator] = None | 
|  |  | 
|  | def __iter__(self) -> Iterator[_T_co]: | 
|  | return self | 
|  |  | 
|  | def __getattr__(self, attribute_name): | 
|  | if attribute_name in IterDataPipe.functions: | 
|  | if attribute_name in _iter_deprecated_functional_names: | 
|  | kwargs = _iter_deprecated_functional_names[attribute_name] | 
|  | _deprecation_warning(**kwargs) | 
|  | f = IterDataPipe.functions[attribute_name] | 
|  | function = functools.partial(f, self) | 
|  | functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",)) | 
|  | return function | 
|  | else: | 
|  | raise AttributeError( | 
|  | f"'{self.__class__.__name__}' object has no attribute '{attribute_name}" | 
|  | ) | 
|  |  | 
|  | @classmethod | 
|  | def register_function(cls, function_name, function): | 
|  | cls.functions[function_name] = function | 
|  |  | 
|  | @classmethod | 
|  | def register_datapipe_as_function( | 
|  | cls, function_name, cls_to_register, enable_df_api_tracing=False | 
|  | ): | 
|  | if function_name in cls.functions: | 
|  | raise Exception(  # noqa: TRY002 | 
|  | f"Unable to add DataPipe function name {function_name} as it is already taken" | 
|  | ) | 
|  |  | 
|  | def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs): | 
|  | result_pipe = cls(source_dp, *args, **kwargs) | 
|  | if isinstance(result_pipe, IterDataPipe): | 
|  | if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe): | 
|  | if function_name not in UNTRACABLE_DATAFRAME_PIPES: | 
|  | result_pipe = result_pipe.trace_as_dataframe() | 
|  |  | 
|  | return result_pipe | 
|  |  | 
|  | function = functools.partial( | 
|  | class_function, cls_to_register, enable_df_api_tracing | 
|  | ) | 
|  | functools.update_wrapper( | 
|  | wrapper=function, wrapped=cls_to_register, assigned=("__doc__",) | 
|  | ) | 
|  | cls.functions[function_name] = function | 
|  |  | 
|  | def __getstate__(self): | 
|  | """ | 
|  | Serialize `lambda` functions when `dill` is available. | 
|  |  | 
|  | If this doesn't cover your custom DataPipe's use case, consider writing custom methods for | 
|  | `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization. | 
|  | """ | 
|  | state = self.__dict__ | 
|  | if IterDataPipe.getstate_hook is not None: | 
|  | return IterDataPipe.getstate_hook(state) | 
|  | return state | 
|  |  | 
|  | def __reduce_ex__(self, *args, **kwargs): | 
|  | if IterDataPipe.reduce_ex_hook is not None: | 
|  | try: | 
|  | return IterDataPipe.reduce_ex_hook(self) | 
|  | except NotImplementedError: | 
|  | pass | 
|  | return super().__reduce_ex__(*args, **kwargs) | 
|  |  | 
|  | @classmethod | 
|  | def set_getstate_hook(cls, hook_fn): | 
|  | if IterDataPipe.getstate_hook is not None and hook_fn is not None: | 
|  | raise RuntimeError("Attempt to override existing getstate_hook") | 
|  | IterDataPipe.getstate_hook = hook_fn | 
|  |  | 
|  | @classmethod | 
|  | def set_reduce_ex_hook(cls, hook_fn): | 
|  | if IterDataPipe.reduce_ex_hook is not None and hook_fn is not None: | 
|  | raise RuntimeError("Attempt to override existing reduce_ex_hook") | 
|  | IterDataPipe.reduce_ex_hook = hook_fn | 
|  |  | 
|  | def __repr__(self): | 
|  | if self.repr_hook is not None: | 
|  | return self.repr_hook(self) | 
|  | # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name | 
|  | return str(self.__class__.__qualname__) | 
|  |  | 
|  | def __str__(self): | 
|  | if self.str_hook is not None: | 
|  | return self.str_hook(self) | 
|  | # Instead of showing <torch. ... .MapperIterDataPipe object at 0x.....>, return the class name | 
|  | return str(self.__class__.__qualname__) | 
|  |  | 
|  | def __dir__(self): | 
|  | # for auto-completion in a REPL (e.g. Jupyter notebook) | 
|  | return list(super().__dir__()) + list(self.functions.keys()) | 
|  |  | 
|  | def reset(self) -> None: | 
|  | r""" | 
|  | Reset the `IterDataPipe` to the initial state. | 
|  |  | 
|  | By default, no-op. For subclasses of `IterDataPipe`, depending on their functionalities, | 
|  | they may want to override this method with implementations that | 
|  | may clear the buffers and reset pointers of the DataPipe. | 
|  | The `reset` method is always called when `__iter__` is called as part of `hook_iterator`. | 
|  | """ | 
|  |  | 
|  |  | 
|  | class DFIterDataPipe(IterDataPipe): | 
|  | def _is_dfpipe(self): | 
|  | return True | 
|  |  | 
|  |  | 
|  | class MapDataPipe(Dataset[_T_co], metaclass=_DataPipeMeta): | 
|  | r""" | 
|  | Map-style DataPipe. | 
|  |  | 
|  | All datasets that represent a map from keys to data samples should subclass this. | 
|  | Subclasses should overwrite :meth:`__getitem__`, supporting fetching a | 
|  | data sample for a given, unique key. Subclasses can also optionally overwrite | 
|  | :meth:`__len__`, which is expected to return the size of the dataset by many | 
|  | :class:`~torch.utils.data.Sampler` implementations and the default options | 
|  | of :class:`~torch.utils.data.DataLoader`. | 
|  |  | 
|  | These DataPipes can be invoked in two ways, using the class constructor or applying their | 
|  | functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes). | 
|  |  | 
|  | Note: | 
|  | :class:`~torch.utils.data.DataLoader` by default constructs an index | 
|  | sampler that yields integral indices. To make it work with a map-style | 
|  | DataPipe with non-integral indices/keys, a custom sampler must be provided. | 
|  |  | 
|  | Example: | 
|  | >>> # xdoctest: +SKIP | 
|  | >>> from torchdata.datapipes.map import SequenceWrapper, Mapper | 
|  | >>> dp = SequenceWrapper(range(10)) | 
|  | >>> map_dp_1 = dp.map(lambda x: x + 1)  # Using functional form (recommended) | 
|  | >>> list(map_dp_1) | 
|  | [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | 
|  | >>> map_dp_2 = Mapper(dp, lambda x: x + 1)  # Using class constructor | 
|  | >>> list(map_dp_2) | 
|  | [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | 
|  | >>> batch_dp = map_dp_1.batch(batch_size=2) | 
|  | >>> list(batch_dp) | 
|  | [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] | 
|  | """ | 
|  |  | 
|  | functions: Dict[str, Callable] = {} | 
|  | reduce_ex_hook: Optional[Callable] = None | 
|  | getstate_hook: Optional[Callable] = None | 
|  | str_hook: Optional[Callable] = None | 
|  | repr_hook: Optional[Callable] = None | 
|  |  | 
|  | def __getattr__(self, attribute_name): | 
|  | if attribute_name in MapDataPipe.functions: | 
|  | if attribute_name in _map_deprecated_functional_names: | 
|  | kwargs = _map_deprecated_functional_names[attribute_name] | 
|  | _deprecation_warning(**kwargs) | 
|  | f = MapDataPipe.functions[attribute_name] | 
|  | function = functools.partial(f, self) | 
|  | functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",)) | 
|  | return function | 
|  | else: | 
|  | raise AttributeError( | 
|  | f"'{self.__class__.__name__}' object has no attribute '{attribute_name}" | 
|  | ) | 
|  |  | 
|  | @classmethod | 
|  | def register_function(cls, function_name, function): | 
|  | cls.functions[function_name] = function | 
|  |  | 
|  | @classmethod | 
|  | def register_datapipe_as_function(cls, function_name, cls_to_register): | 
|  | if function_name in cls.functions: | 
|  | raise Exception(  # noqa: TRY002 | 
|  | f"Unable to add DataPipe function name {function_name} as it is already taken" | 
|  | ) | 
|  |  | 
|  | def class_function(cls, source_dp, *args, **kwargs): | 
|  | result_pipe = cls(source_dp, *args, **kwargs) | 
|  | return result_pipe | 
|  |  | 
|  | function = functools.partial(class_function, cls_to_register) | 
|  | functools.update_wrapper( | 
|  | wrapper=function, wrapped=cls_to_register, assigned=("__doc__",) | 
|  | ) | 
|  | cls.functions[function_name] = function | 
|  |  | 
|  | def __getstate__(self): | 
|  | """ | 
|  | Serialize `lambda` functions when `dill` is available. | 
|  |  | 
|  | If this doesn't cover your custom DataPipe's use case, consider writing custom methods for | 
|  | `__getstate__` and `__setstate__`, or use `pickle.dumps` for serialization. | 
|  | """ | 
|  | state = self.__dict__ | 
|  | if MapDataPipe.getstate_hook is not None: | 
|  | return MapDataPipe.getstate_hook(state) | 
|  | return state | 
|  |  | 
|  | def __reduce_ex__(self, *args, **kwargs): | 
|  | if MapDataPipe.reduce_ex_hook is not None: | 
|  | try: | 
|  | return MapDataPipe.reduce_ex_hook(self) | 
|  | except NotImplementedError: | 
|  | pass | 
|  | return super().__reduce_ex__(*args, **kwargs) | 
|  |  | 
|  | @classmethod | 
|  | def set_getstate_hook(cls, hook_fn): | 
|  | if MapDataPipe.getstate_hook is not None and hook_fn is not None: | 
|  | raise RuntimeError("Attempt to override existing getstate_hook") | 
|  | MapDataPipe.getstate_hook = hook_fn | 
|  |  | 
|  | @classmethod | 
|  | def set_reduce_ex_hook(cls, hook_fn): | 
|  | if MapDataPipe.reduce_ex_hook is not None and hook_fn is not None: | 
|  | raise RuntimeError("Attempt to override existing reduce_ex_hook") | 
|  | MapDataPipe.reduce_ex_hook = hook_fn | 
|  |  | 
|  | def __repr__(self): | 
|  | if self.repr_hook is not None: | 
|  | return self.repr_hook(self) | 
|  | # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name | 
|  | return str(self.__class__.__qualname__) | 
|  |  | 
|  | def __str__(self): | 
|  | if self.str_hook is not None: | 
|  | return self.str_hook(self) | 
|  | # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name | 
|  | return str(self.__class__.__qualname__) | 
|  |  | 
|  | def __dir__(self): | 
|  | # for auto-completion in a REPL (e.g. Jupyter notebook) | 
|  | return list(super().__dir__()) + list(self.functions.keys()) | 
|  |  | 
|  |  | 
|  | class _DataPipeSerializationWrapper: | 
|  | def __init__(self, datapipe): | 
|  | self._datapipe = datapipe | 
|  |  | 
|  | def __getstate__(self): | 
|  | use_dill = False | 
|  | try: | 
|  | value = pickle.dumps(self._datapipe) | 
|  | except Exception: | 
|  | if HAS_DILL: | 
|  | value = dill.dumps(self._datapipe) | 
|  | use_dill = True | 
|  | else: | 
|  | raise | 
|  | return (value, use_dill) | 
|  |  | 
|  | def __setstate__(self, state): | 
|  | value, use_dill = state | 
|  | if use_dill: | 
|  | self._datapipe = dill.loads(value) | 
|  | else: | 
|  | self._datapipe = pickle.loads(value) | 
|  |  | 
|  | def __len__(self): | 
|  | try: | 
|  | return len(self._datapipe) | 
|  | except Exception as e: | 
|  | raise TypeError( | 
|  | f"{type(self).__name__} instance doesn't have valid length" | 
|  | ) from e | 
|  |  | 
|  |  | 
|  | class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe): | 
|  | def __init__(self, datapipe: IterDataPipe[_T_co]): | 
|  | super().__init__(datapipe) | 
|  | self._datapipe_iter: Optional[Iterator[_T_co]] = None | 
|  |  | 
|  | def __iter__(self) -> "_IterDataPipeSerializationWrapper": | 
|  | self._datapipe_iter = iter(self._datapipe) | 
|  | return self | 
|  |  | 
|  | def __next__(self) -> _T_co:  # type: ignore[type-var] | 
|  | assert self._datapipe_iter is not None | 
|  | return next(self._datapipe_iter) | 
|  |  | 
|  |  | 
|  | class _MapDataPipeSerializationWrapper(_DataPipeSerializationWrapper, MapDataPipe): | 
|  | def __getitem__(self, idx): | 
|  | return self._datapipe[idx] |