add warning if DataLoader is going to create excessive number of thread (#46867)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46867

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D24545540

Pulled By: glaringlee

fbshipit-source-id: a3bef0d417e535b8ec0bb33f39cfa2308aadfff0
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index 67a9c84..0d6ee2e 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -1564,7 +1564,7 @@
             # In all cases, all processes should end properly.
             if use_workers:
                 exit_methods = [None, 'loader_error', 'loader_kill', 'worker_error', 'worker_kill']
-                persistent_workers = self.persistent_workers 
+                persistent_workers = self.persistent_workers
             else:
                 exit_methods = [None, 'loader_error', 'loader_kill']
                 persistent_workers = False
@@ -1840,6 +1840,12 @@
         finally:
             _utils.worker._worker_info = old
 
+    def test_excessive_thread_creation_warning(self):
+        with self.assertWarnsRegex(
+            UserWarning,
+                r"excessive worker creation might get DataLoader running slow or even freeze"):
+            dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000)
+
 
 class StringDataset(Dataset):
     def __init__(self):
diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py
index 59c1827..8d7726e 100644
--- a/torch/utils/data/dataloader.py
+++ b/torch/utils/data/dataloader.py
@@ -5,6 +5,7 @@
 in `./_utils/worker.py`.
 """
 
+import os
 import threading
 import itertools
 import warnings
@@ -290,10 +291,13 @@
 
         self._iterator = None
 
+        self.check_worker_number_rationality()
+
     def _get_iterator(self) -> '_BaseDataLoaderIter':
         if self.num_workers == 0:
             return _SingleProcessDataLoaderIter(self)
         else:
+            self.check_worker_number_rationality()
             return _MultiProcessingDataLoaderIter(self)
 
     @property
@@ -399,6 +403,83 @@
         else:
             return len(self._index_sampler)
 
+    def check_worker_number_rationality(self):
+        # This function check whether the dataloader's worker number is rational based on
+        # current system's resource. Current rule is that if the number of workers this
+        # Dataloader will create is bigger than the number of logical cpus that is allowed to
+        # use, than we will pop up a warning to let user pay attention.
+        #
+        # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
+        #     threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
+        #     DataLoader process can use half of them which is 32, then the rational max number of
+        #     worker that initiated from this process is 32.
+        #     Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
+        #     So the warning message is triggered to notify the user to lower the worker number if
+        #     necessary.
+        #
+        #
+        # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
+        #        available (available in most of Linux system, but not OSX and Windows).
+        #        When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
+        #        it doesn't repect cpuset.
+        #        We don't take threading into account since each worker process is single threaded
+        #        at this time.
+        #
+        #        We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
+        #        other than `torch.set_num_threads` to 1 in the worker process, if the passing
+        #        in functions use 3rd party modules that rely on those threading flags to determine
+        #        how many thread to create (eg. numpy, etc), then it is caller's responsibility to
+        #        set those flags correctly.
+        def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
+
+            suggested_max_worker_msg = ((
+                "Our suggested max number of worker in current system is {}{}, which is smaller "
+                "than what this DataLoader is going to create.").format(
+                    num_worker_suggest,
+                    ("" if cpuset_checked else " (`cpuset` is not taken into account)"))
+            ) if num_worker_suggest is not None else (
+                "DataLoader is not able to compute a suggested max number of worker in current system.")
+
+            warn_msg = (
+                "This DataLoader will create {} worker processes in total. {} "
+                "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
+                "lower the worker number to avoid potential slowness/freeze if necessary.").format(
+                    num_worker_created,
+                    suggested_max_worker_msg)
+            return warn_msg
+
+        if not self.num_workers or self.num_workers == 0:
+            return
+
+        # try to compute a suggested max number of worker based on system's resource
+        max_num_worker_suggest = None
+        cpuset_checked = False
+        if hasattr(os, 'sched_getaffinity'):
+            try:
+                max_num_worker_suggest = len(os.sched_getaffinity(0))
+                cpuset_checked = True
+            except Exception:
+                pass
+        if max_num_worker_suggest is None:
+            # os.cpu_count() could return Optional[int]
+            # get cpu count first and check None in order to satify mypy check
+            cpu_count = os.cpu_count()
+            if cpu_count is not None:
+                max_num_worker_suggest = cpu_count
+
+        if max_num_worker_suggest is None:
+            warnings.warn(_create_warning_msg(
+                max_num_worker_suggest,
+                self.num_workers,
+                cpuset_checked))
+            return
+
+        if self.num_workers > max_num_worker_suggest:
+            warnings.warn(_create_warning_msg(
+                max_num_worker_suggest,
+                self.num_workers,
+                cpuset_checked))
+
 
 class _BaseDataLoaderIter(object):
     def __init__(self, loader: DataLoader) -> None:
@@ -843,7 +924,7 @@
         # contains all `True`s if not using an iterable-style dataset
         # (i.e., if kind != Iterable).
         # Not that this indicates that a worker still has work to do *for this epoch*.
-        # It does not mean that a worker is dead. In case of `_persistent_workers`, 
+        # It does not mean that a worker is dead. In case of `_persistent_workers`,
         # the worker will be reset to available in the next epoch.
         self._workers_status = [True for i in range(self._num_workers)]
         # We resume the prefetching in case it was enabled