insert_quant_dequant work with qconfig_dict (#25127)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25127

Extend insert_quant_dequant pass to go through forward function graphs

Test Plan:
```
python test/test_jit.py 'TestJit.test_insert_quant_dequant'
python test/test_quantizer.py
```

Imported from OSS

Differential Revision: D17001137

fbshipit-source-id: 41b029906fe5c8bc0de01956059388a7d552a380
diff --git a/test/test_jit.py b/test/test_jit.py
index 7ad039c..429a323 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1001,7 +1001,6 @@
         check_observed(get_forward(m._c._get_module('sub')._get_module('linear')).graph)
 
     @_tmp_donotuse_dont_inline_everything
-    @unittest.skip("temporary turn off the test")
     def test_insert_quant_dequant(self):
         class Observer(torch.nn.Module):
             def __init__(self):
@@ -1045,28 +1044,12 @@
         FileCheck().check("aten::quantize_linear") \
                    .check_next("aten::int_repr") \
                    .check_next("aten::_dequantize_linear") \
-                   .check("aten::quantize_linear") \
-                   .check_next("aten::int_repr") \
-                   .check_next("aten::_dequantize_linear") \
-                   .check("aten::quantize_linear") \
-                   .check_next("aten::int_repr") \
-                   .check_next("aten::_dequantize_linear") \
-                   .check("aten::conv2d") \
+                   .check("prim::CallMethod[name=\"forward\"]") \
                    .check("aten::quantize_linear") \
                    .check_next("aten::int_repr") \
                    .check_next("aten::_dequantize_linear") \
                    .check("return") \
-                   .run(str(m._c._get_method('forward').graph))
-        # Test for inline
-        # FileCheck().check("aten::quantize_linear") \
-        #            .check_next("aten::int_repr") \
-        #            .check_next("aten::_dequantize_linear") \
-        #            .check("prim::CallMethod[name=\"forward\"]") \
-        #            .check("aten::quantize_linear") \
-        #            .check_next("aten::int_repr") \
-        #            .check_next("aten::_dequantize_linear") \
-        #            .check("return") \
-        #            .run(str(get_forward(m).graph))
+                   .run(str(get_forward(m).graph))
 
     def test_quant_fusion(self):
         input_str = """
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index b18f48c..ebab51f 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -229,29 +229,6 @@
   }
 }
 
-} // namespace
-
-// PyBind APIs
-void PropagateQuantInfo(std::shared_ptr<Graph>& graph) {
-  throw std::runtime_error("Pass not implemented yet!");
-}
-
-void QuantLinting(std::shared_ptr<Graph>& graph) {
-  throw std::runtime_error("Pass not implemented yet!");
-}
-
-void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
-  throw std::runtime_error("Pass not implemented yet!");
-}
-
-TORCH_API void InsertObservers(
-    script::Module& module,
-    const std::string& method_name,
-    const QConfigDict& qconfig_dict) {
-  auto module_qconfig_map = getQConfigMap(module, qconfig_dict);
-  InsertObserversImpl(module, method_name, module_qconfig_map);
-}
-
 Node* insertQuantDeQuantCall(Value* v, const IValue& qparams, at::ScalarType t, bool insert_after=true) {
   Graph* g = v->node()->owningGraph();
   Node* quant = createQuantNode(v, g);
@@ -319,6 +296,7 @@
  public:
   QuantizeHelper(const script::Module& m) : module_(m) {}
   IValue getQParams(Value* v);
+  c10::optional<script::Module> findChildModuleToQuantize(Value* v);
   void quantizeBias(Value* v);
   void quantizeTensor(Value* v, bool insert_after=true);
   void removeObserver(Value* v, const std::string& observer_name);
@@ -425,10 +403,23 @@
   q->replaceInputWith(dequant->output(), v);
 }
 
-script::Module InsertQuantDeQuant(
-    script::Module& input_module,
+c10::optional<script::Module> QuantizeHelper::findChildModuleToQuantize(Value* v) {
+  if (v->node()->kind() == prim::CallMethod && v->node()->s(attr::name) == "forward") {
+    auto child_instance = v->node()->inputs()[0];
+    TORCH_INTERNAL_ASSERT(child_instance->node()->kind() == prim::GetAttr, "Child instance should come from GetAttr.");
+    auto child_module_name = child_instance->node()->s(attr::name);
+    if (child_module_name.find("observer_for_") == std::string::npos) {
+      auto child_module = module_.find_module(child_module_name);
+      TORCH_INTERNAL_ASSERT(child_module, "InsertQuantDeQuant - Child module " + child_module_name + " does not exist");
+      return child_module;
+    }
+  }
+  return c10::nullopt;
+}
+
+void InsertQuantDeQuantImpl(
+    script::Module& module,
     const std::string& method_name) {
-  script::Module module = input_module.clone();
   script::Method method = module.get_method(method_name);
   auto graph = method.graph();
   std::vector<Value*> values_to_quantize;
@@ -476,6 +467,10 @@
 
   for (Value* v : values_to_quantize) {
     if (v->type()->isSubtypeOf(TensorType::get())) {
+      auto child_module = qh.findChildModuleToQuantize(v);
+      if (child_module) {
+        InsertQuantDeQuantImpl(child_module.value(), "forward");
+      }
       if (v->node()->kind() == prim::GetAttr && v->node()->s(c10::attr::name) == "bias") {
         qh.quantizeBias(v);
       } else {
@@ -491,6 +486,23 @@
   }
 
   qh.destroyNodes();
+}
+
+} // namespace
+
+TORCH_API void InsertObservers(
+    script::Module& module,
+    const std::string& method_name,
+    const QConfigDict& qconfig_dict) {
+  auto module_qconfig_map = getQConfigMap(module, qconfig_dict);
+  InsertObserversImpl(module, method_name, module_qconfig_map);
+}
+
+script::Module InsertQuantDeQuant(
+    script::Module& input_module,
+    const std::string& method_name) {
+  script::Module module = input_module.clone();
+  InsertQuantDeQuantImpl(module, method_name);
 
   // NOTE: Remove observer module does not work right now, we'll return
   // the module with observer modules as a temporary workaround
@@ -498,6 +510,19 @@
   return module;
 }
 
+// PyBind APIs
+void PropagateQuantInfo(std::shared_ptr<Graph>& graph) {
+  throw std::runtime_error("Pass not implemented yet!");
+}
+
+void QuantLinting(std::shared_ptr<Graph>& graph) {
+  throw std::runtime_error("Pass not implemented yet!");
+}
+
+void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
+  throw std::runtime_error("Pass not implemented yet!");
+}
+
 void QuantFusion(std::shared_ptr<Graph>& graph) {
   SubgraphRewriter rewriter;
   std::string pattern = R"(