Adds list of operator-related information for testing (#41662)

Summary:
This PR adds:

- an "OpInfo" class in common_method_invocations that can contain useful information about an operator, like what dtypes it supports
- a more specialized "UnaryUfuncInfo" class designed to help test the unary ufuncs
- the `ops` decorator, which can generate test variants from lists of OpInfos
- test_unary_ufuncs.py, a new test suite stub that shows how the `ops` decorator and operator information can be used to improve the thoroughness of our testing

The single test in test_unary_ufuncs.py simply ensures that the dtypes associated with a unary ufunc operator in its OpInfo entry are correct. Writing a test like this previously, however, would have required manually constructing test-specific operator information and writing a custom test generator. The `ops` decorator and a common place to put operator information make writing tests like this easier and allows what would have been test-specific information to be reused.

The `ops` decorator extends and composes with the existing device generic test framework, allowing its decorators to be reused. For example, the `onlyOnCPUAndCUDA` decorator works with the new `ops` decorator. This should keep the tests readable and consistent.

Future PRs will likely:

- continue refactoring the too large test_torch.py into more verticals (unary ufuncs, binary ufuncs, reductions...)
- add more operator information to common_method_invocations.py
- refactor tests for unary ufuncs into test_unary_ufunc

Examples of possible future extensions are [here](https://github.com/pytorch/pytorch/pull/41662/commits/616747e50dbb5a3338deedf41ff44957b162ab51), where an example unary ufunc test is added, and [here](https://github.com/pytorch/pytorch/pull/41662/commits/d0b624f110d470b9a37ad02b389d2f4258c3d632), where example autograd tests are added. Both tests leverage the operator info in common_method_invocations to simplify testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41662

Reviewed By: ngimel

Differential Revision: D23048416

Pulled By: mruberry

fbshipit-source-id: ecce279ac8767f742150d45854404921a6855f2c
diff --git a/test/run_test.py b/test/run_test.py
index 32cbebb..63ae5b3 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -61,6 +61,7 @@
     'test_torch',
     'test_type_info',
     'test_type_hints',
+    'test_unary_ufuncs',
     'test_utils',
     'test_namedtuple_return_api',
     'test_jit_profiling',
diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py
new file mode 100644
index 0000000..fa19434
--- /dev/null
+++ b/test/test_unary_ufuncs.py
@@ -0,0 +1,40 @@
+import torch
+
+from torch.testing._internal.common_utils import \
+    (TestCase, run_tests)
+from torch.testing._internal.common_methods_invocations import \
+    (unary_ufuncs)
+from torch.testing._internal.common_device_type import \
+    (instantiate_device_type_tests, ops, onlyOnCPUAndCUDA, skipCUDAIfRocm)
+
+
+# Tests for unary "universal functions (ufuncs)" that accept a single
+# tensor and have common properties like:
+#   - they are elementwise functions
+#   - the input shape is the output shape
+#   - they typically have method and inplace variants
+#   - they typically support the out kwarg
+#   - they typically have NumPy or SciPy references
+
+# See NumPy's universal function documentation
+# (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
+# about the concept of ufuncs.
+class TestUnaryUfuncs(TestCase):
+    exact_dtype = True
+
+    # Verifies that the unary ufuncs have their supported dtypes
+    #   registered correctly by testing that each unlisted dtype
+    #   throws a runtime error
+    @skipCUDAIfRocm
+    @onlyOnCPUAndCUDA
+    @ops(unary_ufuncs, unsupported_dtypes_only=True)
+    def test_unsupported_dtypes(self, device, dtype, op):
+        t = torch.empty(1, device=device, dtype=dtype)
+        with self.assertRaises(RuntimeError):
+            op(t)
+
+
+instantiate_device_type_tests(TestUnaryUfuncs, globals())
+
+if __name__ == '__main__':
+    run_tests()
diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py
index 02ec017..c70abed 100644
--- a/torch/testing/__init__.py
+++ b/torch/testing/__init__.py
@@ -251,6 +251,68 @@
     return input.data
 
 
+# Functions and classes for describing the dtypes a function supports
+# NOTE: these helpers should correspond to PyTorch's C++ dispatch macros
+
+# Verifies each given dtype is a torch.dtype
+def _validate_dtypes(*dtypes):
+    for dtype in dtypes:
+        assert isinstance(dtype, torch.dtype)
+    return dtypes
+
+# class for tuples corresponding to a PyTorch dispatch macro
+class _dispatch_dtypes(tuple):
+    def __add__(self, other):
+        assert isinstance(other, tuple)
+        return _dispatch_dtypes(tuple.__add__(self, other))
+
+_floating_types = _dispatch_dtypes((torch.float32, torch.float64))
+def floating_types():
+    return _floating_types
+
+_floating_types_and_half = _floating_types + (torch.half,)
+def floating_types_and_half():
+    return _floating_types_and_half
+
+def floating_types_and(*dtypes):
+    return _floating_types + _validate_dtypes(*dtypes)
+
+_floating_and_complex_types = _floating_types + (torch.cfloat, torch.cdouble)
+def floating_and_complex_types():
+    return _floating_and_complex_types
+
+def floating_and_complex_types_and(*dtypes):
+    return _floating_and_complex_types + _validate_dtypes(*dtypes)
+
+_integral_types = _dispatch_dtypes((torch.uint8, torch.int, torch.int16, torch.int32, torch.int64))
+def integral_types():
+    return _integral_types
+
+def integral_types_and(*dtypes):
+    return _integral_types + _validate_dtypes(*dtypes)
+
+_all_types = _floating_types + _integral_types
+def all_types():
+    return _all_types
+
+def all_types_and(*dtypes):
+    return _all_types + _validate_dtypes(*dtypes)
+
+_complex_types = (torch.cfloat, torch.cdouble)
+def complex_types():
+    return _complex_types
+
+_all_types_and_complex = _all_types + _complex_types
+def all_types_and_complex():
+    return _all_types_and_complex
+
+def all_types_and_complex_and(*dtypes):
+    return _all_types_and_complex + _validate_dtypes(*dtypes)
+
+_all_types_and_half = _all_types + (torch.half,)
+def all_types_and_half():
+    return _all_types_and_half
+
 def get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=True, include_complex=True) -> List[torch.dtype]:
     dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16)
     if include_bool:
