[quant][graphmode] Clean up and add more logging (#40196)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40196

- separate passes in insert observers to make it more robust
- added print for quantization type
- added more logging for insert observers

Test Plan: Imported from OSS

Differential Revision: D22106545

fbshipit-source-id: 6d8d722e33c1259b1a6a501853c801c275dbfcff
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index 68acda1..2ac2c51 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -178,6 +178,7 @@
     "torch/csrc/jit/passes/utils/subgraph_utils.cpp",
     "torch/csrc/jit/passes/xnnpack_rewrite.cpp",
     "torch/csrc/jit/passes/quantization/helper.cpp",
+    "torch/csrc/jit/passes/quantization/quantization_type.cpp",
     "torch/csrc/jit/passes/quantization/insert_observers.cpp",
     "torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp",
     "torch/csrc/jit/passes/quantization/dedup_module_uses.cpp",
diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp
index 6626c32..7b28583 100644
--- a/torch/csrc/jit/passes/quantization/insert_observers.cpp
+++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp
@@ -47,8 +47,10 @@
     const c10::optional<QConfig>& parent_qconfig = c10::nullopt) {
   c10::optional<QConfig> qconfig;
   if (qconfig_dict.find(key) != qconfig_dict.end()) {
+    GRAPH_DEBUG("Got module config for key:", key);
     qconfig = qconfig_dict.at(key);
   } else {
+    GRAPH_DEBUG("Inheriting qconfig from parent module:", key);
     qconfig = parent_qconfig;
   }
   map[module._ivalue()] = qconfig;
@@ -251,11 +253,24 @@
 
 class InsertObserversHelper {
  public:
-  explicit InsertObserversHelper(const ModuleQConfigMap& map)
-      : module_qconfig_map_(map) {}
+  explicit InsertObserversHelper(
+      const ModuleQConfigMap& map, QuantType quant_type)
+      : module_qconfig_map_(map), quant_type_(quant_type) {}
 
+  // TODO: replace (module, method_name) with graph?
+  // preprocess to clean up the graph from tracing
   void preprocess(Module& module, const std::string& method_name);
 
+  // Fill the map between the caller input/output to input/output
+  // of called graph, this is used to navigate through the graph
+  // to find the observer for a given value
+  void fillBoundaryValueMap(Module& module, const std::string& method_name);
+
+
+  // analyze the graph and record necessary information that can
+  // be used in insert observers
+  void analyze(Module& module, const std::string& method_name);
+
   /**
    * Recursively insert observers for the method, also we'll process
    * the nodes in the graph in the order of execution of these nodes
@@ -289,10 +304,6 @@
       std::unordered_set<Value*> graph_observed_values =
           std::unordered_set<Value*>());
 
-  void setQuantType(QuantType quant_type) {
-    quant_type_ = quant_type;
-  }
-
  private:
   std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
   insertObserversFor(
@@ -327,11 +338,6 @@
     return block_observed_values.count(v) || observed_values_.count(v);
   }
 
-  // Fill the map between the caller input/output to input/output
-  // of called graph, this is used to navigate through the graph
-  // to find the observer for a given value
-  void fillBoundaryValueMap(Module& module, const std::string& method_name);
-
   // Fill the map from value to the corresponding observer module
   // this map is used in insertObservers to actually insert
   // observers to the module
@@ -799,6 +805,7 @@
   if (observed_values_.count(v)) {
     return;
   }
+  GRAPH_DEBUG("Inserting observer for:", v->debugName());
   Module observer = observer_module.deepcopy();
   std::string observer_name = "_observer_" + c10::to_string(uid_++);
   while (module.hasattr(observer_name)) {
@@ -899,6 +906,12 @@
 void InsertObserversHelper::fillBoundaryValueMap(
     Module& module,
     const std::string& method_name) {
+  for (auto& invoked_method : getInvokedMethods(module, method_name)) {
+    auto& invoked_module = std::get<0>(invoked_method);
+    const auto& invoked_method_name = std::get<1>(invoked_method);
+    fillBoundaryValueMap(invoked_module, invoked_method_name);
+  }
+
   auto graph = module.get_method(method_name).graph();
   std::stack<Block*> blocks_to_visit;
   blocks_to_visit.push(graph->block());
@@ -924,19 +937,26 @@
         // add mapping from callsite value to value in called graph
         for (auto i = 0U; i < g->outputs().size(); ++i) {
           auto* return_val = g->outputs()[i];
+          GRAPH_DEBUG("Boundary Map[return]:", n->output(i)->debugName(),
+                      " -> ", return_val->debugName());
           boundary_value_map_[n->output(i)].insert(return_val);
         }
         for (auto i = 0U; i < g->inputs().size(); ++i) {
           auto caller_input_index = i + input_offset;
           auto* caller_input = n->input(caller_input_index);
           auto* input_val = g->inputs()[i];
+          GRAPH_DEBUG("Boundary Map[input]:", caller_input->debugName(),
+                      " -> ", input_val->debugName());
           boundary_value_map_[caller_input].insert(input_val);
         }
       } else if (n->kind() == prim::If) {
         for (Block* subblock : n->blocks()) {
           blocks_to_visit.push(subblock);
           for (Value* v : n->outputs()) {
-            boundary_value_map_[v].insert(subblock->outputs()[v->offset()]);
+            Value* subblock_output = subblock->outputs()[v->offset()];
+            GRAPH_DEBUG("Boundary Map[if_output]:", v->debugName(),
+                        " -> ", subblock_output->debugName());
+            boundary_value_map_[v].insert(subblock_output);
           }
         }
       } else {
@@ -970,12 +990,23 @@
   replaceConvolutionWithAtenConv(graph);
   // fuse decomposed linear into aten::linear
   FuseLinear(graph);
+}
+
+void InsertObserversHelper::analyze(
+    Module& module,
+    const std::string& method_name) {
+  for (auto& invoked_method : getInvokedMethods(module, method_name)) {
+    auto& invoked_module = std::get<0>(invoked_method);
+    const auto& invoked_method_name = std::get<1>(invoked_method);
+    analyze(invoked_module, invoked_method_name);
+  }
 
   // fill out various internal state which will be later used in
-  // insertObservers to insert the correct observers
+  // insertObservers to insert the correct observer
   addValuesToDelayObservation(module, method_name);
   fillValueObserverMap(module, method_name);
-  fillBoundaryValueMap(module, method_name);
+  Method method = module.get_method(method_name);
+  auto graph = method.graph();
   fillPassThroughValueMap(graph);
 }
 
@@ -1021,6 +1052,8 @@
   auto qconfig = *qconfig_opt;
   for (auto* v : graph->inputs()) {
     if (valueNeedsToBeQuantized(v)) {
+      GRAPH_DEBUG("Recording observer for ", v->debugName());
+      GRAPH_DUMP("In graph:", v->owningGraph());
       observer_for_value_[v] = getObserverModuleFor(v, qconfig);
     }
   }
@@ -1032,6 +1065,8 @@
     for (Node* n : b->nodes()) {
       for (Value* v : n->outputs()) {
         if (valueNeedsToBeQuantized(v)) {
+          GRAPH_DEBUG("Recording observer for ", v->debugName());
+          GRAPH_DUMP("In graph:", v->owningGraph());
           observer_for_value_[v] = getObserverModuleFor(v, qconfig);
         }
       }
@@ -1046,11 +1081,16 @@
 c10::optional<Module> InsertObserversHelper::getObserverFor(Value* v) {
   if (observer_for_value_.count(v)) {
     auto observer = observer_for_value_.at(v);
+    GRAPH_DEBUG("Got observer module config for:", v->debugName());
     return observer;
   }
   c10::optional<Module> result;
   if (boundary_value_map_.count(v)) {
     for (Value* next : boundary_value_map_.at(v)) {
+      GRAPH_DEBUG("Going through boundary map:", v->debugName(), " --> ",
+                  next->debugName());
+      GRAPH_DUMP("From graph:", v->owningGraph());
+      GRAPH_DUMP("To graph:", next->owningGraph());
       auto observer_opt = getObserverFor(next);
       if (observer_opt) {
         // Need to make sure all values are
@@ -1065,6 +1105,8 @@
       }
     }
   }
+  GRAPH_DEBUG("Observer module config for ", v->debugName(), ":",
+              result.has_value());
   return result;
 }
 
@@ -1125,7 +1167,7 @@
     }
 
     for (auto* v : block->inputs()) {
-      block_input_observers.push_back(getObserverFor(v));
+      block_input_observers.emplace_back(getObserverFor(v));
     }
 
     for (auto* v : block->outputs()) {
@@ -1153,11 +1195,13 @@
     }
   }
   // NB: Why do we need to process the graph even if it's visited?
-  // Reason is `graph_observed_values` can
+  // Reason is `block_observed_values` can
   // change depending on where the method is called, and
   // outputs that's been observed(third item of the returned result)
   // can change depending on that, so for each graph we'll need to go through
-  // the whole process of inserting observers
+  // the whole process of inserting observers, the observers inserted in this
+  // block won't change, but the information we return to the caller will change
+  // based on `block_observed_values`
 
   std::stack<Block*> blocks_to_visit;
   blocks_to_visit.push(block);
@@ -1369,9 +1413,14 @@
   // Since the types are changed after clone, we need to fill
   // the qconfig map again
   fillQConfigMap(module, qconfig_dict, module_qconfig_map);
-  InsertObserversHelper helper(module_qconfig_map);
-  helper.setQuantType(quant_type);
+  GRAPH_DEBUG("Quant type:", quant_type);
+  InsertObserversHelper helper(module_qconfig_map, quant_type);
   helper.preprocess(module, method_name);
+  helper.fillBoundaryValueMap(module, method_name);
+  // analyze needs to run after fillBoundaryValueMap
+  // since we need to know the boundary value mapping to trace
+  // through the calls
+  helper.analyze(module, method_name);
   helper.insertObservers(module, method_name, /* is_entry_point */ true);
   return module;
 }
