Enable relu fusion with prepacked linear/conv. (#35705)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35705
Introduces a pass for relu fusion.
Test Plan:
python test/test_xnnpack_integration.py
Imported from OSS
Differential Revision: D20746592
fbshipit-source-id: 6c22f60a20e9121618c85077b9b58fb8d4082b3b
diff --git a/test/test_xnnpack_integration.py b/test/test_xnnpack_integration.py
index ea07620..ae1f5e9 100644
--- a/test/test_xnnpack_integration.py
+++ b/test/test_xnnpack_integration.py
@@ -74,10 +74,10 @@
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
dilations = (dilation, dilation)
- assume(height + 2 * paddings[0] >=
- dilations[0] * (kernels[0] - 1) + 1)
- assume(width + 2 * paddings[1] >=
- dilations[1] * (kernels[1] - 1) + 1)
+ assume(height + 2 * paddings[0]
+ >= dilations[0] * (kernels[0] - 1) + 1)
+ assume(width + 2 * paddings[1]
+ >= dilations[1] * (kernels[1] - 1) + 1)
input_data = torch.rand((batch_size, input_channels, height, width))
if (format is not None):
@@ -325,10 +325,10 @@
strides = (stride_h, stride_w)
paddings = (pad_h, pad_w)
dilations = (dilation, dilation)
- assume(height + 2 * paddings[0] >=
- dilations[0] * (kernels[0] - 1) + 1)
- assume(width + 2 * paddings[1] >=
- dilations[1] * (kernels[1] - 1) + 1)
+ assume(height + 2 * paddings[0]
+ >= dilations[0] * (kernels[0] - 1) + 1)
+ assume(width + 2 * paddings[1]
+ >= dilations[1] * (kernels[1] - 1) + 1)
input_data = torch.rand((batch_size, input_channels, height, width))
if (format is not None):
@@ -387,29 +387,35 @@
" Please build with USE_XNNPACK=1.")
class TestXNNPACKRewritePass(TestCase):
def test_linear(self):
- def validate_transformed_module(module_name, pattern_count_map, data_shape, prepack_removal=False):
- scripted_model = torch.jit.script(module_name())
+ def validate_transformed_module(
+ module_instance,
+ pattern_count_map,
+ data_shape,
+ prepack_removal=False,
+ fuse_clamping_ops=False):
+ scripted_model = torch.jit.script(module_instance)
scripted_model.eval()
- input_data = torch.rand(data_shape)
+ input_data = torch.normal(1, 20, size=data_shape)
ref_result = scripted_model(input_data)
torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
- if (prepack_removal):
+ if fuse_clamping_ops or prepack_removal:
scripted_model._c = torch._C._freeze_module(scripted_model._c)
+ if fuse_clamping_ops:
+ torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
+ if (prepack_removal):
torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
buffer = io.BytesIO()
torch.jit.save(scripted_model, buffer)
buffer.seek(0)
deserialized_scripted_model = torch.jit.load(buffer)
- file_check = FileCheck()
for pattern, v in pattern_count_map.items():
if (v == 0):
- file_check.check(pattern)
+ FileCheck().check(pattern).run(deserialized_scripted_model.graph)
elif (v == -1):
- file_check.check_not(pattern)
+ FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
else:
- file_check.check_count(pattern, v, exactly=True)
- file_check.run(deserialized_scripted_model.graph)
+ FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
xnnpack_result = deserialized_scripted_model(input_data)
torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
@@ -438,8 +444,8 @@
pattern_count_map = {"Tensor = prim::CallFunction": -1,
"prepacked::linear_clamp_prepack": 1,
"prepacked::linear_clamp_run": 1}
- validate_transformed_module(Linear, pattern_count_map, data_shape)
- validate_transformed_module(LinearNoBias, pattern_count_map, data_shape)
+ validate_transformed_module(Linear(), pattern_count_map, data_shape)
+ validate_transformed_module(LinearNoBias(), pattern_count_map, data_shape)
# Conv params
batch_size = 2
@@ -479,7 +485,7 @@
pattern_count_map = {"Tensor = aten::conv2d": -1,
"prepacked::conv2d_clamp_prepack": 1,
"prepacked::conv2d_clamp_run": 1}
- validate_transformed_module(Conv2D, pattern_count_map, data_shape)
+ validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
input_data = torch.rand((batch_size, input_channels, height, width))
conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
@@ -490,7 +496,7 @@
linear_weight_shape = (weight_output_dim, linear_input_shape)
class M(torch.nn.Module):
- def __init__(self):
+ def __init__(self, activation_fn=F.relu):
super(M, self).__init__()
self.conv_weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)))
self.conv_bias = torch.nn.Parameter(torch.Tensor(torch.rand((conv_bias_shape))))
@@ -500,24 +506,125 @@
self.paddings = paddings
self.dilations = dilations
self.groups = groups
+ self.activation_fn = activation_fn
def forward(self, x):
o = F.conv2d(x, self.conv_weight, self.conv_bias,
self.strides, self.paddings, self.dilations, self.groups)
+ o = self.activation_fn(o)
o = o.permute([0, 2, 3, 1])
o = F.linear(o, self.linear_weight, self.linear_bias)
- return F.relu(o)
+ return self.activation_fn(o)
pattern_count_map = {"Tensor = aten::conv2d": -1,
"prepacked::conv2d_clamp_prepack": 1,
"prepacked::conv2d_clamp_run": 1,
- "Tensor = prim::CallFunction": -1,
"prepacked::linear_clamp_prepack": 1,
"prepacked::linear_clamp_run": 1}
- validate_transformed_module(M, pattern_count_map, data_shape)
+ validate_transformed_module(M(), pattern_count_map, data_shape)
+ pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
+ pattern_count_map["Tensor = prim::CallFunction"] = -1
+ pattern_count_map["prepacked::linear_clamp_prepack"] = -1
+ validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
+
+ # Not inplace relu fusion test.
+ pattern_count_map = {"aten::relu": 2,
+ "prepacked::conv2d_clamp_prepack": -1,
+ "prepacked::conv2d_clamp_run": 1,
+ "prepacked::linear_clamp_prepack": -1,
+ "prepacked::linear_clamp_run": 1}
+ validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
pattern_count_map["prepacked::linear_clamp_prepack"] = -1
- validate_transformed_module(M, pattern_count_map, data_shape, True)
+ pattern_count_map["aten::relu"] = -1
+ validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True, fuse_clamping_ops=True)
+
+ # Inplace relu fusion test.
+ pattern_count_map = {"aten::relu": 2,
+ "prepacked::conv2d_clamp_prepack": -1,
+ "prepacked::conv2d_clamp_run": 1,
+ "prepacked::linear_clamp_prepack": -1,
+ "prepacked::linear_clamp_run": 1}
+ validate_transformed_module(M(F.relu_), pattern_count_map, data_shape, prepack_removal=True)
+ pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
+ pattern_count_map["prepacked::linear_clamp_prepack"] = -1
+ pattern_count_map["aten::relu"] = -1
+ validate_transformed_module(M(F.relu_), pattern_count_map, data_shape,
+ prepack_removal=True, fuse_clamping_ops=True)
+
+ # Not inplace hardtanh fusion test.
+ pattern_count_map = {"aten::hardtanh": 2,
+ "prepacked::conv2d_clamp_prepack": -1,
+ "prepacked::conv2d_clamp_run": 1,
+ "prepacked::linear_clamp_prepack": -1,
+ "prepacked::linear_clamp_run": 1}
+ validate_transformed_module(M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True)
+ pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
+ pattern_count_map["prepacked::linear_clamp_prepack"] = -1
+ pattern_count_map["aten::hardtanh"] = -1
+ validate_transformed_module(M(F.hardtanh), pattern_count_map, data_shape,
+ prepack_removal=True, fuse_clamping_ops=True)
+
+ # Inplace hardtanh fusion test.
+ pattern_count_map = {"aten::hardtanh_": 2,
+ "prepacked::conv2d_clamp_prepack": -1,
+ "prepacked::conv2d_clamp_run": 1,
+ "prepacked::linear_clamp_prepack": -1,
+ "prepacked::linear_clamp_run": 1}
+ validate_transformed_module(M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True)
+ pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
+ pattern_count_map["prepacked::linear_clamp_prepack"] = -1
+ pattern_count_map["aten::hardtanh_"] = -1
+ validate_transformed_module(M(F.hardtanh_), pattern_count_map, data_shape,
+ prepack_removal=True, fuse_clamping_ops=True)
+
+ class MFusionAntiPattern(torch.nn.Module):
+ def __init__(self):
+ super(MFusionAntiPattern, self).__init__()
+ self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
+ self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
+ self.strides = strides
+ self.paddings = paddings
+ self.dilations = dilations
+ self.groups = groups
+
+ def forward(self, x):
+ o = F.linear(x, self.linear_weight, self.linear_bias)
+ o = F.relu(o)
+ o = F.hardtanh(o)
+ return o
+
+ # Unfusable hardtanh.
+ pattern_count_map = {"aten::hardtanh": 1, # hardtanh cannot be.
+ "aten::relu": -1, # relu is fused.
+ "prepacked::linear_clamp_prepack": -1,
+ "prepacked::linear_clamp_run": 1}
+ validate_transformed_module(MFusionAntiPattern(), pattern_count_map, (16, linear_weight_shape[1]),
+ prepack_removal=True, fuse_clamping_ops=True)
+
+ class MFusionAntiPatternParamMinMax(torch.nn.Module):
+ def __init__(self):
+ super(MFusionAntiPatternParamMinMax, self).__init__()
+ self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
+ self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
+ self.strides = strides
+ self.paddings = paddings
+ self.dilations = dilations
+ self.groups = groups
+
+ def forward(self, x):
+ min = x[0, 0]
+ max = min + 10
+ o = F.linear(x, self.linear_weight, self.linear_bias)
+ o = F.hardtanh(o, min, max)
+ return o
+
+ # Unfusable hardtanh.
+ pattern_count_map = {"aten::hardtanh": 1, # hardtanh cannot be.
+ "prepacked::linear_clamp_prepack": -1,
+ "prepacked::linear_clamp_run": 1}
+ validate_transformed_module(MFusionAntiPatternParamMinMax(), pattern_count_map, (16, linear_weight_shape[1]),
+ prepack_removal=True, fuse_clamping_ops=True)
if __name__ == "__main__":
diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.h b/torch/csrc/jit/passes/graph_rewrite_helper.h
index 1da61d2..f6e6f4d 100644
--- a/torch/csrc/jit/passes/graph_rewrite_helper.h
+++ b/torch/csrc/jit/passes/graph_rewrite_helper.h
@@ -1,6 +1,7 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/ir/irparser.h>
namespace torch {
namespace jit {
@@ -17,6 +18,25 @@
const std::unordered_map<std::string, Value*>& vmap);
void replaceConvolutionWithConv2d(std::shared_ptr<Graph>& graph);
+// This struct contains a compiled IR patterns slated for use in the
+// findPatternMatches function. The struct encapsulates the common
+// information from parseIR that is used in conjunction with the
+// pattern matching facility. A const instance of this struct can
+// also be stored away to cache the compiled IR pattern and reduce
+// runtime cost
+struct PatternInfo {
+ std::string pattern_string;
+ std::unique_ptr<Graph> pattern_graph;
+ std::unordered_map<std::string, Value*> vmap;
+
+ static PatternInfo parse_from_str(std::string pattern_string) {
+ PatternInfo rv{
+ std::move(pattern_string), std::make_unique<Graph>(), decltype(vmap){}};
+ parseIR(rv.pattern_string, rv.pattern_graph.get(), rv.vmap);
+ return rv;
+ }
+};
+
} // namespace graph_rewrite_helper
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index 4381465..6af2e8f 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -32,6 +32,7 @@
using graph_rewrite_helper::getFuncName;
using graph_rewrite_helper::getIValue;
using graph_rewrite_helper::getValue;
+using graph_rewrite_helper::PatternInfo;
using graph_rewrite_helper::replaceConvolutionWithConv2d;
// Map of quantization parameter name and value
@@ -39,25 +40,6 @@
// _scalar_type and _axis(for per channel quantization)
using QParamVector = std::vector<std::pair<std::string, IValue>>;
-// This struct contains a compiled IR patterns slated for use in the
-// findPatternMatches function. The struct encapsulates the common
-// information from parseIR that is used in conjunction with the
-// pattern matching facility. A const instance of this struct can
-// also be stored away to cache the compiled IR pattern and reduce
-// runtime cost
-struct PatternInfo {
- std::string pattern_string;
- std::unique_ptr<Graph> pattern_graph;
- std::unordered_map<std::string, Value*> vmap;
-
- static PatternInfo parse_from_str(std::string pattern_string) {
- PatternInfo rv{
- std::move(pattern_string), std::make_unique<Graph>(), decltype(vmap){}};
- parseIR(rv.pattern_string, rv.pattern_graph.get(), rv.vmap);
- return rv;
- }
-};
-
struct PatternsAndModules {
bool is_conv;
bool is_per_channel;
diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
index 67b0556..2d10870 100644
--- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp
+++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
@@ -3,6 +3,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/prepack_folding.h>
@@ -86,6 +87,182 @@
rewriter.runOnGraph(graph);
}
+bool isClampFusable(
+ const Match& match,
+ const std::unordered_map<std::string, Value*>& vmap) {
+ const auto& match_vmap = match.values_map;
+ TORCH_CHECK(
+ vmap.find("dummy_min_max") != vmap.end(),
+ "Expected to find dummy_min_max Value in the subgraph to be replaced.");
+ auto dummy_min_max =
+ graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
+
+ auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
+
+ // Also check if the output_min and output_max values are actually constant.
+ // If hardtanh's min/max Value's are not actually constants, we will end up
+ // rerouting those values to prepack op. And if they are not constants
+ // we will not be able to remove prepacking ops.
+ if (vmap.find("output_min") != vmap.end()) {
+ // aten::relu pattern does not have output_min/output_max.
+ // aten::hardtanh/_ does.
+ TORCH_CHECK(
+ vmap.find("output_max") != vmap.end(),
+ "Expected to find output_max as well given "
+ "output_min exist in pattern graph.");
+ // If output_min/max are not constant, we get c10::nullopt.
+ auto output_min =
+ graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
+ auto output_max =
+ graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
+ is_fusable =
+ is_fusable && (output_min.has_value() && output_max.has_value());
+ }
+
+ return is_fusable;
+}
+
+void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
+ SubgraphRewriter rewriter;
+
+ std::string linear_prepack_run_hardtanh_fused = R"(
+ graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
+ %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack(
+ %weight, %bias, %output_min, %output_max)
+ %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+ return (%res))";
+
+ std::string conv2d_prepack_run_hardtanh_fused = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+ %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %output_min, %output_max)
+ %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+ return (%r) )";
+
+ std::string linear_prepack_run_hardtanh = R"(
+ graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
+ %packed_weight_bias = prepacked::linear_clamp_prepack(
+ %weight, %bias, %dummy_min_max, %dummy_min_max)
+ %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+ %res = aten::hardtanh(%linear_res, %output_min, %output_max)
+ return (%res))";
+
+ rewriter.RegisterRewritePattern(
+ linear_prepack_run_hardtanh, linear_prepack_run_hardtanh_fused);
+
+ std::string conv2d_prepack_run_hardtanh = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+ %packed_weight_bias = prepacked::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %dummy_min_max, %dummy_min_max)
+ %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+ %r = aten::hardtanh(%conv2d_res, %output_min, %output_max)
+ return (%r) )";
+
+ rewriter.RegisterRewritePattern(
+ conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
+
+ std::string linear_prepack_run_hardtanh_inplace = R"(
+ graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
+ %packed_weight_bias = prepacked::linear_clamp_prepack(
+ %weight, %bias, %dummy_min_max, %dummy_min_max)
+ %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+ %res = aten::hardtanh_(%linear_res, %output_min, %output_max)
+ return (%res))";
+
+ std::string conv2d_prepack_run_hardtanh_inplace = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+ %packed_weight_bias = prepacked::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %dummy_min_max, %dummy_min_max)
+ %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+ %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
+ return (%r) )";
+
+ rewriter.RegisterRewritePattern(
+ linear_prepack_run_hardtanh_inplace, linear_prepack_run_hardtanh_fused);
+ rewriter.RegisterRewritePattern(
+ conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
+
+ rewriter.runOnGraph(graph, isClampFusable);
+}
+
+void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
+ SubgraphRewriter rewriter;
+
+ std::string linear_prepack_run_relu_fused = R"(
+ graph(%input, %weight, %bias, %dummy_min_max):
+ %output_min: float = prim::Constant[value=0.0]()
+ %output_max: None = prim::Constant()
+ %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack(
+ %weight, %bias, %output_min, %output_max)
+ %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+ return (%res))";
+
+ std::string conv2d_prepack_run_relu_fused = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %dummy_min_max):
+ %output_min: float = prim::Constant[value=0.0]()
+ %output_max: None = prim::Constant()
+ %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %output_min, %output_max)
+ %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+ return (%r) )";
+
+ std::string linear_prepack_run_relu = R"(
+ graph(%input, %weight, %bias, %dummy_min_max):
+ %packed_weight_bias = prepacked::linear_clamp_prepack(
+ %weight, %bias, %dummy_min_max, %dummy_min_max)
+ %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+ %res = aten::relu(%linear_res)
+ return (%res))";
+
+ rewriter.RegisterRewritePattern(
+ linear_prepack_run_relu, linear_prepack_run_relu_fused);
+
+ std::string conv2d_prepack_run_relu = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %dummy_min_max):
+ %packed_weight_bias = prepacked::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %dummy_min_max, %dummy_min_max)
+ %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+ %r = aten::relu(%conv2d_res)
+ return (%r) )";
+
+ rewriter.RegisterRewritePattern(
+ conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
+
+ std::string linear_prepack_run_relu_inplace = R"(
+ graph(%input, %weight, %bias, %dummy_min_max):
+ %packed_weight_bias = prepacked::linear_clamp_prepack(
+ %weight, %bias, %dummy_min_max, %dummy_min_max)
+ %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+ %res = aten::relu_(%linear_res)
+ return (%res))";
+
+ std::string conv2d_prepack_run_relu_inplace = R"(
+ graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+ %dilation:int[], %groups:int, %dummy_min_max):
+ %packed_weight_bias = prepacked::conv2d_clamp_prepack(
+ %weight, %bias, %stride, %padding, %dilation, %groups,
+ %dummy_min_max, %dummy_min_max)
+ %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+ %r = aten::relu_(%conv2d_res)
+ return (%r) )";
+
+ rewriter.RegisterRewritePattern(
+ linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused);
+ rewriter.RegisterRewritePattern(
+ conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
+ rewriter.runOnGraph(graph, isClampFusable);
+}
+
} // namespace
void insertPrePackedOps(std::shared_ptr<Graph>& graph) {
@@ -103,6 +280,12 @@
}
}
+void fusePrePackedLinearConvWithClamp(script::Module& module) {
+ auto graph = module.get_method("forward").graph();
+ fuseReluWithPackedOps(graph);
+ fuseHardtanhWithPackedOps(graph);
+}
+
void FoldPrePackingOps(script::Module& m) {
PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
return (
@@ -133,6 +316,11 @@
"XNNPACK is not enabled. Please build with USE_XNNPACK=1");
}
+void fusePrePackedLinearConvWithClamp(script::Module& module) {
+ TORCH_INTERNAL_ASSERT(
+ "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
+}
+
void FoldPrePackingOps(script::Module& m) {
TORCH_INTERNAL_ASSERT(
"XNNPACK is not enabled. Please build with USE_XNNPACK=1");
diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.h b/torch/csrc/jit/passes/xnnpack_rewrite.h
index b969bfd..e69fff1 100644
--- a/torch/csrc/jit/passes/xnnpack_rewrite.h
+++ b/torch/csrc/jit/passes/xnnpack_rewrite.h
@@ -7,6 +7,7 @@
namespace jit {
TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);
TORCH_API void insertPrePackedOps(script::Module& module);
+TORCH_API void fusePrePackedLinearConvWithClamp(script::Module& module);
TORCH_API void FoldPrePackingOps(script::Module& module);
TORCH_API void optimizeForMobile(script::Module& module);
} // namespace jit
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 364337e..8693a19 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -502,6 +502,11 @@
"_jit_pass_insert_prepacked_ops",
[](script::Module& module) { return insertPrePackedOps(module); })
.def(
+ "_jit_pass_fuse_clamp_w_prepacked_linear_conv",
+ [](script::Module& module) {
+ return fusePrePackedLinearConvWithClamp(module);
+ })
+ .def(
"_jit_pass_fold_prepacking_ops",
[](script::Module& module) { return FoldPrePackingOps(module); })
.def(