make more iterator attributes private (#23744)
Summary:
1. Prefixed underscores to any `DataLoaderIter` attribute that is not part of the data loader ctor argument list.
2. Prefixed `DataLoader.dataset_kind` with underscore because it only makes sense with the private enum `_DatasetKind`, and is an implementation detail.
3. Disallow setting `DataLoader.dataset` and `DataLoader.batch_sampler` after initializing a `DataLoader` because they affect other attributes in `__init__`.
These changes should not have major BC breaking effect since the big changes are on the iterator class and most users don't even store it. I GitHub searched `pin_memory_thread` and (while I didn't look through all result pages) results I see are forks of pytorch and blog posts on how data loader works.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23744
Differential Revision: D16732507
Pulled By: ezyang
fbshipit-source-id: 9f04d000b4200b8047f31eaa3473780b66cebd26
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index 48d9887..34e30b3 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -560,7 +560,7 @@
it = iter(loader)
if use_workers:
- workers = it.workers
+ workers = it._workers
def kill_pid(pid):
psutil_p = psutil.Process(pid)
@@ -638,7 +638,7 @@
data = []
for d in it:
data.append(d)
- worker_pids = [w.pid for w in it.workers]
+ worker_pids = [w.pid for w in it._workers]
data = torch.cat(data, 0)
for d in data:
# each `d` is a [worker_id, worker_pid] pair, which is set in
@@ -735,7 +735,7 @@
def test_invalid_assign_after_init(self):
dl = DataLoader(self.dataset)
- for attr in ('batch_size', 'sampler', 'drop_last'):
+ for attr in ('batch_size', 'sampler', 'batch_sampler', 'drop_last', 'dataset'):
def fn():
setattr(dl, attr, {})
@@ -919,7 +919,7 @@
len(dataloader) # DataLoader with iterable-style dataset should error in __len__
# [no auto-batching] test that workers exit gracefully
- workers = dataloader_iter.workers
+ workers = dataloader_iter._workers
del dataloader_iter
try:
for w in workers:
@@ -955,7 +955,7 @@
self.assertEqual(fetched, {tuple(range(4)), tuple(range(7)), tuple(range(7, 14)), tuple(range(14, 20))})
# [auto-batching] test that workers exit gracefully
- workers = dataloader_iter.workers
+ workers = dataloader_iter._workers
del dataloader_iter
try:
for w in workers:
@@ -991,7 +991,7 @@
self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))})
# [auto-batching & drop_last] test that workers exit gracefully
- workers = dataloader_iter.workers
+ workers = dataloader_iter._workers
del dataloader_iter
try:
for w in workers:
@@ -1226,9 +1226,9 @@
for pin_memory in pin_memory_configs:
loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory))
- workers = loader.workers
+ workers = loader._workers
if pin_memory:
- pin_memory_thread = loader.pin_memory_thread
+ pin_memory_thread = loader._pin_memory_thread
for i, _ in enumerate(loader):
if i == 10:
break
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index 0349215..b544c4b 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -142,7 +142,7 @@
# samplers first, so that they don't learn that this combo doesn't work
# after spending time fixing the custom sampler errors.
if isinstance(dataset, IterableDataset):
- self.dataset_kind = _DatasetKind.Iterable
+ self._dataset_kind = _DatasetKind.Iterable
# NOTE [ Custom Samplers and `IterableDataset` ]
#
# `IterableDataset` does not support custom `batch_sampler` or
@@ -183,7 +183,7 @@
"DataLoader with IterableDataset: expected unspecified "
"batch_sampler option, but got batch_sampler={}".format(batch_sampler))
else:
- self.dataset_kind = _DatasetKind.Map
+ self._dataset_kind = _DatasetKind.Map
if sampler is not None and shuffle:
raise ValueError('sampler option is mutually exclusive with '
@@ -205,7 +205,7 @@
'shuffle, sampler, and drop_last')
if sampler is None: # give default samplers
- if self.dataset_kind == _DatasetKind.Iterable:
+ if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
@@ -265,7 +265,7 @@
self.__multiprocessing_context = multiprocessing_context
def __setattr__(self, attr, val):
- if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
+ if self.__initialized and attr in ('batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset'):
raise ValueError('{} attribute should not be set after {} is '
'initialized'.format(attr, self.__class__.__name__))
@@ -299,29 +299,29 @@
class _BaseDataLoaderIter(object):
def __init__(self, loader):
- self.dataset = loader.dataset
- self.dataset_kind = loader.dataset_kind
- self.auto_collation = loader._auto_collation
- self.drop_last = loader.drop_last
- self.index_sampler = loader._index_sampler
- self.num_workers = loader.num_workers
- self.pin_memory = loader.pin_memory and torch.cuda.is_available()
- self.timeout = loader.timeout
- self.collate_fn = loader.collate_fn
- self.sampler_iter = iter(self.index_sampler)
- self.base_seed = torch.empty((), dtype=torch.int64).random_().item()
+ self._dataset = loader.dataset
+ self._dataset_kind = loader._dataset_kind
+ self._auto_collation = loader._auto_collation
+ self._drop_last = loader.drop_last
+ self._index_sampler = loader._index_sampler
+ self._num_workers = loader.num_workers
+ self._pin_memory = loader.pin_memory and torch.cuda.is_available()
+ self._timeout = loader.timeout
+ self._collate_fn = loader.collate_fn
+ self._sampler_iter = iter(self._index_sampler)
+ self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
def __iter__(self):
return self
def _next_index(self):
- return next(self.sampler_iter) # may raise StopIteration
+ return next(self._sampler_iter) # may raise StopIteration
def __next__(self):
raise NotImplementedError
def __len__(self):
- return len(self.index_sampler)
+ return len(self._index_sampler)
def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
@@ -335,16 +335,16 @@
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
- assert self.timeout == 0
- assert self.num_workers == 0
+ assert self._timeout == 0
+ assert self._num_workers == 0
- self.dataset_fetcher = _DatasetKind.create_fetcher(
- self.dataset_kind, self.dataset, self.auto_collation, self.collate_fn, self.drop_last)
+ self._dataset_fetcher = _DatasetKind.create_fetcher(
+ self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def __next__(self):
index = self._next_index() # may raise StopIteration
- data = self.dataset_fetcher.fetch(index) # may raise StopIteration
- if self.pin_memory:
+ data = self._dataset_fetcher.fetch(index) # may raise StopIteration
+ if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
@@ -635,81 +635,81 @@
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
- assert self.num_workers > 0
+ assert self._num_workers > 0
if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
else:
multiprocessing_context = loader.multiprocessing_context
- self.worker_init_fn = loader.worker_init_fn
- self.worker_queue_idx_cycle = itertools.cycle(range(self.num_workers))
- self.worker_result_queue = multiprocessing_context.Queue()
- self.worker_pids_set = False
- self.shutdown = False
- self.send_idx = 0 # idx of the next task to be sent to workers
- self.rcvd_idx = 0 # idx of the next task to be returned in __next__
+ self._worker_init_fn = loader.worker_init_fn
+ self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
+ self._worker_result_queue = multiprocessing_context.Queue()
+ self._worker_pids_set = False
+ self._shutdown = False
+ self._send_idx = 0 # idx of the next task to be sent to workers
+ self._rcvd_idx = 0 # idx of the next task to be returned in __next__
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
# \ (worker_id, data) if data is already fetched (out-of-order)
- self.task_info = {}
- self.tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
- self.workers_done_event = multiprocessing_context.Event()
+ self._task_info = {}
+ self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
+ self._workers_done_event = multiprocessing_context.Event()
- self.index_queues = []
- self.workers = []
+ self._index_queues = []
+ self._workers = []
# A list of booleans representing whether each worker still has work to
# do, i.e., not having exhausted its iterable dataset object. It always
# contains all `True`s if not using an iterable-style dataset
# (i.e., if kind != Iterable).
- self.workers_status = []
- for i in range(self.num_workers):
+ self._workers_status = []
+ for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue()
# index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop,
- args=(self.dataset_kind, self.dataset, index_queue,
- self.worker_result_queue, self.workers_done_event,
- self.auto_collation, self.collate_fn, self.drop_last,
- self.base_seed + i, self.worker_init_fn, i, self.num_workers))
+ args=(self._dataset_kind, self._dataset, index_queue,
+ self._worker_result_queue, self._workers_done_event,
+ self._auto_collation, self._collate_fn, self._drop_last,
+ self._base_seed + i, self._worker_init_fn, i, self._num_workers))
w.daemon = True
# NB: Process.start() actually take some time as it needs to
# start a process and pass the arguments over via a pipe.
- # Therefore, we only add a worker to self.workers list after
+ # Therefore, we only add a worker to self._workers list after
# it started, so that we do not call .join() if program dies
# before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
- self.index_queues.append(index_queue)
- self.workers.append(w)
- self.workers_status.append(True)
+ self._index_queues.append(index_queue)
+ self._workers.append(w)
+ self._workers_status.append(True)
- if self.pin_memory:
- self.pin_memory_thread_done_event = threading.Event()
- self.data_queue = queue.Queue()
+ if self._pin_memory:
+ self._pin_memory_thread_done_event = threading.Event()
+ self._data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
- args=(self.worker_result_queue, self.data_queue,
+ args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
- self.pin_memory_thread_done_event))
+ self._pin_memory_thread_done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
- self.pin_memory_thread = pin_memory_thread
+ self._pin_memory_thread = pin_memory_thread
else:
- self.data_queue = self.worker_result_queue
+ self._data_queue = self._worker_result_queue
- _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
+ _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))
_utils.signal_handling._set_SIGCHLD_handler()
- self.worker_pids_set = True
+ self._worker_pids_set = True
# prime the prefetch loop
- for _ in range(2 * self.num_workers):
+ for _ in range(2 * self._num_workers):
self._try_put_index()
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
- # Tries to fetch data from `self.data_queue` once for a given timeout.
+ # Tries to fetch data from `self._data_queue` once for a given timeout.
# This can also be used as inner loop of fetching without timeout, with
# the sender status as the loop condition.
#
@@ -721,15 +721,15 @@
# Returns a 2-tuple:
# (bool: whether successfully get data, any: data if successful else None)
try:
- data = self.data_queue.get(timeout=timeout)
+ data = self._data_queue.get(timeout=timeout)
return (True, data)
except Exception as e:
# At timeout and error, we manually check whether any worker has
# failed. Note that this is the only mechanism for Windows to detect
# worker failures.
failed_workers = []
- for worker_id, w in enumerate(self.workers):
- if self.workers_status[worker_id] and not w.is_alive():
+ for worker_id, w in enumerate(self._workers):
+ if self._workers_status[worker_id] and not w.is_alive():
failed_workers.append(w)
self._shutdown_worker(worker_id)
if len(failed_workers) > 0:
@@ -740,7 +740,7 @@
raise
def _get_data(self):
- # Fetches data from `self.data_queue`.
+ # Fetches data from `self._data_queue`.
#
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
@@ -750,21 +750,21 @@
#
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
# died at timeouts.
- if self.timeout > 0:
- success, data = self._try_get_data(self.timeout)
+ if self._timeout > 0:
+ success, data = self._try_get_data(self._timeout)
if success:
return data
else:
- raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
- elif self.pin_memory:
- while self.pin_memory_thread.is_alive():
+ raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
+ elif self._pin_memory:
+ while self._pin_memory_thread.is_alive():
success, data = self._try_get_data()
if success:
return data
else:
# while condition is false, i.e., pin_memory_thread died.
raise RuntimeError('Pin memory thread exited unexpectedly')
- # In this case, `self.data_queue` is a `queue.Queue`,. But we don't
+ # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
# need to call `.task_done()` because we don't use `.join()`.
else:
while True:
@@ -774,73 +774,73 @@
def __next__(self):
while True:
- # If the worker responsible for `self.rcvd_idx` has already ended
+ # If the worker responsible for `self._rcvd_idx` has already ended
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
- # we try to advance `self.rcvd_idx` to find the next valid index.
+ # we try to advance `self._rcvd_idx` to find the next valid index.
#
# This part needs to run in the loop because both the `self._get_data()`
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead.
- while self.rcvd_idx < self.send_idx:
- info = self.task_info[self.rcvd_idx]
+ while self._rcvd_idx < self._send_idx:
+ info = self._task_info[self._rcvd_idx]
worker_id = info[0]
- if len(info) == 2 or self.workers_status[worker_id]: # has data or is still active
+ if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break
- del self.task_info[self.rcvd_idx]
- self.rcvd_idx += 1
+ del self._task_info[self._rcvd_idx]
+ self._rcvd_idx += 1
else:
- # no valid `self.rcvd_idx` is found (i.e., didn't break)
+ # no valid `self._rcvd_idx` is found (i.e., didn't break)
self._shutdown_workers()
raise StopIteration
- # Now `self.rcvd_idx` is the batch index we want to fetch
+ # Now `self._rcvd_idx` is the batch index we want to fetch
# Check if the next sample has already been generated
- if len(self.task_info[self.rcvd_idx]) == 2:
- data = self.task_info.pop(self.rcvd_idx)[1]
+ if len(self._task_info[self._rcvd_idx]) == 2:
+ data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data)
- assert not self.shutdown and self.tasks_outstanding > 0
+ assert not self._shutdown and self._tasks_outstanding > 0
idx, data = self._get_data()
- self.tasks_outstanding -= 1
+ self._tasks_outstanding -= 1
- if self.dataset_kind == _DatasetKind.Iterable:
+ if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
self._shutdown_worker(data.worker_id)
self._try_put_index()
continue
- if idx != self.rcvd_idx:
+ if idx != self._rcvd_idx:
# store out-of-order samples
- self.task_info[idx] += (data,)
+ self._task_info[idx] += (data,)
else:
- del self.task_info[idx]
+ del self._task_info[idx]
return self._process_data(data)
next = __next__ # Python 2 compatibility
def _try_put_index(self):
- assert self.tasks_outstanding < 2 * self.num_workers
+ assert self._tasks_outstanding < 2 * self._num_workers
try:
index = self._next_index()
except StopIteration:
return
- for _ in range(self.num_workers): # find the next active worker, if any
- worker_queue_idx = next(self.worker_queue_idx_cycle)
- if self.workers_status[worker_queue_idx]:
+ for _ in range(self._num_workers): # find the next active worker, if any
+ worker_queue_idx = next(self._worker_queue_idx_cycle)
+ if self._workers_status[worker_queue_idx]:
break
else:
# not found (i.e., didn't break)
return
- self.index_queues[worker_queue_idx].put((self.send_idx, index))
- self.task_info[self.send_idx] = (worker_queue_idx,)
- self.tasks_outstanding += 1
- self.send_idx += 1
+ self._index_queues[worker_queue_idx].put((self._send_idx, index))
+ self._task_info[self._send_idx] = (worker_queue_idx,)
+ self._tasks_outstanding += 1
+ self._send_idx += 1
def _process_data(self, data):
- self.rcvd_idx += 1
+ self._rcvd_idx += 1
self._try_put_index()
if isinstance(data, ExceptionWrapper):
data.reraise()
@@ -851,10 +851,10 @@
# exhausting an `IterableDataset`. This should be used only when this
# `_MultiProcessingDataLoaderIter` is going to continue running.
- assert self.workers_status[worker_id]
+ assert self._workers_status[worker_id]
# Signal termination to that specific worker.
- q = self.index_queues[worker_id]
+ q = self._index_queues[worker_id]
# Indicate that no more data will be put on this queue by the current
# process.
q.put(None)
@@ -867,7 +867,7 @@
# Joinning is deferred to `_shutdown_workers`, which it is called when
# all workers finish their jobs (e.g., `IterableDataset` replicas) or
# when this iterator is garbage collected.
- self.workers_status[worker_id] = False
+ self._workers_status[worker_id] = False
def _shutdown_workers(self):
# Called when shutting down this `_MultiProcessingDataLoaderIter`.
@@ -879,32 +879,32 @@
return
# Normal exit when last reference is gone / iterator is depleted.
# See (1) and the second half of the note.
- if not self.shutdown:
- self.shutdown = True
+ if not self._shutdown:
+ self._shutdown = True
try:
# Exit `pin_memory_thread` first because exiting workers may leave
# corrupted data in `worker_result_queue` which `pin_memory_thread`
# reads from.
- if hasattr(self, 'pin_memory_thread'):
+ if hasattr(self, '_pin_memory_thread'):
# Use hasattr in case error happens before we set the attribute.
- self.pin_memory_thread_done_event.set()
+ self._pin_memory_thread_done_event.set()
# Send something to pin_memory_thread in case it is waiting
# so that it can wake up and check `pin_memory_thread_done_event`
- self.worker_result_queue.put((None, None))
- self.pin_memory_thread.join()
- self.worker_result_queue.close()
+ self._worker_result_queue.put((None, None))
+ self._pin_memory_thread.join()
+ self._worker_result_queue.close()
# Exit workers now.
- self.workers_done_event.set()
- for worker_id in range(len(self.workers)):
- # Get number of workers from `len(self.workers)` instead of
- # `self.num_workers` in case we error before starting all
+ self._workers_done_event.set()
+ for worker_id in range(len(self._workers)):
+ # Get number of workers from `len(self._workers)` instead of
+ # `self._num_workers` in case we error before starting all
# workers.
- if self.workers_status[worker_id]:
+ if self._workers_status[worker_id]:
self._shutdown_worker(worker_id)
- for w in self.workers:
+ for w in self._workers:
w.join()
- for q in self.index_queues:
+ for q in self._index_queues:
q.cancel_join_thread()
q.close()
finally:
@@ -918,9 +918,9 @@
# FIXME: Unfortunately, for Windows, we are missing a worker
# error detection mechanism here in this function, as it
# doesn't provide a SIGCHLD handler.
- if self.worker_pids_set:
+ if self._worker_pids_set:
_utils.signal_handling._remove_worker_pids(id(self))
- self.worker_pids_set = False
+ self._worker_pids_set = False
def __del__(self):
self._shutdown_workers()