| r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes. |
| |
| To support these two classes, in `./_utils` we define many utility methods and |
| functions to be run in multiprocessing. E.g., the data loading worker loop is |
| in `./_utils/worker.py`. |
| """ |
| |
| import torch |
| import torch.multiprocessing as multiprocessing |
| from . import SequentialSampler, RandomSampler, BatchSampler |
| from . import _utils |
| import threading |
| from torch._six import queue |
| |
| |
| # This function used to be defined in this file. However, it was moved to |
| # _utils/collate.py. Although it is rather hard to access this from user land |
| # (one has to explicitly directly `import torch.utils.data.dataloader`), there |
| # probably is user code out there using it. This aliasing maintains BC in this |
| # aspect. |
| default_collate = _utils.collate.default_collate |
| |
| |
| class DataLoader(object): |
| r""" |
| Data loader. Combines a dataset and a sampler, and provides |
| single- or multi-process iterators over the dataset. |
| |
| Arguments: |
| dataset (Dataset): dataset from which to load the data. |
| batch_size (int, optional): how many samples per batch to load |
| (default: ``1``). |
| shuffle (bool, optional): set to ``True`` to have the data reshuffled |
| at every epoch (default: ``False``). |
| sampler (Sampler, optional): defines the strategy to draw samples from |
| the dataset. If specified, ``shuffle`` must be False. |
| batch_sampler (Sampler, optional): like sampler, but returns a batch of |
| indices at a time. Mutually exclusive with :attr:`batch_size`, |
| :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`. |
| num_workers (int, optional): how many subprocesses to use for data |
| loading. 0 means that the data will be loaded in the main process. |
| (default: ``0``) |
| collate_fn (callable, optional): merges a list of samples to form a mini-batch. |
| pin_memory (bool, optional): If ``True``, the data loader will copy tensors |
| into CUDA pinned memory before returning them. If your data elements |
| are a custom type, or your ``collate_fn`` returns a batch that is a custom type |
| see the example below. |
| drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, |
| if the dataset size is not divisible by the batch size. If ``False`` and |
| the size of dataset is not divisible by the batch size, then the last batch |
| will be smaller. (default: ``False``) |
| timeout (numeric, optional): if positive, the timeout value for collecting a batch |
| from workers. Should always be non-negative. (default: ``0``) |
| worker_init_fn (callable, optional): If not ``None``, this will be called on each |
| worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as |
| input, after seeding and before data loading. (default: ``None``) |
| |
| .. note:: When ``num_workers != 0``, the corresponding worker processes are created each time |
| iterator for the DataLoader is obtained (as in when you call |
| ``enumerate(dataloader,0)``). |
| At this point, the dataset, ``collate_fn`` and ``worker_init_fn`` are passed to each |
| worker, where they are used to access and initialize data based on the indices |
| queued up from the main process. This means that dataset access together with |
| its internal IO, transforms and collation runs in the worker, while any |
| shuffle randomization is done in the main process which guides loading by assigning |
| indices to load. Workers are shut down once the end of the iteration is reached. |
| |
| Since workers rely on Python multiprocessing, worker launch behavior is different |
| on Windows compared to Unix. On Unix fork() is used as the default |
| multiprocessing start method, so child workers typically can access the dataset and |
| Python argument functions directly through the cloned address space. On Windows, another |
| interpreter is launched which runs your main script, followed by the internal |
| worker function that receives the dataset, collate_fn and other arguments |
| through Pickle serialization. |
| |
| This separate serialization means that you should take two steps to ensure you |
| are compatible with Windows while using workers |
| (this also works equally well on Unix): |
| |
| - Wrap most of you main script's code within ``if __name__ == '__main__':`` block, |
| to make sure it doesn't run again (most likely generating error) when each worker |
| process is launched. You can place your dataset and DataLoader instance creation |
| logic here, as it doesn't need to be re-executed in workers. |
| - Make sure that ``collate_fn``, ``worker_init_fn`` or any custom dataset code |
| is declared as a top level def, outside of that ``__main__`` check. This ensures |
| they are available in workers as well |
| (this is needed since functions are pickled as references only, not bytecode). |
| |
| By default, each worker will have its PyTorch seed set to |
| ``base_seed + worker_id``, where ``base_seed`` is a long generated |
| by main process using its RNG. However, seeds for other libraies |
| may be duplicated upon initializing workers (w.g., NumPy), causing |
| each worker to return identical random numbers. (See |
| :ref:`dataloader-workers-random-seed` section in FAQ.) You may |
| use :func:`torch.initial_seed()` to access the PyTorch seed for |
| each worker in :attr:`worker_init_fn`, and use it to set other |
| seeds before data loading. |
| |
| .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an |
| unpicklable object, e.g., a lambda function. |
| |
| The default memory pinning logic only recognizes Tensors and maps and iterables |
| containg Tensors. By default, if the pinning logic sees a batch that is a custom type |
| (which will occur if you have a ``collate_fn`` that returns a custom batch type), |
| or if each element of your batch is a custom type, the pinning logic will not |
| recognize them, and it will return that batch (or those elements) |
| without pinning the memory. To enable memory pinning for custom batch or data types, |
| define a ``pin_memory`` method on your custom type(s). |
| |
| Example:: |
| |
| class SimpleCustomBatch: |
| def __init__(self, data): |
| transposed_data = list(zip(*data)) |
| self.inp = torch.stack(transposed_data[0], 0) |
| self.tgt = torch.stack(transposed_data[1], 0) |
| |
| def pin_memory(self): |
| self.inp = self.inp.pin_memory() |
| self.tgt = self.tgt.pin_memory() |
| return self |
| |
| def collate_wrapper(batch): |
| return SimpleCustomBatch(batch) |
| |
| inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) |
| tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) |
| dataset = TensorDataset(inps, tgts) |
| |
| loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, |
| pin_memory=True) |
| |
| for batch_ndx, sample in enumerate(loader): |
| print(sample.inp.is_pinned()) |
| print(sample.tgt.is_pinned()) |
| |
| """ |
| |
| __initialized = False |
| |
| def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, |
| batch_sampler=None, num_workers=0, collate_fn=default_collate, |
| pin_memory=False, drop_last=False, timeout=0, |
| worker_init_fn=None): |
| torch._C._log_api_usage_once("python.data_loader") |
| self.dataset = dataset |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| self.collate_fn = collate_fn |
| self.pin_memory = pin_memory |
| self.drop_last = drop_last |
| self.timeout = timeout |
| self.worker_init_fn = worker_init_fn |
| |
| if timeout < 0: |
| raise ValueError('timeout option should be non-negative') |
| |
| if batch_sampler is not None: |
| if batch_size > 1 or shuffle or sampler is not None or drop_last: |
| raise ValueError('batch_sampler option is mutually exclusive ' |
| 'with batch_size, shuffle, sampler, and ' |
| 'drop_last') |
| self.batch_size = None |
| self.drop_last = None |
| |
| if sampler is not None and shuffle: |
| raise ValueError('sampler option is mutually exclusive with ' |
| 'shuffle') |
| |
| if self.num_workers < 0: |
| raise ValueError('num_workers option cannot be negative; ' |
| 'use num_workers=0 to disable multiprocessing.') |
| |
| if batch_sampler is None: |
| if sampler is None: |
| if shuffle: |
| sampler = RandomSampler(dataset) |
| else: |
| sampler = SequentialSampler(dataset) |
| batch_sampler = BatchSampler(sampler, batch_size, drop_last) |
| |
| self.sampler = sampler |
| self.batch_sampler = batch_sampler |
| self.__initialized = True |
| |
| def __setattr__(self, attr, val): |
| if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'): |
| raise ValueError('{} attribute should not be set after {} is ' |
| 'initialized'.format(attr, self.__class__.__name__)) |
| |
| super(DataLoader, self).__setattr__(attr, val) |
| |
| def __iter__(self): |
| return _DataLoaderIter(self) |
| |
| def __len__(self): |
| return len(self.batch_sampler) |
| |
| |
| class _DataLoaderIter(object): |
| r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" |
| |
| # NOTE [ Data Loader Multiprocessing Shutdown Logic ] |
| # |
| # Preliminary: |
| # |
| # Our data model looks like this (queues are indicated with curly brackets): |
| # |
| # main process || |
| # | || |
| # {index_queue} || |
| # | || |
| # worker processes || DATA |
| # | || |
| # {worker_result_queue} || FLOW |
| # | || |
| # pin_memory_thread of main process || DIRECTION |
| # | || |
| # {data_queue} || |
| # | || |
| # data output \/ |
| # |
| # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if |
| # `pin_memory=False`. |
| # |
| # |
| # Terminating multiprocessing logic requires very careful design. In |
| # particular, we need to make sure that |
| # |
| # 1. The iterator gracefully exits the workers when its last reference is |
| # gone or it is depleted. |
| # |
| # In this case, the workers should be gracefully exited because the |
| # main process may still need to continue to run, and we want cleaning |
| # up code in the workers to be executed (e.g., releasing GPU memory). |
| # Naturally, we implement the shutdown logic in `__del__` of |
| # DataLoaderIterator. |
| # |
| # We delay the discussion on the logic in this case until later. |
| # |
| # 2. The iterator exits the workers when the loader process and/or worker |
| # processes exits normally or with error. |
| # |
| # We set all workers and `pin_memory_thread` to have `daemon=True`. |
| # |
| # You may ask, why can't we make the workers non-daemonic, and |
| # gracefully exit using the same logic as we have in `__del__` when the |
| # iterator gets deleted (see 1 above)? |
| # |
| # First of all, `__del__` is **not** guaranteed to be called when |
| # interpreter exits. Even if it is called, by the time it executes, |
| # many Python core library resources may alreay be freed, and even |
| # simple things like acquiring an internal lock of a queue may hang. |
| # Therefore, in this case, we actually need to prevent `__del__` from |
| # being executed, and rely on the automatic termination of daemonic |
| # children. Thus, we register an `atexit` hook that sets a global flag |
| # `_utils.python_exit_status`. Since `atexit` hooks are executed in the |
| # reverse order of registration, we are guaranteed that this flag is |
| # set before library resources we use are freed. (Hooks freeing those |
| # resources are registered at importing the Python core libraries at |
| # the top of this file.) So in `__del__`, we check if |
| # `_utils.python_exit_status` is set or `None` (freed), and perform |
| # no-op if so. |
| # |
| # Another problem with `__del__` is also related to the library cleanup |
| # calls. When a process ends, it shuts the all its daemonic children |
| # down with a SIGTERM (instead of joining them without a timeout). |
| # Simiarly for threads, but by a different mechanism. This fact, |
| # together with a few implementation details of multiprocessing, forces |
| # us to make workers daemonic. All of our problems arise when a |
| # DataLoader is used in a subprocess, and are caused by multiprocessing |
| # code which looks more or less like this: |
| # |
| # try: |
| # your_function_using_a_dataloader() |
| # finally: |
| # multiprocessing.util._exit_function() |
| # |
| # The joining/termination mentioned above happens inside |
| # `_exit_function()`. Now, if `your_function_using_a_dataloader()` |
| # throws, the stack trace stored in the exception will prevent the |
| # frame which uses `DataLoaderIter` to be freed. If the frame has any |
| # reference to the `DataLoaderIter` (e.g., in a method of the iter), |
| # its `__del__`, which starts the shutdown procedure, will not be |
| # called. That, in turn, means that workers aren't notified. Attempting |
| # to join in `_exit_function` will then result in a hang. |
| # |
| # For context, `_exit_function` is also registered as an `atexit` call. |
| # So it is unclear to me (@ssnl) why this is needed in a finally block. |
| # The code dates back to 2008 and there is no comment on the original |
| # PEP 371 or patch https://bugs.python.org/issue3050 (containing both |
| # the finally block and the `atexit` registration) that explains this. |
| # |
| # Another choice is to just shutdown workers with logic in 1 above |
| # whenever we see an error in `next`. This isn't ideal because |
| # a. It prevents users from using try-catch to resume data loading. |
| # b. It doesn't prevent hanging if users have references to the |
| # iterator. |
| # |
| # 3. All processes exit if any of them die unexpectedly by fatal signals. |
| # |
| # As shown above, the workers are set as daemonic children of the main |
| # process. However, automatic cleaning-up of such child processes only |
| # happens if the parent process exits gracefully (e.g., not via fatal |
| # signals like SIGKILL). So we must ensure that each process will exit |
| # even the process that should send/receive data to/from it were |
| # killed, i.e., |
| # |
| # a. A process won't hang when getting from a queue. |
| # |
| # Even with carefully designed data dependencies (i.e., a `put()` |
| # always corresponding to a `get()`), hanging on `get()` can still |
| # happen when data in queue is corrupted (e.g., due to |
| # `cancel_join_thread` or unexpected exit). |
| # |
| # For child exit, we set a timeout whenever we try to get data |
| # from `data_queue`, and check the workers' status on each timeout |
| # and error. |
| # See `_DataLoaderiter._get_batch()` and |
| # `_DataLoaderiter._try_get_batch()` for details. |
| # |
| # Additionally, for child exit on non-Windows platforms, we also |
| # register a SIGCHLD handler (which is supported on Windows) on |
| # the main process, which checks if any of the workers fail in the |
| # (Python) handler. This is more efficient and faster in detecting |
| # worker failures, compared to only using the above mechanism. |
| # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. |
| # |
| # For `.get()` calls where the sender(s) is not the workers, we |
| # guard them with timeouts, and check the status of the sender |
| # when timeout happens: |
| # + in the workers, the `_utils.worker.ManagerWatchdog` class |
| # checks the status of the main process. |
| # + if `pin_memory=True`, when getting from `pin_memory_thread`, |
| # check `pin_memory_thread` status periodically until `.get()` |
| # returns or see that `pin_memory_thread` died. |
| # |
| # b. A process won't hang when putting into a queue; |
| # |
| # We use `mp.Queue` which has a separate background thread to put |
| # objects from an unbounded buffer array. The background thread is |
| # daemonic and usually automatically joined when the process |
| # exits. |
| # |
| # However, in case that the receiver has ended abruptly while |
| # reading from the pipe, the join will hang forever. Therefore, |
| # for both `worker_result_queue` (worker -> main process/pin_memory_thread) |
| # and each `index_queue` (main process -> worker), we use |
| # `q.cancel_join_thread()` in sender process before any `q.put` to |
| # prevent this automatic join. |
| # |
| # Moreover, having all queues called `cancel_join_thread` makes |
| # implementing graceful shutdown logic in `__del__` much easier. |
| # It won't need to get from any queue, which would also need to be |
| # guarded by periodic status checks. |
| # |
| # Note that this may leave corrupted data in the queue, but we |
| # don't care about the data anyways once we are shutting down. |
| # |
| # |
| # Now let's get back to 1: |
| # how we gracefully exit the workers when the last reference to the |
| # iterator is gone. |
| # |
| # To achieve this, we implement the following logic along with the design |
| # choices mentioned above: |
| # |
| # [worker processes] |
| # While loader process is alive: |
| # Get from index_queue. |
| # If got a `None`, exit. |
| # If get anything else, |
| # Check `done_event`. |
| # If set, continue to next iteration |
| # i.e., keep getting until see the `None`, then exit. |
| # Otherwise, process data. |
| # If timed out, |
| # No matter `done_event` is set (still need to see `None`) or not, |
| # must continue to next iteration . |
| # |
| # [pin_memory_thread] |
| # # No need to check main thread. If this thread is alive, the main loader |
| # # thread must be alive, because this thread is set as daemonic. |
| # While True: |
| # Get from index_queue. |
| # If got a `None`, exit. |
| # If get anything else, |
| # Check `done_event`. |
| # If set, continue to next iteration |
| # i.e., keep getting until see the `None`, then exit. |
| # Otherwise, process data. |
| # |
| # NOTE: we don't check the status of the main thread because |
| # 1. if the process is killed by fatal signal, `pin_memory_thread` |
| # ends. |
| # 2. in other cases, either the cleaning-up in __del__ or the |
| # automatic exit of daemonic thread will take care of it. |
| # This won't busy-wait either because `.get(timeout)` does not |
| # busy-wait. |
| # |
| # [main process] |
| # In the DataLoader Iter's `__del__` |
| # a. Set `done_event` (shared with `pin_memory_thread` and workers). |
| # |
| # Note: from here on, the workers & `pin_memory_thread` may exit at |
| # any time after they receive `None`. |
| # |
| # b. Exit `pin_memory_thread` |
| # i. Put `None` in `worker_result_queue`. |
| # ii. Join the `pin_memory_thread`. |
| # |
| # c. Exit the workers. |
| # i. Put `None` in each worker's `index_queue`. |
| # ii. Join the workers. |
| # |
| # NOTE: This has to be after (b) because it may leave corrupted data |
| # in `worker_result_queue`, which `pin_memory_thread` reads |
| # from. |
| # |
| # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) |
| # can be omitted |
| # |
| # NB: `done_event`s isn't strictly needed. E.g., we can just check for |
| # `None` from `index_queue`, but it allows us to skip wasting resources |
| # processing indices already in `index_queue` if we are already shutting |
| # down. |
| |
| def __init__(self, loader): |
| self.dataset = loader.dataset |
| self.collate_fn = loader.collate_fn |
| self.batch_sampler = loader.batch_sampler |
| self.num_workers = loader.num_workers |
| self.pin_memory = loader.pin_memory and torch.cuda.is_available() |
| self.timeout = loader.timeout |
| |
| self.sample_iter = iter(self.batch_sampler) |
| |
| base_seed = torch.LongTensor(1).random_().item() |
| |
| if self.num_workers > 0: |
| self.worker_init_fn = loader.worker_init_fn |
| self.worker_queue_idx = 0 |
| self.worker_result_queue = multiprocessing.Queue() |
| self.batches_outstanding = 0 |
| self.worker_pids_set = False |
| self.shutdown = False |
| self.send_idx = 0 |
| self.rcvd_idx = 0 |
| self.reorder_dict = {} |
| self.done_event = multiprocessing.Event() |
| |
| self.index_queues = [] |
| self.workers = [] |
| for i in range(self.num_workers): |
| index_queue = multiprocessing.Queue() |
| index_queue.cancel_join_thread() |
| w = multiprocessing.Process( |
| target=_utils.worker._worker_loop, |
| args=(self.dataset, index_queue, |
| self.worker_result_queue, self.done_event, |
| self.collate_fn, base_seed + i, |
| self.worker_init_fn, i)) |
| w.daemon = True |
| # NB: Process.start() actually take some time as it needs to |
| # start a process and pass the arguments over via a pipe. |
| # Therefore, we only add a worker to self.workers list after |
| # it started, so that we do not call .join() if program dies |
| # before it starts, and __del__ tries to join but will get: |
| # AssertionError: can only join a started process. |
| w.start() |
| self.index_queues.append(index_queue) |
| self.workers.append(w) |
| |
| if self.pin_memory: |
| self.data_queue = queue.Queue() |
| pin_memory_thread = threading.Thread( |
| target=_utils.pin_memory._pin_memory_loop, |
| args=(self.worker_result_queue, self.data_queue, |
| torch.cuda.current_device(), self.done_event)) |
| pin_memory_thread.daemon = True |
| pin_memory_thread.start() |
| # Similar to workers (see comment above), we only register |
| # pin_memory_thread once it is started. |
| self.pin_memory_thread = pin_memory_thread |
| else: |
| self.data_queue = self.worker_result_queue |
| |
| _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers)) |
| _utils.signal_handling._set_SIGCHLD_handler() |
| self.worker_pids_set = True |
| |
| # prime the prefetch loop |
| for _ in range(2 * self.num_workers): |
| self._put_indices() |
| |
| def __len__(self): |
| return len(self.batch_sampler) |
| |
| def _try_get_batch(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): |
| # Tries to fetch data from `data_queue` for a given timeout. This can |
| # also be used as inner loop of fetching without timeout, with the |
| # sender status as the loop condition. |
| # |
| # This raises a `RuntimeError` if any worker died expectedly. This error |
| # can come from either the SIGCHLD handler in `_utils/signal_handling.py` |
| # (only for non-Windows platforms), or the manual check below on errors |
| # and timeouts. |
| # |
| # Returns a 2-tuple: |
| # (bool: whether successfully get data, any: data if successful else None) |
| try: |
| data = self.data_queue.get(timeout=timeout) |
| return (True, data) |
| except Exception as e: |
| # At timeout and error, we manually check whether any worker has |
| # failed. Note that this is the only mechanism for Windows to detect |
| # worker failures. |
| if not all(w.is_alive() for w in self.workers): |
| pids_str = ', '.join(str(w.pid) for w in self.workers if not w.is_alive()) |
| raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) |
| if isinstance(e, queue.Empty): |
| return (False, None) |
| raise |
| |
| def _get_batch(self): |
| # Fetches data from `self.data_queue`. |
| # |
| # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, |
| # which we achieve by running `self._try_get_batch(timeout=MP_STATUS_CHECK_INTERVAL)` |
| # in a loop. This is the only mechanism to detect worker failures for |
| # Windows. For other platforms, a SIGCHLD handler is also used for |
| # worker failure detection. |
| # |
| # If `pin_memory=True`, we also need check if `pin_memory_thread` had |
| # died at timeouts. |
| if self.timeout > 0: |
| success, data = self._try_get_batch(self.timeout) |
| if success: |
| return data |
| else: |
| raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) |
| elif self.pin_memory: |
| while self.pin_memory_thread.is_alive(): |
| success, data = self._try_get_batch() |
| if success: |
| return data |
| else: |
| # while condition is false, i.e., pin_memory_thread died. |
| raise RuntimeError('Pin memory thread exited unexpectedly') |
| # In this case, `self.data_queue` is a `queue.Queue`,. But we don't |
| # need to call `.task_done()` because we don't use `.join()`. |
| else: |
| while True: |
| success, data = self._try_get_batch() |
| if success: |
| return data |
| |
| def __next__(self): |
| if self.num_workers == 0: # same-process loading |
| indices = next(self.sample_iter) # may raise StopIteration |
| batch = self.collate_fn([self.dataset[i] for i in indices]) |
| if self.pin_memory: |
| batch = _utils.pin_memory.pin_memory_batch(batch) |
| return batch |
| |
| # check if the next sample has already been generated |
| if self.rcvd_idx in self.reorder_dict: |
| batch = self.reorder_dict.pop(self.rcvd_idx) |
| return self._process_next_batch(batch) |
| |
| if self.batches_outstanding == 0: |
| self._shutdown_workers() |
| raise StopIteration |
| |
| while True: |
| assert (not self.shutdown and self.batches_outstanding > 0) |
| idx, batch = self._get_batch() |
| self.batches_outstanding -= 1 |
| if idx != self.rcvd_idx: |
| # store out-of-order samples |
| self.reorder_dict[idx] = batch |
| continue |
| return self._process_next_batch(batch) |
| |
| next = __next__ # Python 2 compatibility |
| |
| def __iter__(self): |
| return self |
| |
| def _put_indices(self): |
| assert self.batches_outstanding < 2 * self.num_workers |
| indices = next(self.sample_iter, None) |
| if indices is None: |
| return |
| self.index_queues[self.worker_queue_idx].put((self.send_idx, indices)) |
| self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers |
| self.batches_outstanding += 1 |
| self.send_idx += 1 |
| |
| def _process_next_batch(self, batch): |
| self.rcvd_idx += 1 |
| self._put_indices() |
| if isinstance(batch, _utils.ExceptionWrapper): |
| # make multiline KeyError msg readable by working around |
| # a python bug https://bugs.python.org/issue2651 |
| if batch.exc_type == KeyError and "\n" in batch.exc_msg: |
| raise Exception("KeyError:" + batch.exc_msg) |
| else: |
| raise batch.exc_type(batch.exc_msg) |
| return batch |
| |
| def __getstate__(self): |
| # TODO: add limited pickling support for sharing an iterator |
| # across multiple threads for HOGWILD. |
| # Probably the best way to do this is by moving the sample pushing |
| # to a separate thread and then just sharing the data queue |
| # but signalling the end is tricky without a non-blocking API |
| raise NotImplementedError("_DataLoaderIter cannot be pickled") |
| |
| def _shutdown_workers(self): |
| # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on |
| # the logic of this function. |
| python_exit_status = _utils.python_exit_status |
| if python_exit_status is True or python_exit_status is None: |
| # See (2) of the note. If Python is shutting down, do no-op. |
| return |
| # Normal exit when last reference is gone / iterator is depleted. |
| # See (1) and the second half of the note. |
| if not self.shutdown: |
| self.shutdown = True |
| try: |
| self.done_event.set() |
| |
| # Exit `pin_memory_thread` first because exiting workers may leave |
| # corrupted data in `worker_result_queue` which `pin_memory_thread` |
| # reads from. |
| if hasattr(self, 'pin_memory_thread'): |
| # Use hasattr in case error happens before we set the attribute. |
| # First time do `worker_result_queue.put` in this process. |
| |
| # `cancel_join_thread` in case that `pin_memory_thread` exited. |
| self.worker_result_queue.cancel_join_thread() |
| self.worker_result_queue.put(None) |
| self.pin_memory_thread.join() |
| # Indicate that no more data will be put on this queue by the |
| # current process. This **must** be called after |
| # `pin_memory_thread` is joined because that thread shares the |
| # same pipe handles with this loader thread. If the handle is |
| # closed, Py3 will error in this case, but Py2 will just time |
| # out even if there is data in the queue. |
| self.worker_result_queue.close() |
| |
| # Exit workers now. |
| for q in self.index_queues: |
| q.put(None) |
| # Indicate that no more data will be put on this queue by the |
| # current process. |
| q.close() |
| for w in self.workers: |
| w.join() |
| finally: |
| # Even though all this function does is putting into queues that |
| # we have called `cancel_join_thread` on, weird things can |
| # happen when a worker is killed by a signal, e.g., hanging in |
| # `Event.set()`. So we need to guard this with SIGCHLD handler, |
| # and remove pids from the C side data structure only at the |
| # end. |
| # |
| # FIXME: Unfortunately, for Windows, we are missing a worker |
| # error detection mechanism here in this function, as it |
| # doesn't provide a SIGCHLD handler. |
| if self.worker_pids_set: |
| _utils.signal_handling._remove_worker_pids(id(self)) |
| self.worker_pids_set = False |
| |
| def __del__(self): |
| if self.num_workers > 0: |
| self._shutdown_workers() |