| # Owner(s): ["module: multiprocessing"] | 
 |  | 
 | import os | 
 | import pickle | 
 | import random | 
 | import signal | 
 | import sys | 
 | import time | 
 | import unittest | 
 |  | 
 | from torch.testing._internal.common_utils import (TestCase, run_tests, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN) | 
 | import torch.multiprocessing as mp | 
 |  | 
 |  | 
 | def _test_success_func(i): | 
 |     pass | 
 |  | 
 |  | 
 | def _test_success_single_arg_func(i, arg): | 
 |     if arg: | 
 |         arg.put(i) | 
 |  | 
 |  | 
 | def _test_exception_single_func(i, arg): | 
 |     if i == arg: | 
 |         raise ValueError("legitimate exception from process %d" % i) | 
 |     time.sleep(1.0) | 
 |  | 
 |  | 
 | def _test_exception_all_func(i): | 
 |     time.sleep(random.random() / 10) | 
 |     raise ValueError("legitimate exception from process %d" % i) | 
 |  | 
 |  | 
 | def _test_terminate_signal_func(i): | 
 |     if i == 0: | 
 |         os.kill(os.getpid(), signal.SIGABRT) | 
 |     time.sleep(1.0) | 
 |  | 
 |  | 
 | def _test_terminate_exit_func(i, arg): | 
 |     if i == 0: | 
 |         sys.exit(arg) | 
 |     time.sleep(1.0) | 
 |  | 
 |  | 
 | def _test_success_first_then_exception_func(i, arg): | 
 |     if i == 0: | 
 |         return | 
 |     time.sleep(0.1) | 
 |     raise ValueError("legitimate exception") | 
 |  | 
 |  | 
 | def _test_nested_child_body(i, ready_queue, nested_child_sleep): | 
 |     ready_queue.put(None) | 
 |     time.sleep(nested_child_sleep) | 
 |  | 
 |  | 
 | def _test_infinite_task(i): | 
 |     while True: | 
 |         time.sleep(1) | 
 |  | 
 |  | 
 | def _test_process_exit(idx): | 
 |     sys.exit(12) | 
 |  | 
 |  | 
 | def _test_nested(i, pids_queue, nested_child_sleep, start_method): | 
 |     context = mp.get_context(start_method) | 
 |     nested_child_ready_queue = context.Queue() | 
 |     nprocs = 2 | 
 |     mp_context = mp.start_processes( | 
 |         fn=_test_nested_child_body, | 
 |         args=(nested_child_ready_queue, nested_child_sleep), | 
 |         nprocs=nprocs, | 
 |         join=False, | 
 |         daemon=False, | 
 |         start_method=start_method, | 
 |     ) | 
 |     pids_queue.put(mp_context.pids()) | 
 |  | 
 |     # Wait for both children to have started, to ensure that they | 
 |     # have called prctl(2) to register a parent death signal. | 
 |     for _ in range(nprocs): | 
 |         nested_child_ready_queue.get() | 
 |  | 
 |     # Kill self. This should take down the child processes as well. | 
 |     os.kill(os.getpid(), signal.SIGTERM) | 
 |  | 
 | class _TestMultiProcessing(object): | 
 |     start_method = None | 
 |  | 
 |     def test_success(self): | 
 |         mp.start_processes(_test_success_func, nprocs=2, start_method=self.start_method) | 
 |  | 
 |     def test_success_non_blocking(self): | 
 |         mp_context = mp.start_processes(_test_success_func, nprocs=2, join=False, start_method=self.start_method) | 
 |  | 
 |         # After all processes (nproc=2) have joined it must return True | 
 |         mp_context.join(timeout=None) | 
 |         mp_context.join(timeout=None) | 
 |         self.assertTrue(mp_context.join(timeout=None)) | 
 |  | 
 |     def test_first_argument_index(self): | 
 |         context = mp.get_context(self.start_method) | 
 |         queue = context.SimpleQueue() | 
 |         mp.start_processes(_test_success_single_arg_func, args=(queue,), nprocs=2, start_method=self.start_method) | 
 |         self.assertEqual([0, 1], sorted([queue.get(), queue.get()])) | 
 |  | 
 |     def test_exception_single(self): | 
 |         nprocs = 2 | 
 |         for i in range(nprocs): | 
 |             with self.assertRaisesRegex( | 
 |                 Exception, | 
 |                 "\nValueError: legitimate exception from process %d$" % i, | 
 |             ): | 
 |                 mp.start_processes(_test_exception_single_func, args=(i,), nprocs=nprocs, start_method=self.start_method) | 
 |  | 
 |     def test_exception_all(self): | 
 |         with self.assertRaisesRegex( | 
 |             Exception, | 
 |             "\nValueError: legitimate exception from process (0|1)$", | 
 |         ): | 
 |             mp.start_processes(_test_exception_all_func, nprocs=2, start_method=self.start_method) | 
 |  | 
 |     def test_terminate_signal(self): | 
 |         # SIGABRT is aliased with SIGIOT | 
 |         message = "process 0 terminated with signal (SIGABRT|SIGIOT)" | 
 |  | 
 |         # Termination through with signal is expressed as a negative exit code | 
 |         # in multiprocessing, so we know it was a signal that caused the exit. | 
 |         # This doesn't appear to exist on Windows, where the exit code is always | 
 |         # positive, and therefore results in a different exception message. | 
 |         # Exit code 22 means "ERROR_BAD_COMMAND". | 
 |         if IS_WINDOWS: | 
 |             message = "process 0 terminated with exit code 22" | 
 |  | 
 |         with self.assertRaisesRegex(Exception, message): | 
 |             mp.start_processes(_test_terminate_signal_func, nprocs=2, start_method=self.start_method) | 
 |  | 
 |     def test_terminate_exit(self): | 
 |         exitcode = 123 | 
 |         with self.assertRaisesRegex( | 
 |             Exception, | 
 |             "process 0 terminated with exit code %d" % exitcode, | 
 |         ): | 
 |             mp.start_processes(_test_terminate_exit_func, args=(exitcode,), nprocs=2, start_method=self.start_method) | 
 |  | 
 |     def test_success_first_then_exception(self): | 
 |         exitcode = 123 | 
 |         with self.assertRaisesRegex( | 
 |             Exception, | 
 |             "ValueError: legitimate exception", | 
 |         ): | 
 |             mp.start_processes(_test_success_first_then_exception_func, args=(exitcode,), nprocs=2, start_method=self.start_method) | 
 |  | 
 |     @unittest.skipIf( | 
 |         sys.platform != "linux", | 
 |         "Only runs on Linux; requires prctl(2)", | 
 |     ) | 
 |     def _test_nested(self): | 
 |         context = mp.get_context(self.start_method) | 
 |         pids_queue = context.Queue() | 
 |         nested_child_sleep = 20.0 | 
 |         mp_context = mp.start_processes( | 
 |             fn=_test_nested, | 
 |             args=(pids_queue, nested_child_sleep, self.start_method), | 
 |             nprocs=1, | 
 |             join=False, | 
 |             daemon=False, | 
 |             start_method=self.start_method, | 
 |         ) | 
 |  | 
 |         # Wait for nested children to terminate in time | 
 |         pids = pids_queue.get() | 
 |         start = time.time() | 
 |         while len(pids) > 0: | 
 |             for pid in pids: | 
 |                 try: | 
 |                     os.kill(pid, 0) | 
 |                 except ProcessLookupError: | 
 |                     pids.remove(pid) | 
 |                     break | 
 |  | 
 |             # This assert fails if any nested child process is still | 
 |             # alive after (nested_child_sleep / 2) seconds. By | 
 |             # extension, this test times out with an assertion error | 
 |             # after (nested_child_sleep / 2) seconds. | 
 |             self.assertLess(time.time() - start, nested_child_sleep / 2) | 
 |             time.sleep(0.1) | 
 |  | 
 | @unittest.skipIf( | 
 |     NO_MULTIPROCESSING_SPAWN, | 
 |     "Disabled for environments that don't support the spawn start method") | 
 | class SpawnTest(TestCase, _TestMultiProcessing): | 
 |     start_method = 'spawn' | 
 |  | 
 |     def test_exception_raises(self): | 
 |         with self.assertRaises(mp.ProcessRaisedException): | 
 |             mp.spawn(_test_success_first_then_exception_func, args=(), nprocs=1) | 
 |  | 
 |     def test_signal_raises(self): | 
 |         context = mp.spawn(_test_infinite_task, args=(), nprocs=1, join=False) | 
 |         for pid in context.pids(): | 
 |             os.kill(pid, signal.SIGTERM) | 
 |         with self.assertRaises(mp.ProcessExitedException): | 
 |             context.join() | 
 |  | 
 |     def _test_process_exited(self): | 
 |         with self.assertRaises(mp.ProcessExitedException) as e: | 
 |             mp.spawn(_test_process_exit, args=(), nprocs=1) | 
 |             self.assertEqual(12, e.exit_code) | 
 |  | 
 |  | 
 | @unittest.skipIf( | 
 |     IS_WINDOWS, | 
 |     "Fork is only available on Unix", | 
 | ) | 
 | class ForkTest(TestCase, _TestMultiProcessing): | 
 |     start_method = 'fork' | 
 |  | 
 |  | 
 | class ErrorTest(TestCase): | 
 |     def test_errors_pickleable(self): | 
 |         for error in ( | 
 |             mp.ProcessRaisedException("Oh no!", 1, 1), | 
 |             mp.ProcessExitedException("Oh no!", 1, 1, 1), | 
 |         ): | 
 |             pickle.loads(pickle.dumps(error)) | 
 |  | 
 |  | 
 | if __name__ == '__main__': | 
 |     run_tests() |