[Testing] Adding reference tests to `OpInfo` class (#59369)

Summary:
This PR will ideally add `ref` argument to `OpInfo` base class. The idea is to add reference checks for all the ops _eligible_. For more discussion, please check https://github.com/pytorch/pytorch/issues/58294

* [x] Migrate (but not removing yet) and modify helper functions from `UnaryUfuncOpInfo` class to `OpInfo` base class.
* [x] Test the reference checks for multiple ops. (also decide a list of different and eligible ops for this)
* [x] Handle possible edge cases (for example: `uint64` isn't implemented in PyTorch but is there in NumPy, and this needs to be handled -- more on this later) -- _Update_: We decided that these reference tests should only test for values and not types.
* [x] Create a sample PR for a single (of all different categories?) on adding reference functions to the eligible ops. -- _Update_: This is being done in this PR only.
* [x] ~Remove reference tests from `test_unary_ufuncs.py` and test to make sure that nothing breaks.~ (*Update*: We won't be touching Unary Ufunc reference tests in this PR)
* [x] Add comments, remove unnecessary prints/comments (added for debugging).

Note: To keep the PR description short, examples of edge cases encountered have been mentioned in the comments below.

cc: mruberry pmeier kshitij12345

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59369

Reviewed By: ngimel

Differential Revision: D29347252

Pulled By: mruberry

fbshipit-source-id: 69719deddb1d23c53db45287a7e66c1bfe7e65bb
diff --git a/test/test_ops.py b/test/test_ops.py
index 8d5c4ec..fa0ba7a 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -8,9 +8,9 @@
     (FileCheck, floating_and_complex_types_and, get_all_dtypes)
 from torch.testing._internal.common_utils import \
     (TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper, make_tensor,
-     gradcheck, gradgradcheck, IS_IN_CI)
+     gradcheck, gradgradcheck, IS_IN_CI, suppress_warnings)
 from torch.testing._internal.common_methods_invocations import \
-    (op_db,)
+    (op_db, _NOTHING, UnaryUfuncInfo, SpectralFuncInfo)
 from torch.testing._internal.common_device_type import \
     (instantiate_device_type_tests, ops, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes)
 from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
@@ -24,6 +24,12 @@
 _variant_ops = partial(ops, dtypes=OpDTypes.supported,
                        allowed_dtypes=(torch.float, torch.cfloat))
 
+# Get names of all the operators which have ref in their entry in OpInfo (testing infra)
+#   except for Unary Ufuncs (separately implemented in test/test_unary_ufuncs.py)
+#   and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py)
+_ref_test_ops = list(filter(lambda op: not isinstance(op, (UnaryUfuncInfo, SpectralFuncInfo)) and
+                     op.ref is not None and op.ref is not _NOTHING, op_db))
+
 
 # Tests that apply to all operators and aren't related to any particular
 #   system
@@ -165,6 +171,16 @@
 
         self.assertEqual(supported_backward_dtypes, claimed_backward_supported, msg=msg)
 
+    # Tests that the function and its (ndarray-accepting) reference produce the same
+    #   values on the tensors from sample_inputs func for the corresponding op.
+    @onlyOnCPUAndCUDA
+    @suppress_warnings
+    @ops(_ref_test_ops, allowed_dtypes=(torch.float32, torch.long))
+    def test_reference_testing(self, device, dtype, op):
+        sample_inputs = op.sample_inputs(device, dtype)
+        for sample_input in sample_inputs:
+            self.compare_with_reference(op, op.ref, sample_input)
+
     # Validates ops implement the correct out= behavior
     # See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-does-out-work-in-pytorch
     #   for a description of the correct behavior
diff --git a/test/test_shape_ops.py b/test/test_shape_ops.py
index d5db714..916adee 100644
--- a/test/test_shape_ops.py
+++ b/test/test_shape_ops.py
@@ -9,10 +9,9 @@
 from torch._six import nan
 from torch.testing._internal.common_utils import (
     TestCase, run_tests, make_tensor, torch_to_numpy_dtype_dict)
-from torch.testing._internal.common_methods_invocations import shape_funcs
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests, onlyCPU, onlyCUDA, dtypes, onlyOnCPUAndCUDA,
-    dtypesIfCPU, dtypesIfCUDA, largeTensorTest, ops)
+    dtypesIfCPU, dtypesIfCUDA, largeTensorTest)
 
 # TODO: replace with make_tensor
 def _generate_input(shape, dtype, device, with_extremal):
@@ -673,21 +672,7 @@
         nz = x.nonzero()
         self.assertFalse(nz.requires_grad)
 
-class TestShapeFuncs(TestCase):
-    """Test suite for Shape manipulating operators using the ShapeFuncInfo."""
-
-    @dtypes(*(torch.uint8, torch.int64, torch.double, torch.complex128))
-    @ops([op for op in shape_funcs if op.name in ['tile', 'repeat']])
-    def test_repeat_tile_vs_numpy(self, device, dtype, op):
-        samples = op.sample_inputs(device, dtype, requires_grad=False)
-        for sample in samples:
-            assert isinstance(sample.input, torch.Tensor)
-            expected = op.ref(sample.input.cpu().numpy(), *sample.args, **sample.kwargs)
-            result = op(sample.input, *sample.args, **sample.kwargs).cpu().numpy()
-            self.assertEqual(expected, result)
-
 instantiate_device_type_tests(TestShapeOps, globals())
