fix conv+bn folding issue when bn hasn't running states (#71259)

Summary:
Doing conv+bn folding which bn hasn't a running stats, there have error for JIT and FX path:

```
import torch

import torch.nn as nn

import torch.fx.experimental.optimization as optimization

class M(nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.conv = nn.Conv2d(32, 64, 3, stride=2)
        self.bn = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

x = torch.randn([1, 32, 50, 50])

model = M().eval()

'''
# jit path
with torch.no_grad():
    traced = torch.jit.trace(model, x).eval()
    traced = torch.jit.freeze(traced)
'''

# FX path
fused_model = optimization.fuse(model)
```

expected result:
1. JIT path
```
Traceback (most recent call last):
  File "bn_test.py", line 27, in <module>
    traced = torch.jit.freeze(traced)
  File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/jit/_freeze.py", line 119, in freeze
    run_frozen_optimizations(out, optimize_numerics, preserved_methods)
  File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/jit/_freeze.py", line 167, in run_frozen_optimizations
    torch._C._jit_pass_optimize_frozen_graph(mod.graph, optimize_numerics)
RuntimeError: Expected Tensor but got None
```
2. FX path
```
Traceback (most recent call last):
  File "bn_test.py", line 31, in <module>
    model = optimization.fuse(model, inplace=True)
  File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/fx/experimental/optimization.py", line 71, in fuse
    fused_conv = fuse_conv_bn_eval(conv, bn)
  File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/nn/utils/fusion.py", line 11, in fuse_conv_bn_eval
    fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
  File "/home/xiaobinz/miniconda3/envs/pytorch-master/lib/python3.8/site-packages/torch/nn/utils/fusion.py", line 23, in fuse_conv_bn_weights
    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
TypeError: unsupported operand type(s) for +: 'NoneType' and 'float'
```

This PR will fix this issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71259

Reviewed By: anjali411

Differential Revision: D33595049

Pulled By: davidberard98

fbshipit-source-id: 0fe56bb2bb25d6d54ebc53789d2ad22458da9012
(cherry picked from commit 5672c083784585e6e1ec5657f02bd3051afb2b50)
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index a5b02c0..331fd38 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -1606,13 +1606,14 @@
         conv_bias = [True, False]
         module_pairs = [(nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d), (nn.Conv3d, nn.BatchNorm3d)]
         use_tracing = [True, False]
+        bn_running_stats = [True, False]
 
-        for use_bias, modules, tracing in product(conv_bias, module_pairs, use_tracing):
+        for use_bias, modules, tracing, track_stats in product(conv_bias, module_pairs, use_tracing, bn_running_stats):
             class ConvBN(torch.nn.Module):
                 def __init__(self, in_channels, out_channels, **kwargs):
                     super(ConvBN, self).__init__()
                     self.conv = modules[0](in_channels, out_channels, bias=use_bias, **kwargs)
-                    self.bn = modules[1](out_channels, eps=0.001)
+                    self.bn = modules[1](out_channels, eps=0.001, track_running_stats=track_stats)
 
                 def forward(self, x):
                     x = self.conv(x)
@@ -1644,7 +1645,10 @@
 
             scripted_mod = torch.jit.freeze(scripted_mod)
             self.run_pass("fold_frozen_conv_bn", scripted_mod.graph)
-            FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph)
+            if track_stats:
+                FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph)
+            else:
+                FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph)
 
             self.assertEqual(mod_eager(inp), scripted_mod(inp))
             self.assertEqual(mod_eager(inp), scripted_mod(inp))
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index 75283ad..a35c94a 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -667,6 +667,30 @@
 
         self.assertEqual(fused(inp), rn18(inp))
 
+    def test_conv_bn_fusion_not_running_state(self):
+        class M(torch.nn.Module):
+            def __init__(self):
+                super(M, self).__init__()
+                self.conv = torch.nn.Conv2d(32, 64, 3, stride=2)
+                self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
+
+            def forward(self, x):
+                x = self.conv(x)
+                x = self.bn(x)
+                return x
+
+        model = M().eval()
+
+        traced = symbolic_trace(model)
+        fused = optimization.fuse(traced)
+        inp = torch.randn([1, 32, 50, 50])
+
+        # bn need not be folded in conv
+        self.assertTrue(
+            any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
+        )
+        self.assertEqual(fused(inp), model(inp))
+
     def test_call_to_assert_no_msg(self):
         class M(torch.nn.Module):
             def forward(self, a, b):
diff --git a/torch/csrc/jit/passes/frozen_conv_folding.cpp b/torch/csrc/jit/passes/frozen_conv_folding.cpp
index 0aa674e..4b5535a 100644
--- a/torch/csrc/jit/passes/frozen_conv_folding.cpp
+++ b/torch/csrc/jit/passes/frozen_conv_folding.cpp
@@ -62,6 +62,15 @@
         continue;
       }
 
+      auto bn_rm_ivalue = bn->namedInput("running_mean");
+      auto bn_rv_ivalue = bn->namedInput("running_var");
+      // check running_mean and running_var has value, if they are
+      // None(track_running_stats=False), skiping the folding path.
+      if (bn_rm_ivalue->type() == NoneType::get() &&
+          bn_rv_ivalue->type() == NoneType::get()) {
+        continue;
+      }
+
       auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
       auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
       auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py
index 7016556..595dbfa 100644
--- a/torch/fx/experimental/optimization.py
+++ b/torch/fx/experimental/optimization.py
@@ -68,6 +68,8 @@
                     continue
                 conv = modules[node.args[0].target]
                 bn = modules[node.target]
+                if not bn.track_running_stats:
+                    continue
                 fused_conv = fuse_conv_bn_eval(conv, bn)
                 replace_node_module(node.args[0], modules, fused_conv)
                 node.replace_all_uses_with(node.args[0])