@@ -259,12 +321,10 @@
         dtypes += get_all_complex_dtypes()
     return dtypes
 
-
 def get_all_math_dtypes(device) -> List[torch.dtype]:
     return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'),
                                                     include_bfloat16=False) + get_all_complex_dtypes()
 
-
 def get_all_complex_dtypes() -> List[torch.dtype]:
     return [torch.complex64, torch.complex128]
 
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
index 034eeea..ad64ade 100644
--- a/torch/testing/_internal/common_device_type.py
+++ b/torch/testing/_internal/common_device_type.py
@@ -9,6 +9,8 @@
 import torch
 from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
     skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN
+from torch.testing import \
+    (get_all_dtypes)
 
 try:
     import psutil
@@ -21,7 +23,7 @@
 # [WRITING TESTS]
 #
 # Write your test class as usual except:
-#   (1) Each test method should have one of four signatures:
+#   (1) Each test method should have one of following five signatures:
 #
 #           (1a) testX(self, device)
 #
@@ -35,9 +37,10 @@
 #                @dtypes(<list of dtypes> or <list of tuples of dtypes>)
 #                testX(self, devices, dtype)
 #
+#           (1e) @ops(<list of OpInfo instances>)
+#                testX(self, device, dtype, op)
 #
-#       Note that the decorators are required for signatures (1b), (1c) and
-#       (1d).
+#       Note that the decorators are required for signatures 1b--1e.
 #
 #       When a test like (1a) is called it will be given a device string,
 #       like 'cpu' or 'cuda:0.'
