Enable relu fusion with prepacked linear/conv. (#35705)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35705

Introduces a pass for relu fusion.

Test Plan:
python test/test_xnnpack_integration.py

Imported from OSS

Differential Revision: D20746592

fbshipit-source-id: 6c22f60a20e9121618c85077b9b58fb8d4082b3b
diff --git a/test/test_xnnpack_integration.py b/test/test_xnnpack_integration.py
index ea07620..ae1f5e9 100644
--- a/test/test_xnnpack_integration.py
+++ b/test/test_xnnpack_integration.py
@@ -74,10 +74,10 @@
         strides = (stride_h, stride_w)
         paddings = (pad_h, pad_w)
         dilations = (dilation, dilation)
-        assume(height + 2 * paddings[0] >=
-               dilations[0] * (kernels[0] - 1) + 1)
-        assume(width + 2 * paddings[1] >=
-               dilations[1] * (kernels[1] - 1) + 1)
+        assume(height + 2 * paddings[0]
+               >= dilations[0] * (kernels[0] - 1) + 1)
+        assume(width + 2 * paddings[1]
+               >= dilations[1] * (kernels[1] - 1) + 1)
 
         input_data = torch.rand((batch_size, input_channels, height, width))
         if (format is not None):
@@ -325,10 +325,10 @@
         strides = (stride_h, stride_w)
         paddings = (pad_h, pad_w)
         dilations = (dilation, dilation)
-        assume(height + 2 * paddings[0] >=
-               dilations[0] * (kernels[0] - 1) + 1)
-        assume(width + 2 * paddings[1] >=
-               dilations[1] * (kernels[1] - 1) + 1)
+        assume(height + 2 * paddings[0]
+               >= dilations[0] * (kernels[0] - 1) + 1)
+        assume(width + 2 * paddings[1]
+               >= dilations[1] * (kernels[1] - 1) + 1)
 
         input_data = torch.rand((batch_size, input_channels, height, width))
         if (format is not None):
@@ -387,29 +387,35 @@
                      " Please build with USE_XNNPACK=1.")
 class TestXNNPACKRewritePass(TestCase):
     def test_linear(self):
-        def validate_transformed_module(module_name, pattern_count_map, data_shape, prepack_removal=False):
-            scripted_model = torch.jit.script(module_name())
+        def validate_transformed_module(
+                module_instance,
+                pattern_count_map,
+                data_shape,
+                prepack_removal=False,
+                fuse_clamping_ops=False):
+            scripted_model = torch.jit.script(module_instance)
             scripted_model.eval()
-            input_data = torch.rand(data_shape)
+            input_data = torch.normal(1, 20, size=data_shape)
             ref_result = scripted_model(input_data)
             torch._C._jit_pass_insert_prepacked_ops(scripted_model._c)
-            if (prepack_removal):
+            if fuse_clamping_ops or prepack_removal:
                 scripted_model._c = torch._C._freeze_module(scripted_model._c)
+            if fuse_clamping_ops:
+                torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv(scripted_model._c)
+            if (prepack_removal):
                 torch._C._jit_pass_fold_prepacking_ops(scripted_model._c)
 
             buffer = io.BytesIO()
             torch.jit.save(scripted_model, buffer)
             buffer.seek(0)
             deserialized_scripted_model = torch.jit.load(buffer)
-            file_check = FileCheck()
             for pattern, v in pattern_count_map.items():
                 if (v == 0):
-                    file_check.check(pattern)
+                    FileCheck().check(pattern).run(deserialized_scripted_model.graph)
                 elif (v == -1):
-                    file_check.check_not(pattern)
+                    FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
                 else:
-                    file_check.check_count(pattern, v, exactly=True)
-            file_check.run(deserialized_scripted_model.graph)
+                    FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
             xnnpack_result = deserialized_scripted_model(input_data)
             torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3)
 
@@ -438,8 +444,8 @@
         pattern_count_map = {"Tensor = prim::CallFunction": -1,
                              "prepacked::linear_clamp_prepack": 1,
                              "prepacked::linear_clamp_run": 1}
-        validate_transformed_module(Linear, pattern_count_map, data_shape)
-        validate_transformed_module(LinearNoBias, pattern_count_map, data_shape)
+        validate_transformed_module(Linear(), pattern_count_map, data_shape)
+        validate_transformed_module(LinearNoBias(), pattern_count_map, data_shape)
 
         # Conv params
         batch_size = 2
@@ -479,7 +485,7 @@
         pattern_count_map = {"Tensor = aten::conv2d": -1,
                              "prepacked::conv2d_clamp_prepack": 1,
                              "prepacked::conv2d_clamp_run": 1}
