blob: 09da0569e288ebfd3fef1a16467c409f5815bdd8 [file] [log] [blame]
# Owner(s): ["module: intel"]
import sys
import unittest
import torch
from torch.testing._internal.common_utils import NoTest, run_tests, TEST_XPU, TestCase
if not TEST_XPU:
print("XPU not available, skipping tests", file=sys.stderr)
TestCase = NoTest # noqa: F811
TEST_MULTIXPU = torch.xpu.device_count() > 1
class TestXpu(TestCase):
def test_device_behavior(self):
current_device = torch.xpu.current_device()
torch.xpu.set_device(current_device)
self.assertEqual(current_device, torch.xpu.current_device())
@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
def test_multi_device_behavior(self):
current_device = torch.xpu.current_device()
target_device = (current_device + 1) % torch.xpu.device_count()
with torch.xpu.device(target_device):
self.assertEqual(target_device, torch.xpu.current_device())
self.assertEqual(current_device, torch.xpu.current_device())
with torch.xpu._DeviceGuard(target_device):
self.assertEqual(target_device, torch.xpu.current_device())
self.assertEqual(current_device, torch.xpu.current_device())
def test_get_device_properties(self):
current_device = torch.xpu.current_device()
device_properties = torch.xpu.get_device_properties(current_device)
self.assertEqual(device_properties, torch.xpu.get_device_properties(None))
self.assertEqual(device_properties, torch.xpu.get_device_properties())
device_name = torch.xpu.get_device_name(current_device)
self.assertEqual(device_name, torch.xpu.get_device_name(None))
self.assertEqual(device_name, torch.xpu.get_device_name())
device_capability = torch.xpu.get_device_capability(current_device)
self.assertTrue(device_capability["max_work_group_size"] > 0)
self.assertTrue(device_capability["max_num_sub_groups"] > 0)
def test_wrong_xpu_fork(self):
stderr = TestCase.runWithPytorchAPIUsageStderr(
"""\
import torch
from torch.multiprocessing import Process
def run(rank):
torch.xpu.set_device(rank)
if __name__ == "__main__":
size = 2
processes = []
for rank in range(size):
# it would work fine without the line below
torch.xpu.set_device(0)
p = Process(target=run, args=(rank,))
p.start()
processes.append(p)
for p in processes:
p.join()
"""
)
self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.")
def test_streams(self):
s0 = torch.xpu.Stream()
torch.xpu.set_stream(s0)
s1 = torch.xpu.current_stream()
self.assertEqual(s0, s1)
s2 = torch.xpu.Stream()
self.assertFalse(s0 == s2)
torch.xpu.set_stream(s2)
with torch.xpu.stream(s0):
self.assertEqual(s0, torch.xpu.current_stream())
self.assertEqual(s2, torch.xpu.current_stream())
def test_stream_priority(self):
low, high = torch.xpu.Stream.priority_range()
s0 = torch.xpu.Stream(device=0, priority=low)
self.assertEqual(low, s0.priority)
self.assertEqual(torch.device("xpu:0"), s0.device)
s1 = torch.xpu.Stream(device=0, priority=high)
self.assertEqual(high, s1.priority)
self.assertEqual(torch.device("xpu:0"), s1.device)
def test_stream_event_repr(self):
s = torch.xpu.current_stream()
self.assertTrue("torch.xpu.Stream" in str(s))
e = torch.xpu.Event()
self.assertTrue("torch.xpu.Event(uninitialized)" in str(e))
s.record_event(e)
self.assertTrue("torch.xpu.Event" in str(e))
def test_events(self):
stream = torch.xpu.current_stream()
event = torch.xpu.Event()
self.assertTrue(event.query())
stream.record_event(event)
event.synchronize()
self.assertTrue(event.query())
def test_generator(self):
torch.manual_seed(2024)
g_state0 = torch.xpu.get_rng_state()
torch.manual_seed(1234)
g_state1 = torch.xpu.get_rng_state()
self.assertNotEqual(g_state0, g_state1)
torch.xpu.manual_seed(2024)
g_state2 = torch.xpu.get_rng_state()
self.assertEqual(g_state0, g_state2)
torch.xpu.set_rng_state(g_state1)
self.assertEqual(g_state1, torch.xpu.get_rng_state())
torch.manual_seed(1234)
torch.xpu.set_rng_state(g_state0)
self.assertEqual(2024, torch.xpu.initial_seed())
if __name__ == "__main__":
run_tests()