refactor the observer removal and quantize tensor
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30360
Differential Revision: D18670373
Pulled By: lly-zero-one
fbshipit-source-id: 1481d6e4d5ce40376577b8deb0a0f74d5559076e
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index 55748da..d3289fe 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -477,39 +477,25 @@
std::tuple<IValue, IValue> getQParams(Value* v);
c10::optional<script::Module> findChildModuleToQuantize(
Value* child_instance);
- void quantizeTensor(Value* v);
- // Remove the observer for value `v`. This function returns
- // the original value (i.e. before observation), and thus all
- // uses of the passed-in `v` should be replaced by the caller with
- // the return value
- Value* removeObserver(Value* v, const std::string& observer_name);
- void removeModulesAndNodes() {
- // Remove observer modules from last one to first one in order to
- // reduce the time complexity, assuming all the observer modules
- // are added after the existing modules, we'll have complexity of
- // O(N) where N is number of observer moduels with this optimization
- for (int64_t i = observer_modules_to_remove_.size() - 1; i >= 0; --i) {
- auto observer_name = observer_modules_to_remove_[i];
- module_._ivalue()->unsafeRemoveAttr(observer_name);
- module_.type()->unsafeRemoveAttribute(observer_name);
- }
- // Destroy observer forward calls
- for (auto& n : nodes_to_destroy_) {
- n->destroy();
- }
- }
+ void collectObserverNodesAndValueToQuantize(Value*);
+ void removeObservers();
+ void quantizeTensors();
private:
script::Module& module_;
std::vector<std::string> observer_modules_to_remove_;
std::vector<Node*> nodes_to_destroy_;
+ std::vector<Value*> values_to_quantize_;
+ std::unordered_map<Value*, std::tuple<IValue, IValue> > values_to_qparams_;
};
-Value* QuantizeHelper::removeObserver(
- Value* v,
- const std::string& observer_name) {
- // remove observer_module
- observer_modules_to_remove_.push_back(observer_name);
+
+void QuantizeHelper::collectObserverNodesAndValueToQuantize(Value* v) {
+ auto observer_name = findObserverName(v);
+ if (!observer_name) {
+ return;
+ }
+ observer_modules_to_remove_.push_back(observer_name.value());
Node* observer = v->node();
TORCH_INTERNAL_ASSERT(
@@ -517,12 +503,52 @@
observer->s(attr::name) == "forward" &&
observer->inputs()[0]->node()->kind() == prim::GetAttr &&
observer->inputs()[0]->node()->s(attr::name) == observer_name);
+
// Observer forward call node
nodes_to_destroy_.push_back(observer);
// GetAttr node for observer module
nodes_to_destroy_.push_back(observer->inputs()[0]->node());
- v->replaceAllUsesWith(observer->input(1));
- return observer->input(1);
+ Value* new_value = observer->input(1);
+ v->replaceAllUsesWith(new_value);
+ values_to_quantize_.push_back(new_value);
+ values_to_qparams_.insert({new_value, getQParams(v)});
+}
+
+void QuantizeHelper::removeObservers() {
+ for (auto& n : nodes_to_destroy_) {
+ n->removeAllInputs();
+ }
+ for (auto& n : nodes_to_destroy_) {
+ n->destroy();
+ }
+ // Remove observer modules from last one to first one in order to
+ // reduce the time complexity, assuming all the observer modules
+ // are added after the existing modules, we'll have complexity of
+ // O(N) where N is number of observer moduels with this optimization
+ for (int64_t i = observer_modules_to_remove_.size() - 1; i >= 0; --i) {
+ auto observer_name = observer_modules_to_remove_[i];
+ module_._ivalue()->unsafeRemoveAttr(observer_name);
+ module_.type()->unsafeRemoveAttribute(observer_name);
+ }
+}
+
+void QuantizeHelper::quantizeTensors() {
+ for (auto& v : values_to_quantize_) {
+ TORCH_INTERNAL_ASSERT(values_to_qparams_.count(v));
+ auto tp = values_to_qparams_[v];
+ auto qparams = std::get<0>(tp);
+ auto scalar_type = std::get<1>(tp);
+ // NB: v is updated here, since removeObserver replaces
+ // v with the input to the observer call
+ Node* dequant;
+ dequant = insertQuantDeQuantCall(v, qparams, scalar_type);
+ v->replaceAllUsesWith(dequant->output());
+ Node* q = dequant->input(0)->node();
+ // replaceAllUsesWith rewrote all uses of V, but we want to keep one: the one
+ // used in quant node. Restore it here:
+ q->replaceInputWith(dequant->output(), v);
+ }
+ // no need to clear the vector or map
}
void checkCalculateQParamsResult(const IValue& qparams) {
@@ -575,26 +601,6 @@
return std::make_tuple(qparams, scalar_type);
}
-void QuantizeHelper::quantizeTensor(Value* v) {
- auto observer_name = findObserverName(v);
- if (!observer_name) {
- return;
- }
- auto tp = getQParams(v);
- auto qparams = std::get<0>(tp);
- auto scalar_type = std::get<1>(tp);
- // NB: v is updated here, since removeObserver replaces
- // v with the input to the observer call
- v = removeObserver(v, observer_name.value());
- Node* dequant;
- dequant = insertQuantDeQuantCall(v, qparams, scalar_type);
- v->replaceAllUsesWith(dequant->output());
- Node* q = dequant->input(0)->node();
- // replaceAllUsesWith rewrote all uses of V, but we want to keep one: the one
- // used in quant node. Restore it here:
- q->replaceInputWith(dequant->output(), v);
-}
-
c10::optional<script::Module> QuantizeHelper::findChildModuleToQuantize(
Value* child_instance) {
TORCH_INTERNAL_ASSERT(
@@ -650,7 +656,7 @@
InsertQuantDeQuantImpl(m.value(), module_method_name);
}
}
- qh.quantizeTensor(v);
+ qh.collectObserverNodesAndValueToQuantize(v);
}
for (Block* subblock : n->blocks()) {
@@ -660,10 +666,10 @@
}
for (Value* v : input_values) {
- qh.quantizeTensor(v);
+ qh.collectObserverNodesAndValueToQuantize(v);
}
-
- qh.removeModulesAndNodes();
+ qh.removeObservers();
+ qh.quantizeTensors();
}
void insertPrepackUnpackForLinear(std::shared_ptr<Graph>& graph) {