[export] Disable backend decomps for capture_pre_autograd (#127120)
Differential Revision: D57785713
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127120
Approved by: https://github.com/ydwu4
diff --git a/test/export/test_export.py b/test/export/test_export.py
index 07f1592..426d3c8 100644
--- a/test/export/test_export.py
+++ b/test/export/test_export.py
@@ -309,6 +309,17 @@
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
self.assertEqual(exported_program.module()(*args), m(*args))
+ from torch._export import capture_pre_autograd_graph
+
+ gm: torch.fx.GraphModule = capture_pre_autograd_graph(
+ m, args=example_args, dynamic_shapes=dynamic_shapes
+ )
+
+ args = (torch.randn(17, 3, 256, 256), torch.ones(17, 32, 256, 256))
+ self.assertEqual(gm(*args), m(*args))
+ args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
+ self.assertEqual(gm(*args), m(*args))
+
def test_basic_non_strict_real_tensor(self):
class Basic(torch.nn.Module):
def __init__(self):
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index 105a7ee..d41ff4b 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -138,7 +138,7 @@
An nn.Module containing the traced method.
"""
- from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG
+ from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
from torch._utils_internal import export_api_rollout_check
capture_pre_autograd_graph_warning()
@@ -165,7 +165,7 @@
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
if op != torch.ops.aten.dropout.default
}
- with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
+ with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
m = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes,