[export] Disable backend decomps for capture_pre_autograd (#127120)

Differential Revision: D57785713

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127120
Approved by: https://github.com/ydwu4
diff --git a/test/export/test_export.py b/test/export/test_export.py
index 07f1592..426d3c8 100644
--- a/test/export/test_export.py
+++ b/test/export/test_export.py
@@ -309,6 +309,17 @@
         args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
         self.assertEqual(exported_program.module()(*args), m(*args))
 
+        from torch._export import capture_pre_autograd_graph
+
+        gm: torch.fx.GraphModule = capture_pre_autograd_graph(
+            m, args=example_args, dynamic_shapes=dynamic_shapes
+        )
+
+        args = (torch.randn(17, 3, 256, 256), torch.ones(17, 32, 256, 256))
+        self.assertEqual(gm(*args), m(*args))
+        args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
+        self.assertEqual(gm(*args), m(*args))
+
     def test_basic_non_strict_real_tensor(self):
         class Basic(torch.nn.Module):
             def __init__(self):
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index 105a7ee..d41ff4b 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -138,7 +138,7 @@
         An nn.Module containing the traced method.
 
     """
-    from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG
+    from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
     from torch._utils_internal import export_api_rollout_check
 
     capture_pre_autograd_graph_warning()
@@ -165,7 +165,7 @@
             for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
             if op != torch.ops.aten.dropout.default
         }
-        with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
+        with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
             m = torch._dynamo.export(
                 f,
                 dynamic_shapes=dynamic_shapes,