TorchDynamo: enable convolution bn folding for functional bn (#89746)

Motivation: for Timm model, there is always use customer-defined BN which using F.batch_norm: https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/layers/norm_act.py#L26, and the fx graph will be like:
```
-------------  ----------------------  ---------------------------------------  ---------------------------------------------------------------------------------------------------------  --------
placeholder    x                       x                                        ()                                                                                                         {}
call_module    self_conv               self_conv                                (x,)                                                                                                       {}
get_attr       self_bn_running_mean_1  self_bn_running_mean                     ()                                                                                                         {}
get_attr       self_bn_running_var     self_bn_running_var                      ()                                                                                                         {}
get_attr       self_bn_weight          self_bn_weight                           ()                                                                                                         {}
get_attr       self_bn_bias            self_bn_bias                             ()                                                                                                         {}
call_function  batch_norm              <function batch_norm at 0x7f07196cdf70>  (self_conv, self_bn_running_mean_1, self_bn_running_var, self_bn_weight, self_bn_bias, False, 0.1, 1e-05)  {}
call_module    self_bn_drop            self_bn_drop                             (batch_norm,)
```

the original conv+bn folding path doesn't work for **F.batch_norm**, but for **F.batch_norm** case, if its' parameters are const(attr of the module and will not be updated), we can also do the const folding's optimization. This PR will enable it and will improve the Timm models' performance.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89746
Approved by: https://github.com/jgong5, https://github.com/jansel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 41fce39..2d532c8 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -1424,7 +1424,7 @@
         )
 
     # For gpu path, there has a accurcy issue,
-    @unittest.skipIf(HAS_CUDA, "only support cpu conv  bn test")
+    @unittest.skipIf(HAS_CUDA, "only support cpu conv bn test")
     def test_conv_bn_fuse(self):
         input_shapes = {1: (112,), 2: (112, 112), 3: (55, 55, 55)}
         conv_modules = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
@@ -1479,6 +1479,88 @@
                         (v,),
                     )
 
+    # For gpu path, there has a accurcy issue,
+    @unittest.skipIf(HAS_CUDA, "only support cpu conv bn test")
+    def test_conv_functional_bn_fuse(self):
+        # Define a BatchNorm using functional BN.
+        class BatchNorm(torch.nn.BatchNorm2d):
+            def __init__(
+                self,
+                num_features,
+                eps=1e-5,
+                momentum=0.1,
+                affine=True,
+                track_running_stats=True,
+                device=None,
+                dtype=None,
+            ):
+                factory_kwargs = {"device": device, "dtype": dtype}
+                super(BatchNorm, self).__init__(
+                    num_features,
+                    eps=eps,
+                    momentum=momentum,
+                    affine=affine,
+                    track_running_stats=track_running_stats,
+                    **factory_kwargs,
+                )
+
+            def forward(self, x):
+                if self.momentum is None:
+                    exponential_average_factor = 0.0
+                else:
+                    exponential_average_factor = self.momentum
+
+                if self.training and self.track_running_stats:
+                    # TODO: if statement only here to tell the jit to skip emitting this when it is None
+                    if self.num_batches_tracked is not None:  # type: ignore[has-type]
+                        self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore[has-type]
+                        if self.momentum is None:  # use cumulative moving average
+                            exponential_average_factor = 1.0 / float(
+                                self.num_batches_tracked
+                            )
+                        else:  # use exponential moving average
+                            exponential_average_factor = self.momentum
+                if self.training:
+                    bn_training = True
+                else:
+                    bn_training = (self.running_mean is None) and (
+                        self.running_var is None
+                    )
+                x = F.batch_norm(
+                    x,
+                    # If buffers are not to be tracked, ensure that they won't be updated
+                    self.running_mean
+                    if not self.training or self.track_running_stats
+                    else None,
+                    self.running_var
+                    if not self.training or self.track_running_stats
+                    else None,
+                    self.weight,
+                    self.bias,
+                    bn_training,
+                    exponential_average_factor,
+                    self.eps,
+                )
+                return x
+
+        v = torch.randn(1, 3, 556, 56, dtype=torch.float32)
+        mod = torch.nn.Sequential(
+            torch.nn.Conv2d(
+                3,
+                64,
+                kernel_size=3,
+                dilation=1,
+                groups=1,
+                bias=True,
+            ),
+            BatchNorm(64),
+        ).eval()
+        with torch.no_grad():
+            self.common(
+                mod,
+                (v,),
+            )
+
     @unittest.skipIf(HAS_CUDA, "only support cpu conv2d unary test")
     def test_conv2d_packed(self):
         x_shape = (1, 3, 56, 56)
diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py
index e1dfdea..8d95971 100644
--- a/torch/_inductor/overrides.py
+++ b/torch/_inductor/overrides.py
@@ -19,7 +19,7 @@
 from torch.fx.passes.shape_prop import ShapeProp
 from torch.nn import functional as F
 from torch.nn.modules.utils import _pair
-from torch.nn.utils.fusion import fuse_conv_bn_eval
+from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
 from torch.overrides import TorchFunctionMode
 
 from . import config
@@ -545,6 +545,7 @@
         return gm
     if not is_cpu:
         return gm
+    gm = remove_identity(gm)
     gm = fuse_conv_bn(gm)
     # For binary fusion, we need to check inputs info to make sure
     # the binary inputs have same tensor info(device, dtype, and layout).
@@ -559,18 +560,78 @@
     return gm
 
 
+# check the pattern: (nn.module, F.function) matched.
+def matches_module_function_pattern(pattern, node, modules):
+    if len(node.args) == 0:
+        return False
+    if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
+        node, torch.fx.Node
+    ):
+        return False
+    # the first node is call_module
+    if node.args[0].op != "call_module":
+        return False
+    if not isinstance(node.args[0].target, str):
+        return False
+    if node.args[0].target not in modules:
+        return False
+    if type(modules[node.args[0].target]) is not pattern[0]:
+        return False
+    # the second node is call_function
+    if node.op != "call_function":
+        return False
+    if node.target != pattern[1]:
+        return False
+    # make sure node.args[0] output is only used by current node.
+    if len(node.args[0].users) > 1:
+        return False
+    return True
+
+
+def fetch_attr(target: str, mod):
+    target_atoms = target.split(".")
+    attr_itr = mod
+    for i, atom in enumerate(target_atoms):
+        if not hasattr(attr_itr, atom):
+            raise RuntimeError(
+                f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}"
+            )
+        attr_itr = getattr(attr_itr, atom)
+    return attr_itr
+
+
+def remove_identity(gm: torch.fx.GraphModule):
+    """
+    Removes all identity layers from the module.
+    """
+
+    class IdentityRemover(torch.fx.Transformer):
+        def call_module(self, target, args, kwargs):
+            if isinstance(self.submodules[target], nn.Identity):
+                assert len(args) == 1
+                return args[0]
+            else:
+                return super().call_module(target, args, kwargs)
+
+    return IdentityRemover(gm).transform()
+
+
 def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False):
     """
     Fuses Convolution/BN layers for inference purposes.
     """
-    patterns = [
+    modules_patterns = [
         (torch.nn.Conv1d, torch.nn.BatchNorm1d),
         (torch.nn.Conv2d, torch.nn.BatchNorm2d),
         (torch.nn.Conv3d, torch.nn.BatchNorm3d),
     ]
+    module_function_patterns = [
+        (torch.nn.Conv1d, F.batch_norm),
+        (torch.nn.Conv2d, F.batch_norm),
+        (torch.nn.Conv3d, F.batch_norm),
+    ]
     modules = dict(gm.named_modules())
-
-    for pattern in patterns:
+    for pattern in modules_patterns:
         for node in gm.graph.nodes:
             if matches_module_pattern(pattern, node, modules):
                 if len(node.args[0].users) > 1:  # Output of conv is used by other nodes
@@ -587,7 +648,46 @@
                 node.replace_all_uses_with(node.args[0])
                 gm.graph.erase_node(node)
                 gm.graph.lint()
+    for pattern in module_function_patterns:
+        for node in gm.graph.nodes:
+            if matches_module_function_pattern(pattern, node, modules):
+                # TODO: support kwargs.
+                if len(node.args) != 8:
+                    continue
+                conv = modules[node.args[0].target]
+                bn_training = node.args[5]
+                bn_eps = node.args[7]
+                if conv.training or bn_training:
+                    continue
+                if type(bn_eps) is not float:
+                    continue
+                bn_args_is_constant = all(
+                    n.op == "get_attr" and len(n.users) == 1 for n in node.args[1:5]
+                )
+                if not bn_args_is_constant:
+                    continue
+                bn_running_mean = fetch_attr(node.args[1].target, gm)
+                bn_running_var = fetch_attr(node.args[2].target, gm)
+                bn_weight = fetch_attr(node.args[3].target, gm)
+                bn_bias = fetch_attr(node.args[4].target, gm)
+                if bn_running_mean is None or bn_running_var is None:
+                    continue
+                fused_conv = copy.deepcopy(conv)
+                fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
+                    fused_conv.weight,
+                    fused_conv.bias,
+                    bn_running_mean,
+                    bn_running_var,
+                    bn_eps,
+                    bn_weight,
+                    bn_bias,
+                )
+                replace_node_module(node.args[0], modules, fused_conv)
+                node.replace_all_uses_with(node.args[0])
+                gm.graph.erase_node(node)
+                gm.graph.lint()
     gm.recompile()
+
     return gm