Adds list of operator-related information for testing (#41662)
Summary:
This PR adds:
- an "OpInfo" class in common_method_invocations that can contain useful information about an operator, like what dtypes it supports
- a more specialized "UnaryUfuncInfo" class designed to help test the unary ufuncs
- the `ops` decorator, which can generate test variants from lists of OpInfos
- test_unary_ufuncs.py, a new test suite stub that shows how the `ops` decorator and operator information can be used to improve the thoroughness of our testing
The single test in test_unary_ufuncs.py simply ensures that the dtypes associated with a unary ufunc operator in its OpInfo entry are correct. Writing a test like this previously, however, would have required manually constructing test-specific operator information and writing a custom test generator. The `ops` decorator and a common place to put operator information make writing tests like this easier and allows what would have been test-specific information to be reused.
The `ops` decorator extends and composes with the existing device generic test framework, allowing its decorators to be reused. For example, the `onlyOnCPUAndCUDA` decorator works with the new `ops` decorator. This should keep the tests readable and consistent.
Future PRs will likely:
- continue refactoring the too large test_torch.py into more verticals (unary ufuncs, binary ufuncs, reductions...)
- add more operator information to common_method_invocations.py
- refactor tests for unary ufuncs into test_unary_ufunc
Examples of possible future extensions are [here](https://github.com/pytorch/pytorch/pull/41662/commits/616747e50dbb5a3338deedf41ff44957b162ab51), where an example unary ufunc test is added, and [here](https://github.com/pytorch/pytorch/pull/41662/commits/d0b624f110d470b9a37ad02b389d2f4258c3d632), where example autograd tests are added. Both tests leverage the operator info in common_method_invocations to simplify testing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41662
Reviewed By: ngimel
Differential Revision: D23048416
Pulled By: mruberry
fbshipit-source-id: ecce279ac8767f742150d45854404921a6855f2c
diff --git a/test/run_test.py b/test/run_test.py
index 32cbebb..63ae5b3 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -61,6 +61,7 @@
'test_torch',
'test_type_info',
'test_type_hints',
+ 'test_unary_ufuncs',
'test_utils',
'test_namedtuple_return_api',
'test_jit_profiling',
diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py
new file mode 100644
index 0000000..fa19434
--- /dev/null
+++ b/test/test_unary_ufuncs.py
@@ -0,0 +1,40 @@
+import torch
+
+from torch.testing._internal.common_utils import \
+ (TestCase, run_tests)
+from torch.testing._internal.common_methods_invocations import \
+ (unary_ufuncs)
+from torch.testing._internal.common_device_type import \
+ (instantiate_device_type_tests, ops, onlyOnCPUAndCUDA, skipCUDAIfRocm)
+
+
+# Tests for unary "universal functions (ufuncs)" that accept a single
+# tensor and have common properties like:
+# - they are elementwise functions
+# - the input shape is the output shape
+# - they typically have method and inplace variants
+# - they typically support the out kwarg
+# - they typically have NumPy or SciPy references
+
+# See NumPy's universal function documentation
+# (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
+# about the concept of ufuncs.
+class TestUnaryUfuncs(TestCase):
+ exact_dtype = True
+
+ # Verifies that the unary ufuncs have their supported dtypes
+ # registered correctly by testing that each unlisted dtype
+ # throws a runtime error
+ @skipCUDAIfRocm
+ @onlyOnCPUAndCUDA
+ @ops(unary_ufuncs, unsupported_dtypes_only=True)
+ def test_unsupported_dtypes(self, device, dtype, op):
+ t = torch.empty(1, device=device, dtype=dtype)
+ with self.assertRaises(RuntimeError):
+ op(t)
+
+
+instantiate_device_type_tests(TestUnaryUfuncs, globals())
+
+if __name__ == '__main__':
+ run_tests()
diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py
index 02ec017..c70abed 100644
--- a/torch/testing/__init__.py
+++ b/torch/testing/__init__.py
@@ -251,6 +251,68 @@
return input.data
+# Functions and classes for describing the dtypes a function supports
+# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
+
+# Verifies each given dtype is a torch.dtype
+def _validate_dtypes(*dtypes):
+ for dtype in dtypes:
+ assert isinstance(dtype, torch.dtype)
+ return dtypes
+
+# class for tuples corresponding to a PyTorch dispatch macro
+class _dispatch_dtypes(tuple):
+ def __add__(self, other):
+ assert isinstance(other, tuple)
+ return _dispatch_dtypes(tuple.__add__(self, other))
+
+_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
+def floating_types():
+ return _floating_types
+
+_floating_types_and_half = _floating_types + (torch.half,)
+def floating_types_and_half():
+ return _floating_types_and_half
+
+def floating_types_and(*dtypes):
+ return _floating_types + _validate_dtypes(*dtypes)
+
+_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
+def floating_and_complex_types():
+ return _floating_and_complex_types
+
+def floating_and_complex_types_and(*dtypes):
+ return _floating_and_complex_types + _validate_dtypes(*dtypes)
+
+_integral_types = _dispatch_dtypes((torch.uint8, torch.int, torch.int16, torch.int32, torch.int64))
+def integral_types():
+ return _integral_types
+
+def integral_types_and(*dtypes):
+ return _integral_types + _validate_dtypes(*dtypes)
+
+_all_types = _floating_types + _integral_types
+def all_types():
+ return _all_types
+
+def all_types_and(*dtypes):
+ return _all_types + _validate_dtypes(*dtypes)
+
+_complex_types = (torch.cfloat, torch.cdouble)
+def complex_types():
+ return _complex_types
+
+_all_types_and_complex = _all_types + _complex_types
+def all_types_and_complex():
+ return _all_types_and_complex
+
+def all_types_and_complex_and(*dtypes):
+ return _all_types_and_complex + _validate_dtypes(*dtypes)
+
+_all_types_and_half = _all_types + (torch.half,)
+def all_types_and_half():
+ return _all_types_and_half
+
def get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=True, include_complex=True) -> List[torch.dtype]:
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
if include_bool:
@@ -259,12 +321,10 @@
dtypes += get_all_complex_dtypes()
return dtypes
-
def get_all_math_dtypes(device) -> List[torch.dtype]:
return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
include_bfloat16=False) + get_all_complex_dtypes()
-
def get_all_complex_dtypes() -> List[torch.dtype]:
return [torch.complex64, torch.complex128]
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
index 034eeea..ad64ade 100644
--- a/torch/testing/_internal/common_device_type.py
+++ b/torch/testing/_internal/common_device_type.py
@@ -9,6 +9,8 @@
import torch
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN
+from torch.testing import \
+ (get_all_dtypes)
try:
import psutil
@@ -21,7 +23,7 @@
# [WRITING TESTS]
#
# Write your test class as usual except:
-# (1) Each test method should have one of four signatures:
+# (1) Each test method should have one of following five signatures:
#
# (1a) testX(self, device)
#
@@ -35,9 +37,10 @@
# @dtypes(<list of dtypes> or <list of tuples of dtypes>)
# testX(self, devices, dtype)
#
+# (1e) @ops(<list of OpInfo instances>)
+# testX(self, device, dtype, op)
#
-# Note that the decorators are required for signatures (1b), (1c) and
-# (1d).
+# Note that the decorators are required for signatures 1b--1e.
#
# When a test like (1a) is called it will be given a device string,
# like 'cpu' or 'cuda:0.'
@@ -55,6 +58,10 @@
# Tests like (1d) take a devices argument like (1b) and a dtype
# argument from (1c).
#
+# Tests like (1e) are instantiated for each provided OpInfo instance,
+# with dtypes specified by the OpInfo instance (unless overridden with
+# an additional @dtypes decorator).
+#
# (2) Prefer using test decorators defined in this file to others.
# For example, using the @skipIfNoLapack decorator instead of the
# @skipCPUIfNoLapack will cause the test to not run on CUDA if
@@ -205,46 +212,83 @@
# Creates device-specific tests.
@classmethod
def instantiate_test(cls, name, test):
- test_name = name + "_" + cls.device_type
- dtypes = cls._get_dtypes(test)
- if dtypes is None: # Test has no dtype variants
- assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name)
+ def instantiate_test_helper(cls, name, *, test, dtype, op):
- @wraps(test)
- def instantiated_test(self, test=test):
- device_arg = cls.get_primary_device() if not hasattr(test, 'num_required_devices') else cls.get_all_devices()
- return test(self, device_arg)
-
- setattr(cls, test_name, instantiated_test)
- else: # Test has dtype variants
- for dtype in dtypes:
- # Constructs dtype suffix
+ # Constructs the test's name
+ test_name = name
+ if op is not None:
+ test_name += "_" + op.name
+ test_name += "_" + cls.device_type
+ if dtype is not None:
if isinstance(dtype, (list, tuple)):
- dtype_str = ""
for d in dtype:
- dtype_str += "_" + str(d).split('.')[1]
+ test_name += "_" + str(d).split('.')[1]
else:
- dtype_str = "_" + str(dtype).split('.')[1]
+ test_name += "_" + str(dtype).split('.')[1]
- dtype_test_name = test_name + dtype_str
- assert not hasattr(cls, dtype_test_name), "Redefinition of test {0}".format(dtype_test_name)
+ # Constructs the test
+ @wraps(test)
+ def instantiated_test(self, test=test, dtype=dtype, op=op):
+ device_arg = cls.get_primary_device()
+ if hasattr(test, 'num_required_devices'):
+ device_arg = cls.get_all_devices()
- @wraps(test)
- def instantiated_test(self, test=test, dtype=dtype):
- device_arg = cls.get_primary_device() if not hasattr(test, 'num_required_devices') else cls.get_all_devices()
- # Sets precision and runs test
- # Note: precision is reset after the test is run
- guard_precision = self.precision
- try :
- self.precision = self._get_precision_override(test, dtype)
- result = test(self, device_arg, dtype)
- finally:
- self.precision = guard_precision
+ # Sets precision and runs test
+ # Note: precision is reset after the test is run
+ guard_precision = self.precision
+ try:
+ self.precision = self._get_precision_override(test, dtype)
+ args = (device_arg, dtype, op)
+ args = (arg for arg in args if arg is not None)
+ result = test(self, *args)
+ finally:
+ self.precision = guard_precision
- return result
+ return result
- setattr(cls, dtype_test_name, instantiated_test)
+ # wraps with op decorators
+ if op is not None and op.decorators is not None:
+ for decorator in op.decorators:
+ instantiated_test = decorator(instantiated_test)
+
+ assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name)
+ setattr(cls, test_name, instantiated_test)
+
+ # Handles tests using the ops decorator
+ if hasattr(test, "op_list"):
+ for op in test.op_list:
+ # Acquires dtypes, using the op data if unspecified
+ dtypes = cls._get_dtypes(test)
+ if dtypes is None:
+ if cls.device_type == 'cpu' and op.dtypesIfCPU is not None:
+ dtypes = op.dtypesIfCPU
+ elif (cls.device_type == 'cuda' and not TEST_WITH_ROCM
+ and op.dtypesIfCUDA is not None):
+ dtypes = op.dtypesIfCUDA
+ elif (cls.device_type == 'cuda' and TEST_WITH_ROCM
+ and op.dtypesIfROCM is not None):
+ dtypes = op.dtypesIfROCM
+ else:
+ dtypes = op.dtypes
+
+ # Inverts dtypes if the function wants unsupported dtypes
+ if test.unsupported_dtypes_only is True:
+ dtypes = [d for d in get_all_dtypes() if d not in dtypes]
+
+ dtypes = dtypes if dtypes is not None else (None,)
+ for dtype in dtypes:
+ instantiate_test_helper(cls,
+ name,
+ test=test,
+ dtype=dtype,
+ op=op)
+ else:
+ # Handles tests that don't use the ops decorator
+ dtypes = cls._get_dtypes(test)
+ dtypes = tuple(dtypes) if dtypes is not None else (None,)
+ for dtype in dtypes:
+ instantiate_test_helper(cls, name, test=test, dtype=dtype, op=None)
class CPUTestBase(DeviceTypeTestBase):
@@ -385,6 +429,23 @@
scope[class_name] = device_type_test_class
+# Decorator that defines the ops a test should be run with
+# The test signature must be:
+# <test_name>(self, device, dtype, op)
+# For example:
+# @ops(unary_ufuncs)
+# test_numerics(self, device, dtype, op):
+# <test_code>
+class ops(object):
+ def __init__(self, op_list, *, unsupported_dtypes_only=False):
+ self.op_list = op_list
+ self.unsupported_dtypes_only = unsupported_dtypes_only
+
+ def __call__(self, fn):
+ fn.op_list = self.op_list
+ fn.unsupported_dtypes_only = self.unsupported_dtypes_only
+ return fn
+
# Decorator that skips a test if the given condition is true.
# Notes:
# (1) Skip conditions stack.
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 116db00..772133b 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -4,14 +4,116 @@
from operator import mul, itemgetter
import collections
from torch.autograd import Variable
-from torch.testing import make_non_contiguous
-from torch.testing._internal.common_device_type import (skipCUDAIfNoMagma, skipCPUIfNoLapack, expectedFailureCUDA,
- expectedAlertNondeterministic)
-from torch.testing._internal.common_utils import (prod_single_zero, random_square_matrix_of_rank,
- random_symmetric_matrix, random_symmetric_psd_matrix,
- random_symmetric_pd_matrix, make_nonzero_det,
- random_fullrank_matrix_distinct_singular_value, set_rng_seed)
+from torch.testing import \
+ (make_non_contiguous,
+ _dispatch_dtypes,
+ floating_types, floating_types_and,
+ floating_and_complex_types, floating_and_complex_types_and)
+from torch.testing._internal.common_device_type import \
+ (skipCUDAIfNoMagma, skipCPUIfNoLapack, expectedFailureCUDA,
+ expectedAlertNondeterministic)
+from torch.testing._internal.common_utils import \
+ (prod_single_zero, random_square_matrix_of_rank,
+ random_symmetric_matrix, random_symmetric_psd_matrix,
+ random_symmetric_pd_matrix, make_nonzero_det,
+ random_fullrank_matrix_distinct_singular_value, set_rng_seed)
+
+# Classes and methods for the operator database
+class OpInfo(object):
+ """Operator information and helper functions for acquiring it."""
+
+ def __init__(self,
+ name, # the string name of the function
+ *,
+ dtypes=floating_types(), # dtypes this function is expected to work with
+ dtypesIfCPU=None, # dtypes this function is expected to work with on CPU
+ dtypesIfCUDA=None, # dtypes this function is expected to work with on CUDA
+ dtypesIfROCM=None, # dtypes this function is expected to work with on ROCM
+ decorators=None): # decorators to apply to generated tests
+ # Validates the dtypes are generated from the dispatch-related functions
+ for dtype_list in (dtypes, dtypesIfCPU, dtypesIfCUDA, dtypesIfROCM):
+ assert isinstance(dtype_list, _dispatch_dtypes)
+
+ self.name = name
+
+ self.dtypes = dtypes
+ self.dtypesIfCPU = dtypesIfCPU if dtypesIfCPU is not None else dtypes
+ self.dtypesIfCUDA = dtypesIfCUDA if dtypesIfCUDA is not None else dtypes
+ self.dtypesIfROCM = dtypesIfROCM if dtypesIfROCM is not None else dtypes
+
+ self.op = getattr(torch, self.name)
+ self.method_variant = getattr(torch.Tensor, name) if hasattr(torch.Tensor, name) else None
+ inplace_name = name + "_"
+ self.inplace_variant = getattr(torch.Tensor, inplace_name) if hasattr(torch.Tensor, name) else None
+ self.decorators = decorators
+
+ def __call__(self, *args, **kwargs):
+ """Calls the function variant of the operator."""
+ return self.op(*args, **kwargs)
+
+ def get_op(self):
+ """Returns the function variant of the operator, torch.<op_name>."""
+ return self.op
+
+ def get_method(self):
+ """Returns the method variant of the operator, torch.Tensor.<op_name>.
+ Returns None if the operator has no method variant.
+ """
+ return self.method_variant
+
+ def get_inplace(self):
+ """Returns the inplace variant of the operator, torch.Tensor.<op_name>_.
+ Returns None if the operator has no inplace variant.
+ """
+ return self.inplace_variant
+
+
+# Metadata class for unary "universal functions (ufuncs)" that accept a single
+# tensor and have common properties like:
+class UnaryUfuncInfo(OpInfo):
+ """Operator information for 'universal unary functions (unary ufuncs).'
+ These are functions of a single tensor with common properties like:
+ - they are elementwise functions
+ - the input shape is the output shape
+ - they typically have method and inplace variants
+ - they typically support the out kwarg
+ - they typically have NumPy or SciPy references
+
+ See NumPy's universal function documentation
+ (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
+ about the concept of ufuncs.
+ """
+
+ def __init__(self,
+ name, # the string name of the function
+ *,
+ dtypes=floating_types(),
+ dtypesIfCPU=floating_and_complex_types_and(torch.bfloat16),
+ dtypesIfCUDA=floating_and_complex_types_and(torch.half),
+ dtypesIfROCM=floating_types_and(torch.half, torch.bfloat16),
+ **kwargs):
+ super(UnaryUfuncInfo, self).__init__(name,
+ dtypes=dtypes,
+ dtypesIfCPU=dtypesIfCPU,
+ dtypesIfCUDA=dtypesIfCUDA,
+ dtypesIfROCM=dtypesIfROCM,
+ **kwargs)
+
+L = 20
+M = 10
+S = 5
+
+# Operator database
+op_db = [
+ UnaryUfuncInfo('cos',
+ dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16)),
+ UnaryUfuncInfo('cosh',
+ dtypesIfCPU=floating_and_complex_types()),
+]
+
+# Common operator groupings
+unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo)]
def index_variable(shape, max_indices):
if not isinstance(shape, tuple):
@@ -93,9 +195,6 @@
return 0
NO_ARGS = NoArgsClass()
-L = 20
-M = 10
-S = 5
def ident(x):
return x