-        validate_transformed_module(Conv2D, pattern_count_map, data_shape)
+        validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
 
         input_data = torch.rand((batch_size, input_channels, height, width))
         conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w))
@@ -490,7 +496,7 @@
         linear_weight_shape = (weight_output_dim, linear_input_shape)
 
         class M(torch.nn.Module):
-            def __init__(self):
+            def __init__(self, activation_fn=F.relu):
                 super(M, self).__init__()
                 self.conv_weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)))
                 self.conv_bias = torch.nn.Parameter(torch.Tensor(torch.rand((conv_bias_shape))))
@@ -500,24 +506,125 @@
                 self.paddings = paddings
                 self.dilations = dilations
                 self.groups = groups
+                self.activation_fn = activation_fn
 
             def forward(self, x):
                 o = F.conv2d(x, self.conv_weight, self.conv_bias,
                              self.strides, self.paddings, self.dilations, self.groups)
+                o = self.activation_fn(o)
                 o = o.permute([0, 2, 3, 1])
                 o = F.linear(o, self.linear_weight, self.linear_bias)
-                return F.relu(o)
+                return self.activation_fn(o)
 
         pattern_count_map = {"Tensor = aten::conv2d": -1,
                              "prepacked::conv2d_clamp_prepack": 1,
                              "prepacked::conv2d_clamp_run": 1,
-                             "Tensor = prim::CallFunction": -1,
                              "prepacked::linear_clamp_prepack": 1,
                              "prepacked::linear_clamp_run": 1}
-        validate_transformed_module(M, pattern_count_map, data_shape)
+        validate_transformed_module(M(), pattern_count_map, data_shape)
+        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
+        pattern_count_map["Tensor = prim::CallFunction"] = -1
+        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
+        validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
+
+        # Not inplace relu fusion test.
+        pattern_count_map = {"aten::relu": 2,
+                             "prepacked::conv2d_clamp_prepack": -1,
+                             "prepacked::conv2d_clamp_run": 1,
+                             "prepacked::linear_clamp_prepack": -1,
+                             "prepacked::linear_clamp_run": 1}
+        validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True)
         pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
         pattern_count_map["prepacked::linear_clamp_prepack"] = -1
-        validate_transformed_module(M, pattern_count_map, data_shape, True)
+        pattern_count_map["aten::relu"] = -1
+        validate_transformed_module(M(), pattern_count_map, data_shape, prepack_removal=True, fuse_clamping_ops=True)
+
+        # Inplace relu fusion test.
+        pattern_count_map = {"aten::relu": 2,
+                             "prepacked::conv2d_clamp_prepack": -1,
+                             "prepacked::conv2d_clamp_run": 1,
+                             "prepacked::linear_clamp_prepack": -1,
+                             "prepacked::linear_clamp_run": 1}
+        validate_transformed_module(M(F.relu_), pattern_count_map, data_shape, prepack_removal=True)
+        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
+        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
+        pattern_count_map["aten::relu"] = -1
+        validate_transformed_module(M(F.relu_), pattern_count_map, data_shape,
+                                    prepack_removal=True, fuse_clamping_ops=True)
+
+        # Not inplace hardtanh fusion test.
+        pattern_count_map = {"aten::hardtanh": 2,
+                             "prepacked::conv2d_clamp_prepack": -1,
+                             "prepacked::conv2d_clamp_run": 1,
+                             "prepacked::linear_clamp_prepack": -1,
+                             "prepacked::linear_clamp_run": 1}
+        validate_transformed_module(M(F.hardtanh), pattern_count_map, data_shape, prepack_removal=True)
+        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
+        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
+        pattern_count_map["aten::hardtanh"] = -1
+        validate_transformed_module(M(F.hardtanh), pattern_count_map, data_shape,
+                                    prepack_removal=True, fuse_clamping_ops=True)
+
+        # Inplace hardtanh fusion test.
+        pattern_count_map = {"aten::hardtanh_": 2,
+                             "prepacked::conv2d_clamp_prepack": -1,
+                             "prepacked::conv2d_clamp_run": 1,
+                             "prepacked::linear_clamp_prepack": -1,
+                             "prepacked::linear_clamp_run": 1}
+        validate_transformed_module(M(F.hardtanh_), pattern_count_map, data_shape, prepack_removal=True)
+        pattern_count_map["prepacked::conv2d_clamp_prepack"] = -1
+        pattern_count_map["prepacked::linear_clamp_prepack"] = -1
+        pattern_count_map["aten::hardtanh_"] = -1
+        validate_transformed_module(M(F.hardtanh_), pattern_count_map, data_shape,
+                                    prepack_removal=True, fuse_clamping_ops=True)
+
+        class MFusionAntiPattern(torch.nn.Module):
+            def __init__(self):
+                super(MFusionAntiPattern, self).__init__()
+                self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
+                self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
+                self.strides = strides
+                self.paddings = paddings
+                self.dilations = dilations
+                self.groups = groups
+
+            def forward(self, x):
+                o = F.linear(x, self.linear_weight, self.linear_bias)
+                o = F.relu(o)
+                o = F.hardtanh(o)
+                return o
+
+        # Unfusable hardtanh.
+        pattern_count_map = {"aten::hardtanh": 1,  # hardtanh cannot be.
+                             "aten::relu": -1,  # relu is fused.
+                             "prepacked::linear_clamp_prepack": -1,
+                             "prepacked::linear_clamp_run": 1}
+        validate_transformed_module(MFusionAntiPattern(), pattern_count_map, (16, linear_weight_shape[1]),
+                                    prepack_removal=True, fuse_clamping_ops=True)
+
+        class MFusionAntiPatternParamMinMax(torch.nn.Module):
+            def __init__(self):
+                super(MFusionAntiPatternParamMinMax, self).__init__()
+                self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
+                self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
+                self.strides = strides
+                self.paddings = paddings
+                self.dilations = dilations
+                self.groups = groups
+
+            def forward(self, x):
+                min = x[0, 0]
+                max = min + 10
+                o = F.linear(x, self.linear_weight, self.linear_bias)
+                o = F.hardtanh(o, min, max)
+                return o
+
+        # Unfusable hardtanh.
+        pattern_count_map = {"aten::hardtanh": 1,  # hardtanh cannot be.
+                             "prepacked::linear_clamp_prepack": -1,
+                             "prepacked::linear_clamp_run": 1}
+        validate_transformed_module(MFusionAntiPatternParamMinMax(), pattern_count_map, (16, linear_weight_shape[1]),
+                                    prepack_removal=True, fuse_clamping_ops=True)
 
 
 if __name__ == "__main__":
diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.h b/torch/csrc/jit/passes/graph_rewrite_helper.h
index 1da61d2..f6e6f4d 100644
--- a/torch/csrc/jit/passes/graph_rewrite_helper.h
+++ b/torch/csrc/jit/passes/graph_rewrite_helper.h
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/ir/irparser.h>
 
 namespace torch {
 namespace jit {
@@ -17,6 +18,25 @@
     const std::unordered_map<std::string, Value*>& vmap);
 void replaceConvolutionWithConv2d(std::shared_ptr<Graph>& graph);
 
+// This struct contains a compiled IR patterns slated for use in the
+// findPatternMatches function. The struct encapsulates the common
+// information from parseIR that is used in conjunction with the
+// pattern matching facility. A const instance of this struct can
+// also be stored away to cache the compiled IR pattern and reduce
+// runtime cost
+struct PatternInfo {
+  std::string pattern_string;
+  std::unique_ptr<Graph> pattern_graph;
+  std::unordered_map<std::string, Value*> vmap;
+
+  static PatternInfo parse_from_str(std::string pattern_string) {
+    PatternInfo rv{
+        std::move(pattern_string), std::make_unique<Graph>(), decltype(vmap){}};
+    parseIR(rv.pattern_string, rv.pattern_graph.get(), rv.vmap);
+    return rv;
+  }
+};
+
 } // namespace graph_rewrite_helper
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index 4381465..6af2e8f 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -32,6 +32,7 @@
 using graph_rewrite_helper::getFuncName;
 using graph_rewrite_helper::getIValue;
 using graph_rewrite_helper::getValue;
+using graph_rewrite_helper::PatternInfo;
 using graph_rewrite_helper::replaceConvolutionWithConv2d;
 
 // Map of quantization parameter name and value
@@ -39,25 +40,6 @@
 // _scalar_type and _axis(for per channel quantization)
 using QParamVector = std::vector<std::pair<std::string, IValue>>;
 
