[SR] Fuse quantized linear/relu (#75775)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75775
fbgemm kernels already implement the fused kernel, no reason not to use it
ghstack-source-id: 155450342
Test Plan: New unit tests
Reviewed By: navahgar
Differential Revision: D35633297
fbshipit-source-id: a744a33a65ce7dbb9ce8900dbe091b6d56dd4e48
(cherry picked from commit b1361b349862715aa17e6318c5e658cd6401a464)
diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc
index 41c7dfd..ec001df 100644
--- a/benchmarks/static_runtime/test_static_runtime.cc
+++ b/benchmarks/static_runtime/test_static_runtime.cc
@@ -3274,3 +3274,27 @@
at::randn({42, 42}), at::randn({42, 42}), true, false};
testStaticRuntime(src, args1, args2);
}
+
+TEST(StaticRuntime, QuantizedLinearDynamicFp16ReluFusion) {
+ const auto src = R"IR(
+ graph(%input: Tensor, %weights: Tensor):
+ %bias: None = prim::Constant()
+ %packed_params = quantized::linear_prepack_fp16(%weights, %bias)
+ %x = quantized::linear_dynamic_fp16(%input, %packed_params)
+ %y = aten::relu(%x)
+ %ret = aten::clone(%y, %bias)
+ return (%ret)
+ )IR";
+ at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
+ at::Tensor input = torch::randn({3, 2}, torch::kFloat);
+
+ at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
+ at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
+
+ testStaticRuntime(src, {input, weight}, {input_2, weight_2});
+
+ auto graph = getGraphFromIR(src);
+ QuantizedLinearReluFusion(graph);
+ EXPECT_FALSE(hasNodeWithKind(graph, "quantized::linear_dynamic_fp16"));
+ EXPECT_TRUE(hasNodeWithKind(graph, "quantized::linear_relu_dynamic_fp16"));
+}
diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp
index 57ecd07..cc32288 100644
--- a/torch/csrc/jit/runtime/static/impl.cpp
+++ b/torch/csrc/jit/runtime/static/impl.cpp
@@ -177,6 +177,7 @@
graph, /* custom_ops */ {fromQualString("fb::scale_gradient")});
AddIfThenElseOp(graph);
UseSplitAndSqueeze(graph);
+ QuantizedLinearReluFusion(graph);
GRAPH_DUMP("Final graph after optimizations: ", graph);
}
diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp
index 8de97cd..c19afb3 100644
--- a/torch/csrc/jit/runtime/static/passes.cpp
+++ b/torch/csrc/jit/runtime/static/passes.cpp
@@ -1229,5 +1229,20 @@
}
}
+void QuantizedLinearReluFusion(std::shared_ptr<Graph>& graph) {
+ std::string pattern = R"IR(
+ graph(%input, %packed_params):
+ %x : Tensor = quantized::linear_dynamic_fp16(%input, %packed_params)
+ %y : Tensor = aten::relu(%x)
+ return (%y))IR";
+ std::string fused_pattern = R"IR(
+ graph(%input, %packed_params):
+ %x : Tensor = quantized::linear_relu_dynamic_fp16(%input, %packed_params)
+ return (%x))IR";
+ SubgraphRewriter fuse;
+ fuse.RegisterRewritePattern(pattern, fused_pattern);
+ fuse.runOnGraph(graph);
+}
+
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/runtime/static/passes.h b/torch/csrc/jit/runtime/static/passes.h
index 6672d00..7b79478 100644
--- a/torch/csrc/jit/runtime/static/passes.h
+++ b/torch/csrc/jit/runtime/static/passes.h
@@ -76,5 +76,7 @@
TORCH_API void RemoveUnnecessaryEmbeddingBagOutputs(
std::shared_ptr<Graph>& graph);
+TORCH_API void QuantizedLinearReluFusion(std::shared_ptr<Graph>& graph);
+
} // namespace jit
} // namespace torch