inductor: convert view to reshape before doing fake_tensor_prop at freezing step (#104612)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104612
Approved by: https://github.com/jgong5, https://github.com/eellison, https://github.com/shunting314
diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py
index ed87c2a..5c8593e 100644
--- a/test/inductor/test_inductor_freezing.py
+++ b/test/inductor/test_inductor_freezing.py
@@ -329,6 +329,29 @@
         self.assertEqual(eager, compiled)
         self.assertTrue(weight_ref() is None)
 
+    def test_conv_layout_convert_with_view(self):
+        class Model(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.conv = nn.Conv2d(
+                    3, 128, kernel_size=3, padding=1, stride=1, bias=False
+                )
+
+            def forward(self, x):
+                x = self.conv(x)
+                return torch.flatten(x, 1)
+
+        mod = Model().to(self.device).eval()
+
+        @torch.compile()
+        def foo(mod, inp):
+            return mod(inp)
+
+        with torch.no_grad():
+            x = torch.rand(2, 3, 5, 5).to(self.device)
+            mod_eager = mod(x)
+            self.assertEqual(foo(mod, x), mod_eager)
+
     def test_conv_weight_layout_convert(self):
         class Model(torch.nn.Module):
             def __init__(self):
diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py
index c5e41f9..06b8e86 100644
--- a/torch/_inductor/freezing.py
+++ b/torch/_inductor/freezing.py
@@ -7,10 +7,12 @@
 import torch
 import torch.fx.traceback as fx_traceback
 import torch.utils._pytree as pytree
-
 from torch._dynamo.utils import detect_fake_mode, dynamo_timed
 from torch._functorch.compile_utils import fx_graph_cse
+
+from torch._inductor.compile_fx import fake_tensor_prop
 from torch._inductor.fx_passes.freezing_patterns import freezing_passes
+from torch._inductor.fx_passes.post_grad import view_to_reshape
 from torch.ao.quantization._pt2e.utils import _fuse_conv_bn_
 from torch.fx.experimental.proxy_tensor import make_fx
 from . import config
@@ -221,8 +223,10 @@
     aot_autograd_gm.graph = cse_graph
     aot_autograd_gm.recompile()
 
-    from torch._inductor.compile_fx import fake_tensor_prop
-
+    # We have convert conv's weight to channels last which may meet error for .view
+    # when doing fake_tensor_prop. So we need to convert view to reshape first.
+    # See the details in fx_codegen_and_compile of compile_fx.py.
+    view_to_reshape(aot_autograd_gm)
     # Make sure meta['val'] is properly setup(weight conversion
     # or decompose_unfused_batchnorms lost meta['val']).
     aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py
index 6963b56..c485bc2 100644
--- a/torch/_inductor/fx_passes/mkldnn_fusion.py
+++ b/torch/_inductor/fx_passes/mkldnn_fusion.py
@@ -625,10 +625,10 @@
         # convert reshape+linear+reshape to a single linear for applying fusion path.
         @register_freezing_graph_pattern(
             CallFunction(
-                aten.view.default,
+                aten.reshape.default,
                 CallFunction(
                     mkldnn._linear_pointwise.default,
-                    CallFunction(aten.view.default, Arg(), KeywordArg("reshape_1")),
+                    CallFunction(aten.reshape.default, Arg(), KeywordArg("reshape_1")),
                     Arg(),
                     Arg(),
                     Arg(),