[threaded_pg] fix the comments of MultiThreadTestCase (#92373)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92373
Approved by: https://github.com/wz337
diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py
index 61501d9..795dd74 100644
--- a/torch/testing/_internal/common_distributed.py
+++ b/torch/testing/_internal/common_distributed.py
@@ -937,13 +937,15 @@
class MultiThreadedTestCase(TestCase):
"""
- Simple test runner that executes all tests with the in-proc process group.
+ Test runner that runs all tests with the in-proc process group using
+ multiple threads with the threaded process group.
- A single instance of the TestCase object for all threads.
+ Each test spawns world_size threads and run the test method in each thread.
- Difference from regular test runner:
+ Difference from regular MultiProcess test runner:
+ Must explicitly defines SetUp and call self._spawn_threads() to run the tests.
Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown)
- Not sure what these two would be good for though.
+ to set up / tear down each thread when running each test.
No global state possible
How bad of a limitation is this?
"""
@@ -994,7 +996,7 @@
def _spawn_threads(self):
"""
- class method to spawn threads and run test, this is shared by both wrapper and base class approach
+ class method to spawn threads and run test, use this method in the SetUp of your TestCase
"""
test_name = self._current_test_name
# for each test case, we need to create thread local world, and a global store
@@ -1028,7 +1030,7 @@
def run_test_with_threaded_pg(self, test_name, rank, world_size):
"""
- Run ``callback`` with ``world_size`` threads using the in-proc process group
+ Run the current test associated with `test_name` using the threaded process group.
"""
c10d.init_process_group(
@@ -1128,7 +1130,6 @@
@property
def world_size(self) -> int:
- # raise RuntimeError("world size not implemented")
return DEFAULT_WORLD_SIZE
@property