make more iterator attributes private (#23744)

Summary:
1. Prefixed underscores to any `DataLoaderIter` attribute that is not part of the data loader ctor argument list.
2. Prefixed `DataLoader.dataset_kind` with underscore because it only makes sense with the private enum `_DatasetKind`, and is an implementation detail.
3. Disallow setting `DataLoader.dataset` and `DataLoader.batch_sampler` after initializing a `DataLoader` because they affect other attributes in `__init__`.

These changes should not have major BC breaking effect since the big changes are on the iterator class and most users don't even store it. I GitHub searched `pin_memory_thread` and (while I didn't look through all result pages) results I see are forks of pytorch and blog posts on how data loader works.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23744

Differential Revision: D16732507

Pulled By: ezyang

fbshipit-source-id: 9f04d000b4200b8047f31eaa3473780b66cebd26
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index 48d9887..34e30b3 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -560,7 +560,7 @@
 
     it = iter(loader)
     if use_workers:
-        workers = it.workers
+        workers = it._workers
 
     def kill_pid(pid):
         psutil_p = psutil.Process(pid)
@@ -638,7 +638,7 @@
     data = []
     for d in it:
         data.append(d)
-    worker_pids = [w.pid for w in it.workers]
+    worker_pids = [w.pid for w in it._workers]
     data = torch.cat(data, 0)
     for d in data:
         # each `d` is a [worker_id, worker_pid] pair, which is set in
@@ -735,7 +735,7 @@
 
     def test_invalid_assign_after_init(self):
         dl = DataLoader(self.dataset)
-        for attr in ('batch_size', 'sampler', 'drop_last'):
+        for attr in ('batch_size', 'sampler', 'batch_sampler', 'drop_last', 'dataset'):
             def fn():
                 setattr(dl, attr, {})
 
@@ -919,7 +919,7 @@
             len(dataloader)  # DataLoader with iterable-style dataset should error in __len__
 
         # [no auto-batching] test that workers exit gracefully
-        workers = dataloader_iter.workers
+        workers = dataloader_iter._workers
         del dataloader_iter
         try:
             for w in workers:
@@ -955,7 +955,7 @@
         self.assertEqual(fetched, {tuple(range(4)), tuple(range(7)), tuple(range(7, 14)), tuple(range(14, 20))})
 
         # [auto-batching] test that workers exit gracefully
-        workers = dataloader_iter.workers
+        workers = dataloader_iter._workers
         del dataloader_iter
         try:
             for w in workers:
@@ -991,7 +991,7 @@
         self.assertEqual(fetched, {tuple(range(7)), tuple(range(7, 14))})
 
         # [auto-batching & drop_last] test that workers exit gracefully
-        workers = dataloader_iter.workers
+        workers = dataloader_iter._workers
         del dataloader_iter
         try:
             for w in workers:
@@ -1226,9 +1226,9 @@
 
         for pin_memory in pin_memory_configs:
             loader = iter(DataLoader(self.dataset, batch_size=2, num_workers=4, pin_memory=pin_memory))
-            workers = loader.workers
+            workers = loader._workers
             if pin_memory:
-                pin_memory_thread = loader.pin_memory_thread
+                pin_memory_thread = loader._pin_memory_thread
             for i, _ in enumerate(loader):
                 if i == 10:
                     break
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index 0349215..b544c4b 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -142,7 +142,7 @@
         # samplers first, so that they don't learn that this combo doesn't work
         # after spending time fixing the custom sampler errors.
         if isinstance(dataset, IterableDataset):
-            self.dataset_kind = _DatasetKind.Iterable
+            self._dataset_kind = _DatasetKind.Iterable
             # NOTE [ Custom Samplers and `IterableDataset` ]
             #
             # `IterableDataset` does not support custom `batch_sampler` or
@@ -183,7 +183,7 @@
                     "DataLoader with IterableDataset: expected unspecified "
                     "batch_sampler option, but got batch_sampler={}".format(batch_sampler))
         else:
-            self.dataset_kind = _DatasetKind.Map
+            self._dataset_kind = _DatasetKind.Map
 
         if sampler is not None and shuffle:
             raise ValueError('sampler option is mutually exclusive with '
@@ -205,7 +205,7 @@
                                  'shuffle, sampler, and drop_last')
 
         if sampler is None:  # give default samplers
-            if self.dataset_kind == _DatasetKind.Iterable:
+            if self._dataset_kind == _DatasetKind.Iterable:
                 # See NOTE [ Custom Samplers and IterableDataset ]
                 sampler = _InfiniteConstantSampler()
             else:  # map-style
@@ -265,7 +265,7 @@
         self.__multiprocessing_context = multiprocessing_context
 
     def __setattr__(self, attr, val):
-        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):
+        if self.__initialized and attr in ('batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset'):
             raise ValueError('{} attribute should not be set after {} is '
                              'initialized'.format(attr, self.__class__.__name__))
 
@@ -299,29 +299,29 @@
 
 class _BaseDataLoaderIter(object):
     def __init__(self, loader):
-        self.dataset = loader.dataset
-        self.dataset_kind = loader.dataset_kind
-        self.auto_collation = loader._auto_collation
-        self.drop_last = loader.drop_last
-        self.index_sampler = loader._index_sampler
-        self.num_workers = loader.num_workers
-        self.pin_memory = loader.pin_memory and torch.cuda.is_available()
-        self.timeout = loader.timeout
-        self.collate_fn = loader.collate_fn
-        self.sampler_iter = iter(self.index_sampler)
-        self.base_seed = torch.empty((), dtype=torch.int64).random_().item()
+        self._dataset = loader.dataset
+        self._dataset_kind = loader._dataset_kind
+        self._auto_collation = loader._auto_collation
+        self._drop_last = loader.drop_last
+        self._index_sampler = loader._index_sampler
+        self._num_workers = loader.num_workers
+        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
+        self._timeout = loader.timeout
+        self._collate_fn = loader.collate_fn
+        self._sampler_iter = iter(self._index_sampler)
+        self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
 
     def __iter__(self):
         return self
 
     def _next_index(self):
-        return next(self.sampler_iter)  # may raise StopIteration
+        return next(self._sampler_iter)  # may raise StopIteration
 
     def __next__(self):
         raise NotImplementedError
 
     def __len__(self):
-        return len(self.index_sampler)
+        return len(self._index_sampler)
 
     def __getstate__(self):
         # TODO: add limited pickling support for sharing an iterator
@@ -335,16 +335,16 @@
 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
     def __init__(self, loader):
         super(_SingleProcessDataLoaderIter, self).__init__(loader)
-        assert self.timeout == 0
-        assert self.num_workers == 0
+        assert self._timeout == 0
+        assert self._num_workers == 0
 
-        self.dataset_fetcher = _DatasetKind.create_fetcher(
-            self.dataset_kind, self.dataset, self.auto_collation, self.collate_fn, self.drop_last)
+        self._dataset_fetcher = _DatasetKind.create_fetcher(
+            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
 
     def __next__(self):
         index = self._next_index()  # may raise StopIteration
-        data = self.dataset_fetcher.fetch(index)  # may raise StopIteration
-        if self.pin_memory:
+        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
+        if self._pin_memory:
             data = _utils.pin_memory.pin_memory(data)
         return data
 
@@ -635,81 +635,81 @@
     def __init__(self, loader):
         super(_MultiProcessingDataLoaderIter, self).__init__(loader)
 
-        assert self.num_workers > 0
+        assert self._num_workers > 0
 
         if loader.multiprocessing_context is None:
             multiprocessing_context = multiprocessing
         else:
             multiprocessing_context = loader.multiprocessing_context
 
-        self.worker_init_fn = loader.worker_init_fn
-        self.worker_queue_idx_cycle = itertools.cycle(range(self.num_workers))
-        self.worker_result_queue = multiprocessing_context.Queue()
-        self.worker_pids_set = False
-        self.shutdown = False
-        self.send_idx = 0  # idx of the next task to be sent to workers
-        self.rcvd_idx = 0  # idx of the next task to be returned in __next__
+        self._worker_init_fn = loader.worker_init_fn
+        self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
+        self._worker_result_queue = multiprocessing_context.Queue()
+        self._worker_pids_set = False
+        self._shutdown = False
+        self._send_idx = 0  # idx of the next task to be sent to workers
+        self._rcvd_idx = 0  # idx of the next task to be returned in __next__
         # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
         # map: task idx => - (worker_id,)        if data isn't fetched (outstanding)
         #                  \ (worker_id, data)   if data is already fetched (out-of-order)
-        self.task_info = {}
-        self.tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)
-        self.workers_done_event = multiprocessing_context.Event()
+        self._task_info = {}
+        self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)
+        self._workers_done_event = multiprocessing_context.Event()
 
