Support quantizing any methods called (#25505)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25505
Support for quantizing all the methods called by forward method, including
child module methods and other methods in the current module
It relies on module level constant prop, we need to figure out a way to do constant prop
for these methods as well. We can either do constant prop in the module level or do constant
prop in the quantization function, but this will need some discussion.
Test Plan:
python test/test_jit.py 'TestJit.insert_quant_dequant'
python test/test_quantizer.py
Imported from OSS
Differential Revision: D17208887
fbshipit-source-id: 21749457b21b00a6edada290c26324e2fb210b10
diff --git a/test/test_jit.py b/test/test_jit.py
index 8beae27..7351ab9 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1115,7 +1115,7 @@
FileCheck().check("aten::quantize_linear") \
.check_next("aten::int_repr") \
.check_next("aten::_dequantize_linear") \
- .check("prim::CallMethod[name=\"forward\"]") \
+ .check("aten::conv2d") \
.check("aten::quantize_linear") \
.check_next("aten::int_repr") \
.check_next("aten::_dequantize_linear") \
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index 1b0b1f4..4921e63 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -366,7 +366,7 @@
public:
QuantizeHelper(const script::Module& m) : module_(m) {}
IValue getQParams(Value* v);
- c10::optional<script::Module> findChildModuleToQuantize(Value* v);
+ c10::optional<script::Module> findChildModuleToQuantize(Value* child_instance);
void quantizeBias(Value* v);
void quantizeTensor(Value* v, bool insert_after = true);
void removeObserver(Value* v, const std::string& observer_name);
@@ -479,21 +479,18 @@
}
c10::optional<script::Module> QuantizeHelper::findChildModuleToQuantize(
- Value* v) {
- if (v->node()->kind() == prim::CallMethod) {
- auto child_instance = v->node()->inputs()[0];
+ Value* child_instance) {
+ TORCH_INTERNAL_ASSERT(
+ child_instance->node()->kind() == prim::GetAttr,
+ "Child instance should come from GetAttr.");
+ auto child_module_name = child_instance->node()->s(attr::name);
+ if (child_module_name.find("observer_for_") == std::string::npos) {
+ auto child_module = module_.find_module(child_module_name);
TORCH_INTERNAL_ASSERT(
- child_instance->node()->kind() == prim::GetAttr,
- "Child instance should come from GetAttr.");
- auto child_module_name = child_instance->node()->s(attr::name);
- if (child_module_name.find("observer_for_") == std::string::npos) {
- auto child_module = module_.find_module(child_module_name);
- TORCH_INTERNAL_ASSERT(
- child_module,
- "InsertQuantDeQuant - Child module " + child_module_name +
- " does not exist");
- return child_module;
- }
+ child_module,
+ "InsertQuantDeQuant - Child module " + child_module_name +
+ " does not exist");
+ return child_module;
}
return c10::nullopt;
}
@@ -515,51 +512,50 @@
}
}
- std::vector<Value*> values_to_quantize;
- std::unordered_map<script::ModulePtr, script::Module>
- child_modules_to_quantize;
QuantizeHelper qh(module);
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
- for (Node* n : b->nodes()) {
+ for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) {
+ Node* n = *it++;
for (Value* v : n->outputs()) {
- if (v->type()->isSubtypeOf(TensorType::get())) {
- auto child_module = qh.findChildModuleToQuantize(v);
- if (child_module) {
- child_modules_to_quantize[child_module.value().module_object()] =
- child_module.value();
+ if (!v->type()->isSubtypeOf(TensorType::get())) {
+ continue;
+ }
+ if (v->node()->kind() == prim::CallMethod) {
+ auto module_instance = v->node()->inputs()[0];
+ auto module_method_name = v->node()->s(attr::name);
+ c10::optional<script::Module> m;
+ // calling method on self
+ if (module_instance == graph->inputs()[0]) {
+ m = module;
+ } else {
+ m = qh.findChildModuleToQuantize(module_instance);
}
- values_to_quantize.push_back(v);
+ if (m) {
+ InsertQuantDeQuantImpl(m.value(), module_method_name);
+ }
+ }
+ if (v->node()->kind() == prim::GetAttr &&
+ v->node()->s(c10::attr::name) == "bias") {
+ qh.quantizeBias(v);
+ } else {
+ qh.quantizeTensor(v);
}
}
- // Schedule subblocks (if any) for visiting.
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
- for (Value* v : values_to_quantize) {
- if (v->node()->kind() == prim::GetAttr &&
- v->node()->s(c10::attr::name) == "bias") {
- qh.quantizeBias(v);
- } else {
- qh.quantizeTensor(v);
- }
- }
-
for (Value* v : input_values) {
qh.quantizeTensor(v, false);
}
- for (auto& item : child_modules_to_quantize) {
- InsertQuantDeQuantImpl(item.second, "forward");
- }
-
qh.destroyNodes();
}