|  | r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter | 
|  |  | 
|  | To support these two classes, in `./_utils` we define many utility methods and | 
|  | functions to be run in multiprocessing. E.g., the data loading worker loop is | 
|  | in `./_utils/worker.py`. | 
|  | """ | 
|  |  | 
|  | import functools | 
|  | import itertools | 
|  | import logging | 
|  | import os | 
|  | import queue | 
|  | import threading | 
|  | import warnings | 
|  |  | 
|  | from typing import Any, Callable, Iterable, TypeVar, Generic, List, Optional, Union | 
|  |  | 
|  | import multiprocessing as python_multiprocessing | 
|  | import torch | 
|  | import torch.distributed as dist | 
|  | import torch.multiprocessing as multiprocessing | 
|  | import torch.utils.data.graph_settings | 
|  |  | 
|  | from torch._utils import ExceptionWrapper | 
|  |  | 
|  | from . import ( | 
|  | IterDataPipe, | 
|  | MapDataPipe, | 
|  | IterableDataset, | 
|  | Sampler, | 
|  | SequentialSampler, | 
|  | RandomSampler, | 
|  | BatchSampler, | 
|  | Dataset,) | 
|  |  | 
|  | from torch.utils.data.datapipes.datapipe import _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper | 
|  |  | 
|  | from . import _utils | 
|  |  | 
|  | __all__ = [ | 
|  | "DataLoader", | 
|  | "get_worker_info", | 
|  | "default_collate", | 
|  | "default_convert", | 
|  | ] | 
|  |  | 
|  | T_co = TypeVar('T_co', covariant=True) | 
|  | T = TypeVar('T') | 
|  | _worker_init_fn_t = Callable[[int], None] | 
|  |  | 
|  | # Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that | 
|  | # type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. | 
|  | # See https://github.com/python/mypy/issues/3737. | 
|  | _collate_fn_t = Callable[[List[T]], Any] | 
|  |  | 
|  |  | 
|  | # These functions used to be defined in this file. However, it was moved to | 
|  | # _utils/collate.py. Although it is rather hard to access this from user land | 
|  | # (one has to explicitly directly `import torch.utils.data.dataloader`), there | 
|  | # probably is user code out there using it. This aliasing maintains BC in this | 
|  | # aspect. | 
|  | default_collate: _collate_fn_t = _utils.collate.default_collate | 
|  | default_convert = _utils.collate.default_convert | 
|  |  | 
|  | get_worker_info = _utils.worker.get_worker_info | 
|  |  | 
|  | logger = logging.getLogger(__name__) | 
|  |  | 
|  |  | 
|  | class _DatasetKind: | 
|  | Map = 0 | 
|  | Iterable = 1 | 
|  |  | 
|  | @staticmethod | 
|  | def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): | 
|  | if kind == _DatasetKind.Map: | 
|  | return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) | 
|  | else: | 
|  | return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last) | 
|  |  | 
|  |  | 
|  | class _InfiniteConstantSampler(Sampler): | 
|  | r"""Analogous to ``itertools.repeat(None, None)``. | 
|  | Used as sampler for :class:`~torch.utils.data.IterableDataset`. | 
|  | """ | 
|  |  | 
|  | def __iter__(self): | 
|  | while True: | 
|  | yield None | 
|  |  | 
|  |  | 
|  | def _get_distributed_settings(): | 
|  | if dist.is_available() and dist.is_initialized(): | 
|  | return dist.get_world_size(), dist.get_rank() | 
|  | else: | 
|  | return 1, 0 | 
|  |  | 
|  |  | 
|  | def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id): | 
|  | global_worker_id = worker_id | 
|  | info = torch.utils.data.get_worker_info() | 
|  | assert info is not None | 
|  | total_workers = info.num_workers | 
|  | datapipe = info.dataset | 
|  | assert isinstance(datapipe, (IterDataPipe, MapDataPipe)) | 
|  | # To distribute elements across distributed process evenly, we should shard data on distributed | 
|  | # processes first then shard on worker processes | 
|  | total_workers *= world_size | 
|  | global_worker_id = global_worker_id * world_size + rank_id | 
|  | # For BC, use default SHARDING_PRIORITIES | 
|  | torch.utils.data.graph_settings.apply_sharding(datapipe, total_workers, global_worker_id) | 
|  | if worker_init_fn is not None: | 
|  | worker_init_fn(worker_id) | 
|  |  | 
|  |  | 
|  | def _share_dist_seed(generator, pg): | 
|  | _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator) | 
|  | if isinstance(pg, dist.ProcessGroup): | 
|  | dist.broadcast(_shared_seed, src=0, group=pg) | 
|  | return _shared_seed.item() | 
|  |  | 
|  |  | 
|  | class DataLoader(Generic[T_co]): | 
|  | r""" | 
|  | Data loader. Combines a dataset and a sampler, and provides an iterable over | 
|  | the given dataset. | 
|  |  | 
|  | The :class:`~torch.utils.data.DataLoader` supports both map-style and | 
|  | iterable-style datasets with single- or multi-process loading, customizing | 
|  | loading order and optional automatic batching (collation) and memory pinning. | 
|  |  | 
|  | See :py:mod:`torch.utils.data` documentation page for more details. | 
|  |  | 
|  | Args: | 
|  | dataset (Dataset): dataset from which to load the data. | 
|  | batch_size (int, optional): how many samples per batch to load | 
|  | (default: ``1``). | 
|  | shuffle (bool, optional): set to ``True`` to have the data reshuffled | 
|  | at every epoch (default: ``False``). | 
|  | sampler (Sampler or Iterable, optional): defines the strategy to draw | 
|  | samples from the dataset. Can be any ``Iterable`` with ``__len__`` | 
|  | implemented. If specified, :attr:`shuffle` must not be specified. | 
|  | batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but | 
|  | returns a batch of indices at a time. Mutually exclusive with | 
|  | :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, | 
|  | and :attr:`drop_last`. | 
|  | num_workers (int, optional): how many subprocesses to use for data | 
|  | loading. ``0`` means that the data will be loaded in the main process. | 
|  | (default: ``0``) | 
|  | collate_fn (Callable, optional): merges a list of samples to form a | 
|  | mini-batch of Tensor(s).  Used when using batched loading from a | 
|  | map-style dataset. | 
|  | pin_memory (bool, optional): If ``True``, the data loader will copy Tensors | 
|  | into device/CUDA pinned memory before returning them.  If your data elements | 
|  | are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, | 
|  | see the example below. | 
|  | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, | 
|  | if the dataset size is not divisible by the batch size. If ``False`` and | 
|  | the size of dataset is not divisible by the batch size, then the last batch | 
|  | will be smaller. (default: ``False``) | 
|  | timeout (numeric, optional): if positive, the timeout value for collecting a batch | 
|  | from workers. Should always be non-negative. (default: ``0``) | 
|  | worker_init_fn (Callable, optional): If not ``None``, this will be called on each | 
|  | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as | 
|  | input, after seeding and before data loading. (default: ``None``) | 
|  | multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If | 
|  | ``None``, the default `multiprocessing context`_ of your operating system will | 
|  | be used. (default: ``None``) | 
|  | generator (torch.Generator, optional): If not ``None``, this RNG will be used | 
|  | by RandomSampler to generate random indexes and multiprocessing to generate | 
|  | ``base_seed`` for workers. (default: ``None``) | 
|  | prefetch_factor (int, optional, keyword-only arg): Number of batches loaded | 
|  | in advance by each worker. ``2`` means there will be a total of | 
|  | 2 * num_workers batches prefetched across all workers. (default value depends | 
|  | on the set value for num_workers. If value of num_workers=0 default is ``None``. | 
|  | Otherwise, if value of ``num_workers > 0`` default is ``2``). | 
|  | persistent_workers (bool, optional): If ``True``, the data loader will not shut down | 
|  | the worker processes after a dataset has been consumed once. This allows to | 
|  | maintain the workers `Dataset` instances alive. (default: ``False``) | 
|  | pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is | 
|  | ``True``. | 
|  |  | 
|  |  | 
|  | .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` | 
|  | cannot be an unpicklable object, e.g., a lambda function. See | 
|  | :ref:`multiprocessing-best-practices` on more details related | 
|  | to multiprocessing in PyTorch. | 
|  |  | 
|  | .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. | 
|  | When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, | 
|  | it instead returns an estimate based on ``len(dataset) / batch_size``, with proper | 
|  | rounding depending on :attr:`drop_last`, regardless of multi-process loading | 
|  | configurations. This represents the best guess PyTorch can make because PyTorch | 
|  | trusts user :attr:`dataset` code in correctly handling multi-process | 
|  | loading to avoid duplicate data. | 
|  |  | 
|  | However, if sharding results in multiple workers having incomplete last batches, | 
|  | this estimate can still be inaccurate, because (1) an otherwise complete batch can | 
|  | be broken into multiple ones and (2) more than one batch worth of samples can be | 
|  | dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such | 
|  | cases in general. | 
|  |  | 
|  | See `Dataset Types`_ for more details on these two types of datasets and how | 
|  | :class:`~torch.utils.data.IterableDataset` interacts with | 
|  | `Multi-process data loading`_. | 
|  |  | 
|  | .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and | 
|  | :ref:`data-loading-randomness` notes for random seed related questions. | 
|  |  | 
|  | .. _multiprocessing context: | 
|  | https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods | 
|  | """ | 
|  | dataset: Dataset[T_co] | 
|  | batch_size: Optional[int] | 
|  | num_workers: int | 
|  | pin_memory: bool | 
|  | drop_last: bool | 
|  | timeout: float | 
|  | sampler: Union[Sampler, Iterable] | 
|  | pin_memory_device: str | 
|  | prefetch_factor: Optional[int] | 
|  | _iterator : Optional['_BaseDataLoaderIter'] | 
|  | __initialized = False | 
|  |  | 
|  | def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, | 
|  | shuffle: Optional[bool] = None, sampler: Union[Sampler, Iterable, None] = None, | 
|  | batch_sampler: Union[Sampler[List], Iterable[List], None] = None, | 
|  | num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, | 
|  | pin_memory: bool = False, drop_last: bool = False, | 
|  | timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, | 
|  | multiprocessing_context=None, generator=None, | 
|  | *, prefetch_factor: Optional[int] = None, | 
|  | persistent_workers: bool = False, | 
|  | pin_memory_device: str = ""): | 
|  | torch._C._log_api_usage_once("python.data_loader") | 
|  |  | 
|  | if num_workers < 0: | 
|  | raise ValueError('num_workers option should be non-negative; ' | 
|  | 'use num_workers=0 to disable multiprocessing.') | 
|  |  | 
|  | if timeout < 0: | 
|  | raise ValueError('timeout option should be non-negative') | 
|  |  | 
|  | if num_workers == 0 and prefetch_factor is not None: | 
|  | raise ValueError('prefetch_factor option could only be specified in multiprocessing.' | 
|  | 'let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.') | 
|  | elif num_workers > 0 and prefetch_factor is None: | 
|  | prefetch_factor = 2 | 
|  | elif prefetch_factor is not None and prefetch_factor < 0: | 
|  | raise ValueError('prefetch_factor option should be non-negative') | 
|  |  | 
|  | if persistent_workers and num_workers == 0: | 
|  | raise ValueError('persistent_workers option needs num_workers > 0') | 
|  |  | 
|  | self.dataset = dataset | 
|  | self.num_workers = num_workers | 
|  | self.prefetch_factor = prefetch_factor | 
|  | self.pin_memory = pin_memory | 
|  | self.pin_memory_device = pin_memory_device | 
|  | self.timeout = timeout | 
|  | self.worker_init_fn = worker_init_fn | 
|  | self.multiprocessing_context = multiprocessing_context | 
|  |  | 
|  | # Adds forward compatibilities so classic DataLoader can work with DataPipes: | 
|  | #   _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler | 
|  | if isinstance(self.dataset, IterDataPipe): | 
|  | self.dataset = _IterDataPipeSerializationWrapper(self.dataset) | 
|  | elif isinstance(self.dataset, MapDataPipe): | 
|  | self.dataset = _MapDataPipeSerializationWrapper(self.dataset) | 
|  |  | 
|  | # Arg-check dataset related before checking samplers because we want to | 
|  | # tell users that iterable-style datasets are incompatible with custom | 
|  | # samplers first, so that they don't learn that this combo doesn't work | 
|  | # after spending time fixing the custom sampler errors. | 
|  | if isinstance(dataset, IterableDataset): | 
|  | self._dataset_kind = _DatasetKind.Iterable | 
|  | # NOTE [ Custom Samplers and IterableDataset ] | 
|  | # | 
|  | # `IterableDataset` does not support custom `batch_sampler` or | 
|  | # `sampler` since the key is irrelevant (unless we support | 
|  | # generator-style dataset one day...). | 
|  | # | 
|  | # For `sampler`, we always create a dummy sampler. This is an | 
|  | # infinite sampler even when the dataset may have an implemented | 
|  | # finite `__len__` because in multi-process data loading, naive | 
|  | # settings will return duplicated data (which may be desired), and | 
|  | # thus using a sampler with length matching that of dataset will | 
|  | # cause data lost (you may have duplicates of the first couple | 
|  | # batches, but never see anything afterwards). Therefore, | 
|  | # `Iterabledataset` always uses an infinite sampler, an instance of | 
|  | # `_InfiniteConstantSampler` defined above. | 
|  | # | 
|  | # A custom `batch_sampler` essentially only controls the batch size. | 
|  | # However, it is unclear how useful it would be since an iterable-style | 
|  | # dataset can handle that within itself. Moreover, it is pointless | 
|  | # in multi-process data loading as the assignment order of batches | 
|  | # to workers is an implementation detail so users can not control | 
|  | # how to batchify each worker's iterable. Thus, we disable this | 
|  | # option. If this turns out to be useful in future, we can re-enable | 
|  | # this, and support custom samplers that specify the assignments to | 
|  | # specific workers. | 
|  | if isinstance(dataset, IterDataPipe): | 
|  | if shuffle is not None: | 
|  | dataset = torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) | 
|  | # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. | 
|  | elif shuffle not in {False, None}: | 
|  | raise ValueError( | 
|  | f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}") | 
|  |  | 
|  | if sampler is not None: | 
|  | # See NOTE [ Custom Samplers and IterableDataset ] | 
|  | raise ValueError( | 
|  | f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}") | 
|  | elif batch_sampler is not None: | 
|  | # See NOTE [ Custom Samplers and IterableDataset ] | 
|  | raise ValueError( | 
|  | "DataLoader with IterableDataset: expected unspecified " | 
|  | f"batch_sampler option, but got batch_sampler={batch_sampler}") | 
|  | else: | 
|  | shuffle = bool(shuffle) | 
|  | self._dataset_kind = _DatasetKind.Map | 
|  |  | 
|  |  | 
|  |  | 
|  | if sampler is not None and shuffle: | 
|  | raise ValueError('sampler option is mutually exclusive with ' | 
|  | 'shuffle') | 
|  |  | 
|  | if batch_sampler is not None: | 
|  | # auto_collation with custom batch_sampler | 
|  | if batch_size != 1 or shuffle or sampler is not None or drop_last: | 
|  | raise ValueError('batch_sampler option is mutually exclusive ' | 
|  | 'with batch_size, shuffle, sampler, and ' | 
|  | 'drop_last') | 
|  | batch_size = None | 
|  | drop_last = False | 
|  | elif batch_size is None: | 
|  | # no auto_collation | 
|  | if drop_last: | 
|  | raise ValueError('batch_size=None option disables auto-batching ' | 
|  | 'and is mutually exclusive with drop_last') | 
|  |  | 
|  | if sampler is None:  # give default samplers | 
|  | if self._dataset_kind == _DatasetKind.Iterable: | 
|  | # See NOTE [ Custom Samplers and IterableDataset ] | 
|  | sampler = _InfiniteConstantSampler() | 
|  | else:  # map-style | 
|  | if shuffle: | 
|  | sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type] | 
|  | else: | 
|  | sampler = SequentialSampler(dataset)  # type: ignore[arg-type] | 
|  |  | 
|  | if batch_size is not None and batch_sampler is None: | 
|  | # auto_collation without custom batch_sampler | 
|  | batch_sampler = BatchSampler(sampler, batch_size, drop_last) | 
|  |  | 
|  | self.batch_size = batch_size | 
|  | self.drop_last = drop_last | 
|  | self.sampler = sampler | 
|  | self.batch_sampler = batch_sampler | 
|  | self.generator = generator | 
|  |  | 
|  | if collate_fn is None: | 
|  | if self._auto_collation: | 
|  | collate_fn = _utils.collate.default_collate | 
|  | else: | 
|  | collate_fn = _utils.collate.default_convert | 
|  |  | 
|  | self.collate_fn = collate_fn | 
|  | self.persistent_workers = persistent_workers | 
|  |  | 
|  | self.__initialized = True | 
|  | self._IterableDataset_len_called = None  # See NOTE [ IterableDataset and __len__ ] | 
|  |  | 
|  | self._iterator = None | 
|  |  | 
|  | self.check_worker_number_rationality() | 
|  |  | 
|  | torch.set_vital('Dataloader', 'enabled', 'True')  # type: ignore[attr-defined] | 
|  |  | 
|  | def _get_iterator(self) -> '_BaseDataLoaderIter': | 
|  | if self.num_workers == 0: | 
|  | return _SingleProcessDataLoaderIter(self) | 
|  | else: | 
|  | self.check_worker_number_rationality() | 
|  | return _MultiProcessingDataLoaderIter(self) | 
|  |  | 
|  | @property | 
|  | def multiprocessing_context(self): | 
|  | return self.__multiprocessing_context | 
|  |  | 
|  | @multiprocessing_context.setter | 
|  | def multiprocessing_context(self, multiprocessing_context): | 
|  | if multiprocessing_context is not None: | 
|  | if self.num_workers > 0: | 
|  | if isinstance(multiprocessing_context, str): | 
|  | valid_start_methods = multiprocessing.get_all_start_methods() | 
|  | if multiprocessing_context not in valid_start_methods: | 
|  | raise ValueError( | 
|  | 'multiprocessing_context option ' | 
|  | f'should specify a valid start method in {valid_start_methods!r}, but got ' | 
|  | f'multiprocessing_context={multiprocessing_context!r}') | 
|  | multiprocessing_context = multiprocessing.get_context(multiprocessing_context) | 
|  |  | 
|  | if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext): | 
|  | raise TypeError('multiprocessing_context option should be a valid context ' | 
|  | 'object or a string specifying the start method, but got ' | 
|  | f'multiprocessing_context={multiprocessing_context}') | 
|  | else: | 
|  | raise ValueError('multiprocessing_context can only be used with ' | 
|  | 'multi-process loading (num_workers > 0), but got ' | 
|  | f'num_workers={self.num_workers}') | 
|  |  | 
|  | self.__multiprocessing_context = multiprocessing_context | 
|  |  | 
|  | def __setattr__(self, attr, val): | 
|  | if self.__initialized and attr in ( | 
|  | 'batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset', 'persistent_workers'): | 
|  | raise ValueError(f'{attr} attribute should not be set after {self.__class__.__name__} is initialized') | 
|  |  | 
|  | super().__setattr__(attr, val) | 
|  |  | 
|  | # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up | 
|  | # since '_BaseDataLoaderIter' references 'DataLoader'. | 
|  | def __iter__(self) -> '_BaseDataLoaderIter': | 
|  | # When using a single worker the returned iterator should be | 
|  | # created everytime to avoid resetting its state | 
|  | # However, in the case of a multiple workers iterator | 
|  | # the iterator is only created once in the lifetime of the | 
|  | # DataLoader object so that workers can be reused | 
|  | if self.persistent_workers and self.num_workers > 0: | 
|  | if self._iterator is None: | 
|  | self._iterator = self._get_iterator() | 
|  | else: | 
|  | self._iterator._reset(self) | 
|  | return self._iterator | 
|  | else: | 
|  | return self._get_iterator() | 
|  |  | 
|  | @property | 
|  | def _auto_collation(self): | 
|  | return self.batch_sampler is not None | 
|  |  | 
|  | @property | 
|  | def _index_sampler(self): | 
|  | # The actual sampler used for generating indices for `_DatasetFetcher` | 
|  | # (see _utils/fetch.py) to read data at each time. This would be | 
|  | # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. | 
|  | # We can't change `.sampler` and `.batch_sampler` attributes for BC | 
|  | # reasons. | 
|  | if self._auto_collation: | 
|  | return self.batch_sampler | 
|  | else: | 
|  | return self.sampler | 
|  |  | 
|  | def __len__(self) -> int: | 
|  | if self._dataset_kind == _DatasetKind.Iterable: | 
|  | # NOTE [ IterableDataset and __len__ ] | 
|  | # | 
|  | # For `IterableDataset`, `__len__` could be inaccurate when one naively | 
|  | # does multi-processing data loading, since the samples will be duplicated. | 
|  | # However, no real use case should be actually using that behavior, so | 
|  | # it should count as a user error. We should generally trust user | 
|  | # code to do the proper thing (e.g., configure each replica differently | 
|  | # in `__iter__`), and give us the correct `__len__` if they choose to | 
|  | # implement it (this will still throw if the dataset does not implement | 
|  | # a `__len__`). | 
|  | # | 
|  | # To provide a further warning, we track if `__len__` was called on the | 
|  | # `DataLoader`, save the returned value in `self._len_called`, and warn | 
|  | # if the iterator ends up yielding more than this number of samples. | 
|  |  | 
|  | # Cannot statically verify that dataset is Sized | 
|  | length = self._IterableDataset_len_called = len(self.dataset)  # type: ignore[assignment, arg-type] | 
|  | if self.batch_size is not None:  # IterableDataset doesn't allow custom sampler or batch_sampler | 
|  | from math import ceil | 
|  | if self.drop_last: | 
|  | length = length // self.batch_size | 
|  | else: | 
|  | length = ceil(length / self.batch_size) | 
|  | return length | 
|  | else: | 
|  | return len(self._index_sampler) | 
|  |  | 
|  | def check_worker_number_rationality(self): | 
|  | # This function check whether the dataloader's worker number is rational based on | 
|  | # current system's resource. Current rule is that if the number of workers this | 
|  | # Dataloader will create is bigger than the number of logical cpus that is allowed to | 
|  | # use, than we will pop up a warning to let user pay attention. | 
|  | # | 
|  | # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 | 
|  | #     threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current | 
|  | #     DataLoader process can use half of them which is 32, then the rational max number of | 
|  | #     worker that initiated from this process is 32. | 
|  | #     Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. | 
|  | #     So the warning message is triggered to notify the user to lower the worker number if | 
|  | #     necessary. | 
|  | # | 
|  | # | 
|  | # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is | 
|  | #        available (available in most of Linux system, but not OSX and Windows). | 
|  | #        When os.sched_getaffinity is not available, os.cpu_count() is called instead, but | 
|  | #        it doesn't repect cpuset. | 
|  | #        We don't take threading into account since each worker process is single threaded | 
|  | #        at this time. | 
|  | # | 
|  | #        We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) | 
|  | #        other than `torch.set_num_threads` to 1 in the worker process, if the passing | 
|  | #        in functions use 3rd party modules that rely on those threading flags to determine | 
|  | #        how many thread to create (eg. numpy, etc), then it is caller's responsibility to | 
|  | #        set those flags correctly. | 
|  | def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): | 
|  |  | 
|  | suggested_max_worker_msg = (( | 
|  | "Our suggested max number of worker in current system is {}{}, which is smaller " | 
|  | "than what this DataLoader is going to create.").format( | 
|  | num_worker_suggest, | 
|  | ("" if cpuset_checked else " (`cpuset` is not taken into account)")) | 
|  | ) if num_worker_suggest is not None else ( | 
|  | "DataLoader is not able to compute a suggested max number of worker in current system.") | 
|  |  | 
|  | warn_msg = ( | 
|  | "This DataLoader will create {} worker processes in total. {} " | 
|  | "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " | 
|  | "lower the worker number to avoid potential slowness/freeze if necessary.").format( | 
|  | num_worker_created, | 
|  | suggested_max_worker_msg) | 
|  | return warn_msg | 
|  |  | 
|  | if not self.num_workers or self.num_workers == 0: | 
|  | return | 
|  |  | 
|  | # try to compute a suggested max number of worker based on system's resource | 
|  | max_num_worker_suggest = None | 
|  | cpuset_checked = False | 
|  | if hasattr(os, 'sched_getaffinity'): | 
|  | try: | 
|  | max_num_worker_suggest = len(os.sched_getaffinity(0)) | 
|  | cpuset_checked = True | 
|  | except Exception: | 
|  | pass | 
|  | if max_num_worker_suggest is None: | 
|  | # os.cpu_count() could return Optional[int] | 
|  | # get cpu count first and check None in order to satisfy mypy check | 
|  | cpu_count = os.cpu_count() | 
|  | if cpu_count is not None: | 
|  | max_num_worker_suggest = cpu_count | 
|  |  | 
|  | if max_num_worker_suggest is None: | 
|  | warnings.warn(_create_warning_msg( | 
|  | max_num_worker_suggest, | 
|  | self.num_workers, | 
|  | cpuset_checked)) | 
|  | return | 
|  |  | 
|  | if self.num_workers > max_num_worker_suggest: | 
|  | warnings.warn(_create_warning_msg( | 
|  | max_num_worker_suggest, | 
|  | self.num_workers, | 
|  | cpuset_checked)) | 
|  |  | 
|  |  | 
|  | class _BaseDataLoaderIter: | 
|  | def __init__(self, loader: DataLoader) -> None: | 
|  | self._dataset = loader.dataset | 
|  | self._shared_seed = None | 
|  | self._pg = None | 
|  | if isinstance(self._dataset, IterDataPipe): | 
|  | if dist.is_available() and dist.is_initialized(): | 
|  | self._pg = dist.new_group(backend="gloo") | 
|  | self._shared_seed = _share_dist_seed(loader.generator, self._pg) | 
|  | shared_rng = torch.Generator() | 
|  | shared_rng.manual_seed(self._shared_seed) | 
|  | self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng) | 
|  | self._dataset_kind = loader._dataset_kind | 
|  | self._IterableDataset_len_called = loader._IterableDataset_len_called | 
|  | self._auto_collation = loader._auto_collation | 
|  | self._drop_last = loader.drop_last | 
|  | self._index_sampler = loader._index_sampler | 
|  | self._num_workers = loader.num_workers | 
|  | ws, rank = _get_distributed_settings() | 
|  | self._world_size = ws | 
|  | self._rank = rank | 
|  | # for other backends, pin_memory_device need to set. if not set | 
|  | # default behaviour is CUDA device. if pin_memory_device is selected | 
|  | # and pin_memory is not set, the default behaviour false. | 
|  | if (len(loader.pin_memory_device) == 0): | 
|  | self._pin_memory = loader.pin_memory and torch.cuda.is_available() | 
|  | self._pin_memory_device = None | 
|  | else: | 
|  | if not loader.pin_memory: | 
|  | warn_msg = ("pin memory device is set and pin_memory flag is not used then device pinned memory won't be used" | 
|  | "please set pin_memory to true, if you need to use the device pin memory") | 
|  | warnings.warn(warn_msg) | 
|  |  | 
|  | self._pin_memory = loader.pin_memory | 
|  | self._pin_memory_device = loader.pin_memory_device | 
|  | self._timeout = loader.timeout | 
|  | self._collate_fn = loader.collate_fn | 
|  | self._sampler_iter = iter(self._index_sampler) | 
|  | self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item() | 
|  | self._persistent_workers = loader.persistent_workers | 
|  | self._num_yielded = 0 | 
|  | self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__" | 
|  |  | 
|  | def __iter__(self) -> '_BaseDataLoaderIter': | 
|  | return self | 
|  |  | 
|  | def _reset(self, loader, first_iter=False): | 
|  | self._sampler_iter = iter(self._index_sampler) | 
|  | self._num_yielded = 0 | 
|  | self._IterableDataset_len_called = loader._IterableDataset_len_called | 
|  | if isinstance(self._dataset, IterDataPipe): | 
|  | self._shared_seed = _share_dist_seed(loader.generator, self._pg) | 
|  | shared_rng = torch.Generator() | 
|  | shared_rng.manual_seed(self._shared_seed) | 
|  | self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng) | 
|  |  | 
|  | def _next_index(self): | 
|  | return next(self._sampler_iter)  # may raise StopIteration | 
|  |  | 
|  | def _next_data(self): | 
|  | raise NotImplementedError | 
|  |  | 
|  | def __next__(self) -> Any: | 
|  | with torch.autograd.profiler.record_function(self._profile_name): | 
|  | if self._sampler_iter is None: | 
|  | # TODO(https://github.com/pytorch/pytorch/issues/76750) | 
|  | self._reset()  # type: ignore[call-arg] | 
|  | data = self._next_data() | 
|  | self._num_yielded += 1 | 
|  | if self._dataset_kind == _DatasetKind.Iterable and \ | 
|  | self._IterableDataset_len_called is not None and \ | 
|  | self._num_yielded > self._IterableDataset_len_called: | 
|  | warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} " | 
|  | "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called, | 
|  | self._num_yielded) | 
|  | if self._num_workers > 0: | 
|  | warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the " | 
|  | "IterableDataset replica at each worker. Please see " | 
|  | "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.") | 
|  | warnings.warn(warn_msg) | 
|  | return data | 
|  |  | 
|  | def __len__(self) -> int: | 
|  | return len(self._index_sampler) | 
|  |  | 
|  | def __getstate__(self): | 
|  | # TODO: add limited pickling support for sharing an iterator | 
|  | # across multiple threads for HOGWILD. | 
|  | # Probably the best way to do this is by moving the sample pushing | 
|  | # to a separate thread and then just sharing the data queue | 
|  | # but signalling the end is tricky without a non-blocking API | 
|  | raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) | 
|  |  | 
|  |  | 
|  | class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): | 
|  | def __init__(self, loader): | 
|  | super().__init__(loader) | 
|  | assert self._timeout == 0 | 
|  | assert self._num_workers == 0 | 
|  |  | 
|  | # Adds forward compatibilities so classic DataLoader can work with DataPipes: | 
|  | #   Taking care of distributed sharding | 
|  | if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): | 
|  | # For BC, use default SHARDING_PRIORITIES | 
|  | torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank) | 
|  |  | 
|  | self._dataset_fetcher = _DatasetKind.create_fetcher( | 
|  | self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) | 
|  |  | 
|  | def _next_data(self): | 
|  | index = self._next_index()  # may raise StopIteration | 
|  | data = self._dataset_fetcher.fetch(index)  # may raise StopIteration | 
|  | if self._pin_memory: | 
|  | data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) | 
|  | return data | 
|  |  | 
|  |  | 
|  | class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): | 
|  | r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" | 
|  |  | 
|  | # NOTE [ Data Loader Multiprocessing Shutdown Logic ] | 
|  | # | 
|  | # Preliminary: | 
|  | # | 
|  | # Our data model looks like this (queues are indicated with curly brackets): | 
|  | # | 
|  | #                main process                              || | 
|  | #                     |                                    || | 
|  | #               {index_queue}                              || | 
|  | #                     |                                    || | 
|  | #              worker processes                            ||     DATA | 
|  | #                     |                                    || | 
|  | #            {worker_result_queue}                         ||     FLOW | 
|  | #                     |                                    || | 
|  | #      pin_memory_thread of main process                   ||   DIRECTION | 
|  | #                     |                                    || | 
|  | #               {data_queue}                               || | 
|  | #                     |                                    || | 
|  | #                data output                               \/ | 
|  | # | 
|  | # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if | 
|  | #      `pin_memory=False`. | 
|  | # | 
|  | # | 
|  | # Terminating multiprocessing logic requires very careful design. In | 
|  | # particular, we need to make sure that | 
|  | # | 
|  | #   1. The iterator gracefully exits the workers when its last reference is | 
|  | #      gone or it is depleted. | 
|  | # | 
|  | #      In this case, the workers should be gracefully exited because the | 
|  | #      main process may still need to continue to run, and we want cleaning | 
|  | #      up code in the workers to be executed (e.g., releasing GPU memory). | 
|  | #      Naturally, we implement the shutdown logic in `__del__` of | 
|  | #      DataLoaderIterator. | 
|  | # | 
|  | #      We delay the discussion on the logic in this case until later. | 
|  | # | 
|  | #   2. The iterator exits the workers when the loader process and/or worker | 
|  | #      processes exits normally or with error. | 
|  | # | 
|  | #      We set all workers and `pin_memory_thread` to have `daemon=True`. | 
|  | # | 
|  | #      You may ask, why can't we make the workers non-daemonic, and | 
|  | #      gracefully exit using the same logic as we have in `__del__` when the | 
|  | #      iterator gets deleted (see 1 above)? | 
|  | # | 
|  | #      First of all, `__del__` is **not** guaranteed to be called when | 
|  | #      interpreter exits. Even if it is called, by the time it executes, | 
|  | #      many Python core library resources may already be freed, and even | 
|  | #      simple things like acquiring an internal lock of a queue may hang. | 
|  | #      Therefore, in this case, we actually need to prevent `__del__` from | 
|  | #      being executed, and rely on the automatic termination of daemonic | 
|  | #      children. | 
|  | # | 
|  | #      Thus, we register an `atexit` hook that sets a global flag | 
|  | #      `_utils.python_exit_status`. Since `atexit` hooks are executed in the | 
|  | #      reverse order of registration, we are guaranteed that this flag is | 
|  | #      set before library resources we use are freed (which, at least in | 
|  | #      CPython, is done via an `atexit` handler defined in | 
|  | #      `multiprocessing/util.py` | 
|  | #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 | 
|  | #      registered when an object requiring this mechanism is first | 
|  | #      created, e.g., `mp.Queue` | 
|  | #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 | 
|  | #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 | 
|  | #      ) | 
|  | # | 
|  | #      So in `__del__`, we check if `_utils.python_exit_status` is set or | 
|  | #      `None` (freed), and perform no-op if so. | 
|  | # | 
|  | #      However, simply letting library clean-up codes run can also be bad, | 
|  | #      because such codes (i.e., `multiprocessing.util._exit_function()`) | 
|  | #      include join putting threads for `mp.Queue`, which can be blocking. | 
|  | #      Hence, the main process putting threads are called with | 
|  | #      `cancel_join_thread` at creation.  See later section | 
|  | #      [ 3b. A process won't hang when putting into a queue; ] | 
|  | #      for more details. | 
|  | # | 
|  | #      Here are two example cases where library clean-up codes can run | 
|  | #      before `__del__` is called: | 
|  | # | 
|  | #        1. If we hold onto a reference to the iterator, it more often | 
|  | #           than not tries to do `multiprocessing` library cleaning before | 
|  | #           clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) | 
|  | #           and thus prevents our cleaning-up code to run first. | 
|  | # | 
|  | #        2. A similar issue araises when a `DataLoader` is used in a subprocess. | 
|  | #           When a process ends, it shuts the all its daemonic children | 
|  | #           down with a SIGTERM (instead of joining them without a timeout). | 
|  | #           Simiarly for threads, but by a different mechanism. This fact, | 
|  | #           together with a few implementation details of multiprocessing, forces | 
|  | #           us to make workers daemonic. All of our problems arise when a | 
|  | #           DataLoader is used in a subprocess, and are caused by multiprocessing | 
|  | #           code which looks more or less like this: | 
|  | # | 
|  | #               try: | 
|  | #                   your_function_using_a_dataloader() | 
|  | #               finally: | 
|  | #                   multiprocessing.util._exit_function() | 
|  | # | 
|  | #           The joining/termination mentioned above happens inside | 
|  | #           `_exit_function()`. Now, if `your_function_using_a_dataloader()` | 
|  | #           throws, the stack trace stored in the exception will prevent the | 
|  | #           frame which uses `DataLoaderIter` to be freed. If the frame has any | 
|  | #           reference to the `DataLoaderIter` (e.g., in a method of the iter), | 
|  | #           its  `__del__`, which starts the shutdown procedure, will not be | 
|  | #           called. That, in turn, means that workers aren't notified. Attempting | 
|  | #           to join in `_exit_function` will then result in a hang. | 
|  | # | 
|  | #           For context, `_exit_function` is also registered as an `atexit` call. | 
|  | #           So it is unclear to me (@ssnl) why this is needed in a finally block. | 
|  | #           The code dates back to 2008 and there is no comment on the original | 
|  | #           PEP 371 or patch https://bugs.python.org/issue3050 (containing both | 
|  | #           the finally block and the `atexit` registration) that explains this. | 
|  | # | 
|  | # | 
|  | #      Finally, another choice is to just shutdown workers with logic in 1 | 
|  | #      above whenever we see an error in `next`. This isn't ideal because | 
|  | #        a. It prevents users from using try-catch to resume data loading. | 
|  | #        b. It doesn't prevent hanging if users have references to the | 
|  | #           iterator. | 
|  | # | 
|  | #   3. All processes exit if any of them die unexpectedly by fatal signals. | 
|  | # | 
|  | #      As shown above, the workers are set as daemonic children of the main | 
|  | #      process. However, automatic cleaning-up of such child processes only | 
|  | #      happens if the parent process exits gracefully (e.g., not via fatal | 
|  | #      signals like SIGKILL). So we must ensure that each process will exit | 
|  | #      even the process that should send/receive data to/from it were | 
|  | #      killed, i.e., | 
|  | # | 
|  | #        a. A process won't hang when getting from a queue. | 
|  | # | 
|  | #           Even with carefully designed data dependencies (i.e., a `put()` | 
|  | #           always corresponding to a `get()`), hanging on `get()` can still | 
|  | #           happen when data in queue is corrupted (e.g., due to | 
|  | #           `cancel_join_thread` or unexpected exit). | 
|  | # | 
|  | #           For child exit, we set a timeout whenever we try to get data | 
|  | #           from `data_queue`, and check the workers' status on each timeout | 
|  | #           and error. | 
|  | #           See `_DataLoaderiter._get_batch()` and | 
|  | #           `_DataLoaderiter._try_get_data()` for details. | 
|  | # | 
|  | #           Additionally, for child exit on non-Windows platforms, we also | 
|  | #           register a SIGCHLD handler (which is supported on Windows) on | 
|  | #           the main process, which checks if any of the workers fail in the | 
|  | #           (Python) handler. This is more efficient and faster in detecting | 
|  | #           worker failures, compared to only using the above mechanism. | 
|  | #           See `DataLoader.cpp` and `_utils/signal_handling.py` for details. | 
|  | # | 
|  | #           For `.get()` calls where the sender(s) is not the workers, we | 
|  | #           guard them with timeouts, and check the status of the sender | 
|  | #           when timeout happens: | 
|  | #             + in the workers, the `_utils.worker.ManagerWatchdog` class | 
|  | #               checks the status of the main process. | 
|  | #             + if `pin_memory=True`, when getting from `pin_memory_thread`, | 
|  | #               check `pin_memory_thread` status periodically until `.get()` | 
|  | #               returns or see that `pin_memory_thread` died. | 
|  | # | 
|  | #        b. A process won't hang when putting into a queue; | 
|  | # | 
|  | #           We use `mp.Queue` which has a separate background thread to put | 
|  | #           objects from an unbounded buffer array. The background thread is | 
|  | #           daemonic and usually automatically joined when the process | 
|  | #           *exits*. | 
|  | # | 
|  | #           In case that the receiver has ended abruptly while | 
|  | #           reading from the pipe, the join will hang forever.  The usual | 
|  | #           solution for this in Python is calling  `q.cancel_join_thread`, | 
|  | #           which prevents automatically joining it when finalizing | 
|  | #           (exiting). | 
|  | # | 
|  | #           Nonetheless, `cancel_join_thread` must only be called when the | 
|  | #           queue is **not** going to be read from or write into by another | 
|  | #           process, because it may hold onto a lock or leave corrupted data | 
|  | #           in the queue, leading other readers/writers to hang. | 
|  | # | 
|  | #           Hence, | 
|  | #             + For worker processes, we only do so (for their output | 
|  | #               queues, i.e., `worker_result_queue`) before exiting. | 
|  | #             + For `pin_memory_thread`, its output queue `data_queue` is a | 
|  | #               `queue.Queue` that does blocking `put` if the queue is full. | 
|  | #               So there is no above problem, but as a result, in | 
|  | #               `_pin_memory_loop`, we do need to  wrap the `put` in a loop | 
|  | #               that breaks not only upon success, but also when the main | 
|  | #               process stops reading, i.e., is shutting down. | 
|  | #             + For loader process, we `cancel_join_thread()` for all | 
|  | #               `_index_queues` because the whole purpose of workers and | 
|  | #               `pin_memory_thread` is to serve the loader process.  If | 
|  | #               loader process is already exiting, we don't really care if | 
|  | #               the queues are corrupted. | 
|  | # | 
|  | # | 
|  | # Now let's get back to 1: | 
|  | #   how we gracefully exit the workers when the last reference to the | 
|  | #   iterator is gone. | 
|  | # | 
|  | # To achieve this, we implement the following logic along with the design | 
|  | # choices mentioned above: | 
|  | # | 
|  | # `workers_done_event`: | 
|  | #   A `multiprocessing.Event` shared among the main process and all worker | 
|  | #   processes. This is used to signal the workers that the iterator is | 
|  | #   shutting down. After it is set, they will not send processed data to | 
|  | #   queues anymore, and only wait for the final `None` before exiting. | 
|  | #   `done_event` isn't strictly needed. I.e., we can just check for `None` | 
|  | #   from the input queue, but it allows us to skip wasting resources | 
|  | #   processing data if we are already shutting down. | 
|  | # | 
|  | # `pin_memory_thread_done_event`: | 
|  | #   A `threading.Event` for a similar purpose to that of | 
|  | #   `workers_done_event`, but is for the `pin_memory_thread`. The reason | 
|  | #   that separate events are needed is that `pin_memory_thread` reads from | 
|  | #   the output queue of the workers. But the workers, upon seeing that | 
|  | #   `workers_done_event` is set, only wants to see the final `None`, and is | 
|  | #   not required to flush all data in the output queue (e.g., it may call | 
|  | #   `cancel_join_thread` on that queue if its `IterableDataset` iterator | 
|  | #   happens to exhaust coincidentally, which is out of the control of the | 
|  | #   main process). Thus, since we will exit `pin_memory_thread` before the | 
|  | #   workers (see below), two separete events are used. | 
|  | # | 
|  | # NOTE: In short, the protocol is that the main process will set these | 
|  | #       `done_event`s and then the corresponding processes/threads a `None`, | 
|  | #       and that they may exit at any time after receiving the `None`. | 
|  | # | 
|  | # NOTE: Using `None` as the final signal is valid, since normal data will | 
|  | #       always be a 2-tuple with the 1st element being the index of the data | 
|  | #       transferred (different from dataset index/key), and the 2nd being | 
|  | #       either the dataset key or the data sample (depending on which part | 
|  | #       of the data model the queue is at). | 
|  | # | 
|  | # [ worker processes ] | 
|  | #   While loader process is alive: | 
|  | #     Get from `index_queue`. | 
|  | #       If get anything else, | 
|  | #          Check `workers_done_event`. | 
|  | #            If set, continue to next iteration | 
|  | #                    i.e., keep getting until see the `None`, then exit. | 
|  | #            Otherwise, process data: | 
|  | #                If is fetching from an `IterableDataset` and the iterator | 
|  | #                    is exhausted, send an `_IterableDatasetStopIteration` | 
|  | #                    object to signal iteration end. The main process, upon | 
|  | #                    receiving such an object, will send `None` to this | 
|  | #                    worker and not use the corresponding `index_queue` | 
|  | #                    anymore. | 
|  | #       If timed out, | 
|  | #          No matter `workers_done_event` is set (still need to see `None`) | 
|  | #          or not, must continue to next iteration. | 
|  | #   (outside loop) | 
|  | #   If `workers_done_event` is set,  (this can be False with `IterableDataset`) | 
|  | #     `data_queue.cancel_join_thread()`.  (Everything is ending here: | 
|  | #                                          main process won't read from it; | 
|  | #                                          other workers will also call | 
|  | #                                          `cancel_join_thread`.) | 
|  | # | 
|  | # [ pin_memory_thread ] | 
|  | #   # No need to check main thread. If this thread is alive, the main loader | 
|  | #   # thread must be alive, because this thread is set as daemonic. | 
|  | #   While `pin_memory_thread_done_event` is not set: | 
|  | #     Get from `index_queue`. | 
|  | #       If timed out, continue to get in the next iteration. | 
|  | #       Otherwise, process data. | 
|  | #       While `pin_memory_thread_done_event` is not set: | 
|  | #         Put processed data to `data_queue` (a `queue.Queue` with blocking put) | 
|  | #         If timed out, continue to put in the next iteration. | 
|  | #         Otherwise, break, i.e., continuing to the out loop. | 
|  | # | 
|  | #   NOTE: we don't check the status of the main thread because | 
|  | #           1. if the process is killed by fatal signal, `pin_memory_thread` | 
|  | #              ends. | 
|  | #           2. in other cases, either the cleaning-up in __del__ or the | 
|  | #              automatic exit of daemonic thread will take care of it. | 
|  | #              This won't busy-wait either because `.get(timeout)` does not | 
|  | #              busy-wait. | 
|  | # | 
|  | # [ main process ] | 
|  | #   In the DataLoader Iter's `__del__` | 
|  | #     b. Exit `pin_memory_thread` | 
|  | #          i.   Set `pin_memory_thread_done_event`. | 
|  | #          ii   Put `None` in `worker_result_queue`. | 
|  | #          iii. Join the `pin_memory_thread`. | 
|  | #          iv.  `worker_result_queue.cancel_join_thread()`. | 
|  | # | 
|  | #     c. Exit the workers. | 
|  | #          i.   Set `workers_done_event`. | 
|  | #          ii.  Put `None` in each worker's `index_queue`. | 
|  | #          iii. Join the workers. | 
|  | #          iv.  Call `.cancel_join_thread()` on each worker's `index_queue`. | 
|  | # | 
|  | #        NOTE: (c) is better placed after (b) because it may leave corrupted | 
|  | #              data in `worker_result_queue`, which `pin_memory_thread` | 
|  | #              reads from, in which case the `pin_memory_thread` can only | 
|  | #              happen at timing out, which is slow. Nonetheless, same thing | 
|  | #              happens if a worker is killed by signal at unfortunate times, | 
|  | #              but in other cases, we are better off having a non-corrupted | 
|  | #              `worker_result_queue` for `pin_memory_thread`. | 
|  | # | 
|  | #   NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) | 
|  | #         can be omitted | 
|  | # | 
|  | # NB: `done_event`s isn't strictly needed. E.g., we can just check for | 
|  | #     `None` from `index_queue`, but it allows us to skip wasting resources | 
|  | #     processing indices already in `index_queue` if we are already shutting | 
|  | #     down. | 
|  |  | 
|  | def __init__(self, loader): | 
|  | super().__init__(loader) | 
|  |  | 
|  | self._prefetch_factor = loader.prefetch_factor | 
|  |  | 
|  | assert self._num_workers > 0 | 
|  | assert self._prefetch_factor > 0 | 
|  |  | 
|  | if loader.multiprocessing_context is None: | 
|  | multiprocessing_context = multiprocessing | 
|  | else: | 
|  | multiprocessing_context = loader.multiprocessing_context | 
|  |  | 
|  | self._worker_init_fn = loader.worker_init_fn | 
|  |  | 
|  | # Adds forward compatibilities so classic DataLoader can work with DataPipes: | 
|  | #   Additional worker init function will take care of sharding in MP and Distributed | 
|  | if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): | 
|  | self._worker_init_fn = functools.partial( | 
|  | _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank) | 
|  |  | 
|  | # No certainty which module multiprocessing_context is | 
|  | self._worker_result_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated] | 
|  | self._worker_pids_set = False | 
|  | self._shutdown = False | 
|  | self._workers_done_event = multiprocessing_context.Event() | 
|  |  | 
|  | self._index_queues = [] | 
|  | self._workers = [] | 
|  | for i in range(self._num_workers): | 
|  | # No certainty which module multiprocessing_context is | 
|  | index_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated] | 
|  | # Need to `cancel_join_thread` here! | 
|  | # See sections (2) and (3b) above. | 
|  | index_queue.cancel_join_thread() | 
|  | w = multiprocessing_context.Process( | 
|  | target=_utils.worker._worker_loop, | 
|  | args=(self._dataset_kind, self._dataset, index_queue, | 
|  | self._worker_result_queue, self._workers_done_event, | 
|  | self._auto_collation, self._collate_fn, self._drop_last, | 
|  | self._base_seed, self._worker_init_fn, i, self._num_workers, | 
|  | self._persistent_workers, self._shared_seed)) | 
|  | w.daemon = True | 
|  | # NB: Process.start() actually take some time as it needs to | 
|  | #     start a process and pass the arguments over via a pipe. | 
|  | #     Therefore, we only add a worker to self._workers list after | 
|  | #     it started, so that we do not call .join() if program dies | 
|  | #     before it starts, and __del__ tries to join but will get: | 
|  | #     AssertionError: can only join a started process. | 
|  | w.start() | 
|  | self._index_queues.append(index_queue) | 
|  | self._workers.append(w) | 
|  |  | 
|  | if self._pin_memory: | 
|  | self._pin_memory_thread_done_event = threading.Event() | 
|  |  | 
|  | # Queue is not type-annotated | 
|  | self._data_queue = queue.Queue()  # type: ignore[var-annotated] | 
|  | if self._pin_memory_device == "xpu": | 
|  | current_device = torch.xpu.current_device()  # type: ignore[attr-defined] | 
|  | elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): | 
|  | custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) | 
|  | current_device = custom_device_mod.current_device() | 
|  | else: | 
|  | current_device = torch.cuda.current_device()  # choose cuda for default | 
|  | pin_memory_thread = threading.Thread( | 
|  | target=_utils.pin_memory._pin_memory_loop, | 
|  | args=(self._worker_result_queue, self._data_queue, | 
|  | current_device, | 
|  | self._pin_memory_thread_done_event, self._pin_memory_device)) | 
|  | pin_memory_thread.daemon = True | 
|  | pin_memory_thread.start() | 
|  | # Similar to workers (see comment above), we only register | 
|  | # pin_memory_thread once it is started. | 
|  | self._pin_memory_thread = pin_memory_thread | 
|  | else: | 
|  | self._data_queue = self._worker_result_queue  # type: ignore[assignment] | 
|  |  | 
|  | # In some rare cases, persistent workers (daemonic processes) | 
|  | # would be terminated before `__del__` of iterator is invoked | 
|  | # when main process exits | 
|  | # It would cause failure when pin_memory_thread tries to read | 
|  | # corrupted data from worker_result_queue | 
|  | # atexit is used to shutdown thread and child processes in the | 
|  | # right sequence before main process exits | 
|  | if self._persistent_workers and self._pin_memory: | 
|  | import atexit | 
|  | for w in self._workers: | 
|  | atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) | 
|  |  | 
|  | # .pid can be None only before process is spawned (not the case, so ignore) | 
|  | _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))  # type: ignore[misc] | 
|  | _utils.signal_handling._set_SIGCHLD_handler() | 
|  | self._worker_pids_set = True | 
|  | self._reset(loader, first_iter=True) | 
|  |  | 
|  | def _reset(self, loader, first_iter=False): | 
|  | super()._reset(loader, first_iter) | 
|  | self._send_idx = 0  # idx of the next task to be sent to workers | 
|  | self._rcvd_idx = 0  # idx of the next task to be returned in __next__ | 
|  | # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). | 
|  | # map: task idx => - (worker_id,)        if data isn't fetched (outstanding) | 
|  | #                  \ (worker_id, data)   if data is already fetched (out-of-order) | 
|  | self._task_info = {} | 
|  | self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1) | 
|  | # A list of booleans representing whether each worker still has work to | 
|  | # do, i.e., not having exhausted its iterable dataset object. It always | 
|  | # contains all `True`s if not using an iterable-style dataset | 
|  | # (i.e., if kind != Iterable). | 
|  | # Not that this indicates that a worker still has work to do *for this epoch*. | 
|  | # It does not mean that a worker is dead. In case of `_persistent_workers`, | 
|  | # the worker will be reset to available in the next epoch. | 
|  | self._workers_status = [True for i in range(self._num_workers)] | 
|  | # Reset the worker queue cycle so it resumes next epoch at worker 0 | 
|  | self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) | 
|  | # We resume the prefetching in case it was enabled | 
|  | if not first_iter: | 
|  | for idx in range(self._num_workers): | 
|  | self._index_queues[idx].put(_utils.worker._ResumeIteration(self._shared_seed)) | 
|  | resume_iteration_cnt = self._num_workers | 
|  | while resume_iteration_cnt > 0: | 
|  | return_idx, return_data = self._get_data() | 
|  | if isinstance(return_idx, _utils.worker._ResumeIteration): | 
|  | assert return_data is None | 
|  | resume_iteration_cnt -= 1 | 
|  | # prime the prefetch loop | 
|  | for _ in range(self._prefetch_factor * self._num_workers): | 
|  | self._try_put_index() | 
|  |  | 
|  | def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): | 
|  | # Tries to fetch data from `self._data_queue` once for a given timeout. | 
|  | # This can also be used as inner loop of fetching without timeout, with | 
|  | # the sender status as the loop condition. | 
|  | # | 
|  | # This raises a `RuntimeError` if any worker died expectedly. This error | 
|  | # can come from either the SIGCHLD handler in `_utils/signal_handling.py` | 
|  | # (only for non-Windows platforms), or the manual check below on errors | 
|  | # and timeouts. | 
|  | # | 
|  | # Returns a 2-tuple: | 
|  | #   (bool: whether successfully get data, any: data if successful else None) | 
|  | try: | 
|  | data = self._data_queue.get(timeout=timeout) | 
|  | return (True, data) | 
|  | except Exception as e: | 
|  | # At timeout and error, we manually check whether any worker has | 
|  | # failed. Note that this is the only mechanism for Windows to detect | 
|  | # worker failures. | 
|  | failed_workers = [] | 
|  | for worker_id, w in enumerate(self._workers): | 
|  | if self._workers_status[worker_id] and not w.is_alive(): | 
|  | failed_workers.append(w) | 
|  | self._mark_worker_as_unavailable(worker_id) | 
|  | if len(failed_workers) > 0: | 
|  | pids_str = ', '.join(str(w.pid) for w in failed_workers) | 
|  | raise RuntimeError(f'DataLoader worker (pid(s) {pids_str}) exited unexpectedly') from e | 
|  | if isinstance(e, queue.Empty): | 
|  | return (False, None) | 
|  | import tempfile | 
|  | import errno | 
|  | try: | 
|  | # Raise an exception if we are this close to the FDs limit. | 
|  | # Apparently, trying to open only one file is not a sufficient | 
|  | # test. | 
|  | # See NOTE [ DataLoader on Linux and open files limit ] | 
|  | fds_limit_margin = 10 | 
|  | fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] | 
|  | except OSError as e: | 
|  | if e.errno == errno.EMFILE: | 
|  | raise RuntimeError( | 
|  | "Too many open files. Communication with the" | 
|  | " workers is no longer possible. Please increase the" | 
|  | " limit using `ulimit -n` in the shell or change the" | 
|  | " sharing strategy by calling" | 
|  | " `torch.multiprocessing.set_sharing_strategy('file_system')`" | 
|  | " at the beginning of your code") from None | 
|  | raise | 
|  |  | 
|  | # NOTE [ DataLoader on Linux and open files limit ] | 
|  | # | 
|  | # On Linux when DataLoader is used with multiprocessing we pass the data between | 
|  | # the root process and the workers through SHM files. We remove those files from | 
|  | # the filesystem as soon as they are created and keep them alive by | 
|  | # passing around their file descriptors through AF_UNIX sockets. (See | 
|  | # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in | 
|  | # the wiki (https://github.com/pytorch/pytorch/wiki).) | 
|  | # | 
|  | # This sometimes leads us to exceeding the open files limit. When that happens, | 
|  | # and the offending file descriptor is coming over a socket, the `socket` Python | 
|  | # package silently strips the file descriptor from the message, setting only the | 
|  | # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that | 
|  | # it _indicates that some control data were discarded due to lack of space in | 
|  | # the buffer for ancillary data_). This might reflect the C implementation of | 
|  | # AF_UNIX sockets. | 
|  | # | 
|  | # This behaviour can be reproduced with the script and instructions at the | 
|  | # bottom of this note. | 
|  | # | 
|  | # When that happens, the standard Python `multiprocessing` (and not | 
|  | # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata` | 
|  | # | 
|  | # Sometimes, instead of the FD being stripped, you may get an `OSError: | 
|  | # Too many open files`, both in the script below and in DataLoader. However, | 
|  | # this is rare and seems to be nondeterministic. | 
|  | # | 
|  | # | 
|  | #   #!/usr/bin/env python3 | 
|  | #   import sys | 
|  | #   import socket | 
|  | #   import os | 
|  | #   import array | 
|  | #   import shutil | 
|  | #   import socket | 
|  | # | 
|  | # | 
|  | #   if len(sys.argv) != 4: | 
|  | #       print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)") | 
|  | #       sys.exit(1) | 
|  | # | 
|  | #   if __name__ == '__main__': | 
|  | #       dirname = sys.argv[1] | 
|  | #       sock_path = dirname + "/sock" | 
|  | #       iterations = int(sys.argv[2]) | 
|  | #       def dummy_path(i): | 
|  | #           return dirname + "/" + str(i) + ".dummy" | 
|  | # | 
|  | # | 
|  | #       if sys.argv[3] == 'send': | 
|  | #           while not os.path.exists(sock_path): | 
|  | #               pass | 
|  | #           client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) | 
|  | #           client.connect(sock_path) | 
|  | #           for i in range(iterations): | 
|  | #               fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT) | 
|  | #               ancdata = array.array('i', [fd]) | 
|  | #               msg = bytes([i % 256]) | 
|  | #               print("Sending fd ", fd, " (iteration #", i, ")") | 
|  | #               client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)]) | 
|  | # | 
|  | # | 
|  | #       else: | 
|  | #           assert sys.argv[3] == 'recv' | 
|  | # | 
|  | #           if os.path.exists(dirname): | 
|  | #               raise Exception("Directory exists") | 
|  | # | 
|  | #           os.mkdir(dirname) | 
|  | # | 
|  | #           print("Opening socket...") | 
|  | #           server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) | 
|  | #           server.bind(sock_path) | 
|  | # | 
|  | #           print("Listening...") | 
|  | #           for i in range(iterations): | 
|  | #               a = array.array('i') | 
|  | #               msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize)) | 
|  | #               assert(len(ancdata) == 1) | 
|  | #               cmsg_level, cmsg_type, cmsg_data = ancdata[0] | 
|  | #               a.frombytes(cmsg_data) | 
|  | #               print("Received fd ", a[0], " (iteration #", i, ")") | 
|  | # | 
|  | #           shutil.rmtree(dirname) | 
|  | # | 
|  | # Steps to reproduce: | 
|  | # | 
|  | # 1. Run two shells and set lower file descriptor limit in the receiving one: | 
|  | # (shell1) ulimit -n 1020 | 
|  | # (shell2) ulimit -n 1022 | 
|  | # | 
|  | # 2. Run the script above with the `recv` option in the first shell | 
|  | # (shell1) ./test_socket.py sock_tmp 1017 recv | 
|  | # | 
|  | # 3. Run the script with the `send` option in the second shell: | 
|  | # (shell2) ./test_socket.py sock_tmp 1017 send | 
|  |  | 
|  | def _get_data(self): | 
|  | # Fetches data from `self._data_queue`. | 
|  | # | 
|  | # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, | 
|  | # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` | 
|  | # in a loop. This is the only mechanism to detect worker failures for | 
|  | # Windows. For other platforms, a SIGCHLD handler is also used for | 
|  | # worker failure detection. | 
|  | # | 
|  | # If `pin_memory=True`, we also need check if `pin_memory_thread` had | 
|  | # died at timeouts. | 
|  | if self._timeout > 0: | 
|  | success, data = self._try_get_data(self._timeout) | 
|  | if success: | 
|  | return data | 
|  | else: | 
|  | raise RuntimeError(f'DataLoader timed out after {self._timeout} seconds') | 
|  | elif self._pin_memory: | 
|  | while self._pin_memory_thread.is_alive(): | 
|  | success, data = self._try_get_data() | 
|  | if success: | 
|  | return data | 
|  | else: | 
|  | # while condition is false, i.e., pin_memory_thread died. | 
|  | raise RuntimeError('Pin memory thread exited unexpectedly') | 
|  | # In this case, `self._data_queue` is a `queue.Queue`,. But we don't | 
|  | # need to call `.task_done()` because we don't use `.join()`. | 
|  | else: | 
|  | while True: | 
|  | success, data = self._try_get_data() | 
|  | if success: | 
|  | return data | 
|  |  | 
|  | def _next_data(self): | 
|  | while True: | 
|  | # If the worker responsible for `self._rcvd_idx` has already ended | 
|  | # and was unable to fulfill this task (due to exhausting an `IterableDataset`), | 
|  | # we try to advance `self._rcvd_idx` to find the next valid index. | 
|  | # | 
|  | # This part needs to run in the loop because both the `self._get_data()` | 
|  | # call and `_IterableDatasetStopIteration` check below can mark | 
|  | # extra worker(s) as dead. | 
|  | while self._rcvd_idx < self._send_idx: | 
|  | info = self._task_info[self._rcvd_idx] | 
|  | worker_id = info[0] | 
|  | if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still active | 
|  | break | 
|  | del self._task_info[self._rcvd_idx] | 
|  | self._rcvd_idx += 1 | 
|  | else: | 
|  | # no valid `self._rcvd_idx` is found (i.e., didn't break) | 
|  | if not self._persistent_workers: | 
|  | self._shutdown_workers() | 
|  | raise StopIteration | 
|  |  | 
|  | # Now `self._rcvd_idx` is the batch index we want to fetch | 
|  |  | 
|  | # Check if the next sample has already been generated | 
|  | if len(self._task_info[self._rcvd_idx]) == 2: | 
|  | data = self._task_info.pop(self._rcvd_idx)[1] | 
|  | return self._process_data(data) | 
|  |  | 
|  | assert not self._shutdown and self._tasks_outstanding > 0 | 
|  | idx, data = self._get_data() | 
|  | self._tasks_outstanding -= 1 | 
|  | if self._dataset_kind == _DatasetKind.Iterable: | 
|  | # Check for _IterableDatasetStopIteration | 
|  | if isinstance(data, _utils.worker._IterableDatasetStopIteration): | 
|  | if self._persistent_workers: | 
|  | self._workers_status[data.worker_id] = False | 
|  | else: | 
|  | self._mark_worker_as_unavailable(data.worker_id) | 
|  | self._try_put_index() | 
|  | continue | 
|  |  | 
|  | if idx != self._rcvd_idx: | 
|  | # store out-of-order samples | 
|  | self._task_info[idx] += (data,) | 
|  | else: | 
|  | del self._task_info[idx] | 
|  | return self._process_data(data) | 
|  |  | 
|  | def _try_put_index(self): | 
|  | assert self._tasks_outstanding < self._prefetch_factor * self._num_workers | 
|  |  | 
|  | try: | 
|  | index = self._next_index() | 
|  | except StopIteration: | 
|  | return | 
|  | for _ in range(self._num_workers):  # find the next active worker, if any | 
|  | worker_queue_idx = next(self._worker_queue_idx_cycle) | 
|  | if self._workers_status[worker_queue_idx]: | 
|  | break | 
|  | else: | 
|  | # not found (i.e., didn't break) | 
|  | return | 
|  |  | 
|  | self._index_queues[worker_queue_idx].put((self._send_idx, index)) | 
|  | self._task_info[self._send_idx] = (worker_queue_idx,) | 
|  | self._tasks_outstanding += 1 | 
|  | self._send_idx += 1 | 
|  |  | 
|  | def _process_data(self, data): | 
|  | self._rcvd_idx += 1 | 
|  | self._try_put_index() | 
|  | if isinstance(data, ExceptionWrapper): | 
|  | data.reraise() | 
|  | return data | 
|  |  | 
|  | def _mark_worker_as_unavailable(self, worker_id, shutdown=False): | 
|  | # Mark a worker as having finished its work e.g., due to | 
|  | # exhausting an `IterableDataset`. This should be used only when this | 
|  | # `_MultiProcessingDataLoaderIter` is going to continue running. | 
|  |  | 
|  | assert self._workers_status[worker_id] or (self._persistent_workers and shutdown) | 
|  |  | 
|  | # Signal termination to that specific worker. | 
|  | q = self._index_queues[worker_id] | 
|  | # Indicate that no more data will be put on this queue by the current | 
|  | # process. | 
|  | q.put(None) | 
|  |  | 
|  | # Note that we don't actually join the worker here, nor do we remove the | 
|  | # worker's pid from C side struct because (1) joining may be slow, and | 
|  | # (2) since we don't join, the worker may still raise error, and we | 
|  | # prefer capturing those, rather than ignoring them, even though they | 
|  | # are raised after the worker has finished its job. | 
|  | # Joinning is deferred to `_shutdown_workers`, which it is called when | 
|  | # all workers finish their jobs (e.g., `IterableDataset` replicas) or | 
|  | # when this iterator is garbage collected. | 
|  |  | 
|  | self._workers_status[worker_id] = False | 
|  |  | 
|  | assert self._workers_done_event.is_set() == shutdown | 
|  |  | 
|  | def _shutdown_workers(self): | 
|  | # Called when shutting down this `_MultiProcessingDataLoaderIter`. | 
|  | # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on | 
|  | # the logic of this function. | 
|  | if _utils is None or _utils.python_exit_status is True or _utils.python_exit_status is None: | 
|  | # See (2) of the note. If Python is shutting down, do no-op. | 
|  | return | 
|  | # Normal exit when last reference is gone / iterator is depleted. | 
|  | # See (1) and the second half of the note. | 
|  | if not self._shutdown: | 
|  | self._shutdown = True | 
|  | try: | 
|  | # Normal exit when last reference is gone / iterator is depleted. | 
|  | # See (1) and the second half of the note. | 
|  |  | 
|  | # Exit `pin_memory_thread` first because exiting workers may leave | 
|  | # corrupted data in `worker_result_queue` which `pin_memory_thread` | 
|  | # reads from. | 
|  | if hasattr(self, '_pin_memory_thread'): | 
|  | # Use hasattr in case error happens before we set the attribute. | 
|  | self._pin_memory_thread_done_event.set() | 
|  | # Send something to pin_memory_thread in case it is waiting | 
|  | # so that it can wake up and check `pin_memory_thread_done_event` | 
|  | self._worker_result_queue.put((None, None)) | 
|  | self._pin_memory_thread.join() | 
|  | self._worker_result_queue.cancel_join_thread() | 
|  | self._worker_result_queue.close() | 
|  |  | 
|  | # Exit workers now. | 
|  | self._workers_done_event.set() | 
|  | for worker_id in range(len(self._workers)): | 
|  | # Get number of workers from `len(self._workers)` instead of | 
|  | # `self._num_workers` in case we error before starting all | 
|  | # workers. | 
|  | # If we are using workers_status with persistent_workers | 
|  | # we have to shut it down because the worker is paused | 
|  | if self._persistent_workers or self._workers_status[worker_id]: | 
|  | self._mark_worker_as_unavailable(worker_id, shutdown=True) | 
|  | for w in self._workers: | 
|  | # We should be able to join here, but in case anything went | 
|  | # wrong, we set a timeout and if the workers fail to join, | 
|  | # they are killed in the `finally` block. | 
|  | w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) | 
|  | for q in self._index_queues: | 
|  | q.cancel_join_thread() | 
|  | q.close() | 
|  | finally: | 
|  | # Even though all this function does is putting into queues that | 
|  | # we have called `cancel_join_thread` on, weird things can | 
|  | # happen when a worker is killed by a signal, e.g., hanging in | 
|  | # `Event.set()`. So we need to guard this with SIGCHLD handler, | 
|  | # and remove pids from the C side data structure only at the | 
|  | # end. | 
|  | # | 
|  | # FIXME: Unfortunately, for Windows, we are missing a worker | 
|  | #        error detection mechanism here in this function, as it | 
|  | #        doesn't provide a SIGCHLD handler. | 
|  | if self._worker_pids_set: | 
|  | _utils.signal_handling._remove_worker_pids(id(self)) | 
|  | self._worker_pids_set = False | 
|  | for w in self._workers: | 
|  | if w.is_alive(): | 
|  | # Existing mechanisms try to make the workers exit | 
|  | # peacefully, but in case that we unfortunately reach | 
|  | # here, which we shouldn't, (e.g., pytorch/pytorch#39570), | 
|  | # we kill the worker. | 
|  | w.terminate() | 
|  |  | 
|  | # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter` | 
|  | @staticmethod | 
|  | def _clean_up_worker(w): | 
|  | try: | 
|  | w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) | 
|  | finally: | 
|  | if w.is_alive(): | 
|  | w.terminate() | 
|  |  | 
|  | def __del__(self): | 
|  | self._shutdown_workers() |