[quant][graphmode] Clean up and add more logging (#40196)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40196
- separate passes in insert observers to make it more robust
- added print for quantization type
- added more logging for insert observers
Test Plan: Imported from OSS
Differential Revision: D22106545
fbshipit-source-id: 6d8d722e33c1259b1a6a501853c801c275dbfcff
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index 68acda1..2ac2c51 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -178,6 +178,7 @@
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
"torch/csrc/jit/passes/xnnpack_rewrite.cpp",
"torch/csrc/jit/passes/quantization/helper.cpp",
+ "torch/csrc/jit/passes/quantization/quantization_type.cpp",
"torch/csrc/jit/passes/quantization/insert_observers.cpp",
"torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp",
"torch/csrc/jit/passes/quantization/dedup_module_uses.cpp",
diff --git a/torch/csrc/jit/passes/quantization/insert_observers.cpp b/torch/csrc/jit/passes/quantization/insert_observers.cpp
index 6626c32..7b28583 100644
--- a/torch/csrc/jit/passes/quantization/insert_observers.cpp
+++ b/torch/csrc/jit/passes/quantization/insert_observers.cpp
@@ -47,8 +47,10 @@
const c10::optional<QConfig>& parent_qconfig = c10::nullopt) {
c10::optional<QConfig> qconfig;
if (qconfig_dict.find(key) != qconfig_dict.end()) {
+ GRAPH_DEBUG("Got module config for key:", key);
qconfig = qconfig_dict.at(key);
} else {
+ GRAPH_DEBUG("Inheriting qconfig from parent module:", key);
qconfig = parent_qconfig;
}
map[module._ivalue()] = qconfig;
@@ -251,11 +253,24 @@
class InsertObserversHelper {
public:
- explicit InsertObserversHelper(const ModuleQConfigMap& map)
- : module_qconfig_map_(map) {}
+ explicit InsertObserversHelper(
+ const ModuleQConfigMap& map, QuantType quant_type)
+ : module_qconfig_map_(map), quant_type_(quant_type) {}
+ // TODO: replace (module, method_name) with graph?
+ // preprocess to clean up the graph from tracing
void preprocess(Module& module, const std::string& method_name);
+ // Fill the map between the caller input/output to input/output
+ // of called graph, this is used to navigate through the graph
+ // to find the observer for a given value
+ void fillBoundaryValueMap(Module& module, const std::string& method_name);
+
+
+ // analyze the graph and record necessary information that can
+ // be used in insert observers
+ void analyze(Module& module, const std::string& method_name);
+
/**
* Recursively insert observers for the method, also we'll process
* the nodes in the graph in the order of execution of these nodes
@@ -289,10 +304,6 @@
std::unordered_set<Value*> graph_observed_values =
std::unordered_set<Value*>());
- void setQuantType(QuantType quant_type) {
- quant_type_ = quant_type;
- }
-
private:
std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
insertObserversFor(
@@ -327,11 +338,6 @@
return block_observed_values.count(v) || observed_values_.count(v);
}
- // Fill the map between the caller input/output to input/output
- // of called graph, this is used to navigate through the graph
- // to find the observer for a given value
- void fillBoundaryValueMap(Module& module, const std::string& method_name);
-
// Fill the map from value to the corresponding observer module
// this map is used in insertObservers to actually insert
// observers to the module
@@ -799,6 +805,7 @@
if (observed_values_.count(v)) {
return;
}
+ GRAPH_DEBUG("Inserting observer for:", v->debugName());
Module observer = observer_module.deepcopy();
std::string observer_name = "_observer_" + c10::to_string(uid_++);
while (module.hasattr(observer_name)) {
@@ -899,6 +906,12 @@
void InsertObserversHelper::fillBoundaryValueMap(
Module& module,
const std::string& method_name) {
+ for (auto& invoked_method : getInvokedMethods(module, method_name)) {
+ auto& invoked_module = std::get<0>(invoked_method);
+ const auto& invoked_method_name = std::get<1>(invoked_method);
+ fillBoundaryValueMap(invoked_module, invoked_method_name);
+ }
+
auto graph = module.get_method(method_name).graph();
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
@@ -924,19 +937,26 @@
// add mapping from callsite value to value in called graph
for (auto i = 0U; i < g->outputs().size(); ++i) {
auto* return_val = g->outputs()[i];
+ GRAPH_DEBUG("Boundary Map[return]:", n->output(i)->debugName(),
+ " -> ", return_val->debugName());
boundary_value_map_[n->output(i)].insert(return_val);
}
for (auto i = 0U; i < g->inputs().size(); ++i) {
auto caller_input_index = i + input_offset;
auto* caller_input = n->input(caller_input_index);
auto* input_val = g->inputs()[i];
+ GRAPH_DEBUG("Boundary Map[input]:", caller_input->debugName(),
+ " -> ", input_val->debugName());
boundary_value_map_[caller_input].insert(input_val);
}
} else if (n->kind() == prim::If) {
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
for (Value* v : n->outputs()) {
- boundary_value_map_[v].insert(subblock->outputs()[v->offset()]);
+ Value* subblock_output = subblock->outputs()[v->offset()];
+ GRAPH_DEBUG("Boundary Map[if_output]:", v->debugName(),
+ " -> ", subblock_output->debugName());
+ boundary_value_map_[v].insert(subblock_output);
}
}
} else {
@@ -970,12 +990,23 @@
replaceConvolutionWithAtenConv(graph);
// fuse decomposed linear into aten::linear
FuseLinear(graph);
+}
+
+void InsertObserversHelper::analyze(
+ Module& module,
+ const std::string& method_name) {
+ for (auto& invoked_method : getInvokedMethods(module, method_name)) {
+ auto& invoked_module = std::get<0>(invoked_method);
+ const auto& invoked_method_name = std::get<1>(invoked_method);
+ analyze(invoked_module, invoked_method_name);
+ }
// fill out various internal state which will be later used in
- // insertObservers to insert the correct observers
+ // insertObservers to insert the correct observer
addValuesToDelayObservation(module, method_name);
fillValueObserverMap(module, method_name);
- fillBoundaryValueMap(module, method_name);
+ Method method = module.get_method(method_name);
+ auto graph = method.graph();
fillPassThroughValueMap(graph);
}
@@ -1021,6 +1052,8 @@
auto qconfig = *qconfig_opt;
for (auto* v : graph->inputs()) {
if (valueNeedsToBeQuantized(v)) {
+ GRAPH_DEBUG("Recording observer for ", v->debugName());
+ GRAPH_DUMP("In graph:", v->owningGraph());
observer_for_value_[v] = getObserverModuleFor(v, qconfig);
}
}
@@ -1032,6 +1065,8 @@
for (Node* n : b->nodes()) {
for (Value* v : n->outputs()) {
if (valueNeedsToBeQuantized(v)) {
+ GRAPH_DEBUG("Recording observer for ", v->debugName());
+ GRAPH_DUMP("In graph:", v->owningGraph());
observer_for_value_[v] = getObserverModuleFor(v, qconfig);
}
}
@@ -1046,11 +1081,16 @@
c10::optional<Module> InsertObserversHelper::getObserverFor(Value* v) {
if (observer_for_value_.count(v)) {
auto observer = observer_for_value_.at(v);
+ GRAPH_DEBUG("Got observer module config for:", v->debugName());
return observer;
}
c10::optional<Module> result;
if (boundary_value_map_.count(v)) {
for (Value* next : boundary_value_map_.at(v)) {
+ GRAPH_DEBUG("Going through boundary map:", v->debugName(), " --> ",
+ next->debugName());
+ GRAPH_DUMP("From graph:", v->owningGraph());
+ GRAPH_DUMP("To graph:", next->owningGraph());
auto observer_opt = getObserverFor(next);
if (observer_opt) {
// Need to make sure all values are
@@ -1065,6 +1105,8 @@
}
}
}
+ GRAPH_DEBUG("Observer module config for ", v->debugName(), ":",
+ result.has_value());
return result;
}
@@ -1125,7 +1167,7 @@
}
for (auto* v : block->inputs()) {
- block_input_observers.push_back(getObserverFor(v));
+ block_input_observers.emplace_back(getObserverFor(v));
}
for (auto* v : block->outputs()) {
@@ -1153,11 +1195,13 @@
}
}
// NB: Why do we need to process the graph even if it's visited?
- // Reason is `graph_observed_values` can
+ // Reason is `block_observed_values` can
// change depending on where the method is called, and
// outputs that's been observed(third item of the returned result)
// can change depending on that, so for each graph we'll need to go through
- // the whole process of inserting observers
+ // the whole process of inserting observers, the observers inserted in this
+ // block won't change, but the information we return to the caller will change
+ // based on `block_observed_values`
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(block);
@@ -1369,9 +1413,14 @@
// Since the types are changed after clone, we need to fill
// the qconfig map again
fillQConfigMap(module, qconfig_dict, module_qconfig_map);
- InsertObserversHelper helper(module_qconfig_map);
- helper.setQuantType(quant_type);
+ GRAPH_DEBUG("Quant type:", quant_type);
+ InsertObserversHelper helper(module_qconfig_map, quant_type);
helper.preprocess(module, method_name);
+ helper.fillBoundaryValueMap(module, method_name);
+ // analyze needs to run after fillBoundaryValueMap
+ // since we need to know the boundary value mapping to trace
+ // through the calls
+ helper.analyze(module, method_name);
helper.insertObservers(module, method_name, /* is_entry_point */ true);
return module;
}
diff --git a/torch/csrc/jit/passes/quantization/quantization_type.cpp b/torch/csrc/jit/passes/quantization/quantization_type.cpp
new file mode 100644
index 0000000..8c8d327
--- /dev/null
+++ b/torch/csrc/jit/passes/quantization/quantization_type.cpp
@@ -0,0 +1,15 @@
+#include <torch/csrc/jit/passes/quantization/quantization_type.h>
+
+namespace torch {
+namespace jit {
+
+std::ostream& operator<<(std::ostream& os, QuantType t) {
+ switch(t) {
+ case QuantType::DYNAMIC: os << "dynamic"; break;
+ case QuantType::STATIC: os << "static"; break;
+ default: os.setstate(std::ios_base::failbit);
+ }
+ return os;
+}
+
+}}
diff --git a/torch/csrc/jit/passes/quantization/quantization_type.h b/torch/csrc/jit/passes/quantization/quantization_type.h
index b5699eb..cdcfe8f 100644
--- a/torch/csrc/jit/passes/quantization/quantization_type.h
+++ b/torch/csrc/jit/passes/quantization/quantization_type.h
@@ -1,11 +1,14 @@
#pragma once
+#include <ostream>
namespace torch {
namespace jit {
// Quantization type (dynamic quantization, static quantization).
// Should match the Python enum in quantize_script.py
-enum class QuantType { DYNAMIC, STATIC };
+enum QuantType : uint8_t { DYNAMIC = 0, STATIC };
+
+std::ostream& operator<<(std::ostream& os, QuantType t);
} // namespace jit
} // namespace torch