inductor: don't assert error when do cpu fx fusion for training mode (#93837)

This PR will do:

1. skip CPU fx fusion for training mode.
2. skip Linear packed when input dim<2.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93837
Approved by: https://github.com/jgong5, https://github.com/desertfire, https://github.com/jansel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index fc6e3f8..4fa7dc3 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -1775,13 +1775,16 @@
             raise unittest.SkipTest("only support cpu conv2d packed test")
 
         x_shape = (1, 3, 56, 56)
-        mod = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 3, 3)).eval()
-        v = torch.randn(x_shape, dtype=torch.float32)
-        with torch.no_grad():
-            self.common(
-                mod,
-                (v,),
+        for mode_train in [True, False]:
+            mod = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 3, 3)).train(
+                mode=mode_train
             )
+            v = torch.randn(x_shape, dtype=torch.float32)
+            with torch.no_grad():
+                self.common(
+                    mod,
+                    (v,),
+                )
 
     @slow()
     def test_conv2d_unary(self):
@@ -1819,6 +1822,7 @@
             [1, 4],
             ["same", 0],
             test_memory_format,
+            [True, False],
         )
 
         for (
@@ -1829,6 +1833,7 @@
             groups,
             padding,
             memory_format,
+            mode_train,
         ) in options:
             oC = 32 * groups
             iC = 3 * groups
@@ -1842,7 +1847,7 @@
                 dilation=dilation,
                 groups=groups,
                 bias=bias,
-            ).eval()
+            ).train(mode=mode_train)
 
             # TODO: add bf16 test for cpu path?
             # TODO: this test fails when requires_grad=False
@@ -1916,6 +1921,7 @@
             [1, 4],
             ["same", 0],
             test_memory_format,
+            [True, False],
         )
 
         for (
@@ -1927,6 +1933,7 @@
             groups,
             padding,
             memory_format,
+            mode_train,
         ) in options:
             oC = 32 * groups
             iC = 3 * groups
@@ -1941,7 +1948,7 @@
                 padding,
                 bias,
                 kernel_size=kernel_size,
-            ).eval()
+            ).train(mode=mode_train)
             mod = mod.to(memory_format=memory_format)
             # TODO: add bf16 test
             v = torch.randn(x_shape, dtype=torch.float32).to(
@@ -1954,7 +1961,7 @@
                 )
 
     def test_linear_packed(self):
-        options = itertools.product([[2, 3, 10], [2, 10]], [True, False])
+        options = itertools.product([[2, 3, 10], [2, 10], [10]], [True, False])
         for input_shape, bias in options:
             mod = torch.nn.Sequential(
                 torch.nn.Linear(input_shape[-1], 30, bias=bias)
diff --git a/torch/_inductor/mkldnn.py b/torch/_inductor/mkldnn.py
index b4354e0..c451493 100644
--- a/torch/_inductor/mkldnn.py
+++ b/torch/_inductor/mkldnn.py
@@ -756,6 +756,9 @@
                         if len(node.args[index_node].users) > 1:
                             continue
                         computation_node = modules[node.args[index_node].target]
+                        if computation_node.training:
+                            continue
+
                         # TODO: support padding str input("valid", "same").
                         if type(computation_node) in [nn.Conv2d] and isinstance(
                             computation_node.padding, str
@@ -805,6 +808,8 @@
                     if node.args[1].args[0] == node.args[0]:
                         continue
                     computation_node = modules[node.args[1].target]
+                    if computation_node.training:
+                        continue
                     # TODO: support padding str input("valid", "same").
                     if type(computation_node) in [nn.Conv2d] and isinstance(
                         computation_node.padding, str
@@ -835,12 +840,19 @@
             assert isinstance(node.target, str)
             cur_module = modules[node.target]
             if type(cur_module) in computation_op_packed_map:
+                if cur_module.training:
+                    continue
                 computation_node_input_meta = node.args[0].meta.get("tensor_meta")
                 if computation_node_input_meta.dtype != torch.float32:
                     continue
                 if type(cur_module) in [torch.nn.Linear] and not torch._C.has_mkl:
                     continue
                 computation_node_input_size = computation_node_input_meta.shape
+                if (
+                    type(cur_module) in [torch.nn.Linear]
+                    and len(computation_node_input_size) < 2
+                ):
+                    continue
                 if type(cur_module) in [nn.Conv2d] and isinstance(
                     cur_module.padding, str
                 ):