blob: 06c0e44cdcc1e0c840dbd286041bda3fc42a1dea [file] [log] [blame]
# Owner(s): ["module: intel"]
import sys
import unittest
import torch
from torch.testing._internal.common_utils import NoTest, run_tests, TEST_XPU, TestCase
if not TEST_XPU:
print("XPU not available, skipping tests", file=sys.stderr)
TestCase = NoTest # noqa: F811
TEST_MULTIXPU = torch.xpu.device_count() > 1
class TestXpu(TestCase):
def test_device_behavior(self):
current_device = torch.xpu.current_device()
torch.xpu.set_device(current_device)
self.assertEqual(current_device, torch.xpu.current_device())
@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
def test_multi_device_behavior(self):
current_device = torch.xpu.current_device()
target_device = (current_device + 1) % torch.xpu.device_count()
with torch.xpu.device(target_device):
self.assertEqual(target_device, torch.xpu.current_device())
self.assertEqual(current_device, torch.xpu.current_device())
with torch.xpu._DeviceGuard(target_device):
self.assertEqual(target_device, torch.xpu.current_device())
self.assertEqual(current_device, torch.xpu.current_device())
def test_get_device_properties(self):
current_device = torch.xpu.current_device()
device_properties = torch.xpu.get_device_properties(current_device)
self.assertEqual(device_properties, torch.xpu.get_device_properties(None))
self.assertEqual(device_properties, torch.xpu.get_device_properties())
device_name = torch.xpu.get_device_name(current_device)
self.assertEqual(device_name, torch.xpu.get_device_name(None))
self.assertEqual(device_name, torch.xpu.get_device_name())
device_capability = torch.xpu.get_device_capability(current_device)
self.assertTrue(device_capability["max_work_group_size"] > 0)
self.assertTrue(device_capability["max_num_sub_groups"] > 0)
def test_wrong_xpu_fork(self):
stderr = TestCase.runWithPytorchAPIUsageStderr(
"""\
import torch
from torch.multiprocessing import Process
def run(rank):
torch.xpu.set_device(rank)
if __name__ == "__main__":
size = 2
processes = []
for rank in range(size):
# it would work fine without the line below
torch.xpu.set_device(0)
p = Process(target=run, args=(rank,))
p.start()
processes.append(p)
for p in processes:
p.join()
"""
)
self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.")
def test_streams(self):
s0 = torch.xpu.Stream()
torch.xpu.set_stream(s0)
s1 = torch.xpu.current_stream()
self.assertEqual(s0, s1)
s2 = torch.xpu.Stream()
self.assertFalse(s0 == s2)
torch.xpu.set_stream(s2)
with torch.xpu.stream(s0):
self.assertEqual(s0, torch.xpu.current_stream())
self.assertEqual(s2, torch.xpu.current_stream())
def test_stream_priority(self):
low, high = torch.xpu.Stream.priority_range()
s0 = torch.xpu.Stream(device=0, priority=low)
self.assertEqual(low, s0.priority)
self.assertEqual(torch.device("xpu:0"), s0.device)
s1 = torch.xpu.Stream(device=0, priority=high)
self.assertEqual(high, s1.priority)
self.assertEqual(torch.device("xpu:0"), s1.device)
if __name__ == "__main__":
run_tests()