@@ -55,6 +58,10 @@
 #       Tests like (1d) take a devices argument like (1b) and a dtype
 #       argument from (1c).
 #
+#       Tests like (1e) are instantiated for each provided OpInfo instance,
+#       with dtypes specified by the OpInfo instance (unless overridden with
+#       an additional @dtypes decorator).
+#
 #   (2) Prefer using test decorators defined in this file to others.
 #       For example, using the @skipIfNoLapack decorator instead of the
 #       @skipCPUIfNoLapack will cause the test to not run on CUDA if
@@ -205,46 +212,83 @@
     # Creates device-specific tests.
     @classmethod
     def instantiate_test(cls, name, test):
-        test_name = name + "_" + cls.device_type
 
-        dtypes = cls._get_dtypes(test)
-        if dtypes is None:  # Test has no dtype variants
-            assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name)
+        def instantiate_test_helper(cls, name, *, test, dtype, op):
 
-            @wraps(test)
-            def instantiated_test(self, test=test):
-                device_arg = cls.get_primary_device() if not hasattr(test, 'num_required_devices') else cls.get_all_devices()
-                return test(self, device_arg)
-
-            setattr(cls, test_name, instantiated_test)
-        else:  # Test has dtype variants
-            for dtype in dtypes:
-                # Constructs dtype suffix
+            # Constructs the test's name
+            test_name = name
+            if op is not None:
+                test_name += "_" + op.name
+            test_name += "_" + cls.device_type
+            if dtype is not None:
                 if isinstance(dtype, (list, tuple)):
-                    dtype_str = ""
                     for d in dtype:
-                        dtype_str += "_" + str(d).split('.')[1]
+                        test_name += "_" + str(d).split('.')[1]
                 else:
-                    dtype_str = "_" + str(dtype).split('.')[1]
+                    test_name += "_" + str(dtype).split('.')[1]
 
-                dtype_test_name = test_name + dtype_str
-                assert not hasattr(cls, dtype_test_name), "Redefinition of test {0}".format(dtype_test_name)
+            # Constructs the test
+            @wraps(test)
+            def instantiated_test(self, test=test, dtype=dtype, op=op):
+                device_arg = cls.get_primary_device()
+                if hasattr(test, 'num_required_devices'):
+                    device_arg = cls.get_all_devices()
 
-                @wraps(test)
-                def instantiated_test(self, test=test, dtype=dtype):
-                    device_arg = cls.get_primary_device() if not hasattr(test, 'num_required_devices') else cls.get_all_devices()
-                    # Sets precision and runs test
-                    # Note: precision is reset after the test is run
-                    guard_precision = self.precision
-                    try :
-                        self.precision = self._get_precision_override(test, dtype)
-                        result = test(self, device_arg, dtype)
-                    finally:
-                        self.precision = guard_precision
+                # Sets precision and runs test
+                # Note: precision is reset after the test is run
+                guard_precision = self.precision
+                try:
+                    self.precision = self._get_precision_override(test, dtype)
+                    args = (device_arg, dtype, op)
+                    args = (arg for arg in args if arg is not None)
+                    result = test(self, *args)
+                finally:
+                    self.precision = guard_precision
 
-                    return result
+                return result
 
