[SR] Fuse quantized linear/relu (#75775)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75775

fbgemm kernels already implement the fused kernel, no reason not to use it
ghstack-source-id: 155450342

Test Plan: New unit tests

Reviewed By: navahgar

Differential Revision: D35633297

fbshipit-source-id: a744a33a65ce7dbb9ce8900dbe091b6d56dd4e48
(cherry picked from commit b1361b349862715aa17e6318c5e658cd6401a464)
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index 41c7dfd..ec001df 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -3274,3 +3274,27 @@
       at::randn({42, 42}), at::randn({42, 42}), true, false};
   testStaticRuntime(src, args1, args2);
 }
+
+TEST(StaticRuntime, QuantizedLinearDynamicFp16ReluFusion) {
+  const auto src = R"IR(
+    graph(%input: Tensor, %weights: Tensor):
+        %bias: None = prim::Constant()
+        %packed_params = quantized::linear_prepack_fp16(%weights, %bias)
+        %x = quantized::linear_dynamic_fp16(%input, %packed_params)
+        %y = aten::relu(%x)
+        %ret = aten::clone(%y, %bias)
+        return (%ret)
+  )IR";
+  at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
+  at::Tensor input = torch::randn({3, 2}, torch::kFloat);
+
+  at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
+  at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
+
+  testStaticRuntime(src, {input, weight}, {input_2, weight_2});
+
+  auto graph = getGraphFromIR(src);
+  QuantizedLinearReluFusion(graph);
+  EXPECT_FALSE(hasNodeWithKind(graph, "quantized::linear_dynamic_fp16"));
+  EXPECT_TRUE(hasNodeWithKind(graph, "quantized::linear_relu_dynamic_fp16"));
+}
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 57ecd07..cc32288 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -177,6 +177,7 @@
       graph, /* custom_ops */ {fromQualString("fb::scale_gradient")});
   AddIfThenElseOp(graph);
   UseSplitAndSqueeze(graph);
+  QuantizedLinearReluFusion(graph);
   GRAPH_DUMP("Final graph after optimizations: ", graph);
 }
 
diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp
index 8de97cd..c19afb3 100644
--- a/torch/csrc/jit/runtime/static/passes.cpp
+++ b/torch/csrc/jit/runtime/static/passes.cpp
@@ -1229,5 +1229,20 @@
   }
 }
 
+void QuantizedLinearReluFusion(std::shared_ptr<Graph>& graph) {
+  std::string pattern = R"IR(
+    graph(%input, %packed_params):
+        %x : Tensor = quantized::linear_dynamic_fp16(%input, %packed_params)
+        %y : Tensor = aten::relu(%x)
+        return (%y))IR";
+  std::string fused_pattern = R"IR(
+    graph(%input, %packed_params):
+        %x : Tensor = quantized::linear_relu_dynamic_fp16(%input, %packed_params)
+        return (%x))IR";
+  SubgraphRewriter fuse;
+  fuse.RegisterRewritePattern(pattern, fused_pattern);
+  fuse.runOnGraph(graph);
+}
+
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h
index 6672d00..7b79478 100644
--- a/torch/csrc/jit/runtime/static/passes.h
+++ b/torch/csrc/jit/runtime/static/passes.h
@@ -76,5 +76,7 @@
 TORCH_API void RemoveUnnecessaryEmbeddingBagOutputs(
     std::shared_ptr<Graph>& graph);
 
+TORCH_API void QuantizedLinearReluFusion(std::shared_ptr<Graph>& graph);
+
 } // namespace jit
 } // namespace torch