-instantiate_device_type_tests(TestShapeFuncs, globals())
 
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 67d1463..136c684 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -5,6 +5,7 @@
 import copy
 import operator
 import random
+import numbers
 
 import torch
 import numpy as np
@@ -147,6 +148,30 @@
 
         return self._repr_helper(formatter)
 
+    # Returns the NumPy version of the sample input object in the form of a tuple: (input, args, kwargs)
+    def numpy(self):
+        # Converts tensors to ndarrays by calling .detach().cpu().numpy() on them
+        # Numbers, strings, and bool are preserved as is
+        # Lists, tuples and dicts are handled by calling this function recursively
+        def to_numpy(x):
+            def _np(t):
+                return t.detach().cpu().numpy()
+
+            if isinstance(x, torch.Tensor):
+                return _np(x)
+            elif isinstance(x, list):
+                return list(map(to_numpy, x))
+            elif isinstance(x, tuple):
+                return tuple(map(to_numpy, x))
+            elif isinstance(x, dict):
+                return {k: to_numpy(v) for k, v in x.items()}
+            elif isinstance(x, (numbers.Number, bool, str)):
+                return x
+
+            raise ValueError("Unknown type {0}!".format(type(x)))
+
+        sample_np_input, np_args, np_kwargs = to_numpy(self.input), to_numpy(self.args), to_numpy(self.kwargs)
+        return (sample_np_input, np_args, np_kwargs)
 
 class AliasInfo(object):
     """Class holds alias information. For example, torch.abs ->
@@ -412,6 +437,8 @@
     def __init__(self,
                  name,  # the string name of the function
                  *,
+                 ref=None,  # An optional reference function that accepts ndarrays (AKA "NumPy arrays").
+                            # If given, the op will be compared with its reference on each of its sample inputs.
                  # the following metadata describes the operator, its variants,
                  #   and its aliases, if any
                  aliases=None,  # iterable of aliases, e.g. ("absolute",) for torch.abs
@@ -499,6 +526,7 @@
             assert isinstance(dtype_list, (_dispatch_dtypes, type(None)))
 
         self.name = name
+        self.ref = ref
         self.aten_name = aten_name if aten_name is not None else name
         self.variant_test_name = variant_test_name
 
@@ -1790,20 +1818,21 @@
             )
 
 def sample_inputs_amax_amin(op_info, device, dtype, requires_grad, **kwargs):
-    test_cases = (
-        ((S, S, S), ()),
-        ((S, S, S), (1,)),
-        ((S, S, S), ((1, 2,),)),
-        ((S, S, S), (1, True,)),
-        ((), (0,)),
-        ((), ()),
-        ((), (0, True,)),
+    # Ordered as (shape, positional args, kwargs)
+    test_cases: Tuple[tuple, tuple, dict] = (  # type: ignore[assignment]
+        ((S, S, S), (), {}),
+        ((S, S, S), (1,), {}),
+        ((S, S, S), ((1, 2,),), {}),
+        ((S, S, S), (1,), {'keepdim': True}),
+        ((), (0,), {}),
+        ((), (), {}),
+        ((), (0,), {'keepdim': True}),
     )
     return tuple(SampleInput((make_tensor(size, device, dtype,
                                           low=None, high=None,
                                           requires_grad=requires_grad)),
-                             args=args)
-                 for size, args in test_cases)
+                             args=args, kwargs=kwargs)
+                 for size, args, kwargs in test_cases)
 
 def sample_inputs_argmax_argmin(op_info, device, dtype, requires_grad, **kwargs):
     test_cases = (
@@ -4913,9 +4942,11 @@
                SkipInfo('TestCommon', 'test_variant_consistency_eager'),),
            sample_inputs_func=sample_inputs_addcmul_addcdiv),
     OpInfo('amax',
+           ref=lambda a, dim=None, keepdim=False, **kwargs: np.amax(a, axis=dim, keepdims=keepdim, **kwargs),
            dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
            sample_inputs_func=sample_inputs_amax_amin,),
     OpInfo('amin',
+           ref=lambda a, dim=None, keepdim=False, **kwargs: np.amin(a, axis=dim, keepdims=keepdim, **kwargs),
            dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool),
            sample_inputs_func=sample_inputs_amax_amin),
     OpInfo('argmax',
@@ -6305,6 +6336,7 @@
                        SkipInfo('TestUnaryUfuncs', 'test_out_arg_all_dtypes'),
                    )),
     OpInfo('roll',
+           ref=np.roll,
            dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half),
            supports_out=False,
            sample_inputs_func=sample_inputs_roll),
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 21fdce6..74762cc 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -1152,6 +1152,17 @@
     def safeToDense(self, t):
         return t.coalesce().to_dense()
 
+    # Compares torch function with reference function for given sample input (object of SampleInput)
+    # Note: only values are compared, type comparison is not done here
+    def compare_with_reference(self, torch_fn, ref_fn, sample_input, **kwargs):
+        n_inp, n_args, n_kwargs = sample_input.numpy()
+        t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
+
+        actual = torch_fn(t_inp, *t_args, **t_kwargs)
+        expected = ref_fn(n_inp, *n_args, **n_kwargs)
+
+        self.assertEqual(actual, expected, exact_device=False)
+
     # Compares the given Torch and NumPy functions on the given tensor-like object.
     # NOTE: both torch_fn and np_fn should be functions that take a single
     #   tensor (array). If the torch and/or NumPy function require additional