-// This struct contains a compiled IR patterns slated for use in the
-// findPatternMatches function. The struct encapsulates the common
-// information from parseIR that is used in conjunction with the
-// pattern matching facility. A const instance of this struct can
-// also be stored away to cache the compiled IR pattern and reduce
-// runtime cost
-struct PatternInfo {
-  std::string pattern_string;
-  std::unique_ptr<Graph> pattern_graph;
-  std::unordered_map<std::string, Value*> vmap;
-
-  static PatternInfo parse_from_str(std::string pattern_string) {
-    PatternInfo rv{
-        std::move(pattern_string), std::make_unique<Graph>(), decltype(vmap){}};
-    parseIR(rv.pattern_string, rv.pattern_graph.get(), rv.vmap);
-    return rv;
-  }
-};
-
 struct PatternsAndModules {
   bool is_conv;
   bool is_per_channel;
diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
index 67b0556..2d10870 100644
--- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp
+++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp
@@ -3,6 +3,7 @@
 
 #include <torch/csrc/jit/ir/ir.h>
 #include <torch/csrc/jit/ir/subgraph_matcher.h>
+#include <torch/csrc/jit/passes/constant_pooling.h>
 #include <torch/csrc/jit/passes/freeze_module.h>
 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
 #include <torch/csrc/jit/passes/prepack_folding.h>
@@ -86,6 +87,182 @@
   rewriter.runOnGraph(graph);
 }
 
+bool isClampFusable(
+    const Match& match,
+    const std::unordered_map<std::string, Value*>& vmap) {
+  const auto& match_vmap = match.values_map;
+  TORCH_CHECK(
+      vmap.find("dummy_min_max") != vmap.end(),
+      "Expected to find dummy_min_max Value in the subgraph to be replaced.");
+  auto dummy_min_max =
+      graph_rewrite_helper::getIValue("dummy_min_max", match_vmap, vmap);
+
+  auto is_fusable = !dummy_min_max || dummy_min_max.value().isNone();
+
+  // Also check if the output_min and output_max values are actually constant.
+  // If hardtanh's min/max Value's are not actually constants, we will end up
+  // rerouting those values to prepack op. And if they are not constants
+  // we will not be able to remove prepacking ops.
+  if (vmap.find("output_min") != vmap.end()) {
+    // aten::relu pattern does not have output_min/output_max.
+    // aten::hardtanh/_ does.
+    TORCH_CHECK(
+        vmap.find("output_max") != vmap.end(),
+        "Expected to find output_max as well given "
+        "output_min exist in pattern graph.");
+    // If output_min/max are not constant, we get c10::nullopt.
+    auto output_min =
+        graph_rewrite_helper::getIValue("output_min", match_vmap, vmap);
+    auto output_max =
+        graph_rewrite_helper::getIValue("output_max", match_vmap, vmap);
+    is_fusable =
+        is_fusable && (output_min.has_value() && output_max.has_value());
+  }
+
+  return is_fusable;
+}
+
+void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
+  SubgraphRewriter rewriter;
+
+  std::string linear_prepack_run_hardtanh_fused = R"(
+    graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
+        %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack(
+            %weight, %bias, %output_min, %output_max)
+        %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+        return (%res))";
+
+  std::string conv2d_prepack_run_hardtanh_fused = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+        %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %output_min, %output_max)
+        %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+        return (%r) )";
+
+  std::string linear_prepack_run_hardtanh = R"(
+    graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
+        %packed_weight_bias = prepacked::linear_clamp_prepack(
+            %weight, %bias, %dummy_min_max, %dummy_min_max)
+        %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+        %res = aten::hardtanh(%linear_res, %output_min, %output_max)
+        return (%res))";
+
+  rewriter.RegisterRewritePattern(
+      linear_prepack_run_hardtanh, linear_prepack_run_hardtanh_fused);
+
+  std::string conv2d_prepack_run_hardtanh = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+        %packed_weight_bias = prepacked::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %dummy_min_max, %dummy_min_max)
+        %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+        %r = aten::hardtanh(%conv2d_res, %output_min, %output_max)
+        return (%r) )";
+
+  rewriter.RegisterRewritePattern(
+      conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
+
+  std::string linear_prepack_run_hardtanh_inplace = R"(
+    graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
+        %packed_weight_bias = prepacked::linear_clamp_prepack(
+            %weight, %bias, %dummy_min_max, %dummy_min_max)
+        %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+        %res = aten::hardtanh_(%linear_res, %output_min, %output_max)
+        return (%res))";
+
+  std::string conv2d_prepack_run_hardtanh_inplace = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
+        %packed_weight_bias = prepacked::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %dummy_min_max, %dummy_min_max)
+        %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+        %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
+        return (%r) )";
+
+  rewriter.RegisterRewritePattern(
+      linear_prepack_run_hardtanh_inplace, linear_prepack_run_hardtanh_fused);
+  rewriter.RegisterRewritePattern(
+      conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
+
+  rewriter.runOnGraph(graph, isClampFusable);
+}
+
+void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
+  SubgraphRewriter rewriter;
+
+  std::string linear_prepack_run_relu_fused = R"(
+    graph(%input, %weight, %bias, %dummy_min_max):
+        %output_min: float = prim::Constant[value=0.0]()
+        %output_max: None = prim::Constant()
+        %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack(
+            %weight, %bias, %output_min, %output_max)
+        %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+        return (%res))";
+
+  std::string conv2d_prepack_run_relu_fused = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %dummy_min_max):
+        %output_min: float = prim::Constant[value=0.0]()
+        %output_max: None = prim::Constant()
+        %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %output_min, %output_max)
+        %r = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+        return (%r) )";
+
+  std::string linear_prepack_run_relu = R"(
+    graph(%input, %weight, %bias, %dummy_min_max):
+        %packed_weight_bias = prepacked::linear_clamp_prepack(
+            %weight, %bias, %dummy_min_max, %dummy_min_max)
+        %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+        %res = aten::relu(%linear_res)
+        return (%res))";
+
+  rewriter.RegisterRewritePattern(
+      linear_prepack_run_relu, linear_prepack_run_relu_fused);
+
+  std::string conv2d_prepack_run_relu = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %dummy_min_max):
+        %packed_weight_bias = prepacked::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %dummy_min_max, %dummy_min_max)
+        %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+        %r = aten::relu(%conv2d_res)
+        return (%r) )";
+
+  rewriter.RegisterRewritePattern(
+      conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
+
+  std::string linear_prepack_run_relu_inplace = R"(
+    graph(%input, %weight, %bias, %dummy_min_max):
+        %packed_weight_bias = prepacked::linear_clamp_prepack(
+            %weight, %bias, %dummy_min_max, %dummy_min_max)
+        %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
+        %res = aten::relu_(%linear_res)
+        return (%res))";
+
+  std::string conv2d_prepack_run_relu_inplace = R"(
+    graph(%input, %weight, %bias, %stride:int[], %padding:int[],
+          %dilation:int[], %groups:int, %dummy_min_max):
+        %packed_weight_bias = prepacked::conv2d_clamp_prepack(
+            %weight, %bias, %stride, %padding, %dilation, %groups,
+            %dummy_min_max, %dummy_min_max)
+        %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
+        %r = aten::relu_(%conv2d_res)
+        return (%r) )";
+
+  rewriter.RegisterRewritePattern(
+      linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused);
+  rewriter.RegisterRewritePattern(
+      conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
+  rewriter.runOnGraph(graph, isClampFusable);
+}
+
 } // namespace
 
 void insertPrePackedOps(std::shared_ptr<Graph>& graph) {
@@ -103,6 +280,12 @@
   }
 }
 
