Disable tf32 in functorch transform tests (#86799)
This PR applies a large hammer and disables TF32 in specific functorch transform tests. TF32 isn't precise enough to test correctness.
We could have applied a smaller hammer by disabling TF32 per-OpInfo, but that doesn't seem to have too much additional benefit (e.g. if a convolution batching rule is correct on fp32 then I would expect it to be correct under TF32 modulo precision issues because the actual sequence of PyTorch operators we invoke has not changed, only the backend did).
Test Plan:
- I tested this locally on a machine with A100 GPUs.
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86799
Approved by: https://github.com/malfet
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index 9c321dc..7895d30 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -13,6 +13,7 @@
import torch
from torch import Tensor
import functools
+from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_device_type import \
@@ -334,6 +335,7 @@
class TestOperators(TestCase):
+ @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@skipOps('TestOperators', 'test_grad', vjp_fail.union({
xfail('linalg.eig'), # diagonal_scatter does not support complex
@@ -576,6 +578,7 @@
return op.inplace_variant(inp.clone(), *args, **kwargs)
test(fn, inplace=True)
+ @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@skipOps('TestOperators', 'test_vmapvjpvjp', vjp_fail.union({
skip("atleast_1d"), # Takes too long
skip("atleast_2d"), # Takes too long
@@ -766,6 +769,7 @@
# ---------------------------------------------------------------------
})
+ @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@opsToleranceOverride('TestOperators', 'test_vmapvjp', (
@@ -852,6 +856,7 @@
# ----------------------------------------------------------------------
}
+ @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@opsToleranceOverride('TestOperators', 'test_vmapjvpall', (
tol1('nn.functional.conv_transpose3d',
@@ -1276,6 +1281,7 @@
expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
self.assertEqual(result, expected)
+ @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@skipOps('TestOperators', 'test_vmapjvpvjp', vjp_fail.union({
# Following operatos take too long, hence skipped
skip('atleast_1d'),
@@ -1584,6 +1590,7 @@
cotangents = torch.randn_like(result, device=device)
self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias))
+ @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float32, torch.double))
@skipOps('TestOperators', 'test_vmap_autograd_grad', {
xfail('linalg.eig'), # all close?
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index bf066c8..53b2287 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -17,6 +17,7 @@
import warnings
import unittest
from torch.testing._internal.common_methods_invocations import op_db
+from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCUDAIfNoMagma
from torch.testing._internal.common_device_type import ops
@@ -3272,6 +3273,7 @@
# ---------------------------------------------------------------------
}
+ @with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', (
tol1('linalg.det',