Disable tf32 in functorch transform tests (#86799)

This PR applies a large hammer and disables TF32 in specific functorch transform tests. TF32 isn't precise enough to test correctness.

We could have applied a smaller hammer by disabling TF32 per-OpInfo, but that doesn't seem to have too much additional benefit (e.g. if a convolution batching rule is correct on fp32 then I would expect it to be correct under TF32 modulo precision issues because the actual sequence of PyTorch operators we invoke has not changed, only the backend did).

Test Plan:
- I tested this locally on a machine with A100 GPUs.

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86799
Approved by: https://github.com/malfet
diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py
index 9c321dc..7895d30 100644
--- a/test/functorch/test_ops.py
+++ b/test/functorch/test_ops.py
@@ -13,6 +13,7 @@
 import torch
 from torch import Tensor
 import functools
+from torch.testing._internal.common_cuda import with_tf32_off
 from torch.testing._internal.common_device_type import instantiate_device_type_tests
 from torch.testing._internal.common_device_type import ops
 from torch.testing._internal.common_device_type import \
@@ -334,6 +335,7 @@
 
 
 class TestOperators(TestCase):
+    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
     @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
     @skipOps('TestOperators', 'test_grad', vjp_fail.union({
         xfail('linalg.eig'),  # diagonal_scatter does not support complex
@@ -576,6 +578,7 @@
                 return op.inplace_variant(inp.clone(), *args, **kwargs)
             test(fn, inplace=True)
 
+    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
     @skipOps('TestOperators', 'test_vmapvjpvjp', vjp_fail.union({
         skip("atleast_1d"),  # Takes too long
         skip("atleast_2d"),  # Takes too long
@@ -766,6 +769,7 @@
         # ---------------------------------------------------------------------
     })
 
+    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
     @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
     @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
     @opsToleranceOverride('TestOperators', 'test_vmapvjp', (
@@ -852,6 +856,7 @@
         # ----------------------------------------------------------------------
     }
 
+    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
     @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
     @opsToleranceOverride('TestOperators', 'test_vmapjvpall', (
         tol1('nn.functional.conv_transpose3d',
@@ -1276,6 +1281,7 @@
             expected = reference(primals, cotangents, primals_tangents, cotangents_tangents)
             self.assertEqual(result, expected)
 
+    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
     @skipOps('TestOperators', 'test_vmapjvpvjp', vjp_fail.union({
         # Following operatos take too long, hence skipped
         skip('atleast_1d'),
@@ -1584,6 +1590,7 @@
                 cotangents = torch.randn_like(result, device=device)
                 self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias))
 
+    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
     @ops(op_db + additional_op_db, allowed_dtypes=(torch.float32, torch.double))
     @skipOps('TestOperators', 'test_vmap_autograd_grad', {
         xfail('linalg.eig'),  # all close?
diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py
index bf066c8..53b2287 100644
--- a/test/functorch/test_vmap.py
+++ b/test/functorch/test_vmap.py
@@ -17,6 +17,7 @@
 import warnings
 import unittest
 from torch.testing._internal.common_methods_invocations import op_db
+from torch.testing._internal.common_cuda import with_tf32_off
 from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
     skipCUDAIfNoMagma
 from torch.testing._internal.common_device_type import ops
@@ -3272,6 +3273,7 @@
         # ---------------------------------------------------------------------
     }
 
+    @with_tf32_off  # https://github.com/pytorch/pytorch/issues/86798
     @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
     @opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', (
         tol1('linalg.det',