+void fusePrePackedLinearConvWithClamp(script::Module& module) {
+  auto graph = module.get_method("forward").graph();
+  fuseReluWithPackedOps(graph);
+  fuseHardtanhWithPackedOps(graph);
+}
+
 void FoldPrePackingOps(script::Module& m) {
   PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
     return (
@@ -133,6 +316,11 @@
       "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
 }
 
+void fusePrePackedLinearConvWithClamp(script::Module& module) {
+  TORCH_INTERNAL_ASSERT(
+      "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
+}
+
 void FoldPrePackingOps(script::Module& m) {
   TORCH_INTERNAL_ASSERT(
       "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.h b/torch/csrc/jit/passes/xnnpack_rewrite.h
index b969bfd..e69fff1 100644
--- a/torch/csrc/jit/passes/xnnpack_rewrite.h
+++ b/torch/csrc/jit/passes/xnnpack_rewrite.h
@@ -7,6 +7,7 @@
 namespace jit {
 TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);
 TORCH_API void insertPrePackedOps(script::Module& module);
+TORCH_API void fusePrePackedLinearConvWithClamp(script::Module& module);
 TORCH_API void FoldPrePackingOps(script::Module& module);
 TORCH_API void optimizeForMobile(script::Module& module);
 } // namespace jit
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 364337e..8693a19 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -502,6 +502,11 @@
           "_jit_pass_insert_prepacked_ops",
           [](script::Module& module) { return insertPrePackedOps(module); })
       .def(
+          "_jit_pass_fuse_clamp_w_prepacked_linear_conv",
+          [](script::Module& module) {
+            return fusePrePackedLinearConvWithClamp(module);
+          })
+      .def(
           "_jit_pass_fold_prepacking_ops",
           [](script::Module& module) { return FoldPrePackingOps(module); })
       .def(