add warning if DataLoader is going to create excessive number of thread (#46867)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46867
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D24545540
Pulled By: glaringlee
fbshipit-source-id: a3bef0d417e535b8ec0bb33f39cfa2308aadfff0
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index 67a9c84..0d6ee2e 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -1564,7 +1564,7 @@
# In all cases, all processes should end properly.
if use_workers:
exit_methods = [None, 'loader_error', 'loader_kill', 'worker_error', 'worker_kill']
- persistent_workers = self.persistent_workers
+ persistent_workers = self.persistent_workers
else:
exit_methods = [None, 'loader_error', 'loader_kill']
persistent_workers = False
@@ -1840,6 +1840,12 @@
finally:
_utils.worker._worker_info = old
+ def test_excessive_thread_creation_warning(self):
+ with self.assertWarnsRegex(
+ UserWarning,
+ r"excessive worker creation might get DataLoader running slow or even freeze"):
+ dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000)
+
class StringDataset(Dataset):
def __init__(self):
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index 59c1827..8d7726e 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -5,6 +5,7 @@
in `./_utils/worker.py`.
"""
+import os
import threading
import itertools
import warnings
@@ -290,10 +291,13 @@
self._iterator = None
+ self.check_worker_number_rationality()
+
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
+ self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
@property
@@ -399,6 +403,83 @@
else:
return len(self._index_sampler)
+ def check_worker_number_rationality(self):
+ # This function check whether the dataloader's worker number is rational based on
+ # current system's resource. Current rule is that if the number of workers this
+ # Dataloader will create is bigger than the number of logical cpus that is allowed to
+ # use, than we will pop up a warning to let user pay attention.
+ #
+ # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
+ # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
+ # DataLoader process can use half of them which is 32, then the rational max number of
+ # worker that initiated from this process is 32.
+ # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
+ # So the warning message is triggered to notify the user to lower the worker number if
+ # necessary.
+ #
+ #
+ # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
+ # available (available in most of Linux system, but not OSX and Windows).
+ # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
+ # it doesn't repect cpuset.
+ # We don't take threading into account since each worker process is single threaded
+ # at this time.
+ #
+ # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
+ # other than `torch.set_num_threads` to 1 in the worker process, if the passing
+ # in functions use 3rd party modules that rely on those threading flags to determine
+ # how many thread to create (eg. numpy, etc), then it is caller's responsibility to
+ # set those flags correctly.
+ def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
+
+ suggested_max_worker_msg = ((
+ "Our suggested max number of worker in current system is {}{}, which is smaller "
+ "than what this DataLoader is going to create.").format(
+ num_worker_suggest,
+ ("" if cpuset_checked else " (`cpuset` is not taken into account)"))
+ ) if num_worker_suggest is not None else (
+ "DataLoader is not able to compute a suggested max number of worker in current system.")
+
+ warn_msg = (
+ "This DataLoader will create {} worker processes in total. {} "
+ "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
+ "lower the worker number to avoid potential slowness/freeze if necessary.").format(
+ num_worker_created,
+ suggested_max_worker_msg)
+ return warn_msg
+
+ if not self.num_workers or self.num_workers == 0:
+ return
+
+ # try to compute a suggested max number of worker based on system's resource
+ max_num_worker_suggest = None
+ cpuset_checked = False
+ if hasattr(os, 'sched_getaffinity'):
+ try:
+ max_num_worker_suggest = len(os.sched_getaffinity(0))
+ cpuset_checked = True
+ except Exception:
+ pass
+ if max_num_worker_suggest is None:
+ # os.cpu_count() could return Optional[int]
+ # get cpu count first and check None in order to satify mypy check
+ cpu_count = os.cpu_count()
+ if cpu_count is not None:
+ max_num_worker_suggest = cpu_count
+
+ if max_num_worker_suggest is None:
+ warnings.warn(_create_warning_msg(
+ max_num_worker_suggest,
+ self.num_workers,
+ cpuset_checked))
+ return
+
+ if self.num_workers > max_num_worker_suggest:
+ warnings.warn(_create_warning_msg(
+ max_num_worker_suggest,
+ self.num_workers,
+ cpuset_checked))
+
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
@@ -843,7 +924,7 @@
# contains all `True`s if not using an iterable-style dataset
# (i.e., if kind != Iterable).
# Not that this indicates that a worker still has work to do *for this epoch*.
- # It does not mean that a worker is dead. In case of `_persistent_workers`,
+ # It does not mean that a worker is dead. In case of `_persistent_workers`,
# the worker will be reset to available in the next epoch.
self._workers_status = [True for i in range(self._num_workers)]
# We resume the prefetching in case it was enabled