diff --git a/torch/csrc/jit/passes/quantization/quantization_type.cpp b/torch/csrc/jit/passes/quantization/quantization_type.cpp
new file mode 100644
index 0000000..8c8d327
--- /dev/null
+++ b/torch/csrc/jit/passes/quantization/quantization_type.cpp
@@ -0,0 +1,15 @@
+#include <torch/csrc/jit/passes/quantization/quantization_type.h>
+
+namespace torch {
+namespace jit {
+
+std::ostream& operator<<(std::ostream& os, QuantType t) {
+  switch(t) {
+    case QuantType::DYNAMIC: os << "dynamic"; break;
+    case QuantType::STATIC: os << "static"; break;
+    default: os.setstate(std::ios_base::failbit);
+  }
+  return os;
+}
+
+}}
diff --git a/torch/csrc/jit/passes/quantization/quantization_type.h b/torch/csrc/jit/passes/quantization/quantization_type.h
index b5699eb..cdcfe8f 100644
--- a/torch/csrc/jit/passes/quantization/quantization_type.h
+++ b/torch/csrc/jit/passes/quantization/quantization_type.h
@@ -1,11 +1,14 @@
 #pragma once
+#include <ostream>
 
 namespace torch {
 namespace jit {
 
 // Quantization type (dynamic quantization, static quantization).
 // Should match the Python enum in quantize_script.py
-enum class QuantType { DYNAMIC, STATIC };
+enum QuantType : uint8_t { DYNAMIC = 0, STATIC };
+
+std::ostream& operator<<(std::ostream& os, QuantType t);
 
 } // namespace jit
 } // namespace torch