-        self.index_queues = []
-        self.workers = []
+        self._index_queues = []
+        self._workers = []
         # A list of booleans representing whether each worker still has work to
         # do, i.e., not having exhausted its iterable dataset object. It always
         # contains all `True`s if not using an iterable-style dataset
         # (i.e., if kind != Iterable).
-        self.workers_status = []
-        for i in range(self.num_workers):
+        self._workers_status = []
+        for i in range(self._num_workers):
             index_queue = multiprocessing_context.Queue()
             # index_queue.cancel_join_thread()
             w = multiprocessing_context.Process(
                 target=_utils.worker._worker_loop,
-                args=(self.dataset_kind, self.dataset, index_queue,
-                      self.worker_result_queue, self.workers_done_event,
-                      self.auto_collation, self.collate_fn, self.drop_last,
-                      self.base_seed + i, self.worker_init_fn, i, self.num_workers))
+                args=(self._dataset_kind, self._dataset, index_queue,
+                      self._worker_result_queue, self._workers_done_event,
+                      self._auto_collation, self._collate_fn, self._drop_last,
+                      self._base_seed + i, self._worker_init_fn, i, self._num_workers))
             w.daemon = True
             # NB: Process.start() actually take some time as it needs to
             #     start a process and pass the arguments over via a pipe.
-            #     Therefore, we only add a worker to self.workers list after
+            #     Therefore, we only add a worker to self._workers list after
             #     it started, so that we do not call .join() if program dies
             #     before it starts, and __del__ tries to join but will get:
             #     AssertionError: can only join a started process.
             w.start()
-            self.index_queues.append(index_queue)
-            self.workers.append(w)
-            self.workers_status.append(True)
+            self._index_queues.append(index_queue)
+            self._workers.append(w)
+            self._workers_status.append(True)
 
-        if self.pin_memory:
-            self.pin_memory_thread_done_event = threading.Event()
-            self.data_queue = queue.Queue()
+        if self._pin_memory:
+            self._pin_memory_thread_done_event = threading.Event()
+            self._data_queue = queue.Queue()
             pin_memory_thread = threading.Thread(
                 target=_utils.pin_memory._pin_memory_loop,
-                args=(self.worker_result_queue, self.data_queue,
+                args=(self._worker_result_queue, self._data_queue,
                       torch.cuda.current_device(),
-                      self.pin_memory_thread_done_event))
+                      self._pin_memory_thread_done_event))
             pin_memory_thread.daemon = True
             pin_memory_thread.start()
             # Similar to workers (see comment above), we only register
             # pin_memory_thread once it is started.
-            self.pin_memory_thread = pin_memory_thread
+            self._pin_memory_thread = pin_memory_thread
         else:
-            self.data_queue = self.worker_result_queue
+            self._data_queue = self._worker_result_queue
 
-        _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
+        _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))
         _utils.signal_handling._set_SIGCHLD_handler()
-        self.worker_pids_set = True
+        self._worker_pids_set = True
 
         # prime the prefetch loop
-        for _ in range(2 * self.num_workers):
+        for _ in range(2 * self._num_workers):
             self._try_put_index()
 
     def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
-        # Tries to fetch data from `self.data_queue` once for a given timeout.
+        # Tries to fetch data from `self._data_queue` once for a given timeout.
         # This can also be used as inner loop of fetching without timeout, with
         # the sender status as the loop condition.
         #
@@ -721,15 +721,15 @@
         # Returns a 2-tuple:
         #   (bool: whether successfully get data, any: data if successful else None)
         try:
-            data = self.data_queue.get(timeout=timeout)
+            data = self._data_queue.get(timeout=timeout)
             return (True, data)
         except Exception as e:
             # At timeout and error, we manually check whether any worker has
             # failed. Note that this is the only mechanism for Windows to detect
             # worker failures.
             failed_workers = []
-            for worker_id, w in enumerate(self.workers):
-                if self.workers_status[worker_id] and not w.is_alive():
+            for worker_id, w in enumerate(self._workers):
+                if self._workers_status[worker_id] and not w.is_alive():
                     failed_workers.append(w)
                     self._shutdown_worker(worker_id)
             if len(failed_workers) > 0:
