blob: fa194340dd195474a96f14ccd6274ab7712c71de [file] [log] [blame]
import torch
from torch.testing._internal.common_utils import \
(TestCase, run_tests)
from torch.testing._internal.common_methods_invocations import \
(unary_ufuncs)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, onlyOnCPUAndCUDA, skipCUDAIfRocm)
# Tests for unary "universal functions (ufuncs)" that accept a single
# tensor and have common properties like:
# - they are elementwise functions
# - the input shape is the output shape
# - they typically have method and inplace variants
# - they typically support the out kwarg
# - they typically have NumPy or SciPy references
# See NumPy's universal function documentation
# (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
# about the concept of ufuncs.
class TestUnaryUfuncs(TestCase):
exact_dtype = True
# Verifies that the unary ufuncs have their supported dtypes
# registered correctly by testing that each unlisted dtype
# throws a runtime error
@skipCUDAIfRocm
@onlyOnCPUAndCUDA
@ops(unary_ufuncs, unsupported_dtypes_only=True)
def test_unsupported_dtypes(self, device, dtype, op):
t = torch.empty(1, device=device, dtype=dtype)
with self.assertRaises(RuntimeError):
op(t)
instantiate_device_type_tests(TestUnaryUfuncs, globals())
if __name__ == '__main__':
run_tests()