jit: Conv3d + BatchNorm3d fusion (#40082)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40082
Differential Revision: D22120340
Pulled By: jerryzh168
fbshipit-source-id: fce6c5f03fe7ab6c60620cbdf547d5a466a470e3
diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py
index cd55347..22431b4 100644
--- a/test/quantization/test_quantize_jit.py
+++ b/test/quantization/test_quantize_jit.py
@@ -79,12 +79,15 @@
""" Test graph mode quantization passes used by quantize_jit
"""
def test_foldbn_trivial(self):
+ bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
+ conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
+
# Test trivial case
class TestModule(torch.nn.Module):
- def __init__(self):
+ def __init__(self, dim):
super(TestModule, self).__init__()
- self.conv = torch.nn.Conv2d(1, 20, 5, 1)
- self.bn = torch.nn.BatchNorm2d(num_features=20)
+ self.conv = conv_module[dim](1, 20, 5, 1)
+ self.bn = bn_module[dim](num_features=20)
self.bn.eps = 0.0023
def forward(self, x):
@@ -92,24 +95,20 @@
x = self.bn(x)
return x
+ options = itertools.product([True, False], [2, 3])
+ data = {2 : torch.rand(1, 1, 6, 6), 3 : torch.rand(1, 1, 6, 6, 6)}
# Check that the transformation doesn't change numerics
- for tracing_mode in [True, False]:
- eager = TestModule()
- eager.eval()
- if tracing_mode:
- x = torch.rand(1, 1, 6, 6)
- scripted_or_traced = torch.jit.trace(eager, x)
- else:
- scripted_or_traced = torch.jit.script(eager)
- scripted_or_traced.eval()
-
+ for tracing, dim in options:
+ eager = TestModule(dim).eval()
+ x = data[dim]
+ scripted_or_traced = get_script_module(eager, tracing, x).eval()
# Check that in the original script module's forward we have two
# CallMethod nodes. One of them should be for conv.forward and the other
# for bn.forward.
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward(scripted_or_traced._c).graph))
- # Run FoldConvBatchnorm2d pass.
+ # Run FoldConvBatchnorm pass.
scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
# Check that after the pass one of the CallMethods is gone (supposedly,
@@ -118,16 +117,18 @@
.run(str(get_forward_graph(scripted_or_traced._c)))
# Check that the transformation doesn't change numerics
- x = torch.rand(1, 1, 6, 6)
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_trivial_nobias(self):
+ bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
+ conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
+
# Test trivial case
class TestModule(torch.nn.Module):
- def __init__(self):
+ def __init__(self, dim):
super(TestModule, self).__init__()
- self.conv = torch.nn.Conv2d(1, 20, 5, 1, bias=False)
- self.bn = torch.nn.BatchNorm2d(num_features=20)
+ self.conv = conv_module[dim](1, 20, 5, 1, bias=False)
+ self.bn = bn_module[dim](num_features=20)
# to make sure new bias is not zero
self.bn.eps = 0.0027
self.bn.bias = torch.nn.Parameter(torch.rand([20]))
@@ -137,23 +138,19 @@
x = self.bn(x)
return x
- for tracing_mode in [True, False]:
- eager = TestModule()
- eager.eval()
- if tracing_mode:
- x = torch.rand(1, 1, 6, 6)
- scripted_or_traced = torch.jit.trace(eager, x)
- else:
- scripted_or_traced = torch.jit.script(eager)
- scripted_or_traced.eval()
-
+ options = itertools.product([True, False], [2, 3])
+ data = {2 : torch.rand(1, 1, 6, 6), 3 : torch.rand(1, 1, 6, 6, 6)}
+ for tracing, dim in options:
+ eager = TestModule(dim).eval()
+ x = data[dim]
+ scripted_or_traced = get_script_module(eager, tracing, x).eval()
# Check that in the original script module's forward we have two
# CallMethod nodes. One of them should be for conv.forward and the other
# for bn.forward.
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced._c)))
- # Run FoldConvBatchnorm2d pass.
+ # Run FoldConvBatchnorm pass.
scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
# Check that after the pass one of the CallMethods is gone (supposedly,
@@ -162,16 +159,18 @@
.run(str(get_forward_graph(scripted_or_traced._c)))
# Check that the transformation doesn't change numerics
- x = torch.rand(1, 1, 6, 6)
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_in_submodule(self):
+ bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
+ conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
+
# Test that we find Conv-BN patterns in submodules
class SubModule(torch.nn.Module):
- def __init__(self):
+ def __init__(self, dim):
super(SubModule, self).__init__()
- self.conv = torch.nn.Conv2d(1, 20, 5, 1)
- self.bn = torch.nn.BatchNorm2d(num_features=20)
+ self.conv = conv_module[dim](1, 20, 5, 1)
+ self.bn = bn_module[dim](num_features=20)
def forward(self, x):
x = self.conv(x)
@@ -179,24 +178,20 @@
return x
class TestModule(torch.nn.Module):
- def __init__(self):
+ def __init__(self, dim):
super(TestModule, self).__init__()
- self.sub = SubModule()
+ self.sub = SubModule(dim)
def forward(self, x):
x = self.sub(x)
return x
- for tracing_mode in [True, False]:
- eager = TestModule()
- eager.eval()
- if tracing_mode:
- x = torch.rand(1, 1, 10, 10)
- scripted_or_traced = torch.jit.trace(eager, x)
- else:
- scripted_or_traced = torch.jit.script(eager)
- scripted_or_traced.eval()
-
+ options = itertools.product([True, False], [2, 3])
+ data = {2 : torch.rand(1, 1, 10, 10), 3 : torch.rand(1, 1, 10, 10, 10)}
+ for tracing, dim in options:
+ eager = TestModule(dim).eval()
+ x = data[dim]
+ scripted_or_traced = get_script_module(eager, tracing, x).eval()
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
@@ -205,72 +200,23 @@
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
- x = torch.rand(1, 1, 10, 10)
- self.assertEqual(eager(x), scripted_or_traced(x))
-
- def test_foldbn_in_customConv2D(self):
- # Make sure a custom Conv2D class is not folded
- # as we do not know it does.
- class CustomConv2D(torch.nn.Module):
- def __init__(self, a, b, c, d):
- super(CustomConv2D, self).__init__()
-
- def forward(self, x):
- return F.relu(x)
-
- class SubModule(torch.nn.Module):
- def __init__(self):
- super(SubModule, self).__init__()
- self.conv = CustomConv2D(1, 20, 5, 1)
- self.bn = torch.nn.BatchNorm2d(num_features=20)
-
- def forward(self, x):
- x = self.conv(x)
- x = self.bn(x)
- return x
-
- class TestModule(torch.nn.Module):
- def __init__(self):
- super(TestModule, self).__init__()
- self.sub = SubModule()
-
- def forward(self, x):
- x = self.sub(x)
- return x
-
- for tracing_mode in [True, False]:
- eager = TestModule()
- eager.eval()
- if tracing_mode:
- x = torch.rand(1, 20, 10, 10)
- scripted_or_traced = torch.jit.trace(eager, x)
- else:
- scripted_or_traced = torch.jit.script(eager)
- scripted_or_traced.eval()
-
- FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
- .run(str(get_forward_graph(scripted_or_traced.sub._c)))
-
- scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
-
- FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
- .run(str(get_forward_graph(scripted_or_traced.sub._c)))
-
- x = torch.rand(1, 20, 10, 10)
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_shared_classtype(self):
+ bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
+ conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
+
class TestModule(torch.nn.Module):
- def __init__(self, bias=False):
+ def __init__(self, dim, bias=False):
super(TestModule, self).__init__()
- self.conv1 = torch.nn.Conv2d(5, 5, 3, bias=bias)
- self.bn1 = torch.nn.BatchNorm2d(num_features=5)
+ self.conv1 = conv_module[dim](5, 5, 3, bias=bias)
+ self.bn1 = bn_module[dim](num_features=5)
self.bn1.running_mean.fill_(-0.2)
self.bn1.bias = torch.nn.Parameter(torch.rand([5]))
# to make sure new bias is not zero
self.bn1.eps = 0.0023
- self.conv2 = torch.nn.Conv2d(5, 5, 3, bias=bias)
- self.bn2 = torch.nn.BatchNorm2d(num_features=5)
+ self.conv2 = conv_module[dim](5, 5, 3, bias=bias)
+ self.bn2 = bn_module[dim](num_features=5)
self.bn2.eps = 0.0029
self.relu = torch.nn.ReLU()
@@ -283,33 +229,32 @@
x = self.relu(x)
return x
- for tracing_mode in [True, False]:
- for bias in [True, False]:
- eager = TestModule(bias).eval()
- if tracing_mode:
- x = torch.rand(1, 5, 6, 6)
- scripted_or_traced = torch.jit.trace(eager, x).copy()
- else:
- scripted_or_traced = torch.jit.script(eager).copy()
- torch._C._jit_pass_dedup_module_uses(scripted_or_traced ._c)
- folded = fuse_conv_bn_jit(scripted_or_traced)
- x = torch.rand(1, 5, 6, 6)
- self.assertEqual(eager(x), scripted_or_traced(x))
+ options = itertools.product([True, False], [2, 2], [True, False])
+ data = {2 : torch.rand(1, 5, 6, 6), 3 : torch.rand(1, 5, 6, 6, 6)}
+ for tracing, dim, bias in options:
+ eager = TestModule(dim, bias).eval()
+ x = data[dim]
+ scripted_or_traced = get_script_module(eager, tracing, x).copy()
+ torch._C._jit_pass_dedup_module_uses(scripted_or_traced ._c)
+ folded = fuse_conv_bn_jit(scripted_or_traced)
+ self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_complex_cases(self):
- # This test case attempt to try combinations of conv2d with bias/nobias
+ # This test case attempt to try combinations of conv2d/conv3d with bias/nobias
# as well as BatchNorm with affine/no-affine along with varying the
# number of layers.
# this only works when default dtype is double
torch.set_default_dtype(torch.double)
+ bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
+ conv_module = {2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
class SubModule(torch.nn.Module):
- def __init__(self, num_blocks, enable_bias, enable_affine):
+ def __init__(self, dim, num_blocks, enable_bias, enable_affine):
super(SubModule, self).__init__()
layers = []
for i in range(num_blocks):
- layers.append(torch.nn.Conv2d(20, 20, 5, 1, bias=enable_bias))
- bn_obj = torch.nn.BatchNorm2d(num_features=20, affine=enable_affine)
+ layers.append(conv_module[dim](20, 20, 5, 1, bias=enable_bias))
+ bn_obj = bn_module[dim](num_features=20, affine=enable_affine)
if enable_affine:
bn_obj.weight = torch.nn.Parameter(torch.rand_like(bn_obj.weight))
bn_obj.bias = torch.nn.Parameter(torch.rand_like(bn_obj.bias))
@@ -322,25 +267,20 @@
return self.layers(x)
class TestModule(torch.nn.Module):
- def __init__(self, num_blocks, enable_bias, enable_affine):
+ def __init__(self, dim, num_blocks, enable_bias, enable_affine):
super(TestModule, self).__init__()
- self.sub = SubModule(num_blocks, enable_bias, enable_affine)
+ self.sub = SubModule(dim, num_blocks, enable_bias, enable_affine)
def forward(self, x):
x = self.sub(x)
return x
- bias_affine_options = itertools.product([True, False], [True, False], [True, False], [1, 2])
- for (tracing_mode, enable_bias, enable_bn_affine, num_layers) in bias_affine_options:
- eager = TestModule(num_layers, enable_bias, enable_bn_affine)
- eager.eval()
-
- if tracing_mode:
- x = torch.rand(1, 20, 10, 10)
- scripted_or_traced = torch.jit.trace(eager, x)
- else:
- scripted_or_traced = torch.jit.script(eager)
- scripted_or_traced.eval()
+ options = itertools.product([True, False], [2, 3], [True, False], [True, False], [1, 2])
+ data = {2 : torch.rand(1, 20, 10, 10), 3 : torch.rand(1, 20, 10, 10, 10)}
+ for tracing, dim, enable_bias, enable_bn_affine, num_layers in options:
+ eager = TestModule(dim, num_layers, enable_bias, enable_bn_affine).eval()
+ x = data[dim]
+ scripted_or_traced = get_script_module(eager, tracing, x).eval()
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers * 2, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
@@ -350,8 +290,8 @@
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
- x = torch.rand(1, 20, 10, 10)
self.assertEqual(eager(x), scripted_or_traced(x))
+
torch.set_default_dtype(torch.float)
def test_fuse_linear(self):
diff --git a/torch/csrc/jit/passes/fold_conv_bn.cpp b/torch/csrc/jit/passes/fold_conv_bn.cpp
index 426eeae..19d811b 100644
--- a/torch/csrc/jit/passes/fold_conv_bn.cpp
+++ b/torch/csrc/jit/passes/fold_conv_bn.cpp
@@ -2,6 +2,7 @@
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
+#include <torch/csrc/jit/passes/quantization/helper.h>
#include <stack>
@@ -25,11 +26,11 @@
return m.hasattr(name) && m.attr(name).isTensor();
}
-void replaceConv2dBiasWithGetAttr(Module& module) {
+void replaceConvBiasWithGetAttr(Module& module) {
auto graph = module.get_method("forward").graph();
// Only looks fors _convolution pattern.
- // Thus assumes that tracing will have always gotten rid of aten::conv2d.
- // If it did not, BN folding will fail.
+ // Thus assumes that tracing will have always gotten rid of aten::conv2d or
+ // aten::conv3d. If it did not, BN folding will fail.
const PatternInfo& pattern_convolution = PatternInfo::parse_from_str(R"(
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
@@ -56,10 +57,9 @@
}
}
-void addBiasForConv2dIfNone(Module& module) {
+void addBiasForConvIfNone(Module& module, const std::string& pattern_name) {
auto t = module.type()->expect<ClassType>();
auto real_typename = t->name()->qualifiedName();
- const std::string pattern_name("Conv2d");
if (real_typename.size() >= pattern_name.size() &&
(0 ==
real_typename.compare(
@@ -71,23 +71,23 @@
t->addAttribute("bias", optional_tensor_type, true);
auto optional_tensor = c10::optional<at::Tensor>();
module.setattr("bias", optional_tensor);
- replaceConv2dBiasWithGetAttr(module);
+ replaceConvBiasWithGetAttr(module);
}
}
for (Module m : module.children()) {
- addBiasForConv2dIfNone(m);
+ addBiasForConvIfNone(m, pattern_name);
}
}
-class FoldConvBatchNorm2dHelper {
+class FoldConvBatchNormHelper {
public:
/**
- * In this step we find all Conv2d - BatchNorm2d patterns in the graph
+ * In this step we find all Conv - BatchNorm patterns in the graph
* and extract the corresponding parameters for these two modules,
* and record informations for the modifications of the graph without
* actually performing these modifications.
*/
- void analyze(Module& module);
+ void analyze(Module& module, const PatternInfo& pattern);
/**
* In this step we perform all the modifications including
* setting the attributes for conv module, rewriting values
@@ -102,8 +102,8 @@
ConvBNParameters& r);
/**
- * Given the current weight and bias tensors of a Conv2d module and parameters
- * of the BatchNorm2d module we're folding with, compute the updated values
+ * Given the current weight and bias tensors of a Conv module and parameters
+ * of the BatchNorm module we're folding with, compute the updated values
* for the weight and bias.
*
* The function is basically copied from torch/nn/utils/fusion.py
@@ -120,10 +120,13 @@
std::unordered_set<Node*> nodes_to_delete_;
};
-std::tuple<at::Tensor, at::Tensor> FoldConvBatchNorm2dHelper::
+std::tuple<at::Tensor, at::Tensor> FoldConvBatchNormHelper::
computeUpdatedConvWeightAndBias(const ConvBNParameters& p) {
at::Tensor bn_var_rsqrt = at::rsqrt(p.bn_rv + p.bn_eps);
- at::Tensor new_w = p.conv_w * (p.bn_w * bn_var_rsqrt).reshape({-1, 1, 1, 1});
+ const int64_t ndim = p.conv_w.dim();
+ at::DimVector sizes(ndim, 1);
+ sizes.at(0) = -1;
+ at::Tensor new_w = p.conv_w * (p.bn_w * bn_var_rsqrt).reshape(sizes);
at::Tensor new_b = (p.conv_b - p.bn_rm) * bn_var_rsqrt * p.bn_w + p.bn_b;
return std::make_tuple(new_w, new_b);
}
@@ -189,7 +192,7 @@
return true;
}
-bool FoldConvBatchNorm2dHelper::tryExtractingConvBNParameters(
+bool FoldConvBatchNormHelper::tryExtractingConvBNParameters(
Module& conv,
Module& bn,
ConvBNParameters& r) {
@@ -214,24 +217,15 @@
return true;
}
-void FoldConvBatchNorm2dHelper::analyze(Module& module) {
- // Dot in the ".Conv2d" and ".BatchNorm2d" is an attempt to
- // prevent matching module's whose name might end with Conv2d
- // But are user defined modules.
- const PatternInfo pattern = PatternInfo::parse_from_str(R"IR(
-graph(%self, %x):
- %conv_submodule = match::module[name=".Conv2d"](%self)
- %conv_out = prim::CallMethod[name="forward"](%conv_submodule, %x)
- %bn_submodule = match::module[name=".BatchNorm2d"](%self)
- %bn_out = prim::CallMethod[name="forward"](%bn_submodule, %conv_out)
- return (%bn_out))IR");
-
+void FoldConvBatchNormHelper::analyze(
+ Module& module,
+ const PatternInfo& pattern) {
const Graph& pattern_graph = *pattern.pattern_graph;
const auto& vmap = pattern.vmap;
Value* pattern_conv_out = vmap.at("conv_out");
Value* pattern_bn_out = vmap.at("bn_out");
- Value* pattern_conv_submodule = vmap.at("conv_submodule");
- Value* pattern_bn_submodule = vmap.at("bn_submodule");
+ Value* pattern_conv_submodule = vmap.at("conv");
+ Value* pattern_bn_submodule = vmap.at("batchnorm");
Node* pattern_conv = pattern_conv_out->node();
Node* pattern_bn = pattern_bn_out->node();
@@ -251,11 +245,11 @@
for (auto& method : current.get_methods()) {
GRAPH_DUMP(
current.type()->name()->name() + "::" + method.name() +
- "() before Conv2d-BatchNorm2d folding",
+ "() before Conv-BatchNorm folding",
method.graph());
const auto& matches = findPatternMatches(pattern_graph, *method.graph());
- GRAPH_DEBUG("number of Conv2d-BatchNorm2d matches: ", matches.size());
+ GRAPH_DEBUG("number of Conv-BatchNorm matches: ", matches.size());
Graph* g = method.graph().get();
if (!conv_bn_names_.count(g)) {
// This is to make sure we don't visit one graph multiple times
@@ -329,7 +323,7 @@
} // while
}
-void FoldConvBatchNorm2dHelper::transform() {
+void FoldConvBatchNormHelper::transform() {
for (const auto& item : conv_module_and_params_) {
Module conv(item.first);
auto w_b = item.second;
@@ -353,12 +347,35 @@
} // namespace
-Module FoldConvBatchNorm2d(const Module& module) {
- FoldConvBatchNorm2dHelper h;
+Module FoldConvBatchNorm(const Module& module) {
Module m = module.clone();
- addBiasForConv2dIfNone(m);
- h.analyze(m);
- h.transform();
+
+ addBiasForConvIfNone(m, "Conv2d");
+ addBiasForConvIfNone(m, "Conv3d");
+ // Conv2d + BatchNorm2d
+ const PatternInfo pattern2d = PatternInfo::parse_from_str(
+ R"(
+graph(%self, %input, %conv, %batchnorm):
+ %conv_out = prim::CallMethod[name="forward"](%conv, %input)
+ %bn_out = prim::CallMethod[name="forward"](%batchnorm, %conv_out)
+ return (%bn_out))",
+ {is_conv2d_module, is_batchnorm2d_module});
+ // Conv3d + BatchNorm3d
+ const PatternInfo pattern3d = PatternInfo::parse_from_str(
+ R"(
+graph(%self, %input, %conv, %batchnorm):
+ %conv_out = prim::CallMethod[name="forward"](%conv, %input)
+ %bn_out = prim::CallMethod[name="forward"](%batchnorm, %conv_out)
+ return (%bn_out))",
+ {is_conv3d_module, is_batchnorm3d_module});
+
+ const std::vector<std::reference_wrapper<const PatternInfo>> patterns = {
+ pattern2d, pattern3d};
+ for (const auto& pattern : patterns) {
+ FoldConvBatchNormHelper h;
+ h.analyze(m, pattern);
+ h.transform();
+ }
return m;
}
diff --git a/torch/csrc/jit/passes/fold_conv_bn.h b/torch/csrc/jit/passes/fold_conv_bn.h
index 1a36b38..54b27c6 100644
--- a/torch/csrc/jit/passes/fold_conv_bn.h
+++ b/torch/csrc/jit/passes/fold_conv_bn.h
@@ -11,7 +11,7 @@
* The weight and bias of the Conv2d are correspondingly updated. Should only be
* used on modules in eval mode.
*/
-TORCH_API Module FoldConvBatchNorm2d(const Module& module);
+TORCH_API Module FoldConvBatchNorm(const Module& module);
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp
index 6d0dc6c..92f3d76 100644
--- a/torch/csrc/jit/passes/vulkan_rewrite.cpp
+++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp
@@ -162,7 +162,7 @@
script::Module vulkanOptimizeForMobile(const script::Module& m) {
auto cloned_module = m.clone();
cloned_module.eval();
- cloned_module = FoldConvBatchNorm2d(cloned_module);
+ cloned_module = FoldConvBatchNorm(cloned_module);
vulkanInsertPrePackedOps(cloned_module);
cloned_module = freeze_module(cloned_module);
vulkanFusePrePackedConvWithClamp(cloned_module);
diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
index 153b4e2..4b23705 100644
--- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp
+++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
@@ -281,7 +281,7 @@
cloned_module.eval();
if (!optimization_blacklist.count(MobileOptimizerType::CONV_BN_FUSION)) {
- cloned_module = FoldConvBatchNorm2d(cloned_module);
+ cloned_module = FoldConvBatchNorm(cloned_module);
}
if (!optimization_blacklist.count(
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 910d5cd..9c3132f 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -230,7 +230,7 @@
.def(
"_jit_pass_quant_fusion",
[](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
- .def("_jit_pass_fold_convbn", &FoldConvBatchNorm2d)
+ .def("_jit_pass_fold_convbn", &FoldConvBatchNorm)
.def(
"_freeze_module",
[](Module& module, std::vector<std::string>& preservedAttrs) {