@@ -740,7 +740,7 @@
             raise
 
     def _get_data(self):
-        # Fetches data from `self.data_queue`.
+        # Fetches data from `self._data_queue`.
         #
         # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
         # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
@@ -750,21 +750,21 @@
         #
         # If `pin_memory=True`, we also need check if `pin_memory_thread` had
         # died at timeouts.
-        if self.timeout > 0:
-            success, data = self._try_get_data(self.timeout)
+        if self._timeout > 0:
+            success, data = self._try_get_data(self._timeout)
             if success:
                 return data
             else:
-                raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
-        elif self.pin_memory:
-            while self.pin_memory_thread.is_alive():
+                raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
+        elif self._pin_memory:
+            while self._pin_memory_thread.is_alive():
                 success, data = self._try_get_data()
                 if success:
                     return data
             else:
                 # while condition is false, i.e., pin_memory_thread died.
                 raise RuntimeError('Pin memory thread exited unexpectedly')
-            # In this case, `self.data_queue` is a `queue.Queue`,. But we don't
+            # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
             # need to call `.task_done()` because we don't use `.join()`.
         else:
             while True:
@@ -774,73 +774,73 @@
 
     def __next__(self):
         while True:
-            # If the worker responsible for `self.rcvd_idx` has already ended
+            # If the worker responsible for `self._rcvd_idx` has already ended
             # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
-            # we try to advance `self.rcvd_idx` to find the next valid index.
+            # we try to advance `self._rcvd_idx` to find the next valid index.
             #
             # This part needs to run in the loop because both the `self._get_data()`
             # call and `_IterableDatasetStopIteration` check below can mark
             # extra worker(s) as dead.
-            while self.rcvd_idx < self.send_idx:
-                info = self.task_info[self.rcvd_idx]
+            while self._rcvd_idx < self._send_idx:
+                info = self._task_info[self._rcvd_idx]
                 worker_id = info[0]
-                if len(info) == 2 or self.workers_status[worker_id]:  # has data or is still active
+                if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still active
                     break
-                del self.task_info[self.rcvd_idx]
-                self.rcvd_idx += 1
+                del self._task_info[self._rcvd_idx]
+                self._rcvd_idx += 1
             else:
-                # no valid `self.rcvd_idx` is found (i.e., didn't break)
+                # no valid `self._rcvd_idx` is found (i.e., didn't break)
                 self._shutdown_workers()
                 raise StopIteration
 
-            # Now `self.rcvd_idx` is the batch index we want to fetch
+            # Now `self._rcvd_idx` is the batch index we want to fetch
 
             # Check if the next sample has already been generated
-            if len(self.task_info[self.rcvd_idx]) == 2:
-                data = self.task_info.pop(self.rcvd_idx)[1]
+            if len(self._task_info[self._rcvd_idx]) == 2:
+                data = self._task_info.pop(self._rcvd_idx)[1]
                 return self._process_data(data)
 
-            assert not self.shutdown and self.tasks_outstanding > 0
+            assert not self._shutdown and self._tasks_outstanding > 0
             idx, data = self._get_data()
-            self.tasks_outstanding -= 1
+            self._tasks_outstanding -= 1
 
-            if self.dataset_kind == _DatasetKind.Iterable:
+            if self._dataset_kind == _DatasetKind.Iterable:
                 # Check for _IterableDatasetStopIteration
                 if isinstance(data, _utils.worker._IterableDatasetStopIteration):
                     self._shutdown_worker(data.worker_id)
                     self._try_put_index()
                     continue
 
-            if idx != self.rcvd_idx:
+            if idx != self._rcvd_idx:
                 # store out-of-order samples
-                self.task_info[idx] += (data,)
+                self._task_info[idx] += (data,)
             else:
-                del self.task_info[idx]
+                del self._task_info[idx]
                 return self._process_data(data)
 
     next = __next__  # Python 2 compatibility
 
     def _try_put_index(self):
-        assert self.tasks_outstanding < 2 * self.num_workers
+        assert self._tasks_outstanding < 2 * self._num_workers
         try:
             index = self._next_index()
         except StopIteration:
             return
