pt2 dper passes: run shape prop before each pass (#122451)

Summary: Most passes relies on shape info. We need to run shape prop after each pass

Reviewed By: frank-wei

Differential Revision: D55221119

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122451
Approved by: https://github.com/frank-wei
diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py
index d290455..c868899 100644
--- a/torch/_inductor/fx_passes/pre_grad.py
+++ b/torch/_inductor/fx_passes/pre_grad.py
@@ -131,21 +131,25 @@
             pass_execution_and_save(
                 normalization_pass_aten.apply,
                 gm,
+                example_inputs,
                 "[Pre grad(predispatch IR)]Apply normalization pass",
             )
             pass_execution_and_save(
                 group_batch_fusion_passes,
                 gm,
+                example_inputs,
                 "[Pre grad(predispatch IR)] Apply group_batch_fusion",
             )
             pass_execution_and_save(
                 fuse_chunk_squeeze_cat_pass.apply,
                 gm,
+                example_inputs,
                 "[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass",
             )
             pass_execution_and_save(
                 fuse_split_linear_add_pass.apply,
                 gm,
+                example_inputs,
                 "[Pre grad(predispatch IR)] Apply fuse_split_linear_add_pass",
             )
 
@@ -159,21 +163,25 @@
                 pass_execution_and_save(
                     pattern_matcher_pass_aten.apply,
                     gm,
+                    example_inputs,
                     f"[Pre grad(predispatch IR)]Apply split_cat, index: {ind}",
                 )
             pass_execution_and_save(
                 remove_reshape_pass.apply,
                 gm,
+                example_inputs,
                 "[Pre grad(predispatch IR)] Apply remove_reshape_pass",
             )
             pass_execution_and_save(
                 fuse_parallel_linear_pass,
                 gm,
+                example_inputs,
                 "[Pre grad(predispatch IR)] Apply fuse_parallel_linear_pass",
             )
             pass_execution_and_save(
                 lambda graph: remove_split_ops(graph.owning_module, shape_prop),
                 gm,
+                example_inputs,
                 "[Pre grad(predispatch IR)] Apply remove_split_ops",
             )
             shape_prop(gm)
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 61e19d9..5c6b2e2 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -46,8 +46,10 @@
 
 import torch
 from torch._dynamo.device_interface import get_interface_for_device
+from torch._dynamo.utils import detect_fake_mode
 from torch.autograd import DeviceType
 from torch.autograd.profiler_util import EventList
+from torch.fx.passes.shape_prop import ShapeProp
 from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
 from . import config
 
@@ -1327,7 +1329,7 @@
     DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
 
 
-def pass_execution_and_save(func, gm, msg):
+def pass_execution_and_save(func, gm, inp, msg):
     from .pattern_matcher import stable_topological_sort
 
     with tempfile.NamedTemporaryFile(
@@ -1337,6 +1339,7 @@
     ) as f:
         before_io = io.StringIO()
         after_io = io.StringIO()
+        ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp)
         print(f"Before:\n{gm.graph}", file=f)
         print(gm.graph, file=before_io)
         start_time = datetime.now()