[PT2] Directly set meta.val in group_batch_fusion_aten (#135078)
Summary: instead of using FakeTensorProp after the pass
Differential Revision: D62162640
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135078
Approved by: https://github.com/frank-wei
diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py
index e9cc56b..eef058a 100644
--- a/torch/_inductor/fx_passes/pre_grad.py
+++ b/torch/_inductor/fx_passes/pre_grad.py
@@ -12,7 +12,6 @@
matches_module_pattern,
replace_node_module,
)
-from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from torch.fx.passes.shape_prop import ShapeProp
from torch.nn import functional as F
@@ -169,10 +168,6 @@
example_inputs,
"[Pre grad(predispatch IR)] Apply group_batch_fusion",
)
- # update node.meta after group batch fusion
- FakeTensorProp(module=gm, mode=detect_fake_mode(example_inputs)).propagate(
- *example_inputs
- )
pass_execution_and_save(
normalize_node_kwargs_pass,
gm,