pt2 dper passes: run shape prop before each pass (#122451)
Summary: Most passes relies on shape info. We need to run shape prop after each pass
Reviewed By: frank-wei
Differential Revision: D55221119
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122451
Approved by: https://github.com/frank-wei
diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py
index d290455..c868899 100644
--- a/torch/_inductor/fx_passes/pre_grad.py
+++ b/torch/_inductor/fx_passes/pre_grad.py
@@ -131,21 +131,25 @@
pass_execution_and_save(
normalization_pass_aten.apply,
gm,
+ example_inputs,
"[Pre grad(predispatch IR)]Apply normalization pass",
)
pass_execution_and_save(
group_batch_fusion_passes,
gm,
+ example_inputs,
"[Pre grad(predispatch IR)] Apply group_batch_fusion",
)
pass_execution_and_save(
fuse_chunk_squeeze_cat_pass.apply,
gm,
+ example_inputs,
"[Pre grad(predispatch IR)] Apply fuse_chunk_squeeze_cat_pass",
)
pass_execution_and_save(
fuse_split_linear_add_pass.apply,
gm,
+ example_inputs,
"[Pre grad(predispatch IR)] Apply fuse_split_linear_add_pass",
)
@@ -159,21 +163,25 @@
pass_execution_and_save(
pattern_matcher_pass_aten.apply,
gm,
+ example_inputs,
f"[Pre grad(predispatch IR)]Apply split_cat, index: {ind}",
)
pass_execution_and_save(
remove_reshape_pass.apply,
gm,
+ example_inputs,
"[Pre grad(predispatch IR)] Apply remove_reshape_pass",
)
pass_execution_and_save(
fuse_parallel_linear_pass,
gm,
+ example_inputs,
"[Pre grad(predispatch IR)] Apply fuse_parallel_linear_pass",
)
pass_execution_and_save(
lambda graph: remove_split_ops(graph.owning_module, shape_prop),
gm,
+ example_inputs,
"[Pre grad(predispatch IR)] Apply remove_split_ops",
)
shape_prop(gm)
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 61e19d9..5c6b2e2 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -46,8 +46,10 @@
import torch
from torch._dynamo.device_interface import get_interface_for_device
+from torch._dynamo.utils import detect_fake_mode
from torch.autograd import DeviceType
from torch.autograd.profiler_util import EventList
+from torch.fx.passes.shape_prop import ShapeProp
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
from . import config
@@ -1327,7 +1329,7 @@
DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
-def pass_execution_and_save(func, gm, msg):
+def pass_execution_and_save(func, gm, inp, msg):
from .pattern_matcher import stable_topological_sort
with tempfile.NamedTemporaryFile(
@@ -1337,6 +1339,7 @@
) as f:
before_io = io.StringIO()
after_io = io.StringIO()
+ ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp)
print(f"Before:\n{gm.graph}", file=f)
print(gm.graph, file=before_io)
start_time = datetime.now()