| import ctypes |
| import torch |
| from common_utils import TestCase, run_tests, skipIfRocm |
| import unittest |
| |
| # NOTE: this needs to be run in a brand new process |
| |
| # We cannot import TEST_CUDA and TEST_MULTIGPU from common_cuda here, |
| # because if we do that, the TEST_CUDNN line from common_cuda will be executed |
| # multiple times as well during the execution of this test suite, and it will |
| # cause CUDA OOM error on Windows. |
| TEST_CUDA = torch.cuda.is_available() |
| TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2 |
| |
| if not TEST_CUDA: |
| print('CUDA not available, skipping tests') |
| TestCase = object # noqa: F811 |
| |
| |
| def get_is_primary_context_created(device): |
| flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint)) |
| active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) |
| result = torch.cuda.cudart().cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active) |
| assert result == 0, 'cuDevicePrimaryCtxGetState failed' |
| return bool(active[0]) |
| |
| |
| class TestCudaPrimaryCtx(TestCase): |
| @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") |
| @skipIfRocm |
| def test_cuda_primary_ctx(self): |
| # Ensure context has not been created beforehand |
| self.assertFalse(get_is_primary_context_created(0)) |
| self.assertFalse(get_is_primary_context_created(1)) |
| |
| x = torch.randn(1, device='cuda:1') |
| |
| # We should have only created context on 'cuda:1' |
| self.assertFalse(get_is_primary_context_created(0)) |
| self.assertTrue(get_is_primary_context_created(1)) |
| |
| print(x) |
| |
| # We should still have only created context on 'cuda:1' |
| self.assertFalse(get_is_primary_context_created(0)) |
| self.assertTrue(get_is_primary_context_created(1)) |
| |
| y = torch.randn(1, device='cpu') |
| y.copy_(x) |
| |
| # We should still have only created context on 'cuda:1' |
| self.assertFalse(get_is_primary_context_created(0)) |
| self.assertTrue(get_is_primary_context_created(1)) |
| |
| # DO NOT ADD ANY OTHER TESTS HERE! ABOVE TEST REQUIRES FRESH PROCESS |
| |
| if __name__ == '__main__': |
| run_tests() |