insert_quant_dequant work with qconfig_dict (#25127)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25127
Extend insert_quant_dequant pass to go through forward function graphs
Test Plan:
```
python test/test_jit.py 'TestJit.test_insert_quant_dequant'
python test/test_quantizer.py
```
Imported from OSS
Differential Revision: D17001137
fbshipit-source-id: 41b029906fe5c8bc0de01956059388a7d552a380
diff --git a/test/test_jit.py b/test/test_jit.py
index 7ad039c..429a323 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1001,7 +1001,6 @@
check_observed(get_forward(m._c._get_module('sub')._get_module('linear')).graph)
@_tmp_donotuse_dont_inline_everything
- @unittest.skip("temporary turn off the test")
def test_insert_quant_dequant(self):
class Observer(torch.nn.Module):
def __init__(self):
@@ -1045,28 +1044,12 @@
FileCheck().check("aten::quantize_linear") \
.check_next("aten::int_repr") \
.check_next("aten::_dequantize_linear") \
- .check("aten::quantize_linear") \
- .check_next("aten::int_repr") \
- .check_next("aten::_dequantize_linear") \
- .check("aten::quantize_linear") \
- .check_next("aten::int_repr") \
- .check_next("aten::_dequantize_linear") \
- .check("aten::conv2d") \
+ .check("prim::CallMethod[name=\"forward\"]") \
.check("aten::quantize_linear") \
.check_next("aten::int_repr") \
.check_next("aten::_dequantize_linear") \
.check("return") \
- .run(str(m._c._get_method('forward').graph))
- # Test for inline
- # FileCheck().check("aten::quantize_linear") \
- # .check_next("aten::int_repr") \
- # .check_next("aten::_dequantize_linear") \
- # .check("prim::CallMethod[name=\"forward\"]") \
- # .check("aten::quantize_linear") \
- # .check_next("aten::int_repr") \
- # .check_next("aten::_dequantize_linear") \
- # .check("return") \
- # .run(str(get_forward(m).graph))
+ .run(str(get_forward(m).graph))
def test_quant_fusion(self):
input_str = """
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index b18f48c..ebab51f 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -229,29 +229,6 @@
}
}
-} // namespace
-
-// PyBind APIs
-void PropagateQuantInfo(std::shared_ptr<Graph>& graph) {
- throw std::runtime_error("Pass not implemented yet!");
-}
-
-void QuantLinting(std::shared_ptr<Graph>& graph) {
- throw std::runtime_error("Pass not implemented yet!");
-}
-
-void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
- throw std::runtime_error("Pass not implemented yet!");
-}
-
-TORCH_API void InsertObservers(
- script::Module& module,
- const std::string& method_name,
- const QConfigDict& qconfig_dict) {
- auto module_qconfig_map = getQConfigMap(module, qconfig_dict);
- InsertObserversImpl(module, method_name, module_qconfig_map);
-}
-
Node* insertQuantDeQuantCall(Value* v, const IValue& qparams, at::ScalarType t, bool insert_after=true) {
Graph* g = v->node()->owningGraph();
Node* quant = createQuantNode(v, g);
@@ -319,6 +296,7 @@
public:
QuantizeHelper(const script::Module& m) : module_(m) {}
IValue getQParams(Value* v);
+ c10::optional<script::Module> findChildModuleToQuantize(Value* v);
void quantizeBias(Value* v);
void quantizeTensor(Value* v, bool insert_after=true);
void removeObserver(Value* v, const std::string& observer_name);
@@ -425,10 +403,23 @@
q->replaceInputWith(dequant->output(), v);
}
-script::Module InsertQuantDeQuant(
- script::Module& input_module,
+c10::optional<script::Module> QuantizeHelper::findChildModuleToQuantize(Value* v) {
+ if (v->node()->kind() == prim::CallMethod && v->node()->s(attr::name) == "forward") {
+ auto child_instance = v->node()->inputs()[0];
+ TORCH_INTERNAL_ASSERT(child_instance->node()->kind() == prim::GetAttr, "Child instance should come from GetAttr.");
+ auto child_module_name = child_instance->node()->s(attr::name);
+ if (child_module_name.find("observer_for_") == std::string::npos) {
+ auto child_module = module_.find_module(child_module_name);
+ TORCH_INTERNAL_ASSERT(child_module, "InsertQuantDeQuant - Child module " + child_module_name + " does not exist");
+ return child_module;
+ }
+ }
+ return c10::nullopt;
+}
+
+void InsertQuantDeQuantImpl(
+ script::Module& module,
const std::string& method_name) {
- script::Module module = input_module.clone();
script::Method method = module.get_method(method_name);
auto graph = method.graph();
std::vector<Value*> values_to_quantize;
@@ -476,6 +467,10 @@
for (Value* v : values_to_quantize) {
if (v->type()->isSubtypeOf(TensorType::get())) {
+ auto child_module = qh.findChildModuleToQuantize(v);
+ if (child_module) {
+ InsertQuantDeQuantImpl(child_module.value(), "forward");
+ }
if (v->node()->kind() == prim::GetAttr && v->node()->s(c10::attr::name) == "bias") {
qh.quantizeBias(v);
} else {
@@ -491,6 +486,23 @@
}
qh.destroyNodes();
+}
+
+} // namespace
+
+TORCH_API void InsertObservers(
+ script::Module& module,
+ const std::string& method_name,
+ const QConfigDict& qconfig_dict) {
+ auto module_qconfig_map = getQConfigMap(module, qconfig_dict);
+ InsertObserversImpl(module, method_name, module_qconfig_map);
+}
+
+script::Module InsertQuantDeQuant(
+ script::Module& input_module,
+ const std::string& method_name) {
+ script::Module module = input_module.clone();
+ InsertQuantDeQuantImpl(module, method_name);
// NOTE: Remove observer module does not work right now, we'll return
// the module with observer modules as a temporary workaround
@@ -498,6 +510,19 @@
return module;
}
+// PyBind APIs
+void PropagateQuantInfo(std::shared_ptr<Graph>& graph) {
+ throw std::runtime_error("Pass not implemented yet!");
+}
+
+void QuantLinting(std::shared_ptr<Graph>& graph) {
+ throw std::runtime_error("Pass not implemented yet!");
+}
+
+void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
+ throw std::runtime_error("Pass not implemented yet!");
+}
+
void QuantFusion(std::shared_ptr<Graph>& graph) {
SubgraphRewriter rewriter;
std::string pattern = R"(