Manually enable `capture_func_transforms` for testing (#107122)
Manually enable `capture_func_transforms` for testing as plan is to default `capture_func_transforms` to False in 2.1. (enable it so that we still test the support on release branch).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107122
Approved by: https://github.com/zou3519
diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py
index 18c65b8..7970155 100644
--- a/test/dynamo/test_higher_order_ops.py
+++ b/test/dynamo/test_higher_order_ops.py
@@ -75,16 +75,6 @@
return result
-@contextlib.contextmanager
-def disable_functorch_capture():
- org_val = torch._dynamo.config.capture_func_transforms
- torch._dynamo.config.capture_func_transforms = False
- try:
- yield
- finally:
- torch._dynamo.config.capture_func_transforms = org_val
-
-
# Checks that a dict matches a dict with "regex keys". That is,
# the keys are regex expressions.
def assert_dict_matches_regex(self, dct, dct_with_regex_keys):
@@ -1619,6 +1609,13 @@
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
+ def run(self, result=None):
+ # capture_func_transform will be set to False (for 2.1) till we
+ # support all transforms, so manually patch it to `True`` for
+ # testing on release branch.
+ with config.patch(capture_func_transforms=True):
+ super().run(result)
+
def _compile_check(self, fn, inputs, fullgraph=True, graph_idx=0):
backend = EagerAndRecordGraphs()
actual = fn(*inputs)
@@ -2139,7 +2136,7 @@
def test_grad_disable_capture(self):
counters.clear()
- with disable_functorch_capture():
+ with config.patch(capture_func_transforms=False):
# We have verified above that this
# function compiles
def fn(x):
@@ -2639,7 +2636,7 @@
def test_vmap_disable_capture(self):
counters.clear()
- with disable_functorch_capture():
+ with config.patch(capture_func_transforms=False):
# We have verified above that this
# function compiles
def wrapper_fn(x):
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index 73279be..fd5693e 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -4769,6 +4769,7 @@
# torch.compile is not supported on Windows
@expectedFailureIf(IS_WINDOWS)
@torch._dynamo.config.patch(suppress_errors=False)
+ @torch._dynamo.config.patch(capture_func_transforms=True)
@skipIfTorchDynamo("Do not test torch.compile on top of torch.compile")
def test_grad_deprecated_api(self, device):
x = torch.randn((), device=device)