[functorch] Create "lagging op database", use it in our OpInfo tests
We have a problem where our tests fail everytime we rebase to the most
recent version of PyTorch. It would be nice to distinguish between
"PyTorch broke a previously passing test" vs "PyTorch added a new test
that would have already failed on PyTorch"
The solution that this PR introduces is for functorch to maintain a
"lagging" OpInfo database. The lagging database needs to be updated
every once in a while with new OpInfos from pytorch/pytorch. This makes
it so that functorch does not randomly get new OpInfo tests.
diff --git a/functorch/codegen/gen_functorch_lagging_op_db.py b/functorch/codegen/gen_functorch_lagging_op_db.py
new file mode 100644
index 0000000..0b421dd
--- /dev/null
+++ b/functorch/codegen/gen_functorch_lagging_op_db.py
@@ -0,0 +1,47 @@
+from torch.testing._internal.common_methods_invocations import op_db
+import pprint
+
+
+def num_leading_spaces(line: str) -> int:
+ result = len(line) - len(line.lstrip())
+ # Empty space handling
+ if result == 0:
+ return 999999
+ return result
+
+
+def deindent(code: str) -> str:
+ lines = code.split('\n')
+ min_leading_spaces = min(map(num_leading_spaces, lines))
+ lines = [line[min_leading_spaces:] for line in lines]
+ return '\n'.join(lines)
+
+
+if __name__ == '__main__':
+ supported = {(opinfo.name, opinfo.variant_test_name) for opinfo in op_db}
+ print(deindent("""\
+ from torch.testing._internal.common_methods_invocations import op_db
+
+ # Generated from codegen/gen_functorch_op_db.py via
+ # python codegen/gen_functorch_lagging_op_db.py > test/functorch_lagging_op_db.py
+ #
+ # People add new OpInfos to PyTorch all the time.
+ # We want them to be able to add OpInfos without breaking our CI.
+ # To achieve this, we keep our OpInfo library behind that of Pytorch's and
+ # we periodically update our OpInfo library by regenerating this file"""))
+
+ print("_functorch_lagging_meta = {")
+ for name, variant in supported:
+ print(f' {(name, variant)},')
+ print("}")
+
+ print(deindent("""\
+
+
+ def in_functorch_lagging_op_db(opinfo):
+ return (opinfo.name, opinfo.variant_test_name) in _functorch_lagging_meta
+
+
+ functorch_lagging_op_db = [
+ opinfo for opinfo in op_db if in_functorch_lagging_op_db(opinfo)
+ ]"""))
diff --git a/functorch/test/functorch_lagging_op_db.py b/functorch/test/functorch_lagging_op_db.py
new file mode 100644
index 0000000..918b1cc
--- /dev/null
+++ b/functorch/test/functorch_lagging_op_db.py
@@ -0,0 +1,322 @@
+from torch.testing._internal.common_methods_invocations import op_db
+
+# Generated from codegen/gen_functorch_op_db.py via
+# python codegen/gen_functorch_lagging_op_db.py > test/functorch_lagging_op_db.py
+#
+# People add new OpInfos to PyTorch all the time.
+# We want them to be able to add OpInfos without breaking our CI.
+# To achieve this, we keep our OpInfo library behind that of Pytorch's and
+# we periodically update our OpInfo library by regenerating this file
+_functorch_lagging_meta = {
+ ('clamp', 'scalar'),
+ ('__rmul__', ''),
+ ('polygamma', 'polygamma_n_1'),
+ ('cumsum', ''),
+ ('logcumsumexp', ''),
+ ('signbit', ''),
+ ('tensor_split', ''),
+ ('square', ''),
+ ('cholesky', ''),
+ ('gradient', ''),
+ ('deg2rad', ''),
+ ('var_mean', ''),
+ ('norm', 'nuc'),
+ ('geqrf', ''),
+ ('argmin', ''),
+ ('baddbmm', ''),
+ ('linalg.eigvals', ''),
+ ('angle', ''),
+ ('logit', ''),
+ ('linalg.matrix_power', ''),
+ ('cos', ''),
+ ('special.i0e', ''),
+ ('masked_fill', ''),
+ ('reciprocal', ''),
+ ('resize_as_', ''),
+ ('fill_', ''),
+ ('special.entr', ''),
+ ('min', 'reduction_with_dim'),
+ ('fliplr', ''),
+ ('flipud', ''),
+ ('fft.fftn', ''),
+ ('logical_not', ''),
+ ('abs', ''),
+ ('remainder', ''),
+ ('broadcast_to', ''),
+ ('conj', ''),
+ ('conj_physical', ''),
+ ('mvlgamma', 'mvlgamma_p_5'),
+ ('trunc', ''),
+ ('cummin', ''),
+ ('vsplit', ''),
+ ('transpose', ''),
+ ('true_divide', ''),
+ ('acos', ''),
+ ('linalg.multi_dot', ''),
+ ('rsub', 'rsub_tensor'),
+ ('max', 'reduction_no_dim'),
+ ('addr', ''),
+ ('mul', ''),
+ ('fft.ifft', ''),
+ ('amin', ''),
+ ('std_mean', ''),
+ ('mode', ''),
+ ('lu', ''),
+ ('linalg.inv', ''),
+ ('max', 'binary'),
+ ('view_as', ''),
+ ('pinverse', ''),
+ ('erf', ''),
+ ('mv', ''),
+ ('logaddexp2', ''),
+ ('fft.fft', ''),
+ ('linalg.svd', ''),
+ ('eig', ''),
+ ('erfc', ''),
+ ('imag', ''),
+ ('split', 'list_args'),
+ ('stack', ''),
+ ('linalg.eig', ''),
+ ('fft.rfftn', ''),
+ ('mvlgamma', 'mvlgamma_p_1'),
+ ('linalg.matrix_rank', 'hermitian'),
+ ('unsqueeze', ''),
+ ('dist', ''),
+ ('special.xlog1py', ''),
+ ('gt', ''),
+ ('floor_divide', ''),
+ ('addbmm', ''),
+ ('logdet', ''),
+ ('inverse', ''),
+ ('log_softmax', ''),
+ ('sign', ''),
+ ('__rmatmul__', ''),
+ ('linalg.lstsq', ''),
+ ('__rmod__', ''),
+ ('expand_as', ''),
+ ('renorm', ''),
+ ('sqrt', ''),
+ ('put', ''),
+ ('logsumexp', ''),
+ ('repeat', ''),
+ ('maximum', ''),
+ ('linalg.svdvals', ''),
+ ('fmax', ''),
+ ('real', ''),
+ ('minimum', ''),
+ ('roll', ''),
+ ('atan', ''),
+ ('linalg.vector_norm', ''),
+ ('einsum', ''),
+ ('polygamma', 'polygamma_n_4'),
+ ('log_softmax', 'dtype'),
+ ('remainder', 'autodiffed'),
+ ('addcmul', ''),
+ ('matrix_exp', ''),
+ ('__rdiv__', ''),
+ ('fft.irfftn', ''),
+ ('polygamma', 'polygamma_n_2'),
+ ('fft.hfft', ''),
+ ('tan', ''),
+ ('ceil', ''),
+ ('view_as_real', ''),
+ ('linalg.solve', ''),
+ ('addcdiv', ''),
+ ('kthvalue', ''),
+ ('linalg.cholesky_ex', ''),
+ ('resize_', ''),
+ ('min', 'binary'),
+ ('log', ''),
+ ('fft.ihfft', ''),
+ ('linalg.pinv', 'hermitian'),
+ ('topk', ''),
+ ('special.i1e', ''),
+ ('lt', ''),
+ ('frac', ''),
+ ('addmm', 'decomposed'),
+ ('linalg.cond', ''),
+ ('linalg.qr', ''),
+ ('view_as_complex', ''),
+ ('select', ''),
+ ('hsplit', ''),
+ ('div', 'trunc_rounding'),
+ ('lu_unpack', ''),
+ ('fmin', ''),
+ ('floor', ''),
+ ('linalg.matrix_rank', ''),
+ ('polar', ''),
+ ('where', ''),
+ ('atanh', ''),
+ ('split_with_sizes', ''),
+ ('gather', ''),
+ ('neg', ''),
+ ('masked_select', ''),
+ ('complex', ''),
+ ('nanquantile', ''),
+ ('permute', ''),
+ ('triu', ''),
+ ('ravel', ''),
+ ('fft.rfft', ''),
+ ('__getitem__', ''),
+ ('exp', ''),
+ ('frexp', ''),
+ ('index_fill', ''),
+ ('nansum', ''),
+ ('exp2', ''),
+ ('eq', ''),
+ ('fmod', ''),
+ ('erfinv', ''),
+ ('trace', ''),
+ ('bmm', ''),
+ ('nn.functional.leaky_relu', ''),
+ ('symeig', ''),
+ ('reshape', ''),
+ ('median', ''),
+ ('linalg.householder_product', ''),
+ ('addmv', ''),
+ ('flip', ''),
+ ('prod', ''),
+ ('sin', ''),
+ ('take', ''),
+ ('xlogy', ''),
+ ('lgamma', ''),
+ ('norm', 'fro'),
+ ('logaddexp', ''),
+ ('sigmoid', ''),
+ ('atan2', ''),
+ ('linalg.det', ''),
+ ('digamma', ''),
+ ('sub', ''),
+ ('split', ''),
+ ('min', 'reduction_no_dim'),
+ ('cummax', ''),
+ ('nn.functional.gelu', ''),
+ ('rsub', 'rsub_scalar'),
+ ('std', ''),
+ ('var', ''),
+ ('linalg.eigvalsh', ''),
+ ('div', 'floor_rounding'),
+ ('log10', ''),
+ ('float_power', ''),
+ ('__rpow__', ''),
+ ('lerp', ''),
+ ('nanmedian', ''),
+ ('hstack', ''),
+ ('hypot', ''),
+ ('linalg.eigh', ''),
+ ('linalg.inv_ex', ''),
+ ('solve', ''),
+ ('linalg.pinv', ''),
+ ('sinh', ''),
+ ('tensordot', ''),
+ ('outer', ''),
+ ('scatter', ''),
+ ('sort', ''),
+ ('cross', ''),
+ ('vdot', ''),
+ ('sinc', ''),
+ ('diag', ''),
+ ('addmm', ''),
+ ('inner', ''),
+ ('special.i1', ''),
+ ('norm', 'inf'),
+ ('linalg.matrix_norm', ''),
+ ('rad2deg', ''),
+ ('expand', ''),
+ ('tanh', ''),
+ ('mean', ''),
+ ('rot90', ''),
+ ('__rsub__', ''),
+ ('triangular_solve', ''),
+ ('diagonal', ''),
+ ('expm1', ''),
+ ('index_select', ''),
+ ('copysign', ''),
+ ('linalg.norm', ''),
+ ('asinh', ''),
+ ('polygamma', 'polygamma_n_3'),
+ ('fmod', 'autodiffed'),
+ ('fft.ifftn', ''),
+ ('nn.functional.hardswish', ''),
+ ('acosh', ''),
+ ('sum', ''),
+ ('__radd__', ''),
+ ('chunk', ''),
+ ('clamp', ''),
+ ('mvlgamma', 'mvlgamma_p_3'),
+ ('qr', ''),
+ ('index_put', ''),
+ ('squeeze', ''),
+ ('t', ''),
+ ('nn.functional.relu6', ''),
+ ('max', 'reduction_with_dim'),
+ ('cholesky_inverse', ''),
+ ('resolve_conj', ''),
+ ('dot', ''),
+ ('special.ndtr', ''),
+ ('pow', ''),
+ ('index_add', ''),
+ ('tile', ''),
+ ('contiguous', ''),
+ ('le', ''),
+ ('movedim', ''),
+ ('diag_embed', ''),
+ ('bitwise_not', ''),
+ ('log1p', ''),
+ ('mm', ''),
+ ('nn.functional.hardtanh', ''),
+ ('ormqr', ''),
+ ('vstack', ''),
+ ('tril', ''),
+ ('to_sparse', ''),
+ ('diff', ''),
+ ('nn.functional.hardshrink', ''),
+ ('take_along_dim', ''),
+ ('cdist', ''),
+ ('ne', ''),
+ ('dsplit', ''),
+ ('argmax', ''),
+ ('div', 'no_rounding_mode'),
+ ('positive', ''),
+ ('masked_scatter', ''),
+ ('narrow', ''),
+ ('svd', ''),
+ ('zero_', ''),
+ ('fft.irfft', ''),
+ ('scatter_add', ''),
+ ('asin', ''),
+ ('i0', ''),
+ ('nan_to_num', ''),
+ ('index_copy', ''),
+ ('matmul', ''),
+ ('norm', ''),
+ ('cumprod', ''),
+ ('cosh', ''),
+ ('log2', ''),
+ ('quantile', ''),
+ ('clone', ''),
+ ('sgn', ''),
+ ('rsqrt', ''),
+ ('view', ''),
+ ('dstack', ''),
+ ('unfold', ''),
+ ('kron', ''),
+ ('ge', ''),
+ ('msort', ''),
+ ('linalg.slogdet', ''),
+ ('add', ''),
+ ('amax', ''),
+ ('polygamma', 'polygamma_n_0'),
+ ('round', ''),
+ ('reshape_as', ''),
+ ('linalg.cholesky', ''),
+}
+
+
+def in_functorch_lagging_op_db(opinfo):
+ return (opinfo.name, opinfo.variant_test_name) in _functorch_lagging_meta
+
+
+functorch_lagging_op_db = [
+ opinfo for opinfo in op_db if in_functorch_lagging_op_db(opinfo)
+]
diff --git a/functorch/test/test_functorch_lagging_op_db.py b/functorch/test/test_functorch_lagging_op_db.py
new file mode 100644
index 0000000..b96013d
--- /dev/null
+++ b/functorch/test/test_functorch_lagging_op_db.py
@@ -0,0 +1,33 @@
+from torch.testing._internal.common_methods_invocations import op_db
+from torch.testing._internal.common_device_type import (
+ instantiate_device_type_tests,
+ ops,
+)
+from torch.testing._internal.common_utils import TestCase, run_tests
+from functorch_lagging_op_db import (
+ functorch_lagging_op_db,
+ in_functorch_lagging_op_db,
+)
+import torch
+
+
+class TestFuncTorchLaggingOpDb(TestCase):
+ def test_functorch_lagging_op_db_has_opinfos(self, device):
+ self.assertEqual(len(functorch_lagging_op_db), len(op_db))
+
+ @ops(op_db, allowed_dtypes=(torch.float,))
+ def test_coverage(self, device, dtype, op):
+ if in_functorch_lagging_op_db(op):
+ return
+ raise RuntimeError(
+ f"{(op.name, op.variant_test_name)} is in PyTorch's OpInfo db ",
+ "but is not in functorch's OpInfo db. Please regenerate ",
+ "test/functorch_lagging_op_db.py and add the new tests to ",
+ "denylists if necessary.")
+
+
+instantiate_device_type_tests(
+ TestFuncTorchLaggingOpDb, globals(), only_for=['cpu'])
+
+if __name__ == '__main__':
+ run_tests()
diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py
index 2f2ec25..8d09f9c 100644
--- a/functorch/test/test_ops.py
+++ b/functorch/test/test_ops.py
@@ -16,7 +16,7 @@
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCUDAIfNoMagma
from torch.testing._internal.common_device_type import ops, onlyCPU
-from torch.testing._internal.common_methods_invocations import op_db
+from functorch_lagging_op_db import functorch_lagging_op_db
from common_utils import (
parameterized,
instantiate_parameterized_methods,
@@ -136,7 +136,7 @@
class TestOperators(TestCase):
- @ops(op_db, allowed_dtypes=(torch.float,))
+ @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
def test_grad(self, device, dtype, op):
op_skip = {
'__getitem__',
@@ -189,7 +189,7 @@
self.assertEqual(result, expected)
- @ops(op_db, allowed_dtypes=(torch.float,))
+ @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
def test_vjp(self, device, dtype, op):
op_skip = {
'__getitem__',
@@ -231,7 +231,7 @@
self.assertEqual(result_vjps, expected_vjps)
- @ops(op_db, allowed_dtypes=(torch.float,))
+ @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
def test_vjpvjp(self, device, dtype, op):
op_skip = {
'__getitem__',
@@ -280,7 +280,7 @@
self.assertEqual(result_vjps, expected_vjps)
- @ops(op_db, allowed_dtypes=(torch.float,))
+ @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
def test_vmapvjp(self, device, dtype, op):
op_skip = {
'__getitem__',
@@ -326,7 +326,7 @@
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(fn, args, {}):
self.assertEqual(loop_out, batched_out)
- @ops(op_db, allowed_dtypes=(torch.float,))
+ @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
def test_vjpvmap(self, device, dtype, op):
op_skip = {
'__getitem__',
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index d95dd18..a663843 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -15,7 +15,7 @@
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCUDAIfNoMagma
from torch.testing._internal.common_device_type import ops, onlyCPU
-from torch.testing._internal.common_methods_invocations import op_db
+from functorch_lagging_op_db import functorch_lagging_op_db
from common_utils import (
parameterized,
instantiate_parameterized_methods,
@@ -2825,7 +2825,7 @@
class TestVmapOperatorsOpInfo(TestCase):
@onlyCPU
- @ops(op_db, allowed_dtypes=(torch.float,))
+ @ops(functorch_lagging_op_db, allowed_dtypes=(torch.float,))
def test_vmap_exhaustive(self, device, dtype, op):
# These are ops that we can't generate fallbacks for
op_skip = {
@@ -2842,6 +2842,7 @@
'resize_as_',
'resolve_conj',
'resize_',
+ 'to_sparse',
}
# Unsupported input types
if op.name in op_skip: