refactor the observer removal and quantize tensor

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30360

Differential Revision: D18670373

Pulled By: lly-zero-one

fbshipit-source-id: 1481d6e4d5ce40376577b8deb0a0f74d5559076e
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index 55748da..d3289fe 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -477,39 +477,25 @@
   std::tuple<IValue, IValue> getQParams(Value* v);
   c10::optional<script::Module> findChildModuleToQuantize(
       Value* child_instance);
-  void quantizeTensor(Value* v);
-  // Remove the observer for value `v`. This function returns
-  // the original value (i.e. before observation), and thus all
-  // uses of the passed-in `v` should be replaced by the caller with
-  // the return value
-  Value* removeObserver(Value* v, const std::string& observer_name);
-  void removeModulesAndNodes() {
-    // Remove observer modules from last one to first one in order to
-    // reduce the time complexity, assuming all the observer modules
-    // are added after the existing modules, we'll have complexity of
-    // O(N) where N is number of observer moduels with this optimization
-    for (int64_t i = observer_modules_to_remove_.size() - 1; i >= 0; --i) {
-      auto observer_name = observer_modules_to_remove_[i];
-      module_._ivalue()->unsafeRemoveAttr(observer_name);
-      module_.type()->unsafeRemoveAttribute(observer_name);
-    }
-    // Destroy observer forward calls
-    for (auto& n : nodes_to_destroy_) {
-      n->destroy();
-    }
-  }
+  void collectObserverNodesAndValueToQuantize(Value*);
+  void removeObservers();
+  void quantizeTensors();
 
  private:
   script::Module& module_;
   std::vector<std::string> observer_modules_to_remove_;
   std::vector<Node*> nodes_to_destroy_;
+  std::vector<Value*> values_to_quantize_;
+  std::unordered_map<Value*, std::tuple<IValue, IValue> > values_to_qparams_;
 };
 
-Value* QuantizeHelper::removeObserver(
-    Value* v,
-    const std::string& observer_name) {
-  // remove observer_module
-  observer_modules_to_remove_.push_back(observer_name);
+
+void QuantizeHelper::collectObserverNodesAndValueToQuantize(Value* v) {
+  auto observer_name = findObserverName(v);
+  if (!observer_name) {
+    return;
+  }
+  observer_modules_to_remove_.push_back(observer_name.value());
 
   Node* observer = v->node();
   TORCH_INTERNAL_ASSERT(
@@ -517,12 +503,52 @@
       observer->s(attr::name) == "forward" &&
       observer->inputs()[0]->node()->kind() == prim::GetAttr &&
       observer->inputs()[0]->node()->s(attr::name) == observer_name);
+
   // Observer forward call node
   nodes_to_destroy_.push_back(observer);
   // GetAttr node for observer module
   nodes_to_destroy_.push_back(observer->inputs()[0]->node());
-  v->replaceAllUsesWith(observer->input(1));
-  return observer->input(1);
+  Value* new_value = observer->input(1);
+  v->replaceAllUsesWith(new_value);
+  values_to_quantize_.push_back(new_value);
+  values_to_qparams_.insert({new_value, getQParams(v)});
+}
+
+void QuantizeHelper::removeObservers() {
+  for (auto& n : nodes_to_destroy_) {
+    n->removeAllInputs();
+  }
+  for (auto& n : nodes_to_destroy_) {
+    n->destroy();
+  }
+  // Remove observer modules from last one to first one in order to
+  // reduce the time complexity, assuming all the observer modules
+  // are added after the existing modules, we'll have complexity of
+  // O(N) where N is number of observer moduels with this optimization
+  for (int64_t i = observer_modules_to_remove_.size() - 1; i >= 0; --i) {
+    auto observer_name = observer_modules_to_remove_[i];
+    module_._ivalue()->unsafeRemoveAttr(observer_name);
+    module_.type()->unsafeRemoveAttribute(observer_name);
+  }
+}
+
+void QuantizeHelper::quantizeTensors() {
+  for (auto& v : values_to_quantize_) {
+    TORCH_INTERNAL_ASSERT(values_to_qparams_.count(v));
+    auto tp = values_to_qparams_[v];
+    auto qparams = std::get<0>(tp);
+    auto scalar_type = std::get<1>(tp);
+    // NB: v is updated here, since removeObserver replaces
+    // v with the input to the observer call
+    Node* dequant;
+    dequant = insertQuantDeQuantCall(v, qparams, scalar_type);
+    v->replaceAllUsesWith(dequant->output());
+    Node* q = dequant->input(0)->node();
+    // replaceAllUsesWith rewrote all uses of V, but we want to keep one: the one
+    // used in quant node. Restore it here:
+    q->replaceInputWith(dequant->output(), v);
+  }
+  // no need to clear the vector or map
 }
 
 void checkCalculateQParamsResult(const IValue& qparams) {
@@ -575,26 +601,6 @@
   return std::make_tuple(qparams, scalar_type);
 }
 
-void QuantizeHelper::quantizeTensor(Value* v) {
-  auto observer_name = findObserverName(v);
-  if (!observer_name) {
-    return;
-  }
-  auto tp = getQParams(v);
-  auto qparams = std::get<0>(tp);
-  auto scalar_type = std::get<1>(tp);
-  // NB: v is updated here, since removeObserver replaces
-  // v with the input to the observer call
-  v = removeObserver(v, observer_name.value());
-  Node* dequant;
-  dequant = insertQuantDeQuantCall(v, qparams, scalar_type);
-  v->replaceAllUsesWith(dequant->output());
-  Node* q = dequant->input(0)->node();
-  // replaceAllUsesWith rewrote all uses of V, but we want to keep one: the one
-  // used in quant node. Restore it here:
-  q->replaceInputWith(dequant->output(), v);
-}
-
 c10::optional<script::Module> QuantizeHelper::findChildModuleToQuantize(
     Value* child_instance) {
   TORCH_INTERNAL_ASSERT(
@@ -650,7 +656,7 @@
             InsertQuantDeQuantImpl(m.value(), module_method_name);
           }
         }
-        qh.quantizeTensor(v);
+        qh.collectObserverNodesAndValueToQuantize(v);
       }
 
       for (Block* subblock : n->blocks()) {
@@ -660,10 +666,10 @@
   }
 
   for (Value* v : input_values) {
-    qh.quantizeTensor(v);
+    qh.collectObserverNodesAndValueToQuantize(v);
   }
-
-  qh.removeModulesAndNodes();
+  qh.removeObservers();
+  qh.quantizeTensors();
 }
 
 void insertPrepackUnpackForLinear(std::shared_ptr<Graph>& graph) {