[quant][graphmode][refactor] Move values_to_skip check inside valueNeedsToBeQuantized (#33275)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33275
att
Test Plan:
.
Imported from OSS
Differential Revision: D20123592
fbshipit-source-id: 2b56ea8bab27eb9ea2bf792c83e48a7af8917e1a
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index d9b6129..95e0a77 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -189,23 +189,6 @@
});
}
-bool valueNeedsToBeQuantized(Value* v) {
- if (!v->type()->isSubtypeOf(TensorType::get())) {
- return false;
- }
- // Check whether producer is quantizable
- if (nodeQuantizable(v->node())) {
- return true;
- }
- // Check whether user is quantizable
- for (const auto& use : v->uses()) {
- if (nodeQuantizable(use.user)) {
- return true;
- }
- }
- return false;
-}
-
script::Module findChildModule(
const script::Module& module,
const std::vector<std::string>& path) {
@@ -435,6 +418,8 @@
script::Module& module,
const std::string& method_name);
+ bool valueNeedsToBeQuantized(Value* v);
+
void insertObserverFor(
Value* v,
script::Module& module,
@@ -755,6 +740,24 @@
}
}
+bool InsertObserversHelper::valueNeedsToBeQuantized(Value* v) {
+ if (!v->type()->isSubtypeOf(TensorType::get()) ||
+ values_to_skip_.count(v)) {
+ return false;
+ }
+ // Check whether producer is quantizable
+ if (nodeQuantizable(v->node())) {
+ return true;
+ }
+ // Check whether user is quantizable
+ for (const auto& use : v->uses()) {
+ if (nodeQuantizable(use.user)) {
+ return true;
+ }
+ }
+ return false;
+}
+
void InsertObserversHelper::insertObservers(
script::Module& module,
const std::string& method_name) {
@@ -793,7 +796,7 @@
// observing a potentially mutated value due to some in-place operation
for (size_t idx = 1; idx < method.num_inputs(); ++idx) {
auto& v = graph->inputs()[idx];
- if (!values_to_skip_.count(v) && valueNeedsToBeQuantized(v)) {
+ if (valueNeedsToBeQuantized(v)) {
insertObserverFor(v, module, qconfig_opt);
}
}
@@ -810,7 +813,7 @@
// Record all outputs in the values_to_observe - we'll later add observers
// for all values from it.
for (Value* v : n->outputs()) {
- if (!values_to_skip_.count(v) && valueNeedsToBeQuantized(v)) {
+ if (valueNeedsToBeQuantized(v)) {
values_to_observe.push_back(v);
}
}