-                setattr(cls, dtype_test_name, instantiated_test)
+            # wraps with op decorators
+            if op is not None and op.decorators is not None:
+                for decorator in op.decorators:
+                    instantiated_test = decorator(instantiated_test)
+
+            assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name)
+            setattr(cls, test_name, instantiated_test)
+
+        # Handles tests using the ops decorator
+        if hasattr(test, "op_list"):
+            for op in test.op_list:
+                # Acquires dtypes, using the op data if unspecified
+                dtypes = cls._get_dtypes(test)
+                if dtypes is None:
+                    if cls.device_type == 'cpu' and op.dtypesIfCPU is not None:
+                        dtypes = op.dtypesIfCPU
+                    elif (cls.device_type == 'cuda' and not TEST_WITH_ROCM
+                          and op.dtypesIfCUDA is not None):
+                        dtypes = op.dtypesIfCUDA
+                    elif (cls.device_type == 'cuda' and TEST_WITH_ROCM
+                          and op.dtypesIfROCM is not None):
+                        dtypes = op.dtypesIfROCM
+                    else:
+                        dtypes = op.dtypes
+
+                # Inverts dtypes if the function wants unsupported dtypes
+                if test.unsupported_dtypes_only is True:
+                    dtypes = [d for d in get_all_dtypes() if d not in dtypes]
+
+                dtypes = dtypes if dtypes is not None else (None,)
+                for dtype in dtypes:
+                    instantiate_test_helper(cls,
+                                            name,
+                                            test=test,
+                                            dtype=dtype,
+                                            op=op)
+        else:
+            # Handles tests that don't use the ops decorator
+            dtypes = cls._get_dtypes(test)
+            dtypes = tuple(dtypes) if dtypes is not None else (None,)
+            for dtype in dtypes:
+                instantiate_test_helper(cls, name, test=test, dtype=dtype, op=None)
 
 
 class CPUTestBase(DeviceTypeTestBase):
@@ -385,6 +429,23 @@
         scope[class_name] = device_type_test_class
 
 
+# Decorator that defines the ops a test should be run with
+# The test signature must be:
+#   <test_name>(self, device, dtype, op)
+# For example:
+# @ops(unary_ufuncs)
+# test_numerics(self, device, dtype, op):
+#   <test_code>
+class ops(object):
+    def __init__(self, op_list, *, unsupported_dtypes_only=False):
+        self.op_list = op_list
+        self.unsupported_dtypes_only = unsupported_dtypes_only
+
+    def __call__(self, fn):
+        fn.op_list = self.op_list
+        fn.unsupported_dtypes_only = self.unsupported_dtypes_only
+        return fn
+
 # Decorator that skips a test if the given condition is true.
 # Notes:
 #   (1) Skip conditions stack.
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 116db00..772133b 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -4,14 +4,116 @@
 from operator import mul, itemgetter
 import collections
 from torch.autograd import Variable
-from torch.testing import make_non_contiguous
-from torch.testing._internal.common_device_type import (skipCUDAIfNoMagma, skipCPUIfNoLapack, expectedFailureCUDA,
-                                                        expectedAlertNondeterministic)
-from torch.testing._internal.common_utils import (prod_single_zero, random_square_matrix_of_rank,
-                                                  random_symmetric_matrix, random_symmetric_psd_matrix,
-                                                  random_symmetric_pd_matrix, make_nonzero_det,
-                                                  random_fullrank_matrix_distinct_singular_value, set_rng_seed)
 
