jit: Conv3d + BatchNorm3d fusion (#40082)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40082

Differential Revision: D22120340

Pulled By: jerryzh168

fbshipit-source-id: fce6c5f03fe7ab6c60620cbdf547d5a466a470e3
diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py
index cd55347..22431b4 100644
--- a/test/quantization/test_quantize_jit.py
+++ b/test/quantization/test_quantize_jit.py
@@ -79,12 +79,15 @@
     """ Test graph mode quantization passes used by quantize_jit
     """
     def test_foldbn_trivial(self):
+        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
+        conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
+
         # Test trivial case
         class TestModule(torch.nn.Module):
-            def __init__(self):
+            def __init__(self, dim):
                 super(TestModule, self).__init__()
-                self.conv = torch.nn.Conv2d(1, 20, 5, 1)
-                self.bn = torch.nn.BatchNorm2d(num_features=20)
+                self.conv = conv_module[dim](1, 20, 5, 1)
+                self.bn = bn_module[dim](num_features=20)
                 self.bn.eps = 0.0023
 
             def forward(self, x):
@@ -92,24 +95,20 @@
                 x = self.bn(x)
                 return x
 
+        options = itertools.product([True, False], [2, 3])
+        data = {2 : torch.rand(1, 1, 6, 6), 3 : torch.rand(1, 1, 6, 6, 6)}
         # Check that the transformation doesn't change numerics
-        for tracing_mode in [True, False]:
-            eager = TestModule()
-            eager.eval()
-            if tracing_mode:
-                x = torch.rand(1, 1, 6, 6)
-                scripted_or_traced = torch.jit.trace(eager, x)
-            else:
-                scripted_or_traced = torch.jit.script(eager)
-            scripted_or_traced.eval()
-
+        for tracing, dim in options:
+            eager = TestModule(dim).eval()
+            x = data[dim]
+            scripted_or_traced = get_script_module(eager, tracing, x).eval()
             # Check that in the original script module's forward we have two
             # CallMethod nodes. One of them should be for conv.forward and the other
             # for bn.forward.
             FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
                 .run(str(get_forward(scripted_or_traced._c).graph))
 
-            # Run FoldConvBatchnorm2d pass.
+            # Run FoldConvBatchnorm pass.
             scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
 
             # Check that after the pass one of the CallMethods is gone (supposedly,
@@ -118,16 +117,18 @@
                 .run(str(get_forward_graph(scripted_or_traced._c)))
 
             # Check that the transformation doesn't change numerics
-            x = torch.rand(1, 1, 6, 6)
             self.assertEqual(eager(x), scripted_or_traced(x))
 
     def test_foldbn_trivial_nobias(self):
+        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
+        conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
+
         # Test trivial case
         class TestModule(torch.nn.Module):
-            def __init__(self):
+            def __init__(self, dim):
                 super(TestModule, self).__init__()
-                self.conv = torch.nn.Conv2d(1, 20, 5, 1, bias=False)
-                self.bn = torch.nn.BatchNorm2d(num_features=20)
+                self.conv = conv_module[dim](1, 20, 5, 1, bias=False)
+                self.bn = bn_module[dim](num_features=20)
                 # to make sure new bias is not zero
                 self.bn.eps = 0.0027
                 self.bn.bias = torch.nn.Parameter(torch.rand([20]))
@@ -137,23 +138,19 @@
                 x = self.bn(x)
                 return x
 
-        for tracing_mode in [True, False]:
-            eager = TestModule()
-            eager.eval()
-            if tracing_mode:
-                x = torch.rand(1, 1, 6, 6)
-                scripted_or_traced = torch.jit.trace(eager, x)
-            else:
-                scripted_or_traced = torch.jit.script(eager)
-            scripted_or_traced.eval()
-
+        options = itertools.product([True, False], [2, 3])
+        data = {2 : torch.rand(1, 1, 6, 6), 3 : torch.rand(1, 1, 6, 6, 6)}
+        for tracing, dim in options:
+            eager = TestModule(dim).eval()
+            x = data[dim]
+            scripted_or_traced = get_script_module(eager, tracing, x).eval()
             # Check that in the original script module's forward we have two
             # CallMethod nodes. One of them should be for conv.forward and the other
             # for bn.forward.
             FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
                 .run(str(get_forward_graph(scripted_or_traced._c)))
 
-            # Run FoldConvBatchnorm2d pass.
+            # Run FoldConvBatchnorm pass.
             scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
 
             # Check that after the pass one of the CallMethods is gone (supposedly,
@@ -162,16 +159,18 @@
                 .run(str(get_forward_graph(scripted_or_traced._c)))
 
             # Check that the transformation doesn't change numerics
-            x = torch.rand(1, 1, 6, 6)
             self.assertEqual(eager(x), scripted_or_traced(x))
 
     def test_foldbn_in_submodule(self):
+        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
+        conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
+
         # Test that we find Conv-BN patterns in submodules
         class SubModule(torch.nn.Module):
-            def __init__(self):
+            def __init__(self, dim):
                 super(SubModule, self).__init__()
-                self.conv = torch.nn.Conv2d(1, 20, 5, 1)
-                self.bn = torch.nn.BatchNorm2d(num_features=20)
+                self.conv = conv_module[dim](1, 20, 5, 1)
+                self.bn = bn_module[dim](num_features=20)
 
             def forward(self, x):
                 x = self.conv(x)
@@ -179,24 +178,20 @@
                 return x
 
         class TestModule(torch.nn.Module):
-            def __init__(self):
+            def __init__(self, dim):
                 super(TestModule, self).__init__()
-                self.sub = SubModule()
+                self.sub = SubModule(dim)
 
             def forward(self, x):
                 x = self.sub(x)
                 return x
 
-        for tracing_mode in [True, False]:
-            eager = TestModule()
-            eager.eval()
-            if tracing_mode:
-                x = torch.rand(1, 1, 10, 10)
-                scripted_or_traced = torch.jit.trace(eager, x)
-            else:
-                scripted_or_traced = torch.jit.script(eager)
-            scripted_or_traced.eval()
-
+        options = itertools.product([True, False], [2, 3])
+        data = {2 : torch.rand(1, 1, 10, 10), 3 : torch.rand(1, 1, 10, 10, 10)}
+        for tracing, dim in options:
+            eager = TestModule(dim).eval()
+            x = data[dim]
+            scripted_or_traced = get_script_module(eager, tracing, x).eval()
             FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
                 .run(str(get_forward_graph(scripted_or_traced.sub._c)))
 
@@ -205,72 +200,23 @@
             FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
                 .run(str(get_forward_graph(scripted_or_traced.sub._c)))
 
-            x = torch.rand(1, 1, 10, 10)
-            self.assertEqual(eager(x), scripted_or_traced(x))
-
-    def test_foldbn_in_customConv2D(self):
-        # Make sure a custom Conv2D class is not folded
-        # as we do not know it does.
-        class CustomConv2D(torch.nn.Module):
-            def __init__(self, a, b, c, d):
-                super(CustomConv2D, self).__init__()
-
-            def forward(self, x):
-                return F.relu(x)
-
-        class SubModule(torch.nn.Module):
-            def __init__(self):
-                super(SubModule, self).__init__()
-                self.conv = CustomConv2D(1, 20, 5, 1)
-                self.bn = torch.nn.BatchNorm2d(num_features=20)
-
-            def forward(self, x):
-                x = self.conv(x)
-                x = self.bn(x)
-                return x
-
-        class TestModule(torch.nn.Module):
-            def __init__(self):
-                super(TestModule, self).__init__()
-                self.sub = SubModule()
-
-            def forward(self, x):
-                x = self.sub(x)
-                return x
-
-        for tracing_mode in [True, False]:
-            eager = TestModule()
-            eager.eval()
-            if tracing_mode:
-                x = torch.rand(1, 20, 10, 10)
-                scripted_or_traced = torch.jit.trace(eager, x)
-            else:
-                scripted_or_traced = torch.jit.script(eager)
-            scripted_or_traced.eval()
-
-            FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
-                .run(str(get_forward_graph(scripted_or_traced.sub._c)))
-
-            scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
-
-            FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
-                .run(str(get_forward_graph(scripted_or_traced.sub._c)))
-
-            x = torch.rand(1, 20, 10, 10)
             self.assertEqual(eager(x), scripted_or_traced(x))
 
     def test_foldbn_shared_classtype(self):
+        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
+        conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
+
         class TestModule(torch.nn.Module):
-            def __init__(self, bias=False):
+            def __init__(self, dim, bias=False):
                 super(TestModule, self).__init__()
-                self.conv1 = torch.nn.Conv2d(5, 5, 3, bias=bias)
-                self.bn1 = torch.nn.BatchNorm2d(num_features=5)
+                self.conv1 = conv_module[dim](5, 5, 3, bias=bias)
+                self.bn1 = bn_module[dim](num_features=5)
                 self.bn1.running_mean.fill_(-0.2)
                 self.bn1.bias = torch.nn.Parameter(torch.rand([5]))
                 # to make sure new bias is not zero
                 self.bn1.eps = 0.0023
-                self.conv2 = torch.nn.Conv2d(5, 5, 3, bias=bias)
-                self.bn2 = torch.nn.BatchNorm2d(num_features=5)
+                self.conv2 = conv_module[dim](5, 5, 3, bias=bias)
+                self.bn2 = bn_module[dim](num_features=5)
                 self.bn2.eps = 0.0029
                 self.relu = torch.nn.ReLU()
 
@@ -283,33 +229,32 @@
                 x = self.relu(x)
                 return x
 
-        for tracing_mode in [True, False]:
-            for bias in [True, False]:
-                eager = TestModule(bias).eval()
-                if tracing_mode:
-                    x = torch.rand(1, 5, 6, 6)
-                    scripted_or_traced = torch.jit.trace(eager, x).copy()
-                else:
-                    scripted_or_traced = torch.jit.script(eager).copy()
-                torch._C._jit_pass_dedup_module_uses(scripted_or_traced ._c)
-                folded = fuse_conv_bn_jit(scripted_or_traced)
-                x = torch.rand(1, 5, 6, 6)
-                self.assertEqual(eager(x), scripted_or_traced(x))
+        options = itertools.product([True, False], [2, 2], [True, False])
+        data = {2 : torch.rand(1, 5, 6, 6), 3 : torch.rand(1, 5, 6, 6, 6)}
+        for tracing, dim, bias in options:
+            eager = TestModule(dim, bias).eval()
+            x = data[dim]
+            scripted_or_traced = get_script_module(eager, tracing, x).copy()
+            torch._C._jit_pass_dedup_module_uses(scripted_or_traced ._c)
+            folded = fuse_conv_bn_jit(scripted_or_traced)
+            self.assertEqual(eager(x), scripted_or_traced(x))
 
     def test_foldbn_complex_cases(self):
-        # This test case attempt to try combinations of conv2d with bias/nobias
+        # This test case attempt to try combinations of conv2d/conv3d with bias/nobias
         # as well as BatchNorm with affine/no-affine along with varying the
         # number of layers.
         # this only works when default dtype is double
         torch.set_default_dtype(torch.double)
+        bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
+        conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
 
         class SubModule(torch.nn.Module):
-            def __init__(self, num_blocks, enable_bias, enable_affine):
+            def __init__(self, dim, num_blocks, enable_bias, enable_affine):
                 super(SubModule, self).__init__()
                 layers = []
                 for i in range(num_blocks):
-                    layers.append(torch.nn.Conv2d(20, 20, 5, 1, bias=enable_bias))
-                    bn_obj = torch.nn.BatchNorm2d(num_features=20, affine=enable_affine)
+                    layers.append(conv_module[dim](20, 20, 5, 1, bias=enable_bias))
+                    bn_obj = bn_module[dim](num_features=20, affine=enable_affine)
                     if enable_affine:
                         bn_obj.weight = torch.nn.Parameter(torch.rand_like(bn_obj.weight))
                         bn_obj.bias = torch.nn.Parameter(torch.rand_like(bn_obj.bias))
@@ -322,25 +267,20 @@
                 return self.layers(x)
 
         class TestModule(torch.nn.Module):
-            def __init__(self, num_blocks, enable_bias, enable_affine):
+            def __init__(self, dim, num_blocks, enable_bias, enable_affine):
                 super(TestModule, self).__init__()
-                self.sub = SubModule(num_blocks, enable_bias, enable_affine)
+                self.sub = SubModule(dim, num_blocks, enable_bias, enable_affine)
 
             def forward(self, x):
                 x = self.sub(x)
                 return x
 
-        bias_affine_options = itertools.product([True, False], [True, False], [True, False], [1, 2])
-        for (tracing_mode, enable_bias, enable_bn_affine, num_layers) in bias_affine_options:
-            eager = TestModule(num_layers, enable_bias, enable_bn_affine)
-            eager.eval()
-
-            if tracing_mode:
-                x = torch.rand(1, 20, 10, 10)
-                scripted_or_traced = torch.jit.trace(eager, x)
-            else:
-                scripted_or_traced = torch.jit.script(eager)
-            scripted_or_traced.eval()
+        options = itertools.product([True, False], [2, 3], [True, False], [True, False], [1, 2])
+        data = {2 : torch.rand(1, 20, 10, 10), 3 : torch.rand(1, 20, 10, 10, 10)}
+        for tracing, dim, enable_bias, enable_bn_affine, num_layers in options:
+            eager = TestModule(dim, num_layers, enable_bias, enable_bn_affine).eval()
+            x = data[dim]
+            scripted_or_traced = get_script_module(eager, tracing, x).eval()
 
             FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers * 2, exactly=True) \
                 .run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
@@ -350,8 +290,8 @@
             FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers, exactly=True) \
                 .run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
 
-            x = torch.rand(1, 20, 10, 10)
             self.assertEqual(eager(x), scripted_or_traced(x))
+
         torch.set_default_dtype(torch.float)
 
     def test_fuse_linear(self):
diff --git a/torch/csrc/jit/passes/fold_conv_bn.cpp b/torch/csrc/jit/passes/fold_conv_bn.cpp
index 426eeae..19d811b 100644
--- a/torch/csrc/jit/passes/fold_conv_bn.cpp
+++ b/torch/csrc/jit/passes/fold_conv_bn.cpp
@@ -2,6 +2,7 @@
 #include <torch/csrc/jit/ir/subgraph_matcher.h>
 #include <torch/csrc/jit/jit_log.h>
 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
+#include <torch/csrc/jit/passes/quantization/helper.h>
 
 #include <stack>
 
@@ -25,11 +26,11 @@
   return m.hasattr(name) && m.attr(name).isTensor();
 }
 
-void replaceConv2dBiasWithGetAttr(Module& module) {
+void replaceConvBiasWithGetAttr(Module& module) {
   auto graph = module.get_method("forward").graph();
   // Only looks fors _convolution pattern.
-  // Thus assumes that tracing will have always gotten rid of aten::conv2d.
-  // If it did not, BN folding will fail.
+  // Thus assumes that tracing will have always gotten rid of aten::conv2d or
+  // aten::conv3d. If it did not, BN folding will fail.
   const PatternInfo& pattern_convolution = PatternInfo::parse_from_str(R"(
       graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
           %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
@@ -56,10 +57,9 @@
   }
 }
 
-void addBiasForConv2dIfNone(Module& module) {
+void addBiasForConvIfNone(Module& module, const std::string& pattern_name) {
   auto t = module.type()->expect<ClassType>();
   auto real_typename = t->name()->qualifiedName();
-  const std::string pattern_name("Conv2d");
   if (real_typename.size() >= pattern_name.size() &&
       (0 ==
        real_typename.compare(
@@ -71,23 +71,23 @@
       t->addAttribute("bias", optional_tensor_type, true);
       auto optional_tensor = c10::optional<at::Tensor>();
       module.setattr("bias", optional_tensor);
-      replaceConv2dBiasWithGetAttr(module);
+      replaceConvBiasWithGetAttr(module);
     }
   }
   for (Module m : module.children()) {
-    addBiasForConv2dIfNone(m);
+    addBiasForConvIfNone(m, pattern_name);
   }
 }
 
-class FoldConvBatchNorm2dHelper {
+class FoldConvBatchNormHelper {
  public:
   /**
-   * In this step we find all Conv2d - BatchNorm2d patterns in the graph
+   * In this step we find all Conv - BatchNorm patterns in the graph
    * and extract the corresponding parameters for these two modules,
    * and record informations for the modifications of the graph without
    * actually performing these modifications.
    */
-  void analyze(Module& module);
+  void analyze(Module& module, const PatternInfo& pattern);
   /**
    * In this step we perform all the modifications including
    * setting the attributes for conv module, rewriting values
@@ -102,8 +102,8 @@
       ConvBNParameters& r);
 
   /**
-   * Given the current weight and bias tensors of a Conv2d module and parameters
-   * of the BatchNorm2d module we're folding with, compute the updated values
+   * Given the current weight and bias tensors of a Conv module and parameters
+   * of the BatchNorm module we're folding with, compute the updated values
    * for the weight and bias.
    *
    * The function is basically copied from torch/nn/utils/fusion.py
@@ -120,10 +120,13 @@
   std::unordered_set<Node*> nodes_to_delete_;
 };
 
-std::tuple<at::Tensor, at::Tensor> FoldConvBatchNorm2dHelper::
+std::tuple<at::Tensor, at::Tensor> FoldConvBatchNormHelper::
     computeUpdatedConvWeightAndBias(const ConvBNParameters& p) {
   at::Tensor bn_var_rsqrt = at::rsqrt(p.bn_rv + p.bn_eps);
-  at::Tensor new_w = p.conv_w * (p.bn_w * bn_var_rsqrt).reshape({-1, 1, 1, 1});
+  const int64_t ndim = p.conv_w.dim();
+  at::DimVector sizes(ndim, 1);
+  sizes.at(0) = -1;
+  at::Tensor new_w = p.conv_w * (p.bn_w * bn_var_rsqrt).reshape(sizes);
   at::Tensor new_b = (p.conv_b - p.bn_rm) * bn_var_rsqrt * p.bn_w + p.bn_b;
   return std::make_tuple(new_w, new_b);
 }
@@ -189,7 +192,7 @@
   return true;
 }
 
-bool FoldConvBatchNorm2dHelper::tryExtractingConvBNParameters(
+bool FoldConvBatchNormHelper::tryExtractingConvBNParameters(
     Module& conv,
     Module& bn,
     ConvBNParameters& r) {
@@ -214,24 +217,15 @@
   return true;
 }
 
-void FoldConvBatchNorm2dHelper::analyze(Module& module) {
-  // Dot in the ".Conv2d" and ".BatchNorm2d" is an attempt to
-  // prevent matching module's whose name might end with Conv2d
-  // But are user defined modules.
-  const PatternInfo pattern = PatternInfo::parse_from_str(R"IR(
-graph(%self, %x):
-    %conv_submodule = match::module[name=".Conv2d"](%self)
-    %conv_out = prim::CallMethod[name="forward"](%conv_submodule, %x)
-    %bn_submodule = match::module[name=".BatchNorm2d"](%self)
-    %bn_out = prim::CallMethod[name="forward"](%bn_submodule, %conv_out)
-    return (%bn_out))IR");
-
+void FoldConvBatchNormHelper::analyze(
+    Module& module,
+    const PatternInfo& pattern) {
   const Graph& pattern_graph = *pattern.pattern_graph;
   const auto& vmap = pattern.vmap;
   Value* pattern_conv_out = vmap.at("conv_out");
   Value* pattern_bn_out = vmap.at("bn_out");
-  Value* pattern_conv_submodule = vmap.at("conv_submodule");
-  Value* pattern_bn_submodule = vmap.at("bn_submodule");
+  Value* pattern_conv_submodule = vmap.at("conv");
+  Value* pattern_bn_submodule = vmap.at("batchnorm");
   Node* pattern_conv = pattern_conv_out->node();
   Node* pattern_bn = pattern_bn_out->node();
 
@@ -251,11 +245,11 @@
     for (auto& method : current.get_methods()) {
       GRAPH_DUMP(
           current.type()->name()->name() + "::" + method.name() +
-              "() before Conv2d-BatchNorm2d folding",
+              "() before Conv-BatchNorm folding",
           method.graph());
       const auto& matches = findPatternMatches(pattern_graph, *method.graph());
 
-      GRAPH_DEBUG("number of Conv2d-BatchNorm2d matches: ", matches.size());
+      GRAPH_DEBUG("number of Conv-BatchNorm matches: ", matches.size());
       Graph* g = method.graph().get();
       if (!conv_bn_names_.count(g)) {
         // This is to make sure we don't visit one graph multiple times
@@ -329,7 +323,7 @@
   } // while
 }
 
-void FoldConvBatchNorm2dHelper::transform() {
+void FoldConvBatchNormHelper::transform() {
   for (const auto& item : conv_module_and_params_) {
     Module conv(item.first);
     auto w_b = item.second;
@@ -353,12 +347,35 @@
 
 } // namespace
 
-Module FoldConvBatchNorm2d(const Module& module) {
-  FoldConvBatchNorm2dHelper h;
+Module FoldConvBatchNorm(const Module& module) {
   Module m = module.clone();
-  addBiasForConv2dIfNone(m);
-  h.analyze(m);
-  h.transform();
+
+  addBiasForConvIfNone(m, "Conv2d");
+  addBiasForConvIfNone(m, "Conv3d");
+  // Conv2d + BatchNorm2d
+  const PatternInfo pattern2d = PatternInfo::parse_from_str(
+      R"(
+graph(%self, %input, %conv, %batchnorm):
+    %conv_out = prim::CallMethod[name="forward"](%conv, %input)
+    %bn_out = prim::CallMethod[name="forward"](%batchnorm, %conv_out)
+    return (%bn_out))",
+      {is_conv2d_module, is_batchnorm2d_module});
+  // Conv3d + BatchNorm3d
+  const PatternInfo pattern3d = PatternInfo::parse_from_str(
+      R"(
+graph(%self, %input, %conv, %batchnorm):
+    %conv_out = prim::CallMethod[name="forward"](%conv, %input)
+    %bn_out = prim::CallMethod[name="forward"](%batchnorm, %conv_out)
+    return (%bn_out))",
+      {is_conv3d_module, is_batchnorm3d_module});
+
+  const std::vector<std::reference_wrapper<const PatternInfo>> patterns = {
+      pattern2d, pattern3d};
+  for (const auto& pattern : patterns) {
+    FoldConvBatchNormHelper h;
+    h.analyze(m, pattern);
+    h.transform();
+  }
   return m;
 }
 
diff --git a/torch/csrc/jit/passes/fold_conv_bn.h b/torch/csrc/jit/passes/fold_conv_bn.h
index 1a36b38..54b27c6 100644
--- a/torch/csrc/jit/passes/fold_conv_bn.h
+++ b/torch/csrc/jit/passes/fold_conv_bn.h
@@ -11,7 +11,7 @@
  * The weight and bias of the Conv2d are correspondingly updated. Should only be
  * used on modules in eval mode.
  */
-TORCH_API Module FoldConvBatchNorm2d(const Module& module);
+TORCH_API Module FoldConvBatchNorm(const Module& module);
 
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp
index 6d0dc6c..92f3d76 100644
--- a/torch/csrc/jit/passes/vulkan_rewrite.cpp
+++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp
@@ -162,7 +162,7 @@
 script::Module vulkanOptimizeForMobile(const script::Module& m) {
   auto cloned_module = m.clone();
   cloned_module.eval();
-  cloned_module = FoldConvBatchNorm2d(cloned_module);
+  cloned_module = FoldConvBatchNorm(cloned_module);
   vulkanInsertPrePackedOps(cloned_module);
   cloned_module = freeze_module(cloned_module);
   vulkanFusePrePackedConvWithClamp(cloned_module);
diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
index 153b4e2..4b23705 100644
--- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp
+++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
@@ -281,7 +281,7 @@
   cloned_module.eval();
 
   if (!optimization_blacklist.count(MobileOptimizerType::CONV_BN_FUSION)) {
-    cloned_module = FoldConvBatchNorm2d(cloned_module);
+    cloned_module = FoldConvBatchNorm(cloned_module);
   }
 
   if (!optimization_blacklist.count(
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 910d5cd..9c3132f 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -230,7 +230,7 @@
       .def(
           "_jit_pass_quant_fusion",
           [](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
-      .def("_jit_pass_fold_convbn", &FoldConvBatchNorm2d)
+      .def("_jit_pass_fold_convbn", &FoldConvBatchNorm)
       .def(
           "_freeze_module",
           [](Module& module, std::vector<std::string>& preservedAttrs) {