blob: 19f55b2f08676263c155539b34ee5cc142d76b2c [file] [log] [blame]
# Owner(s): ["module: intel"]
import sys
import unittest
import torch
from torch.testing._internal.common_utils import NoTest, run_tests, TEST_XPU, TestCase
if not TEST_XPU:
print("XPU not available, skipping tests", file=sys.stderr)
TestCase = NoTest # noqa: F811
TEST_MULTIXPU = torch.xpu.device_count() > 1
class TestXpu(TestCase):
def test_device_behavior(self):
current_device = torch.xpu.current_device()
torch.xpu.set_device(current_device)
self.assertEqual(current_device, torch.xpu.current_device())
@unittest.skipIf(not TEST_MULTIXPU, "only one GPU detected")
def test_multi_device_behavior(self):
current_device = torch.xpu.current_device()
target_device = (current_device + 1) % torch.xpu.device_count()
with torch.xpu.device(target_device):
self.assertEqual(target_device, torch.xpu.current_device())
self.assertEqual(current_device, torch.xpu.current_device())
with torch.xpu._DeviceGuard(target_device):
self.assertEqual(target_device, torch.xpu.current_device())
self.assertEqual(current_device, torch.xpu.current_device())
def test_get_device_properties(self):
current_device = torch.xpu.current_device()
device_properties = torch.xpu.get_device_properties(current_device)
self.assertEqual(device_properties, torch.xpu.get_device_properties(None))
self.assertEqual(device_properties, torch.xpu.get_device_properties())
device_name = torch.xpu.get_device_name(current_device)
self.assertEqual(device_name, torch.xpu.get_device_name(None))
self.assertEqual(device_name, torch.xpu.get_device_name())
device_capability = torch.xpu.get_device_capability(current_device)
self.assertTrue(device_capability["max_work_group_size"] > 0)
self.assertTrue(device_capability["max_num_sub_groups"] > 0)
def test_wrong_xpu_fork(self):
stderr = TestCase.runWithPytorchAPIUsageStderr(
"""\
import torch
from torch.multiprocessing import Process
def run(rank):
torch.xpu.set_device(rank)
if __name__ == "__main__":
size = 2
processes = []
for rank in range(size):
# it would work fine without the line below
torch.xpu.set_device(0)
p = Process(target=run, args=(rank,))
p.start()
processes.append(p)
for p in processes:
p.join()
"""
)
self.assertRegex(stderr, "Cannot re-initialize XPU in forked subprocess.")
if __name__ == "__main__":
run_tests()