+from torch.testing import \
+    (make_non_contiguous,
+     _dispatch_dtypes,
+     floating_types, floating_types_and,
+     floating_and_complex_types, floating_and_complex_types_and)
+from torch.testing._internal.common_device_type import \
+    (skipCUDAIfNoMagma, skipCPUIfNoLapack, expectedFailureCUDA,
+     expectedAlertNondeterministic)
+from torch.testing._internal.common_utils import \
+    (prod_single_zero, random_square_matrix_of_rank,
+     random_symmetric_matrix, random_symmetric_psd_matrix,
+     random_symmetric_pd_matrix, make_nonzero_det,
+     random_fullrank_matrix_distinct_singular_value, set_rng_seed)
+
+# Classes and methods for the operator database
+class OpInfo(object):
+    """Operator information and helper functions for acquiring it."""
+
+    def __init__(self,
+                 name,  # the string name of the function
+                 *,
+                 dtypes=floating_types(),  # dtypes this function is expected to work with
+                 dtypesIfCPU=None,  # dtypes this function is expected to work with on CPU
+                 dtypesIfCUDA=None,  # dtypes this function is expected to work with on CUDA
+                 dtypesIfROCM=None,  # dtypes this function is expected to work with on ROCM
+                 decorators=None):  # decorators to apply to generated tests
+        # Validates the dtypes are generated from the dispatch-related functions
+        for dtype_list in (dtypes, dtypesIfCPU, dtypesIfCUDA, dtypesIfROCM):
+            assert isinstance(dtype_list, _dispatch_dtypes)
+
+        self.name = name
+
+        self.dtypes = dtypes
+        self.dtypesIfCPU = dtypesIfCPU if dtypesIfCPU is not None else dtypes
+        self.dtypesIfCUDA = dtypesIfCUDA if dtypesIfCUDA is not None else dtypes
+        self.dtypesIfROCM = dtypesIfROCM if dtypesIfROCM is not None else dtypes
+
+        self.op = getattr(torch, self.name)
+        self.method_variant = getattr(torch.Tensor, name) if hasattr(torch.Tensor, name) else None
+        inplace_name = name + "_"
+        self.inplace_variant = getattr(torch.Tensor, inplace_name) if hasattr(torch.Tensor, name) else None
+        self.decorators = decorators
+
+    def __call__(self, *args, **kwargs):
+        """Calls the function variant of the operator."""
+        return self.op(*args, **kwargs)
+
+    def get_op(self):
+        """Returns the function variant of the operator, torch.<op_name>."""
+        return self.op
+
+    def get_method(self):
+        """Returns the method variant of the operator, torch.Tensor.<op_name>.
+        Returns None if the operator has no method variant.
+        """
+        return self.method_variant
+
+    def get_inplace(self):
+        """Returns the inplace variant of the operator, torch.Tensor.<op_name>_.
+        Returns None if the operator has no inplace variant.
+        """
+        return self.inplace_variant
+
+
+# Metadata class for unary "universal functions (ufuncs)" that accept a single
+# tensor and have common properties like:
+class UnaryUfuncInfo(OpInfo):
+    """Operator information for 'universal unary functions (unary ufuncs).'
+    These are functions of a single tensor with common properties like:
+      - they are elementwise functions
+      - the input shape is the output shape
+      - they typically have method and inplace variants
+      - they typically support the out kwarg
+      - they typically have NumPy or SciPy references
+
+    See NumPy's universal function documentation
+    (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
+    about the concept of ufuncs.
+    """
+
+    def __init__(self,
+                 name,  # the string name of the function
+                 *,
+                 dtypes=floating_types(),
+                 dtypesIfCPU=floating_and_complex_types_and(torch.bfloat16),
+                 dtypesIfCUDA=floating_and_complex_types_and(torch.half),
+                 dtypesIfROCM=floating_types_and(torch.half, torch.bfloat16),
+                 **kwargs):
+        super(UnaryUfuncInfo, self).__init__(name,
+                                             dtypes=dtypes,
+                                             dtypesIfCPU=dtypesIfCPU,
+                                             dtypesIfCUDA=dtypesIfCUDA,
+                                             dtypesIfROCM=dtypesIfROCM,
+                                             **kwargs)
+
+L = 20
+M = 10
+S = 5
+
+# Operator database
+op_db = [
+    UnaryUfuncInfo('cos',
+                   dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16)),
+    UnaryUfuncInfo('cosh',
+                   dtypesIfCPU=floating_and_complex_types()),
+]
+
+# Common operator groupings
+unary_ufuncs = [op for op in op_db if isinstance(op, UnaryUfuncInfo)]
 
 def index_variable(shape, max_indices):
     if not isinstance(shape, tuple):
@@ -93,9 +195,6 @@
         return 0
 
 NO_ARGS = NoArgsClass()
-L = 20
-M = 10
-S = 5
 
 def ident(x):
     return x