inductor: don't assert error when do cpu fx fusion for training mode (#93837)
This PR will do:
1. skip CPU fx fusion for training mode.
2. skip Linear packed when input dim<2.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93837
Approved by: https://github.com/jgong5, https://github.com/desertfire, https://github.com/jansel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index fc6e3f8..4fa7dc3 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -1775,13 +1775,16 @@
raise unittest.SkipTest("only support cpu conv2d packed test")
x_shape = (1, 3, 56, 56)
- mod = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 3, 3)).eval()
- v = torch.randn(x_shape, dtype=torch.float32)
- with torch.no_grad():
- self.common(
- mod,
- (v,),
+ for mode_train in [True, False]:
+ mod = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 3, 3)).train(
+ mode=mode_train
)
+ v = torch.randn(x_shape, dtype=torch.float32)
+ with torch.no_grad():
+ self.common(
+ mod,
+ (v,),
+ )
@slow()
def test_conv2d_unary(self):
@@ -1819,6 +1822,7 @@
[1, 4],
["same", 0],
test_memory_format,
+ [True, False],
)
for (
@@ -1829,6 +1833,7 @@
groups,
padding,
memory_format,
+ mode_train,
) in options:
oC = 32 * groups
iC = 3 * groups
@@ -1842,7 +1847,7 @@
dilation=dilation,
groups=groups,
bias=bias,
- ).eval()
+ ).train(mode=mode_train)
# TODO: add bf16 test for cpu path?
# TODO: this test fails when requires_grad=False
@@ -1916,6 +1921,7 @@
[1, 4],
["same", 0],
test_memory_format,
+ [True, False],
)
for (
@@ -1927,6 +1933,7 @@
groups,
padding,
memory_format,
+ mode_train,
) in options:
oC = 32 * groups
iC = 3 * groups
@@ -1941,7 +1948,7 @@
padding,
bias,
kernel_size=kernel_size,
- ).eval()
+ ).train(mode=mode_train)
mod = mod.to(memory_format=memory_format)
# TODO: add bf16 test
v = torch.randn(x_shape, dtype=torch.float32).to(
@@ -1954,7 +1961,7 @@
)
def test_linear_packed(self):
- options = itertools.product([[2, 3, 10], [2, 10]], [True, False])
+ options = itertools.product([[2, 3, 10], [2, 10], [10]], [True, False])
for input_shape, bias in options:
mod = torch.nn.Sequential(
torch.nn.Linear(input_shape[-1], 30, bias=bias)
diff --git a/torch/_inductor/mkldnn.py b/torch/_inductor/mkldnn.py
index b4354e0..c451493 100644
--- a/torch/_inductor/mkldnn.py
+++ b/torch/_inductor/mkldnn.py
@@ -756,6 +756,9 @@
if len(node.args[index_node].users) > 1:
continue
computation_node = modules[node.args[index_node].target]
+ if computation_node.training:
+ continue
+
# TODO: support padding str input("valid", "same").
if type(computation_node) in [nn.Conv2d] and isinstance(
computation_node.padding, str
@@ -805,6 +808,8 @@
if node.args[1].args[0] == node.args[0]:
continue
computation_node = modules[node.args[1].target]
+ if computation_node.training:
+ continue
# TODO: support padding str input("valid", "same").
if type(computation_node) in [nn.Conv2d] and isinstance(
computation_node.padding, str
@@ -835,12 +840,19 @@
assert isinstance(node.target, str)
cur_module = modules[node.target]
if type(cur_module) in computation_op_packed_map:
+ if cur_module.training:
+ continue
computation_node_input_meta = node.args[0].meta.get("tensor_meta")
if computation_node_input_meta.dtype != torch.float32:
continue
if type(cur_module) in [torch.nn.Linear] and not torch._C.has_mkl:
continue
computation_node_input_size = computation_node_input_meta.shape
+ if (
+ type(cur_module) in [torch.nn.Linear]
+ and len(computation_node_input_size) < 2
+ ):
+ continue
if type(cur_module) in [nn.Conv2d] and isinstance(
cur_module.padding, str
):