Stop disabling ShapeProp with dynamic_shapes for mkldnn (#103381)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103381
Approved by: https://github.com/anijain2305
diff --git a/torch/_inductor/mkldnn.py b/torch/_inductor/mkldnn.py
index c0fa6d9..ebbf9ac 100644
--- a/torch/_inductor/mkldnn.py
+++ b/torch/_inductor/mkldnn.py
@@ -3,12 +3,12 @@
from typing import Optional
import torch
-import torch._dynamo.config as dynamo_config
import torch.nn as nn
import torch.nn.functional as F
from torch._dynamo.utils import detect_fake_mode
from torch.fx.experimental.optimization import replace_node_module
+from torch.fx.experimental.symbolic_shapes import free_symbols
from torch.fx.passes.shape_prop import ShapeProp
from torch.nn.modules.utils import _pair
from . import config
@@ -256,9 +256,11 @@
if not is_cpu:
return gm
fake_mode = detect_fake_mode(example_inputs)
- if config.cpp.weight_prepack:
- if not dynamo_config.dynamic_shapes:
- ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
+ # NB: free_symbols test here is a BIG hammer. ShapeProp doesn't
+ # work with symbolic shapes though, see
+ # https://github.com/pytorch/pytorch/pull/103512
+ if config.cpp.weight_prepack and not any(free_symbols(e) for e in example_inputs):
+ ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs)
gm = pack_module(gm)
return gm
@@ -283,7 +285,7 @@
and not torch.ops.mkldnn._is_mkldnn_bf16_supported()
):
continue
- if dynamo_config.dynamic_shapes:
+ if free_symbols(node.args[0].meta.get("tensor_meta").shape):
computation_node_input_size = None
# Conv2d and ConvTranspose2d weight format are dependent on input size,
# but ShapeProp may be failed to get the input size, so we skip them.