-        for _ in range(self.num_workers):  # find the next active worker, if any
-            worker_queue_idx = next(self.worker_queue_idx_cycle)
-            if self.workers_status[worker_queue_idx]:
+        for _ in range(self._num_workers):  # find the next active worker, if any
+            worker_queue_idx = next(self._worker_queue_idx_cycle)
+            if self._workers_status[worker_queue_idx]:
                 break
         else:
             # not found (i.e., didn't break)
             return
 
-        self.index_queues[worker_queue_idx].put((self.send_idx, index))
-        self.task_info[self.send_idx] = (worker_queue_idx,)
-        self.tasks_outstanding += 1
-        self.send_idx += 1
+        self._index_queues[worker_queue_idx].put((self._send_idx, index))
+        self._task_info[self._send_idx] = (worker_queue_idx,)
+        self._tasks_outstanding += 1
+        self._send_idx += 1
 
     def _process_data(self, data):
-        self.rcvd_idx += 1
+        self._rcvd_idx += 1
         self._try_put_index()
         if isinstance(data, ExceptionWrapper):
             data.reraise()
@@ -851,10 +851,10 @@
         # exhausting an `IterableDataset`. This should be used only when this
         # `_MultiProcessingDataLoaderIter` is going to continue running.
 
-        assert self.workers_status[worker_id]
+        assert self._workers_status[worker_id]
 
         # Signal termination to that specific worker.
-        q = self.index_queues[worker_id]
+        q = self._index_queues[worker_id]
         # Indicate that no more data will be put on this queue by the current
         # process.
         q.put(None)
@@ -867,7 +867,7 @@
         # Joinning is deferred to `_shutdown_workers`, which it is called when
         # all workers finish their jobs (e.g., `IterableDataset` replicas) or
         # when this iterator is garbage collected.
-        self.workers_status[worker_id] = False
+        self._workers_status[worker_id] = False
 
     def _shutdown_workers(self):
         # Called when shutting down this `_MultiProcessingDataLoaderIter`.
@@ -879,32 +879,32 @@
             return
         # Normal exit when last reference is gone / iterator is depleted.
         # See (1) and the second half of the note.
-        if not self.shutdown:
-            self.shutdown = True
+        if not self._shutdown:
+            self._shutdown = True
             try:
                 # Exit `pin_memory_thread` first because exiting workers may leave
                 # corrupted data in `worker_result_queue` which `pin_memory_thread`
                 # reads from.
-                if hasattr(self, 'pin_memory_thread'):
+                if hasattr(self, '_pin_memory_thread'):
                     # Use hasattr in case error happens before we set the attribute.
-                    self.pin_memory_thread_done_event.set()
+                    self._pin_memory_thread_done_event.set()
                     # Send something to pin_memory_thread in case it is waiting
                     # so that it can wake up and check `pin_memory_thread_done_event`
-                    self.worker_result_queue.put((None, None))
-                    self.pin_memory_thread.join()
-                    self.worker_result_queue.close()
+                    self._worker_result_queue.put((None, None))
+                    self._pin_memory_thread.join()
+                    self._worker_result_queue.close()
 
                 # Exit workers now.
-                self.workers_done_event.set()
-                for worker_id in range(len(self.workers)):
-                    # Get number of workers from `len(self.workers)` instead of
-                    # `self.num_workers` in case we error before starting all
+                self._workers_done_event.set()
+                for worker_id in range(len(self._workers)):
+                    # Get number of workers from `len(self._workers)` instead of
+                    # `self._num_workers` in case we error before starting all
                     # workers.
-                    if self.workers_status[worker_id]:
+                    if self._workers_status[worker_id]:
                         self._shutdown_worker(worker_id)
-                for w in self.workers:
+                for w in self._workers:
                     w.join()
-                for q in self.index_queues:
+                for q in self._index_queues:
                     q.cancel_join_thread()
                     q.close()
             finally:
@@ -918,9 +918,9 @@
                 # FIXME: Unfortunately, for Windows, we are missing a worker
                 #        error detection mechanism here in this function, as it
                 #        doesn't provide a SIGCHLD handler.
-                if self.worker_pids_set:
+                if self._worker_pids_set:
                     _utils.signal_handling._remove_worker_pids(id(self))
-                    self.worker_pids_set = False
+                    self._worker_pids_set = False
 
     def __del__(self):
         self._shutdown_workers()