add onlyprivateuse1 decorator for test framework (#103664)
Fixes #ISSUE_NUMBER
The current community testing framework does not have a decorator for privateuse1, we have fixed this
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103664
Approved by: https://github.com/albanD
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
index 80cc4fd..1f0fe65 100644
--- a/torch/testing/_internal/common_device_type.py
+++ b/torch/testing/_internal/common_device_type.py
@@ -973,6 +973,11 @@
def __init__(self, dep, reason):
super().__init__(dep, reason, device_type='xla')
+class skipPRIVATEUSE1If(skipIf):
+
+ def __init__(self, dep, reason):
+ device_type = torch._C._get_privateuse1_backend_name()
+ super().__init__(dep, reason, device_type=device_type)
def _has_sufficient_memory(device, size):
if torch.device(device).type == 'cuda':
@@ -1231,6 +1236,13 @@
def onlyMPS(fn):
return onlyOn('mps')(fn)
+def onlyPRIVATEUSE1(fn):
+ device_type = torch._C._get_privateuse1_backend_name()
+ device_mod = getattr(torch, device_type, None)
+ if device_mod is None:
+ reason = "Skip as torch has no module of {0}".format(device_type)
+ return unittest.skip(reason)(fn)
+ return onlyOn(device_type)(fn)
def disablecuDNN(fn):
@@ -1429,6 +1441,9 @@
def skipMPS(fn):
return skipMPSIf(True, "test doesn't work on MPS backend")(fn)
+def skipPRIVATEUSE1(fn):
+ return skipPRIVATEUSE1If(True, "test doesn't work on privateuse1 backend")(fn)
+
# TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now.
# This should probably enumerate all available device type test base classes.
def get_all_device_types() -> List[str]: