[functorch] Move to PyTorch core's parametrize testing mechanism
diff --git a/functorch/test/common_utils.py b/functorch/test/common_utils.py
index 79a069b..8959dff 100644
--- a/functorch/test/common_utils.py
+++ b/functorch/test/common_utils.py
@@ -13,141 +13,6 @@
 from functorch_additional_op_db import additional_op_db
 from torch.testing._internal.common_methods_invocations import DecorateInfo
 import unittest
-import warnings
-import re
-
-"""
-Usage:
-
-class MyTestCase(TestCase):
-    @parameterized('param', {'abs': torch.abs, 'cos': torch.cos})
-    def test_single_param(self, param):
-        pass
-
-    @parameterized('param1', {'sin': torch.sin, 'tan': torch.tan})
-    @parameterized('param2', {'abs': torch.abs, 'cos': torch.cos})
-    def test_multiple_param(self, param1, param2):
-        pass
-
-# The following creates:
-# - MyTestCase.test_single_param_abs
-# - MyTestCase.test_single_param_cos
-# - MyTestCase.test_multiple_param_abs_sin
-# - MyTestCase.test_multiple_param_cos_sin
-# - MyTestCase.test_multiple_param_abs_tan
-# - MyTestCase.test_multiple_param_cos_tan
-instantiate_parameterized_methods(MyTestCase)
-
-# This is also composable with PyTorch testing's instantiate_device_type_tests
-# Make sure the param is after the device arg
-class MyDeviceSpecificTest(TestCase):
-    @parameterized('param', {'abs': torch.abs, 'cos': torch.cos})
-    def test_single_param(self, device, param):
-        pass
-
-# The following creates:
-# - MyDeviceSpecificTestCPU.test_single_param_abs_cpu
-# - MyDeviceSpecificTestCPU.test_single_param_cos_cpu
-# - MyDeviceSpecificTestCUDA.test_single_param_abs_cuda
-# - MyDeviceSpecificTestCUDA.test_single_param_cos_cpu
-instantiate_parameterized_methods(MyDeviceSpecificTest)
-instantiate_device_type_tests(MyDeviceSpecificTest, globals())
-
-# !!!!! warning !!!!!
-# 1. The method being parameterized over MUST NOT HAVE A DOCSTRING. We'll
-# error out nicely if this happens.
-# 2. All other decorators MUST USE functools.wraps (they must propagate the docstring)
-# `@parameterized` works by storing some metadata in place of the docstring.
-# This takes advantage of how other decorators work (other decorators usually
-# propagate the docstring via functools.wrap).
-# 3. We might not compose with PyTorch testing's @dtypes and @precision
-# decorators. But that is easily fixable. TODO.
-# I think this composes with PyTorch testing's instantiate_device_type_tests.
-"""
-
-PARAM_META = '_torch_parameterized_meta'
-
-class ParamMeta():
-    def __init__(self):
-        self.stack = []
-
-    def push(self, elt):
-        self.stack.append(elt)
-
-    def pop(self, elt):
-        return self.stack.pop()
-
-def has_param_meta(method):
-    param_meta = getattr(method, '__doc__', None)
-    return param_meta is not None and isinstance(param_meta, ParamMeta)
-
-def get_param_meta(method):
-    param_meta = getattr(method, '__doc__', None)
-    if param_meta is None:
-        method.__doc__ = ParamMeta()
-    if not isinstance(method.__doc__, ParamMeta):
-        raise RuntimeError('Tried to use @parameterized on a method that has '
-                           'a docstring. This is not supported. Please remove '
-                           'the docstring.')
-    return method.__doc__
-
-def parameterized(arg_name, case_dict):
-    def decorator(fn):
-        param_meta = get_param_meta(fn)
-        param_meta.push((arg_name, case_dict))
-        return fn
-    return decorator
-
-def parameterized_with_device(arg_name, case_dict):
-    def decorator(fn):
-        param_meta = get_param_meta(fn)
-        param_meta.push((arg_name, case_dict))
-        fn._has_device = True
-        return fn
-    return decorator
-
-
-def _set_parameterized_method(test_base, fn, instantiated_cases, extension_name):
-    new_name = f'{fn.__name__}_{extension_name}'
-
-    def wrapped_no_device(self, *args, **kwargs):
-        for arg_name, case in instantiated_cases:
-            kwargs[arg_name] = case
-        return fn(self, *args, **kwargs)
-
-    def wrapped_with_device(self, device, *args, **kwargs):
-        for arg_name, case in instantiated_cases:
-            kwargs[arg_name] = case
-        return fn(self, device, *args, **kwargs)
-
-    if getattr(fn, '_has_device', False):
-        wrapped = wrapped_with_device
-    else:
-        wrapped = wrapped_no_device
-
-    wrapped.__name__ = new_name
-    setattr(test_base, new_name, wrapped)
-
-def to_tuples(dct):
-    return [(k, v) for k, v in dct.items()]
-
-def instantiate_parameterized_methods(test_base):
-    allattrs = tuple(dir(test_base))
-    for attr_name in allattrs:
-        attr = getattr(test_base, attr_name)
-        if not has_param_meta(attr):
-            continue
-
-        param_meta = get_param_meta(attr)
-        arg_names, case_dicts = zip(*param_meta.stack)
-        case_dicts = [to_tuples(cd) for cd in case_dicts]
-        for list_of_name_and_case in itertools.product(*case_dicts):
-            case_names, cases = zip(*list_of_name_and_case)
-            extension_name = '_'.join(case_names)
-            instantiated_cases = list(zip(arg_names, cases))
-            _set_parameterized_method(test_base, attr, instantiated_cases, extension_name)
-        # Remove the base fn from the testcase
-        delattr(test_base, attr_name)
 
 
 def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index 1580644..be17f71 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -19,8 +19,6 @@
 from functorch_lagging_op_db import functorch_lagging_op_db
 from functorch_additional_op_db import additional_op_db
 from common_utils import (
-    parameterized,
-    instantiate_parameterized_methods,
     get_fallback_and_vmap_exhaustive,
     get_exhaustive_batched_inputs,
     opinfo_in_dict,
diff --git a/functorch/test/test_pythonkey.py b/functorch/test/test_pythonkey.py
index 40f82bc..e972b72 100644
--- a/functorch/test/test_pythonkey.py
+++ b/functorch/test/test_pythonkey.py
@@ -30,9 +30,6 @@
 from functorch_lagging_op_db import functorch_lagging_op_db
 from functorch_additional_op_db import additional_op_db
 from common_utils import (
-    parameterized,
-    parameterized_with_device,
-    instantiate_parameterized_methods,
     get_fallback_and_vmap_exhaustive,
     opinfo_in_dict,
     xfail,
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index 9469f01..de08cd5 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -18,13 +18,15 @@
 from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
     skipCUDAIfNoMagma
 from torch.testing._internal.common_device_type import ops, onlyCPU
+from torch.testing._internal.common_utils import (
+    parametrize,
+    instantiate_parametrized_tests,
+    subtest
+)
 from functorch_lagging_op_db import functorch_lagging_op_db
 from functorch_additional_op_db import additional_op_db
 from torch.utils._pytree import tree_map
 from common_utils import (
-    parameterized,
-    parameterized_with_device,
-    instantiate_parameterized_methods,
     get_fallback_and_vmap_exhaustive,
     opinfo_in_dict,
     xfail,
@@ -1162,38 +1164,38 @@
         test(vmap(op, in_dims=2), [getter([2, 5, B0, B1, 3], device)],
              in_dims=2, out_dims=2)
 
-    @parameterized("case", {
-        'abs': (torch.abs, TensorFactory.randn),
-        'acos': (torch.acos, TensorFactory.rand),
-        'asin': (torch.asin, TensorFactory.rand),
-        'atan': (torch.atan, TensorFactory.rand),
-        'ceil': (torch.ceil, TensorFactory.randn),
-        'cos': (torch.cos, TensorFactory.rand),
-        'cosh': (torch.cosh, TensorFactory.rand),
-        'digamma': (torch.digamma, TensorFactory.rand),
-        'exp': (torch.exp, TensorFactory.randn),
-        'expm1': (torch.expm1, TensorFactory.randn),
-        'floor': (torch.floor, TensorFactory.randn),
-        'frac': (torch.frac, TensorFactory.randn),
-        'lgamma': (torch.lgamma, TensorFactory.rand),
-        'log': (torch.log, TensorFactory.randp1),
-        'log10': (torch.log10, TensorFactory.randp1),
-        'log1p': (torch.log1p, TensorFactory.randp1),
-        'log2': (torch.log2, TensorFactory.randp1),
-        'neg': (torch.neg, TensorFactory.randn),
-        'reciprocol': (torch.reciprocal, TensorFactory.randp1),
-        'relu': (torch.relu, TensorFactory.randn),
-        'round': (torch.round, TensorFactory.randn),
-        'rsqrt': (torch.rsqrt, TensorFactory.randp1),
-        'sigmoid': (torch.sigmoid, TensorFactory.randn),
-        'sign': (torch.sign, TensorFactory.randn),
-        'sin': (torch.sin, TensorFactory.rand),
-        'sinh': (torch.sinh, TensorFactory.rand),
-        'sqrt': (torch.sqrt, TensorFactory.rand),
-        'tan': (torch.tan, TensorFactory.rand),
-        'tanh': (torch.tanh, TensorFactory.rand),
-        'trunc': (torch.trunc, TensorFactory.randn),
-    })
+    @parametrize("case", [
+        (torch.abs, TensorFactory.randn),
+        (torch.acos, TensorFactory.rand),
+        (torch.asin, TensorFactory.rand),
+        (torch.atan, TensorFactory.rand),
+        (torch.ceil, TensorFactory.randn),
+        (torch.cos, TensorFactory.rand),
+        (torch.cosh, TensorFactory.rand),
+        (torch.digamma, TensorFactory.rand),
+        (torch.exp, TensorFactory.randn),
+        (torch.expm1, TensorFactory.randn),
+        (torch.floor, TensorFactory.randn),
+        (torch.frac, TensorFactory.randn),
+        (torch.lgamma, TensorFactory.rand),
+        (torch.log, TensorFactory.randp1),
+        (torch.log10, TensorFactory.randp1),
+        (torch.log1p, TensorFactory.randp1),
+        (torch.log2, TensorFactory.randp1),
+        (torch.neg, TensorFactory.randn),
+        (torch.reciprocal, TensorFactory.randp1),
+        (torch.relu, TensorFactory.randn),
+        (torch.round, TensorFactory.randn),
+        (torch.rsqrt, TensorFactory.randp1),
+        (torch.sigmoid, TensorFactory.randn),
+        (torch.sign, TensorFactory.randn),
+        (torch.sin, TensorFactory.rand),
+        (torch.sinh, TensorFactory.rand),
+        (torch.sqrt, TensorFactory.rand),
+        (torch.tan, TensorFactory.rand),
+        (torch.tanh, TensorFactory.rand),
+        (torch.trunc, TensorFactory.randn),
+    ], name_fn=lambda x: x[0].__name__)
     def test_unary_pointwise(self, case):
         op, getter = case
         self._test_unary(op, getter, 'cpu')
@@ -1232,10 +1234,10 @@
         with self.assertRaisesRegex(RuntimeError, msg):
             vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(torch.randn(B0))
 
-    @parameterized('case', {
-        'clamp_min': _make_case(torch.clamp_min),
-        'clamp_max': _make_case(torch.clamp_max),
-    })
+    @parametrize('case', [
+        subtest(_make_case(torch.clamp_min), name='clamp_min'),
+        subtest(_make_case(torch.clamp_max), name='clamp_max'),
+    ])
     def test_clamp_variant(self, case):
         test = self._vmap_test
 
@@ -1264,18 +1266,18 @@
         number = get_number(getter)
         self._test_unary(lambda t: op(t, number), getter, device)
 
-    @parameterized('case', {
-        'add': _make_case(torch.add),
-        'add_dunder': _make_case(lambda x, y: x + y),
-        'sub': _make_case(torch.sub),
-        'sub_dunder': _make_case(lambda x, y: x - y),
-        'mul': _make_case(torch.mul),
-        'mul_dunder': _make_case(lambda x, y: x * y),
-        'div': _make_case(torch.div, input_getter=TensorFactory.randp1),
-        'div_dunder': _make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
-        'pow': _make_case(torch.pow, input_getter=TensorFactory.randp1),
-        'pow_dunder': _make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1),
-    })
+    @parametrize('case', [
+        subtest(_make_case(torch.add), name='add'),
+        subtest(_make_case(lambda x, y: x + y), name='add_dunder'),
+        subtest(_make_case(torch.sub), name='sub'),
+        subtest(_make_case(lambda x, y: x - y), name='sub_dunder'),
+        subtest(_make_case(torch.mul), name='mul'),
+        subtest(_make_case(lambda x, y: x * y), name='mul_dunder'),
+        subtest(_make_case(torch.div, input_getter=TensorFactory.randp1), name='div'),
+        subtest(_make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1), name='div_dunder'),
+        subtest(_make_case(torch.pow, input_getter=TensorFactory.randp1), name='pow'),
+        subtest(_make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1), name='pow_dunder'),
+    ])
     def test_arithmetic(self, case):
         test = self._vmap_test
 
@@ -2587,16 +2589,16 @@
 
         self.assertTrue(torch.randn(()).dim() == 0)
 
-    @parameterized('op', {'abs': torch.abs, 'acos': torch.acos})
-    def test_parameterize(self, op):
+    @parametrize('op', [torch.cos, torch.sinh], name_fn=lambda f: f.__name__)
+    def test_foobar_parametrize(self, op):
         pass
 
-    @parameterized('op2', {'cos': torch.cos, 'cosh': torch.cosh})
-    @parameterized('op1', {'sin': torch.sin, 'sinh': torch.sinh})
-    def test_parameterize_multiple(self, op1, op2):
+    @parametrize('op2', [torch.cos, torch.sinh], name_fn=lambda f: f.__name__)
+    @parametrize('op1', [torch.abs, torch.acos], name_fn=lambda f: f.__name__)
+    def test_parametrize_multiple(self, op1, op2):
         pass
 
-instantiate_parameterized_methods(TestVmapOperators)
+instantiate_parametrized_tests(TestVmapOperators)
 
 
 def construct_v(output, batch_size):
@@ -2771,7 +2773,7 @@
         self._batched_grad_test(torch.log1p, (x,))
         self._batched_grad_grad_test(torch.log1p, (x,))
 
-    @parameterized_with_device('param', {'foo': None, 'bar': None})
+    @parametrize('param', ['foo', 'bar'])
     def test_param_device(self, device, param):
         pass
 
@@ -3136,9 +3138,9 @@
         self.assertEqual(vmap(f, in_dims=(0,None,0))(x, y[0], z)[0], base)
         self.assertEqual(vmap(f, in_dims=(0,0,None))(x, y, z[0])[0], base)
 
-    @parameterized_with_device('training', {'train': True, 'eval': False})
-    @parameterized_with_device('track_running_stats', {'running_stats1': True, 'running_stats0': False})
-    @parameterized_with_device('affine', {'affine1': True, 'affine0': False})
+    @parametrize('training', [True, False])
+    @parametrize('track_running_stats', [True, False])
+    @parametrize('affine', [True, False])
     def test_batch_norm(self, device, affine, track_running_stats, training):
         if not track_running_stats and not training:
             return
@@ -3187,10 +3189,8 @@
 
 
 only_for = ("cpu", "cuda")
-instantiate_parameterized_methods(TestVmapOperatorsOpInfo)
 instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for)
 
-instantiate_parameterized_methods(TestVmapBatchedGradient)
 instantiate_device_type_tests(
     TestVmapBatchedGradient,
     globals(),