Manually enable `capture_func_transforms` for testing (#107122)

Manually enable `capture_func_transforms` for testing as plan is to default `capture_func_transforms` to False in 2.1. (enable it so that we still test the support on release branch).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107122
Approved by: https://github.com/zou3519
diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py
index 18c65b8..7970155 100644
--- a/test/dynamo/test_higher_order_ops.py
+++ b/test/dynamo/test_higher_order_ops.py
@@ -75,16 +75,6 @@
     return result
 
 
-@contextlib.contextmanager
-def disable_functorch_capture():
-    org_val = torch._dynamo.config.capture_func_transforms
-    torch._dynamo.config.capture_func_transforms = False
-    try:
-        yield
-    finally:
-        torch._dynamo.config.capture_func_transforms = org_val
-
-
 # Checks that a dict matches a dict with "regex keys". That is,
 # the keys are regex expressions.
 def assert_dict_matches_regex(self, dct, dct_with_regex_keys):
@@ -1619,6 +1609,13 @@
 
 
 class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
+    def run(self, result=None):
+        # capture_func_transform will be set to False (for 2.1) till we
+        # support all transforms, so manually patch it to `True`` for
+        # testing on release branch.
+        with config.patch(capture_func_transforms=True):
+            super().run(result)
+
     def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0):
         backend = EagerAndRecordGraphs()
         actual = fn(*inputs)
@@ -2139,7 +2136,7 @@
     def test_grad_disable_capture(self):
         counters.clear()
 
-        with disable_functorch_capture():
+        with config.patch(capture_func_transforms=False):
             # We have verified above that this
             # function compiles
             def fn(x):
@@ -2639,7 +2636,7 @@
     def test_vmap_disable_capture(self):
         counters.clear()
 
-        with disable_functorch_capture():
+        with config.patch(capture_func_transforms=False):
             # We have verified above that this
             # function compiles
             def wrapper_fn(x):
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index 73279be..fd5693e 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -4769,6 +4769,7 @@
     # torch.compile is not supported on Windows
     @expectedFailureIf(IS_WINDOWS)
     @torch._dynamo.config.patch(suppress_errors=False)
+    @torch._dynamo.config.patch(capture_func_transforms=True)
     @skipIfTorchDynamo("Do not test torch.compile on top of torch.compile")
     def test_grad_deprecated_api(self, device):
         x = torch.randn((), device=device)