inductor: convert view to reshape before doing fake_tensor_prop at freezing step (#104612)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104612
Approved by: https://github.com/jgong5, https://github.com/eellison, https://github.com/shunting314
diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py
index ed87c2a..5c8593e 100644
--- a/test/inductor/test_inductor_freezing.py
+++ b/test/inductor/test_inductor_freezing.py
@@ -329,6 +329,29 @@
self.assertEqual(eager, compiled)
self.assertTrue(weight_ref() is None)
+ def test_conv_layout_convert_with_view(self):
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ 3, 128, kernel_size=3, padding=1, stride=1, bias=False
+ )
+
+ def forward(self, x):
+ x = self.conv(x)
+ return torch.flatten(x, 1)
+
+ mod = Model().to(self.device).eval()
+
+ @torch.compile()
+ def foo(mod, inp):
+ return mod(inp)
+
+ with torch.no_grad():
+ x = torch.rand(2, 3, 5, 5).to(self.device)
+ mod_eager = mod(x)
+ self.assertEqual(foo(mod, x), mod_eager)
+
def test_conv_weight_layout_convert(self):
class Model(torch.nn.Module):
def __init__(self):
diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py
index c5e41f9..06b8e86 100644
--- a/torch/_inductor/freezing.py
+++ b/torch/_inductor/freezing.py
@@ -7,10 +7,12 @@
import torch
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
-
from torch._dynamo.utils import detect_fake_mode, dynamo_timed
from torch._functorch.compile_utils import fx_graph_cse
+
+from torch._inductor.compile_fx import fake_tensor_prop
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
+from torch._inductor.fx_passes.post_grad import view_to_reshape
from torch.ao.quantization._pt2e.utils import _fuse_conv_bn_
from torch.fx.experimental.proxy_tensor import make_fx
from . import config
@@ -221,8 +223,10 @@
aot_autograd_gm.graph = cse_graph
aot_autograd_gm.recompile()
- from torch._inductor.compile_fx import fake_tensor_prop
-
+ # We have convert conv's weight to channels last which may meet error for .view
+ # when doing fake_tensor_prop. So we need to convert view to reshape first.
+ # See the details in fx_codegen_and_compile of compile_fx.py.
+ view_to_reshape(aot_autograd_gm)
# Make sure meta['val'] is properly setup(weight conversion
# or decompose_unfused_batchnorms lost meta['val']).
aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py
index 6963b56..c485bc2 100644
--- a/torch/_inductor/fx_passes/mkldnn_fusion.py
+++ b/torch/_inductor/fx_passes/mkldnn_fusion.py
@@ -625,10 +625,10 @@
# convert reshape+linear+reshape to a single linear for applying fusion path.
@register_freezing_graph_pattern(
CallFunction(
- aten.view.default,
+ aten.reshape.default,
CallFunction(
mkldnn._linear_pointwise.default,
- CallFunction(aten.view.default, Arg(), KeywordArg("reshape_1")),
+ CallFunction(aten.reshape.default, Arg(), KeywordArg("reshape_1")),
Arg(),
Arg(),
Arg(),