| # Owner(s): ["module: intel"] |
| |
| import sys |
| import unittest |
| |
| import torch |
| from torch.testing._internal.common_utils import NoTest, run_tests, TEST_XPU, TestCase |
| |
| if not TEST_XPU: |
| print("XPU not available, skipping tests", file=sys.stderr) |
| TestCase = NoTest # noqa: F811 |
| |
| TEST_MULTIXPU = torch.xpu.device_count() > 1 |
| |
| |
| class TestXpu(TestCase): |
| def test_device_behavior(self): |
| current_device = torch.xpu.current_device() |
| torch.xpu.set_device(current_device) |
| self.assertEqual(current_device, torch.xpu.current_device()) |
| |
| @unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected") |
| def test_multi_device_behavior(self): |
| current_device = torch.xpu.current_device() |
| target_device = (current_device + 1) % torch.xpu.device_count() |
| |
| with torch.xpu.device(target_device): |
| self.assertEqual(target_device, torch.xpu.current_device()) |
| self.assertEqual(current_device, torch.xpu.current_device()) |
| |
| with torch.xpu._DeviceGuard(target_device): |
| self.assertEqual(target_device, torch.xpu.current_device()) |
| self.assertEqual(current_device, torch.xpu.current_device()) |
| |
| def test_get_device_properties(self): |
| current_device = torch.xpu.current_device() |
| device_properties = torch.xpu.get_device_properties(current_device) |
| self.assertEqual(device_properties, torch.xpu.get_device_properties(None)) |
| self.assertEqual(device_properties, torch.xpu.get_device_properties()) |
| |
| device_name = torch.xpu.get_device_name(current_device) |
| self.assertEqual(device_name, torch.xpu.get_device_name(None)) |
| self.assertEqual(device_name, torch.xpu.get_device_name()) |
| |
| device_capability = torch.xpu.get_device_capability(current_device) |
| self.assertTrue(device_capability["max_work_group_size"] > 0) |
| self.assertTrue(device_capability["max_num_sub_groups"] > 0) |
| |
| def test_wrong_xpu_fork(self): |
| stderr = TestCase.runWithPytorchAPIUsageStderr( |
| """\ |
| import torch |
| from torch.multiprocessing import Process |
| def run(rank): |
| torch.xpu.set_device(rank) |
| if __name__ == "__main__": |
| size = 2 |
| processes = [] |
| for rank in range(size): |
| # it would work fine without the line below |
| torch.xpu.set_device(0) |
| p = Process(target=run, args=(rank,)) |
| p.start() |
| processes.append(p) |
| for p in processes: |
| p.join() |
| """ |
| ) |
| self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.") |
| |
| def test_streams(self): |
| s0 = torch.xpu.Stream() |
| torch.xpu.set_stream(s0) |
| s1 = torch.xpu.current_stream() |
| self.assertEqual(s0, s1) |
| s2 = torch.xpu.Stream() |
| self.assertFalse(s0 == s2) |
| torch.xpu.set_stream(s2) |
| with torch.xpu.stream(s0): |
| self.assertEqual(s0, torch.xpu.current_stream()) |
| self.assertEqual(s2, torch.xpu.current_stream()) |
| |
| def test_stream_priority(self): |
| low, high = torch.xpu.Stream.priority_range() |
| s0 = torch.xpu.Stream(device=0, priority=low) |
| |
| self.assertEqual(low, s0.priority) |
| self.assertEqual(torch.device("xpu:0"), s0.device) |
| |
| s1 = torch.xpu.Stream(device=0, priority=high) |
| |
| self.assertEqual(high, s1.priority) |
| self.assertEqual(torch.device("xpu:0"), s1.device) |
| |
| def test_stream_event_repr(self): |
| s = torch.xpu.current_stream() |
| self.assertTrue("torch.xpu.Stream" in str(s)) |
| e = torch.xpu.Event() |
| self.assertTrue("torch.xpu.Event(uninitialized)" in str(e)) |
| s.record_event(e) |
| self.assertTrue("torch.xpu.Event" in str(e)) |
| |
| def test_events(self): |
| stream = torch.xpu.current_stream() |
| event = torch.xpu.Event() |
| self.assertTrue(event.query()) |
| stream.record_event(event) |
| event.synchronize() |
| self.assertTrue(event.query()) |
| |
| def test_generator(self): |
| torch.manual_seed(2024) |
| g_state0 = torch.xpu.get_rng_state() |
| torch.manual_seed(1234) |
| g_state1 = torch.xpu.get_rng_state() |
| self.assertNotEqual(g_state0, g_state1) |
| |
| torch.xpu.manual_seed(2024) |
| g_state2 = torch.xpu.get_rng_state() |
| self.assertEqual(g_state0, g_state2) |
| |
| torch.xpu.set_rng_state(g_state1) |
| self.assertEqual(g_state1, torch.xpu.get_rng_state()) |
| |
| torch.manual_seed(1234) |
| torch.xpu.set_rng_state(g_state0) |
| self.assertEqual(2024, torch.xpu.initial_seed()) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |