[functorch] Move to PyTorch core's parametrize testing mechanism
diff --git a/functorch/test/common_utils.py b/functorch/test/common_utils.py
index 79a069b..8959dff 100644
--- a/functorch/test/common_utils.py
+++ b/functorch/test/common_utils.py
@@ -13,141 +13,6 @@
from functorch_additional_op_db import additional_op_db
from torch.testing._internal.common_methods_invocations import DecorateInfo
import unittest
-import warnings
-import re
-
-"""
-Usage:
-
-class MyTestCase(TestCase):
- @parameterized('param', {'abs': torch.abs, 'cos': torch.cos})
- def test_single_param(self, param):
- pass
-
- @parameterized('param1', {'sin': torch.sin, 'tan': torch.tan})
- @parameterized('param2', {'abs': torch.abs, 'cos': torch.cos})
- def test_multiple_param(self, param1, param2):
- pass
-
-# The following creates:
-# - MyTestCase.test_single_param_abs
-# - MyTestCase.test_single_param_cos
-# - MyTestCase.test_multiple_param_abs_sin
-# - MyTestCase.test_multiple_param_cos_sin
-# - MyTestCase.test_multiple_param_abs_tan
-# - MyTestCase.test_multiple_param_cos_tan
-instantiate_parameterized_methods(MyTestCase)
-
-# This is also composable with PyTorch testing's instantiate_device_type_tests
-# Make sure the param is after the device arg
-class MyDeviceSpecificTest(TestCase):
- @parameterized('param', {'abs': torch.abs, 'cos': torch.cos})
- def test_single_param(self, device, param):
- pass
-
-# The following creates:
-# - MyDeviceSpecificTestCPU.test_single_param_abs_cpu
-# - MyDeviceSpecificTestCPU.test_single_param_cos_cpu
-# - MyDeviceSpecificTestCUDA.test_single_param_abs_cuda
-# - MyDeviceSpecificTestCUDA.test_single_param_cos_cpu
-instantiate_parameterized_methods(MyDeviceSpecificTest)
-instantiate_device_type_tests(MyDeviceSpecificTest, globals())
-
-# !!!!! warning !!!!!
-# 1. The method being parameterized over MUST NOT HAVE A DOCSTRING. We'll
-# error out nicely if this happens.
-# 2. All other decorators MUST USE functools.wraps (they must propagate the docstring)
-# `@parameterized` works by storing some metadata in place of the docstring.
-# This takes advantage of how other decorators work (other decorators usually
-# propagate the docstring via functools.wrap).
-# 3. We might not compose with PyTorch testing's @dtypes and @precision
-# decorators. But that is easily fixable. TODO.
-# I think this composes with PyTorch testing's instantiate_device_type_tests.
-"""
-
-PARAM_META = '_torch_parameterized_meta'
-
-class ParamMeta():
- def __init__(self):
- self.stack = []
-
- def push(self, elt):
- self.stack.append(elt)
-
- def pop(self, elt):
- return self.stack.pop()
-
-def has_param_meta(method):
- param_meta = getattr(method, '__doc__', None)
- return param_meta is not None and isinstance(param_meta, ParamMeta)
-
-def get_param_meta(method):
- param_meta = getattr(method, '__doc__', None)
- if param_meta is None:
- method.__doc__ = ParamMeta()
- if not isinstance(method.__doc__, ParamMeta):
- raise RuntimeError('Tried to use @parameterized on a method that has '
- 'a docstring. This is not supported. Please remove '
- 'the docstring.')
- return method.__doc__
-
-def parameterized(arg_name, case_dict):
- def decorator(fn):
- param_meta = get_param_meta(fn)
- param_meta.push((arg_name, case_dict))
- return fn
- return decorator
-
-def parameterized_with_device(arg_name, case_dict):
- def decorator(fn):
- param_meta = get_param_meta(fn)
- param_meta.push((arg_name, case_dict))
- fn._has_device = True
- return fn
- return decorator
-
-
-def _set_parameterized_method(test_base, fn, instantiated_cases, extension_name):
- new_name = f'{fn.__name__}_{extension_name}'
-
- def wrapped_no_device(self, *args, **kwargs):
- for arg_name, case in instantiated_cases:
- kwargs[arg_name] = case
- return fn(self, *args, **kwargs)
-
- def wrapped_with_device(self, device, *args, **kwargs):
- for arg_name, case in instantiated_cases:
- kwargs[arg_name] = case
- return fn(self, device, *args, **kwargs)
-
- if getattr(fn, '_has_device', False):
- wrapped = wrapped_with_device
- else:
- wrapped = wrapped_no_device
-
- wrapped.__name__ = new_name
- setattr(test_base, new_name, wrapped)
-
-def to_tuples(dct):
- return [(k, v) for k, v in dct.items()]
-
-def instantiate_parameterized_methods(test_base):
- allattrs = tuple(dir(test_base))
- for attr_name in allattrs:
- attr = getattr(test_base, attr_name)
- if not has_param_meta(attr):
- continue
-
- param_meta = get_param_meta(attr)
- arg_names, case_dicts = zip(*param_meta.stack)
- case_dicts = [to_tuples(cd) for cd in case_dicts]
- for list_of_name_and_case in itertools.product(*case_dicts):
- case_names, cases = zip(*list_of_name_and_case)
- extension_name = '_'.join(case_names)
- instantiated_cases = list(zip(arg_names, cases))
- _set_parameterized_method(test_base, attr, instantiated_cases, extension_name)
- # Remove the base fn from the testcase
- delattr(test_base, attr_name)
def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values):
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index 1580644..be17f71 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -19,8 +19,6 @@
from functorch_lagging_op_db import functorch_lagging_op_db
from functorch_additional_op_db import additional_op_db
from common_utils import (
- parameterized,
- instantiate_parameterized_methods,
get_fallback_and_vmap_exhaustive,
get_exhaustive_batched_inputs,
opinfo_in_dict,
diff --git a/functorch/test/test_pythonkey.py b/functorch/test/test_pythonkey.py
index 40f82bc..e972b72 100644
--- a/functorch/test/test_pythonkey.py
+++ b/functorch/test/test_pythonkey.py
@@ -30,9 +30,6 @@
from functorch_lagging_op_db import functorch_lagging_op_db
from functorch_additional_op_db import additional_op_db
from common_utils import (
- parameterized,
- parameterized_with_device,
- instantiate_parameterized_methods,
get_fallback_and_vmap_exhaustive,
opinfo_in_dict,
xfail,
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index 9469f01..de08cd5 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -18,13 +18,15 @@
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCUDAIfNoMagma
from torch.testing._internal.common_device_type import ops, onlyCPU
+from torch.testing._internal.common_utils import (
+ parametrize,
+ instantiate_parametrized_tests,
+ subtest
+)
from functorch_lagging_op_db import functorch_lagging_op_db
from functorch_additional_op_db import additional_op_db
from torch.utils._pytree import tree_map
from common_utils import (
- parameterized,
- parameterized_with_device,
- instantiate_parameterized_methods,
get_fallback_and_vmap_exhaustive,
opinfo_in_dict,
xfail,
@@ -1162,38 +1164,38 @@
test(vmap(op, in_dims=2), [getter([2, 5, B0, B1, 3], device)],
in_dims=2, out_dims=2)
- @parameterized("case", {
- 'abs': (torch.abs, TensorFactory.randn),
- 'acos': (torch.acos, TensorFactory.rand),
- 'asin': (torch.asin, TensorFactory.rand),
- 'atan': (torch.atan, TensorFactory.rand),
- 'ceil': (torch.ceil, TensorFactory.randn),
- 'cos': (torch.cos, TensorFactory.rand),
- 'cosh': (torch.cosh, TensorFactory.rand),
- 'digamma': (torch.digamma, TensorFactory.rand),
- 'exp': (torch.exp, TensorFactory.randn),
- 'expm1': (torch.expm1, TensorFactory.randn),
- 'floor': (torch.floor, TensorFactory.randn),
- 'frac': (torch.frac, TensorFactory.randn),
- 'lgamma': (torch.lgamma, TensorFactory.rand),
- 'log': (torch.log, TensorFactory.randp1),
- 'log10': (torch.log10, TensorFactory.randp1),
- 'log1p': (torch.log1p, TensorFactory.randp1),
- 'log2': (torch.log2, TensorFactory.randp1),
- 'neg': (torch.neg, TensorFactory.randn),
- 'reciprocol': (torch.reciprocal, TensorFactory.randp1),
- 'relu': (torch.relu, TensorFactory.randn),
- 'round': (torch.round, TensorFactory.randn),
- 'rsqrt': (torch.rsqrt, TensorFactory.randp1),
- 'sigmoid': (torch.sigmoid, TensorFactory.randn),
- 'sign': (torch.sign, TensorFactory.randn),
- 'sin': (torch.sin, TensorFactory.rand),
- 'sinh': (torch.sinh, TensorFactory.rand),
- 'sqrt': (torch.sqrt, TensorFactory.rand),
- 'tan': (torch.tan, TensorFactory.rand),
- 'tanh': (torch.tanh, TensorFactory.rand),
- 'trunc': (torch.trunc, TensorFactory.randn),
- })
+ @parametrize("case", [
+ (torch.abs, TensorFactory.randn),
+ (torch.acos, TensorFactory.rand),
+ (torch.asin, TensorFactory.rand),
+ (torch.atan, TensorFactory.rand),
+ (torch.ceil, TensorFactory.randn),
+ (torch.cos, TensorFactory.rand),
+ (torch.cosh, TensorFactory.rand),
+ (torch.digamma, TensorFactory.rand),
+ (torch.exp, TensorFactory.randn),
+ (torch.expm1, TensorFactory.randn),
+ (torch.floor, TensorFactory.randn),
+ (torch.frac, TensorFactory.randn),
+ (torch.lgamma, TensorFactory.rand),
+ (torch.log, TensorFactory.randp1),
+ (torch.log10, TensorFactory.randp1),
+ (torch.log1p, TensorFactory.randp1),
+ (torch.log2, TensorFactory.randp1),
+ (torch.neg, TensorFactory.randn),
+ (torch.reciprocal, TensorFactory.randp1),
+ (torch.relu, TensorFactory.randn),
+ (torch.round, TensorFactory.randn),
+ (torch.rsqrt, TensorFactory.randp1),
+ (torch.sigmoid, TensorFactory.randn),
+ (torch.sign, TensorFactory.randn),
+ (torch.sin, TensorFactory.rand),
+ (torch.sinh, TensorFactory.rand),
+ (torch.sqrt, TensorFactory.rand),
+ (torch.tan, TensorFactory.rand),
+ (torch.tanh, TensorFactory.rand),
+ (torch.trunc, TensorFactory.randn),
+ ], name_fn=lambda x: x[0].__name__)
def test_unary_pointwise(self, case):
op, getter = case
self._test_unary(op, getter, 'cpu')
@@ -1232,10 +1234,10 @@
with self.assertRaisesRegex(RuntimeError, msg):
vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(torch.randn(B0))
- @parameterized('case', {
- 'clamp_min': _make_case(torch.clamp_min),
- 'clamp_max': _make_case(torch.clamp_max),
- })
+ @parametrize('case', [
+ subtest(_make_case(torch.clamp_min), name='clamp_min'),
+ subtest(_make_case(torch.clamp_max), name='clamp_max'),
+ ])
def test_clamp_variant(self, case):
test = self._vmap_test
@@ -1264,18 +1266,18 @@
number = get_number(getter)
self._test_unary(lambda t: op(t, number), getter, device)
- @parameterized('case', {
- 'add': _make_case(torch.add),
- 'add_dunder': _make_case(lambda x, y: x + y),
- 'sub': _make_case(torch.sub),
- 'sub_dunder': _make_case(lambda x, y: x - y),
- 'mul': _make_case(torch.mul),
- 'mul_dunder': _make_case(lambda x, y: x * y),
- 'div': _make_case(torch.div, input_getter=TensorFactory.randp1),
- 'div_dunder': _make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
- 'pow': _make_case(torch.pow, input_getter=TensorFactory.randp1),
- 'pow_dunder': _make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1),
- })
+ @parametrize('case', [
+ subtest(_make_case(torch.add), name='add'),
+ subtest(_make_case(lambda x, y: x + y), name='add_dunder'),
+ subtest(_make_case(torch.sub), name='sub'),
+ subtest(_make_case(lambda x, y: x - y), name='sub_dunder'),
+ subtest(_make_case(torch.mul), name='mul'),
+ subtest(_make_case(lambda x, y: x * y), name='mul_dunder'),
+ subtest(_make_case(torch.div, input_getter=TensorFactory.randp1), name='div'),
+ subtest(_make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1), name='div_dunder'),
+ subtest(_make_case(torch.pow, input_getter=TensorFactory.randp1), name='pow'),
+ subtest(_make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1), name='pow_dunder'),
+ ])
def test_arithmetic(self, case):
test = self._vmap_test
@@ -2587,16 +2589,16 @@
self.assertTrue(torch.randn(()).dim() == 0)
- @parameterized('op', {'abs': torch.abs, 'acos': torch.acos})
- def test_parameterize(self, op):
+ @parametrize('op', [torch.cos, torch.sinh], name_fn=lambda f: f.__name__)
+ def test_foobar_parametrize(self, op):
pass
- @parameterized('op2', {'cos': torch.cos, 'cosh': torch.cosh})
- @parameterized('op1', {'sin': torch.sin, 'sinh': torch.sinh})
- def test_parameterize_multiple(self, op1, op2):
+ @parametrize('op2', [torch.cos, torch.sinh], name_fn=lambda f: f.__name__)
+ @parametrize('op1', [torch.abs, torch.acos], name_fn=lambda f: f.__name__)
+ def test_parametrize_multiple(self, op1, op2):
pass
-instantiate_parameterized_methods(TestVmapOperators)
+instantiate_parametrized_tests(TestVmapOperators)
def construct_v(output, batch_size):
@@ -2771,7 +2773,7 @@
self._batched_grad_test(torch.log1p, (x,))
self._batched_grad_grad_test(torch.log1p, (x,))
- @parameterized_with_device('param', {'foo': None, 'bar': None})
+ @parametrize('param', ['foo', 'bar'])
def test_param_device(self, device, param):
pass
@@ -3136,9 +3138,9 @@
self.assertEqual(vmap(f, in_dims=(0,None,0))(x, y[0], z)[0], base)
self.assertEqual(vmap(f, in_dims=(0,0,None))(x, y, z[0])[0], base)
- @parameterized_with_device('training', {'train': True, 'eval': False})
- @parameterized_with_device('track_running_stats', {'running_stats1': True, 'running_stats0': False})
- @parameterized_with_device('affine', {'affine1': True, 'affine0': False})
+ @parametrize('training', [True, False])
+ @parametrize('track_running_stats', [True, False])
+ @parametrize('affine', [True, False])
def test_batch_norm(self, device, affine, track_running_stats, training):
if not track_running_stats and not training:
return
@@ -3187,10 +3189,8 @@
only_for = ("cpu", "cuda")
-instantiate_parameterized_methods(TestVmapOperatorsOpInfo)
instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for)
-instantiate_parameterized_methods(TestVmapBatchedGradient)
instantiate_device_type_tests(
TestVmapBatchedGradient,
globals(),