blob: 9f477ee05ad142acbe2863f69513f0327c3380b5 [file] [log] [blame]
#include <torch/csrc/jit/passes/quantization.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/quantization_patterns.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/ir/node_hashing.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/frontend/schema_matching.h>
#include <torch/csrc/jit/ir/subgraph_matcher.h>
#include <c10/core/QScheme.h>
#include <algorithm>
#include <stack>
namespace torch {
namespace jit {
namespace {
using OptionalModuleVector = std::vector<c10::optional<Module>>;
using ModuleMethodVector = std::vector<std::pair<Module, std::string>>;
using NameModuleVector = std::vector<std::pair<std::string, Module>>;
using graph_rewrite_helper::getValue;
using graph_rewrite_helper::getIValue;
using graph_rewrite_helper::getFuncName;
using graph_rewrite_helper::replaceConvolutionWithConv2d;
// Map of quantization parameter name and value
// for example _scale, _zero_point,
// _scalar_type and _axis(for per channel quantization)
using QParamMap = std::unordered_map<std::string, IValue>;
// This struct contains a compiled IR pattens slated for use in the
// findPatternMatches function. The struct encapsulates the common
// information from parseIR that is used in conjunction with the
// pattern matching facility. A const instance of this struct can
// also be stored away to cache the compiled IR pattern and reduce
// runtime cost
struct PatternInfo {
std::string pattern_string;
std::unique_ptr<Graph> pattern_graph;
std::unordered_map<std::string, Value*> vmap;
static PatternInfo parse_from_str(std::string pattern_string) {
PatternInfo rv{
std::move(pattern_string), std::make_unique<Graph>(), decltype(vmap){}};
parseIR(rv.pattern_string, rv.pattern_graph.get(), rv.vmap);
return rv;
}
};
struct PatternsAndModules {
bool is_conv;
bool is_per_channel;
const PatternInfo& pattern;
Module packed_params_module;
};
void fillQConfigMap(
const Module& module,
const QConfigDict& qconfig_dict,
ModuleQConfigMap& map,
const std::string& key = "",
const c10::optional<QConfig>& parent_qconfig = c10::nullopt) {
c10::optional<QConfig> qconfig;
if (qconfig_dict.find(key) != qconfig_dict.end()) {
qconfig = qconfig_dict.at(key);
} else {
qconfig = parent_qconfig;
}
map[module._ivalue()] = qconfig;
for (const NameModule& s : module.named_children()) {
std::string child_key;
if (key == "") {
child_key = s.name;
} else {
child_key = key + "." + s.name;
}
fillQConfigMap(s.value._ivalue(), qconfig_dict, map, child_key, qconfig);
}
}
bool isFunctionNode(Node* n,
const std::vector<std::string>& call_funcs,
const std::vector<std::string>& aten_funcs) {
std::vector<Symbol> aten_func_symbols;
std::transform(
aten_funcs.begin(),
aten_funcs.end(),
std::back_inserter(aten_func_symbols),
[](const std::string& s) { return Symbol::aten(s); });
bool is_quantizable =
std::find(aten_func_symbols.begin(), aten_func_symbols.end(), n->kind()) !=
aten_func_symbols.end();
if (n->kind() == prim::CallFunction) {
auto func_name = getFuncName(n->inputs()[0]);
is_quantizable |=
std::find(call_funcs.begin(), call_funcs.end(), func_name) !=
call_funcs.end();
}
return is_quantizable;
}
// If the op doesn't require observation, return
// the the list of input indexes that we should check to see
// if they are observed/quantized, if so, we can say the output
// of this op is observed/quantized as well, since for these ops we can derive
// the quantization parameters for output given inputs
std::vector<size_t> getGeneralOpTensorInputIndexes(Node* n) {
std::vector<std::string> single_input_aten_funcs = {
"max_pool2d",
"avg_pool2d",
"flatten",
"max",
"min",
"mean",
"upsample_nearest1d",
"upsample_nearest2d",
"upsample_nearest3d",
"adaptive_avg_pool1d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"upsample_linear1d",
"upsample_bilinear2d",
"upsample_trilinear3d",
"upsample_bicubic2d",
// TODO: sort returns a tuple of Tensors, we have
// to extend the API to support that
// "sort",
};
std::vector<std::string> single_input_call_funcs = {
"adaptive_avg_pool2d",
"_max_pool2d",
};
if (isFunctionNode(
n,
// We don't have call functions
// after inline
/* call_funcs = */ single_input_call_funcs,
/* aten_funcs = */ {})) {
return {1};
} else if (isFunctionNode(
n,
// We don't have call functions
// after inline
/* call_funcs = */ {},
/* aten_funcs = */ single_input_aten_funcs)) {
return {0};
}
return {};
}
bool nodeQuantizable(Node* n) {
return isFunctionNode(
n,
/* call_funcs = */ {
"conv2d",
"linear",
"relu",
}, /* aten_funcs = */ {
"conv2d",
"linear",
"relu",
"addmm",
"matmul",
"add_"
});
}
Module findChildModule(
const Module& module,
const std::vector<std::string>& path) {
Module m = module;
for (const auto& p : path) {
m = m.attr(p).toModule();
}
return m;
}
// Check if value is the input of the graph
bool hitGraphInput(Value* value) {
Graph* graph = value->owningGraph();
const auto& inputs = graph->inputs();
return std::find(inputs.begin(), inputs.end(), value) != inputs.end();
}
// Get the module access path for a Value representing a module instance
// by tracing back the GetAttr nodes and recording all the attribute
// names along the way.
// For example, the module access path will be ['conv1', 'basic_block', 'sub']
// for `self.sub.basic_block.conv1`
std::vector<std::string> getModuleAccessPath(Value* instance, Value* self) {
std::vector<std::string> path;
// Iterator to traverse back the GetAttr calls
Value* iter = instance;
// trace back the instance to recover the path of the submodule
while (!hitGraphInput(iter) && iter->node()->kind() == prim::GetAttr) {
Node* get_attr = iter->node();
// record the name of GetAttr
path.push_back(get_attr->s(attr::name));
// trace back the chain of GetAttr
iter = get_attr->inputs()[0];
}
TORCH_CHECK(iter == self,
"Can't handle the access pattern of GetAttr "
" in getModuleAccessPath, traced back to:",
iter->debugName(),
" which is not self:",
self->debugName());
return path;
}
Module getInvokedModule(
Module& module, Node* n, Value* self) {
auto* instance = n->inputs()[0];
auto path = getModuleAccessPath(instance, self);
return findChildModule(module, path);
}
class ModuleCloneHelper {
public:
/** Clone according to module qconfig map, this is for handling the case
* where we have two module instances sharing the same ClassType
* but configured with different QConfig
* code is copied and modified from https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/api/module.cpp
*/
Module clone(
const Module& module,
const ModuleQConfigMap& module_qconfig_map) {
std::unordered_map<TypePtr, QConfigTypePtrMap> type_remap;
return clone_impl(module, module_qconfig_map, type_remap);
}
private:
Module clone_impl(
const Module& module,
const ModuleQConfigMap& module_qconfig_map,
std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap) {
auto qconfig = module_qconfig_map.at(module._ivalue());
auto type = module.type();
// Create a new _ivalue in the same compilation unit.
// Since now we have shared ClassType, we need to preserve the shared
// ClassType during cloning, so we first use type and qconfig to check if
// the type is already cloned, if so, we'll create a new module with the
// cloned ClassType, if not, we'll create a new module and a new ClassType.
bool type_already_cloned = type_remap.find(type) != type_remap.end() &&
type_remap.at(type).find(qconfig) != type_remap.at(type).end();
Module r;
if (type_already_cloned) {
// if we cloned the class type before, we'll reuse it
Module new_module(
module._ivalue()->compilation_unit(),
type_remap.at(type).at(qconfig)->cast<ClassType>());
r = new_module;
} else {
Module new_module(
*type->name(), module._ivalue()->compilation_unit(), true);
r = new_module;
type_remap[type][module_qconfig_map.at(module._ivalue())] = r.type();
}
// Copy slots. If a slot is a module - recursively clone it.
size_t N = type->numAttributes();
for (size_t i = 0; i < N; ++i) {
IValue s = module._ivalue()->getSlot(i);
if (type->getAttribute(i)->is_module()) {
const Module& orig = Module(s.toObject());
Module cloned =
clone_impl(orig, module_qconfig_map, type_remap);
r.register_module(type->getAttributeName(i), cloned);
} else {
r.register_attribute(
type->getAttributeName(i),
type->getAttribute(i),
s,
type->is_parameter(i));
}
}
// only clone the methods and constants if the ClassType is not cloned
// before
if (!type_already_cloned) {
for (size_t i = 0; i < type->numConstants(); ++i) {
r.type()->addConstant(type->getConstantName(i), type->getConstant(i));
}
// Clone methods remapping the types to the cloned ones.
for (auto& fn : type->methods()) {
clone_method(module, r, *fn, module_qconfig_map, type_remap);
}
}
return r;
}
void remapTypes(
Block* block,
Value* self,
const Module& source,
Module& target,
const ModuleQConfigMap& module_qconfig_map,
const std::function<TypePtr(TypePtr, c10::optional<QConfig>)>&
type_remap_fn) {
// remap of %self will be done outside of the function
// and we don't support the case when people pass in
// module as argument of the method because in that case
// we need to do more comprehensive analysis to decide the
// QConfig for the module
for (size_t i = 1; i < block->inputs().size(); ++i) {
TORCH_CHECK(
!block->inputs()[i]->type()->cast<ClassType>(),
"We don't support quantizing methods that has Object as arguments");
}
for (Node* node : block->nodes()) {
// remapping type for module instance
if (node->kind() == prim::CallMethod) {
Value* instance = node->inputs()[0];
auto path = getModuleAccessPath(instance, self);
auto child = findChildModule(source, path);
auto qconfig = module_qconfig_map.at(child._ivalue());
instance->setType(type_remap_fn(instance->type(), qconfig));
}
// We don't remap output and the remapping of module type
// will be done in CallMethod, we don't support type remapping
// for modules returned from methods or functions
for (Block* sub_block : node->blocks()) {
remapTypes(
sub_block, self, source, target, module_qconfig_map, type_remap_fn);
}
for (Symbol name : node->attributeNames()) {
if (node->kindOf(name) == AttributeKind::g) {
remapTypes(
node->g(name).get(),
source,
target,
module_qconfig_map,
type_remap_fn);
} else if (node->kindOf(name) == AttributeKind::gs) {
for (const auto& g : node->gs(name)) {
remapTypes(
g.get(), source, target, module_qconfig_map, type_remap_fn);
}
}
}
}
}
void remapTypes(
Graph* graph,
const Module& source,
Module& target,
const ModuleQConfigMap& module_qconfig_map,
const std::function<TypePtr(TypePtr, c10::optional<QConfig>)>&
type_remap_fn) {
remapTypes(
graph->block(),
graph->inputs()[0],
source,
target,
module_qconfig_map,
type_remap_fn);
}
void clone_method(
const Module& source,
Module& target,
const Function& method,
const ModuleQConfigMap& module_qconfig_map,
const std::unordered_map<TypePtr, QConfigTypePtrMap>& type_remap) {
auto type_remap_fn = [&](TypePtr type_ptr,
const c10::optional<QConfig>& qconfig) {
if (type_remap.find(type_ptr) != type_remap.end()) {
const auto& qconfig_map = type_remap.at(type_ptr);
if (qconfig_map.find(qconfig) != qconfig_map.end()) {
return qconfig_map.at(qconfig);
}
}
return type_ptr;
};
auto graph = method.graph()->copy();
remapTypes(graph.get(), source, target, module_qconfig_map, type_remap_fn);
// remap self
graph->inputs()[0]->setType(target.type());
const auto this_method_name =
c10::QualifiedName(*target.type()->name(), method.name());
auto copied = target._ivalue()->compilation_unit()->create_function(
this_method_name, graph);
target.type()->addMethod(copied);
// we'll use default schema for cloned method
}
};
class InsertObserversHelper {
public:
explicit InsertObserversHelper(const ModuleQConfigMap& map)
: module_qconfig_map_(map) {}
void preprocess(
Module& module,
const std::string& method_name);
/**
* Recursively insert observers for the method, also we'll process
* the nodes in the graph in the order of execution of these nodes
* since we need the context information to decide whether we want to
* observe/quantize a value a not, we don't want to observe a value multiple
* times.
*
* arguemnt: is_entry_point means whether the current method is the forward
* method of the top level module.
*
*Since we want to insert observers in the call site instead of in the called
* graph, we'll postpone inserting observer to caller as much as possible, if
* we know the current method is the outer most method, then
* we will insert all observers in the graph instead of postpone this to the
* parent, note that this assumes we don't have recurisve method
* calls
*
* returns a tuple of vectors of observer modules for input and output, these
* are used for inserting observers for the input/output values
* since we need to insert these values at call site.
* And a vector of indexes of outputs that indicates whether the output value
* is already observed or not, this is used for propagating the observed
* property of a value through CallMethods, because we should skip inserting
* observers for ops that don't require observation
*/
std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>>
insertObservers(
Module& module,
const std::string& method_name,
bool is_entry_point = false,
std::unordered_set<Value*> graph_observed_values =
std::unordered_set<Value*>());
private:
ModuleMethodVector getInvokedMethods(
Module& module,
const std::string& method_name);
bool valueNeedsToBeQuantized(Value* v);
// Fill the map between the caller input/output to input/output
// of called graph, this is used to navigate through the graph
// to find the observer for a given value
void fillBoundaryValueMap(
Module& module, const std::string& method_name);
// Fill the map from value to the corresponding observer module
// this map is used in insertObservers to actually insert
// observers to the module
void fillValueObserverMap(Module& module, const std::string& method_name);
// Clone observer module and add it to the original module,
// and insert a call to observer forward function
void insertObserverFor(
Value* v,
Module& module,
const Module& observer_module,
NameModuleVector& observer_name_and_modules);
c10::optional<Module> getObserverFor(Value* v);
void propagateObservedProperty(Value* output, std::unordered_set<Value*>& graph_observed_values);
void skipValuesInPattern(
Graph& graph,
const PatternInfo& pattern);
void addIntermediateValuesToSkipObserver(
const Module& module,
const std::string& method_name);
// Fill the map from values to the list of values that can pass the observed
// property to it
void fillPassThroughValueMap(const std::shared_ptr<Graph>& graph);
const ModuleQConfigMap& module_qconfig_map_;
// Values we want to skip observing, used to skip values in
// the middle of the ops that are supposed to be fused, e.g.
// the output value of conv in the conv - relu pattern
std::unordered_set<Value*> values_to_skip_;
std::unordered_set<Graph*> visited_graph_of_observer_map_;
std::unordered_map<Value*, Module> observer_for_value_;
std::unordered_map<Value*, Value*> caller_to_callee_;
// Map from values from callsite into the values in the CallMethod graph
std::unordered_map<Value*, std::unordered_set<Value*>> boundary_value_map_;
std::unordered_set<Value*> observed_values_;
// This is used for the observed values to pass through the ops like flatten,
// so that output value of platten do not need to be observed
// key of the map is the value from caller graph, and the value of the map
// is the list of values in the callee graph (the graph
// corresponding to the called method),
// the reason it is a vector is that a value in the caller graph
// can both correspond to the output of one callee graph and input of another
// callee graph.
std::unordered_map<Value*, std::vector<Value*>> pass_through_value_map_;
// Unique id generator for observer module, used for generating
// unique observer names when we insert observer module, we
// record the current unique id used to avoid incrementing from 0
// every time to find a unique id.
int uid_ = 0;
// Set of observer forward call nodes
std::unordered_set<Node*> observer_nodes_;
// Map from graph to a vector of observer name and observer modules we
// want to add to the module instance that has the graph
std::unordered_map<Graph*, NameModuleVector> graph_observer_map_;
// These are the IR patterns we match to skip inserting observers.
// They are compiled once on construction and used repeatedly within
// the pass.
const PatternInfo conv_functional_relu = PatternInfo::parse_from_str(R"(
graph(%self, %input, %inplace):
%relu = prim::Constant[name="relu"]()
%first_module = match::module[name="Conv2d"](%self)
%first_output = prim::CallMethod[name="forward"](%first_module, %input)
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )");
const PatternInfo conv_relu = PatternInfo::parse_from_str(R"(
graph(%self, %input):
%first_module = match::module[name="Conv2d"](%self)
%first_output = prim::CallMethod[name="forward"](%first_module, %input)
%second_module = match::module[name="ReLU"](%self)
%second_output = prim::CallMethod[name="forward"](%second_module, %first_output)
return (%second_output) )");
const PatternInfo matmul_add = PatternInfo::parse_from_str(R"(
graph(%input, %weight, %bias, %4):
%weight_t = aten::t(%weight)
%first_output = aten::matmul(%input, %weight_t)
%second_output = aten::add_(%first_output, %bias, %4)
return (%second_output) )");
const PatternInfo add_module_relu = PatternInfo::parse_from_str(R"(
graph(%self, %a, %b):
%one = prim::Constant[value=1]()
%first_output = aten::add_(%a, %b, %one)
%second_module = match::module[name="ReLU"](%self)
%second_output = prim::CallMethod[name="forward"](%second_module, %first_output)
return (%second_output) )");
const PatternInfo add_functional_relu = PatternInfo::parse_from_str(R"(
graph(%self, %a, %b, %inplace):
%one = prim::Constant[value=1]()
%first_output = aten::add_(%a, %b, %one)
%relu = prim::Constant[name="relu"]()
%second_output = prim::CallFunction(%relu, %first_output, %inplace)
return (%second_output) )");
const std::vector<std::reference_wrapper<const PatternInfo>> skip_patterns = {
conv_functional_relu,
conv_relu,
matmul_add,
add_module_relu,
add_functional_relu,
};
};
// Check if `use` is an aten function of name `func_name` and if value
// `v` is the nth argument of the function
bool isAtenFuncNthArg(
Value* v,
Node* use,
const std::string& func_name,
int n) {
return use->kind() == Symbol::aten(func_name) && v == use->inputs().at(n);
}
// Check if `use` is a CallFunction of name `func_name` and if value
// `v` is the nth argument of the function
bool isCallFunctionNthArg(
Value* v,
Node* use,
const std::string& func_name,
int n) {
return use->kind() == prim::CallFunction &&
getFuncName(use->inputs()[0]) == func_name && v == use->inputs().at(n);
}
struct FuncArg {
std::string func_name;
int arg_index;
};
using AtenFuncArgs = std::vector<FuncArg>;
using CallFuncArgs = std::vector<FuncArg>;
// Check any use of `v` matches the aten function call
// or CallFunction patterns
bool matchArgPattern(
Value* v,
const AtenFuncArgs& aten_func_args,
const CallFuncArgs& call_func_args) {
for (const Use& u : v->uses()) {
for (const auto& func_arg : aten_func_args) {
if (isAtenFuncNthArg(v, u.user, func_arg.func_name, func_arg.arg_index)) {
return true;
}
}
for (const auto& func_arg : call_func_args) {
if (isCallFunctionNthArg(
v, u.user, func_arg.func_name, func_arg.arg_index)) {
return true;
}
}
}
return false;
}
bool isBiasOfConvOrLinear(Value* v) {
bool result = matchArgPattern(
v,
AtenFuncArgs({{"conv2d", 2}, {"linear", 2}}),
CallFuncArgs({{"linear", 3}}));
if (result) {
TORCH_CHECK(
v->uses().size() == 1,
"Graph mode quantization only supports conv/linear bias being used by"
" one node.");
}
return result;
}
bool isWeightOfConvOrLinear(Value* v) {
bool result = matchArgPattern(
v,
AtenFuncArgs({{"conv2d", 1}, {"linear", 1}}),
CallFuncArgs({{"linear", 2}}));
if (result) {
TORCH_CHECK(
v->uses().size() == 1,
"Graph mode quantization only supports conv/linear weight being used by"
" one node.");
}
return result;
}
Module getObserverModuleFor(Value* v, const QConfig& qconfig) {
return
isWeightOfConvOrLinear(v) ? std::get<1>(qconfig) : std::get<0>(qconfig);
}
ModuleMethodVector InsertObserversHelper::getInvokedMethods(
Module& module,
const std::string& method_name) {
ModuleMethodVector invoked_methods;
Method method = module.get_method(method_name);
auto graph = method.graph();
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
// Skip observer nodes
if (observer_nodes_.count(n)) {
continue;
}
if (n->kind() == prim::CallMethod) {
invoked_methods.push_back(std::make_pair(getInvokedModule(module, n, graph->inputs()[0]), n->s(attr::name)));
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
return invoked_methods;
}
void InsertObserversHelper::insertObserverFor(
Value* v,
Module& module,
const Module& observer_module,
NameModuleVector& observer_name_and_modules) {
if (observed_values_.count(v)) {
return;
}
Module observer = observer_module.clone_instance();
std::string observer_name = "_observer_" + c10::to_string(uid_++);
while (module.hasattr(observer_name)) {
observer_name = "_observer_" + c10::to_string(uid_++);
}
module.register_module(observer_name, observer);
observer_name_and_modules.push_back(std::make_pair(observer_name, observer));
auto* g = v->owningGraph();
// Get handle of observer module
Node* observer_instance =
g->createGetAttr(g->inputs()[0], observer_name)->insertAfter(v->node());
observer_instance->output()->setDebugName(observer_name);
{
WithInsertPoint guard(observer_instance->next());
// Match arguments to types of observer's arguments
MatchedSchema forward_matched_schema = matchSchema(
observer.get_method("forward").function().getSchema(),
v->node()->sourceRange(),
*g,
{observer_instance->output(), v},
{});
// Insert call to observer's forward
Node* call = g->insertMethodCall("forward", forward_matched_schema)->node();
call->output()->copyMetadata(v);
// Replace v with the output of observer
v->replaceAllUsesWith(call->output());
// The above also replaced the input to `call`, so switch it back to
// the correct value
call->replaceInput(1, v);
observer_nodes_.emplace(call);
}
observed_values_.insert(v);
}
void InsertObserversHelper::skipValuesInPattern(
Graph& graph,
const PatternInfo& pattern) {
const Graph& pattern_graph = *pattern.pattern_graph;
const std::unordered_map<std::string, Value*>& vmap = pattern.vmap;
const auto& matches = findPatternMatches(pattern_graph, graph);
for (const auto& match : matches) {
auto output_value = match.values_map.at(vmap.at("first_output"));
GRAPH_DEBUG("Skipping value in function pattern:",
output_value->debugName());
values_to_skip_.insert(output_value);
}
}
void InsertObserversHelper::addIntermediateValuesToSkipObserver(
const Module& module,
const std::string& method_name) {
Method method = module.get_method(method_name);
auto graph = method.graph();
for (const auto& pattern : skip_patterns) {
skipValuesInPattern(*graph, pattern);
}
}
void InsertObserversHelper::fillPassThroughValueMap(const std::shared_ptr<Graph>& graph) {
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
auto input_indexes = getGeneralOpTensorInputIndexes(n);
for (auto i : input_indexes) {
for (auto* output : n->outputs()) {
pass_through_value_map_[output].push_back(n->input(i));
}
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
}
void InsertObserversHelper::fillBoundaryValueMap(
Module& module, const std::string& method_name) {
auto graph = module.get_method(method_name).graph();
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
auto* self = graph->inputs()[0];
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
if (n->kind() == prim::CallMethod) {
auto m = getInvokedModule(module, n, self);
auto g = m.get_method(n->s(attr::name)).graph();
// add mapping from callsite value to value in called graph
for (auto i = 0; i < g->outputs().size(); ++i) {
auto* return_val = g->outputs()[i];
boundary_value_map_[n->output(i)].insert(return_val);
}
for (auto i = 0; i < g->inputs().size(); ++i) {
auto* input_val = g->inputs()[i];
boundary_value_map_[n->input(i)].insert(input_val);
caller_to_callee_[n->input(i)] = input_val;
}
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
}
void InsertObserversHelper::preprocess(
Module& module,
const std::string& method_name) {
Method method = module.get_method(method_name);
auto graph = method.graph();
// TODO: remove constant prop, add separate graph
// cleanup step before insert observers
// To cleanup traced graph
ConstantPooling(graph);
ConstantPropagation(graph);
// must do constant propagation first before replacement
replaceConvolutionWithConv2d(graph);
// fuse decomposed linear into aten::linear
FuseLinear(graph);
addIntermediateValuesToSkipObserver(module, method_name);
fillValueObserverMap(module, method_name);
fillBoundaryValueMap(module, method_name);
fillPassThroughValueMap(graph);
for (auto& invoked_method : getInvokedMethods(module, method_name)) {
auto& invoked_module = std::get<0>(invoked_method);
const auto& invoked_method_name = std::get<1>(invoked_method);
preprocess(invoked_module, invoked_method_name);
}
}
bool InsertObserversHelper::valueNeedsToBeQuantized(Value* v) {
if (!v->type()->isSubtypeOf(TensorType::get()) ||
isBiasOfConvOrLinear(v)) {
return false;
}
// Check whether producer is quantizable
if (nodeQuantizable(v->node())) {
return true;
}
// Check whether user is quantizable
for (const auto& use : v->uses()) {
if (nodeQuantizable(use.user)) {
return true;
}
}
return false;
}
void InsertObserversHelper::fillValueObserverMap(
Module& module,
const std::string& method_name) {
Method method = module.get_method(method_name);
auto graph = method.graph();
if (visited_graph_of_observer_map_.count(graph.get())) {
return;
}
visited_graph_of_observer_map_.insert(graph.get());
std::stack<Block*> blocks_to_visit;
auto qconfig_opt = module_qconfig_map_.at(module._ivalue());
if (!qconfig_opt) {
return;
}
auto qconfig = *qconfig_opt;
for (auto* v : graph->inputs()) {
if (valueNeedsToBeQuantized(v)) {
observer_for_value_[v] = getObserverModuleFor(v, qconfig);
}
}
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
for (Value* v : n->outputs()) {
if (valueNeedsToBeQuantized(v)) {
observer_for_value_[v] = getObserverModuleFor(v, qconfig);
}
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
}
c10::optional<Module>
InsertObserversHelper::getObserverFor(Value* v) {
if (observer_for_value_.count(v)) {
auto observer = observer_for_value_.at(v);
return observer;
}
c10::optional<Module> result;
if (boundary_value_map_.count(v)) {
for (Value* next : boundary_value_map_.at(v)) {
auto observer_opt = getObserverFor(next);
if (observer_opt) {
// Need to make sure all values are
// configured with same observer
if (result) {
TORCH_CHECK(
*observer_opt == *result,
"Expecting all values in the graph only configured with one observer");
} else {
result = observer_opt;
}
}
}
}
return result;
}
std::tuple<OptionalModuleVector, OptionalModuleVector, std::vector<size_t>> InsertObserversHelper::insertObservers(
Module& module,
const std::string& method_name,
bool is_entry_point,
std::unordered_set<Value*> graph_observed_values) {
auto graph = module.get_method(method_name).graph();
// graph input/output values, used to skip inserting observers
// for input and output of the graph, we have to insert the observers
// at call site because the graph itself can be shared
std::unordered_set<Value*> graph_inputs_outputs;
// list of observer modules for input values
std::vector<c10::optional<Module>> graph_input_observers;
// list of observer modules for output values
std::vector<c10::optional<Module>> graph_output_observers;
// if the current graph is the entry point graph(the forward graph
// of the top level module), we can insert observers in the graph
if (!is_entry_point) {
for (auto* v : graph->inputs()) {
graph_inputs_outputs.insert(v);
graph_input_observers.push_back(getObserverFor(v));
}
for (auto* v : graph->outputs()) {
graph_inputs_outputs.insert(v);
graph_output_observers.push_back(getObserverFor(v));
}
}
// This means the graph is been processed before, we just
// need to attach observer modules and construct the information
// needed by call site here
bool visited = graph_observer_map_.count(graph.get());
if (visited) {
// instance clone of observer module and setAttr
for (const auto& observer_attrs : graph_observer_map_.at(graph.get())) {
const auto& name = std::get<0>(observer_attrs);
const auto& observer = std::get<1>(observer_attrs);
module._ivalue()->setAttr(name, observer.clone_instance()._ivalue());
}
}
// NB: Why do we need to process the graph even if it's visited?
// Reason is `graph_observed_values` can
// change depending on where the method is called, and
// outputs that's been observed(third item of the returned result)
// can change depending on that, so for each graph we'll need to go through
// the whole process of inserting observers
GRAPH_DUMP("inserting observer for:", graph);
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
auto* self = graph->inputs()[0];
// We first construct a map from value to the module, then
// insert observers for them later, this is to avoid interference
// of the inserted observers with the analysis to decide where
// to insert observers, also we only insert observers for
// "intermediate values" that is not the input/output of the
// graph
std::unordered_map<Value*, Module> values_to_observe;
for (auto* v : graph->inputs()) {
if (!graph_inputs_outputs.count(v) && !values_to_observe.count(v)) {
if (auto observer_opt = getObserverFor(v)) {
values_to_observe[v] = *observer_opt;
}
}
}
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
if (observer_nodes_.count(n)) {
continue;
}
if (n->kind() == prim::CallMethod) {
auto m = getInvokedModule(module, n, self);
std::unordered_set<Value*> callee_observed_inputs;
for (auto i = 0; i < n->inputs().size(); ++i) {
if (graph_observed_values.count(n->inputs()[i])) {
callee_observed_inputs.insert(caller_to_callee_[n->inputs()[i]]);
}
}
auto info_from_callee = insertObservers(m, n->s(attr::name), false, callee_observed_inputs);
auto input_observers = std::get<0>(info_from_callee);
auto output_observers = std::get<1>(info_from_callee);
auto callee_observed_outputs = std::get<2>(info_from_callee);
for (auto idx : callee_observed_outputs) {
graph_observed_values.insert(n->outputs()[idx]);
}
for (auto i = 0; i < n->inputs().size(); ++i) {
if (input_observers[i] && !graph_inputs_outputs.count(n->inputs()[i])
&& !graph_observed_values.count(n->inputs()[i])) {
values_to_observe[n->inputs()[i]] = *input_observers[i];
graph_observed_values.insert(n->inputs()[i]);
}
}
for (auto i = 0; i < n->outputs().size(); ++i) {
if (output_observers[i] && !graph_inputs_outputs.count(n->outputs()[i])
&& !graph_observed_values.count(n->outputs()[i])) {
values_to_observe[n->outputs()[i]] = *output_observers[i];
graph_observed_values.insert(n->outputs()[i]);
}
}
} else {
for (Value* v : n->outputs()) {
propagateObservedProperty(v, graph_observed_values);
if (!graph_inputs_outputs.count(v) && !graph_observed_values.count(v)) {
if (auto observer_opt = getObserverFor(v)) {
values_to_observe[v] = *observer_opt;
graph_observed_values.insert(v);
}
}
}
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
std::vector<size_t> output_idxs;
for (auto i = 0; i < graph->outputs().size(); ++i) {
if (graph_observed_values.count(graph->outputs()[i])) {
output_idxs.push_back(i);
}
}
if (!visited) {
NameModuleVector observer_name_and_modules;
for (const auto& item : values_to_observe) {
auto* v = item.first;
auto observer = item.second;
if (!values_to_skip_.count(v)) {
insertObserverFor(v, module, observer, observer_name_and_modules);
}
}
graph_observer_map_[graph.get()] = observer_name_and_modules;
}
return std::make_tuple(graph_input_observers, graph_output_observers, output_idxs);
}
void InsertObserversHelper::propagateObservedProperty(
Value* output, std::unordered_set<Value*>& graph_observed_values) {
if (pass_through_value_map_.count(output)) {
// since the vector is always non-empty, we will
// not return the initial value
bool all_observed = true;
for (Value* v : pass_through_value_map_.at(output)) {
all_observed &= observed_values_.count(v) || graph_observed_values.count(v);
}
if (all_observed) {
// This is to propagate observed property through
// all ops that doesn't require observation
graph_observed_values.insert(output);
}
}
}
void insertDeQuantCall(Graph* graph,
Value* quantized_val,
Value* original_val,
const std::vector<Use>& uses) {
for (size_t i = 0; i < uses.size(); ++i) {
Node* dequant =
graph->create(Symbol::aten("dequantize"), {quantized_val});
dequant->output()->setDebugName(
original_val->debugName() + ".dequant." + c10::guts::to_string(i));
uses[i].user->replaceInputWith(original_val, dequant->output());
graph->insertNode(dequant);
}
}
void insertQuantDeQuantCall(Value* self, Node* observer, bool is_per_channel) {
Graph* g = observer->owningGraph();
// Original value that is observed
Value* v = observer->input(1);
std::string quantize_func;
std::vector<Value*> inputs = {v};
// Inserting before insert point
WithInsertPoint ins(v->node()->next());
std::string prefix = v->debugName();
// Insert GetAttr nodes for quantization parameters
if (is_per_channel) {
quantize_func = "quantize_per_channel";
inputs.push_back(g->insertGetAttr(self, prefix + "_scale"));
inputs.push_back(g->insertGetAttr(self, prefix + "_zero_point"));
inputs.push_back(g->insertGetAttr(self, prefix + "_axis"));
} else {
quantize_func = "quantize_per_tensor";
inputs.push_back(
g->insertGetAttr(self, prefix + "_scale")->setType(FloatType::get()));
inputs.push_back(g->insertGetAttr(self, prefix + "_zero_point")
->setType(IntType::get()));
}
inputs.push_back(
g->insertGetAttr(self, prefix + "_scalar_type")->setType(IntType::get()));
Node* quant = g->create(at::Symbol::aten(quantize_func), inputs);
quant->output()->setDebugName(v->debugName() + ".quant");
g->insertNode(quant);
// two passes to insert the dequant for every usage
// in first pass, identify all the nodes using "v"
std::vector<Use> uses;
for (const auto& use : v->uses()) {
// Skip quant node and observer node (we need to keep
// observer nodes around since we need them to
// find the quantization parameters)
if (use.user != quant && use.user != observer) {
uses.push_back(use);
}
}
// in second pass, replace the input "v" with dequant output
insertDeQuantCall(g, quant->output(), v, uses);
}
// find the observer for Value `v` and return the name of the observer
c10::optional<std::string> findObserverName(Value* v) {
// Note that here we just check for the name of observer, but the ideally
// we should be comparing the type of observer, this is a temporary
// work around until data only clone of module.clone is supported.
Node* n = v->node();
if (n->kind() == prim::CallMethod && n->s(attr::name) == "forward") {
auto module_instance = n->inputs().at(0);
if (module_instance->node()->kind() == prim::GetAttr &&
module_instance->node()->s(attr::name).find("_observer_") !=
std::string::npos) {
return module_instance->node()->s(attr::name);
}
}
return c10::nullopt;
}
c10::QScheme toAffine(c10::QScheme qscheme) {
switch (qscheme) {
case c10::kPerTensorAffine:
case c10::kPerTensorSymmetric:
return c10::kPerTensorAffine;
case c10::kPerChannelAffine:
case c10::kPerChannelSymmetric:
return c10::kPerChannelAffine;
default:
return qscheme;
}
}
class InsertQuantDeQuantHelper {
public:
InsertQuantDeQuantHelper() {}
void run(Module& module, const std::string& method_name);
ModuleMethodVector getInvokedMethods(
Module& module,
const std::string& method_name);
// Get quantization parameter map of the given Value in Graph
// by searching for observer module of the value and extract the
// quantization parameters from the observer module
std::tuple<c10::QScheme, QParamMap> getQSchemeAndQParamMap(
Module& module,
Node* n);
void checkQScheme(Graph* g, c10::QScheme qscheme) {
if (qscheme_for_graph_.count(g)) {
TORCH_CHECK(
qscheme_for_graph_.at(g) == qscheme ||
"Quantizing same graph with different types of "
"QSchemes is not supported.\n",
" Expecting:",
c10::toString(qscheme_for_graph_.at(g)),
" Got:",
c10::toString(qscheme));
} else {
qscheme_for_graph_[g] = toAffine(qscheme);
}
}
c10::optional<Module> findChildModuleToQuantize(
Module& module,
Value* child_instance);
void collectObserverNodesAndValueToQuantize(Module& module, Value*);
// Cleanup observer nodes from graph and observer modules
// from module object and ClassType
void cleanup(Module& module);
void cleanup(Module& module, Graph* g);
void quantizeTensors(Module& module, Graph* g, Value* self);
private:
std::unordered_map<Graph*, std::vector<std::string>>
observer_modules_to_remove_;
// We only remove observer module attributes from type in the
// first encounter of the graph, after that since the attributes
// is already removed from the ClassType, we'll use the list of slot index to
// replay this removal
std::unordered_map<Graph*, std::vector<int>> removed_observer_slots_;
std::unordered_map<Graph*, std::vector<Node*>> nodes_to_destroy_;
// Map from Graph to observer node, we can use observer node to
// get the information of original value that's been observed and
// the quantization parameters
std::unordered_map<Graph*, std::vector<Node*>> observer_nodes_;
// Record qscheme for every graph, this is for checking
// each graph is only quantized with one type of QScheme
std::unordered_map<Graph*, c10::QScheme> qscheme_for_graph_;
};
void InsertQuantDeQuantHelper::collectObserverNodesAndValueToQuantize(
Module& module,
Value* v) {
auto* g = v->owningGraph();
auto observer_name = findObserverName(v);
if (!observer_name) {
return;
}
observer_modules_to_remove_[g].push_back(observer_name.value());
Node* observer = v->node();
TORCH_INTERNAL_ASSERT(
observer->kind() == prim::CallMethod &&
observer->s(attr::name) == "forward" &&
observer->inputs()[0]->node()->kind() == prim::GetAttr &&
observer->inputs()[0]->node()->s(attr::name) == observer_name);
// Observer forward call node
nodes_to_destroy_[g].push_back(observer);
// GetAttr node for observer module
nodes_to_destroy_[g].push_back(observer->inputs()[0]->node());
Value* original_value = observer->input(1);
v->replaceAllUsesWith(original_value);
observer_nodes_[g].push_back(observer);
}
void InsertQuantDeQuantHelper::cleanup(Module& module) {
for (auto& method : module.get_methods()) {
cleanup(module, method.graph().get());
}
for (Module m : module.children()) {
cleanup(m);
}
}
void InsertQuantDeQuantHelper::cleanup(Module& module, Graph* g) {
GRAPH_DUMP("Before Remove Observers:", g);
if (nodes_to_destroy_.count(g)) {
for (auto& n : nodes_to_destroy_.at(g)) {
n->removeAllInputs();
}
for (auto& n : nodes_to_destroy_.at(g)) {
n->destroy();
}
nodes_to_destroy_.at(g).clear();
}
// 1. If we have seen this graph before, this means the observer
// attributes has been removed from the type(see step 2) but the slot
// index of these attributes are kept in the list, we'll replay the observer
// slots removal using these slot indexes
if (removed_observer_slots_.count(g)) {
for (auto slot : removed_observer_slots_.at(g)) {
module._ivalue()->unsafeRemoveSlot(slot);
}
}
// 2. Remove observer modules from last one to first one in order to
// reduce the time complexity, assuming all the observer modules
// are added after the existing modules, we'll have complexity of
// O(N) where N is number of observer modules with this optimization
if (observer_modules_to_remove_.count(g)) {
auto& observers = observer_modules_to_remove_.at(g);
for (int64_t i = observers.size() - 1; i >= 0; --i) {
auto observer_name = observers[i];
GRAPH_DEBUG("Trying to remove: ", observer_name);
if (module.type()->hasAttribute(observer_name)) {
// We record the slot index here in order to replay the
// slot removal in other objects that's sharing the ClassType
// since we're going to remove attribute in the ClassType here
removed_observer_slots_[g].push_back(
module.type()->getAttributeSlot(observer_name));
module._ivalue()->unsafeRemoveAttr(observer_name);
module.type()->unsafeRemoveAttribute(observer_name);
}
}
observers.clear();
}
GRAPH_DUMP("After remove observers :", g);
}
void InsertQuantDeQuantHelper::quantizeTensors(
Module& module,
Graph* g,
Value* self) {
if (!observer_nodes_.count(g)) {
return;
}
for (auto* n : observer_nodes_.at(g)) {
auto* original_value = n->input(1);
auto tp = getQSchemeAndQParamMap(module, n);
checkQScheme(g, std::get<0>(tp));
auto qparam_map = std::get<1>(tp);
for (auto& pr : qparam_map) {
const auto& name = pr.first;
const auto& qparam = pr.second;
module.register_attribute(
original_value->debugName() + name, qparam.type(), qparam);
}
bool is_per_channel = qparam_map.at("_scale").isTensor();
insertQuantDeQuantCall(self, n, is_per_channel);
}
}
void checkGetQParamsResult(const IValue& qparams) {
TORCH_CHECK(
qparams.isTuple(),
"`get_qparams` function is expected to return a "
"Tuple, but got:",
qparams.tagKind());
auto tp = qparams.toTuple();
TORCH_CHECK(
tp->elements().size() == 2 || tp->elements().size() == 3,
"`get_qparams` function is expected to return a "
"Tuple of size 2 or 3, got Tuple of size ",
tp->elements().size());
// Expect first two elements of the tuple to be Tensor
for (size_t i = 0; i < 2; ++i) {
TORCH_CHECK(
tp->elements()[i].isTensor(),
"Element of Tuple is expected to be Tensor, but element ",
i,
" has type: ",
tp->elements()[i].tagKind());
}
// Expect the third elements of the tuple to be int
if (tp->elements().size() == 3) {
TORCH_CHECK(
tp->elements()[2].isInt(),
"Element of Tuple is expected to be int, but element ",
2,
" has type: ",
tp->elements()[2].tagKind());
}
}
std::tuple<c10::QScheme, QParamMap> InsertQuantDeQuantHelper::
getQSchemeAndQParamMap(Module& module, Node* n) {
// TODO: refactor findObserverName to take Node* as input
Value* v = n->output();
TORCH_INTERNAL_ASSERT(
v->type()->isSubtypeOf(TensorType::get()),
"Expected output of observer node to be Tensor");
auto observer_name = findObserverName(v);
TORCH_INTERNAL_ASSERT(
observer_name,
"getQSchemeAndParamMap expects the corresponding observer for ",
v->debugName(),
" exists.");
auto observer_module = module.attr(observer_name.value()).toModule();
auto get_qparams = observer_module.get_method("get_qparams");
IValue result = get_qparams(std::vector<IValue>());
checkGetQParamsResult(result);
auto scalar_type = observer_module.attr("dtype");
TORCH_CHECK(
scalar_type.toScalarType() != at::ScalarType::Undefined,
"dtype of observer can't be undefined");
auto tp = result.toTuple();
at::Tensor scale = tp->elements()[0].toTensor().to(at::kFloat);
at::Tensor zero_point = tp->elements()[1].toTensor().to(at::kInt);
std::unordered_map<std::string, IValue> qparams = {
{"_scalar_type", scalar_type},
};
auto qscheme = observer_module.attr("qscheme").toQScheme();
if (qscheme == c10::kPerChannelAffine ||
qscheme == c10::kPerChannelSymmetric) {
qparams["_scale"] = scale;
qparams["_zero_point"] = zero_point;
qparams["_axis"] = tp->elements()[2].toInt();
} else {
qparams["_scale"] = scale.item<double>();
qparams["_zero_point"] = zero_point.item<int64_t>();
}
return std::make_tuple(qscheme, qparams);
}
c10::optional<Module> InsertQuantDeQuantHelper::
findChildModuleToQuantize(Module& module, Value* child_instance) {
TORCH_INTERNAL_ASSERT(
child_instance->node()->kind() == prim::GetAttr,
"Child instance should come from GetAttr.");
auto child_module_name = child_instance->node()->s(attr::name);
if (child_module_name.find("_observer_") == std::string::npos) {
return module.attr(child_module_name).toModule();
}
return c10::nullopt;
}
ModuleMethodVector InsertQuantDeQuantHelper::getInvokedMethods(
Module& module,
const std::string& method_name) {
auto graph = module.get_method(method_name).graph();
ModuleMethodVector invoked_methods;
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
if (n->kind() == prim::CallMethod) {
auto module_instance = n->inputs()[0];
auto module_method_name = n->s(attr::name);
c10::optional<Module> m;
// calling method on self
if (module_instance == graph->inputs()[0]) {
m = module;
} else {
m = findChildModuleToQuantize(module, module_instance);
}
if (m) {
invoked_methods.push_back({*m, module_method_name});
}
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
return invoked_methods;
}
void InsertQuantDeQuantHelper::run(
Module& module,
const std::string& method_name) {
for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
auto& invoked_module = std::get<0>(invoked_methods);
const auto& invoked_method_name = std::get<1>(invoked_methods);
run(invoked_module, invoked_method_name);
}
Method method = module.get_method(method_name);
auto graph = method.graph();
// We only need to register new parameters if the graph has
// been quantized before
// TODO: dedup this part with code in quantizeTensors
if (observer_nodes_.count(graph.get())) {
for (auto* n : observer_nodes_.at(graph.get())) {
auto* original_value = n->input(1);
auto tp = getQSchemeAndQParamMap(module, n);
checkQScheme(graph.get(), std::get<0>(tp));
auto qparam_map = std::get<1>(tp);
for (auto& pr : qparam_map) {
const auto& name = pr.first;
const auto& qparam = pr.second;
module._ivalue()->setAttr(original_value->debugName() + name, qparam);
}
}
return;
}
// prim::Param nodes do not belong to the graph. Hence the Insert
// point is the beginning of graph node. This also safe guards against
// observing a potentially mutated value due to some in-place operation
std::vector<Value*> input_values;
for (size_t idx = 1; idx < method.num_inputs(); ++idx) {
auto& v = graph->inputs()[idx];
if (v->type()->isSubtypeOf(TensorType::get())) {
input_values.push_back(v);
}
}
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) {
Node* n = *it++;
for (Value* v : n->outputs()) {
if (!v->type()->isSubtypeOf(TensorType::get())) {
continue;
}
collectObserverNodesAndValueToQuantize(module, v);
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
for (Value* v : input_values) {
collectObserverNodesAndValueToQuantize(module, v);
}
GRAPH_DUMP("Before Quantize Tensors:", graph);
Value* self = graph->inputs()[0];
quantizeTensors(module, graph.get(), self);
GRAPH_DUMP("After Quantize Tensors:", graph);
}
void insertPrepackUnpackForLinear(std::shared_ptr<Graph>& graph) {
std::string linear_with_quant = R"(
graph(%linear, %a_dequant, %w_quant, %b):
%w_dequant = aten::dequantize(%w_quant)
%r = prim::CallFunction(%linear, %a_dequant, %w_dequant, %b)
return (%r) )";
std::string linear_with_quant_prepack = R"(
graph(%linear, %a_dequant, %w_quant, %b):
%packed_params = quantized::linear_prepack(%w_quant, %b)
%w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::linear_unpack(%packed_params)
%w_dequant = aten::dequantize(%w_quant_unpacked)
%r = prim::CallFunction(%linear, %a_dequant, %w_dequant, %b)
return (%r) )";
// Filter to match linear CallFunction
auto filter = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto linear_value = match_vmap.at(vmap.at("linear"));
auto func_name = getFuncName(linear_value);
if (func_name == "linear") {
return true;
}
return false;
};
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(linear_with_quant, linear_with_quant_prepack);
rewriter.runOnGraph(graph, filter);
}
void insertPrepackUnpackForConv2d(std::shared_ptr<Graph>& graph) {
std::string conv_with_quant = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
%w_dequant = aten::dequantize(%w_quant)
%r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups)
return (%r) )";
std::string conv_with_quant_prepack = R"(
graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
%packed_params = quantized::conv2d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
%w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv2d_unpack(%packed_params)
%w_dequant = aten::dequantize(%w_quant_unpacked)
%r = aten::conv2d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
return (%r) )";
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(conv_with_quant, conv_with_quant_prepack);
rewriter.runOnGraph(graph);
}
c10::optional<IValue> toTwoElementIntList(Value* v) {
auto* n = v->node();
if (n->kind() == prim::Constant) {
auto iv = toIValue(v);
if (iv && iv.value().isIntList() && iv.value().toIntList().size() == 2) {
return iv;
}
}
if (n->kind() == prim::ListConstruct && n->inputs().size() == 2) {
auto e0 = toIValue(n->inputs()[0]);
auto e1 = toIValue(n->inputs()[1]);
if (!e0 || !e1 || !e0.value().isInt() || !e1.value().isInt()) {
return c10::nullopt;
}
return IValue(c10::List<int64_t>({e0.value().toInt(), e1.value().toInt()}));
}
return c10::nullopt;
}
// A helper class to make uses of module unique
class ModuleUseDeduper {
public:
ModuleUseDeduper(Module& module) : module_(module) {}
void dedup() {
for (auto& method : module_.get_methods()) {
const auto& graph = method.graph();
findModuleUses(graph.get());
}
dedupModuleUses();
}
private:
// Analyze the code to record information represents
// uses of the module, which we'll use later to actually perform the dedup
// operation Please see the comments of member variables of the class for more
// information
void findModuleUses(Graph* graph) {
GRAPH_DUMP("Finding module uses for ", graph);
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
Value* self = graph->inputs()[0];
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
if (n->kind() != prim::CallMethod) {
continue;
}
Value* instance = n->inputs()[0];
// boundary_val is the value we get when we trace back
// the GetAttr access chain until we hit the input of graph
// or a node that is not prim::GetAttr
auto path = getModuleAccessPath(instance, self);
// path.size() == 0 means we're calling a method
// on self, we don't need to dedup uses of self
if (path.size() == 0) {
continue;
}
value_to_path_map_[instance] = path;
auto m = findChildModule(module_, path);
// If we fail to insert the module to the unique_modules_ set,
// which means there are uses of this module before this point,
// we'll have to rewrite the use
if (!unique_modules_.insert(m._ivalue()).second) {
uses_to_rewrite_.push_back(instance);
GRAPH_DEBUG("Found use to rewrite: ", instance->debugName());
}
}
}
}
// Deduplicate module uses given the information we recorded before
void dedupModuleUses() {
for (Value* v : uses_to_rewrite_) {
const auto& path = value_to_path_map_.at(v);
const auto& m = findChildModule(module_, path);
// add a clone of the child module to the parent of the duplicated module
const auto& child_name = addChildModule(module_, m, path);
TORCH_INTERNAL_ASSERT(v->node()->kind() == prim::GetAttr);
// change the name in GetAttr call
auto original_name = v->node()->s(attr::name);
v->node()->s_(attr::name, child_name);
GRAPH_UPDATE(
"Module use dedup: changing use of original module ",
original_name,
" to ",
child_name);
}
}
std::string addChildModule(
Module& module,
const Module& child_module,
const std::vector<std::string>& path) {
TORCH_INTERNAL_ASSERT(
path.size() > 0, "path must have at least one element.");
// Parent module of the leaf child module corresponding to
// the path
auto parent_of_leaf = findChildModule(
module, std::vector<std::string>(path.begin(), path.end() - 1));
// Original name of the child module
std::string original_name = path[path.size() - 1];
int uid = 0;
std::string child_name = original_name + "_" + c10::to_string(uid++);
while (parent_of_leaf.hasattr(child_name)) {
child_name = original_name + "_" + c10::to_string(uid++);
}
parent_of_leaf.register_module(child_name, child_module.clone_instance());
return child_name;
}
Module module_;
// Map from value of module instance to the list of names of submodules
// starting from the top level module, e.g. ["sub1", "sub2", "relu"]
// Also this is a cache of calling `getModuleAccessPath` of the value
std::unordered_map<Value*, std::vector<std::string>> value_to_path_map_;
// Set of unique modules that are used in the graphs
std::unordered_set<ModulePtr> unique_modules_;
// Values that represent the module instance(the use of the module)
// that we'll need to rewrite as a use of a cloned module
// instance
std::vector<Value*> uses_to_rewrite_;
};
struct ConvBNParameters {
at::Tensor conv_w;
at::Tensor conv_b;
at::Tensor bn_rm;
at::Tensor bn_rv;
double bn_eps = 0.0;
at::Tensor bn_w;
at::Tensor bn_b;
};
static bool hastensor(Module& m, const char* name) {
return m.hasattr(name) && m.attr(name).isTensor();
}
class FoldConvBatchNorm2dHelper {
public:
/**
* In this step we find all Conv2d - BatchNorm2d patterns in the graph
* and extract the corresponding parameters for these two modules,
* and record informations for the modifications of the graph without
* actually performing these modifications.
*/
void analyze(Module& module);
/**
* In this step we perform all the modifications including
* setting the attributes for conv module, rewriting values
* and deleting nodes in the graph
*/
void transform();
private:
bool tryExtractingConvBNParameters(
Module& conv,
Module& bn,
ConvBNParameters& r);
/**
* Given the current weight and bias tensors of a Conv2d module and parameters
* of the BatchNorm2d module we're folding with, compute the updated values
* for the weight and bias.
*
* The function is basically copied from torch/nn/utils/fusion.py
*/
std::tuple<at::Tensor, at::Tensor> computeUpdatedConvWeightAndBias(
const ConvBNParameters& p);
std::unordered_map<ModulePtr,
std::tuple<at::Tensor, at::Tensor>> conv_module_and_params_;
std::unordered_map<Graph*, std::vector<std::tuple<std::string, std::string>>> conv_bn_names_;
std::unordered_map<Value*, Value*> rewrite_map_;
std::vector<Value*> values_to_rewrite_;
std::unordered_set<Node*> nodes_to_delete_;
};
std::tuple<at::Tensor, at::Tensor> FoldConvBatchNorm2dHelper::
computeUpdatedConvWeightAndBias(const ConvBNParameters& p) {
at::Tensor bn_var_rsqrt = at::rsqrt(p.bn_rv + p.bn_eps);
at::Tensor new_w = p.conv_w * (p.bn_w * bn_var_rsqrt).reshape({-1, 1, 1, 1});
at::Tensor new_b = (p.conv_b - p.bn_rm) * bn_var_rsqrt * p.bn_w + p.bn_b;
return std::make_tuple(new_w, new_b);
}
bool FoldConvBatchNorm2dHelper::tryExtractingConvBNParameters(
Module& conv,
Module& bn,
ConvBNParameters& r) {
if (!hastensor(conv, "weight") || !conv.hasattr("bias") ||
!hastensor(bn, "weight") || !hastensor(bn, "bias") ||
!hastensor(bn, "running_mean") || !hastensor(bn, "running_var") ||
!bn.hasattr("eps")) {
return false;
}
r.bn_rm = bn.attr("running_mean").toTensor();
r.bn_rv = bn.attr("running_var").toTensor();
r.bn_eps = bn.attr("eps").toDouble();
r.bn_w = bn.attr("weight").toTensor();
r.bn_b = bn.attr("bias").toTensor();
r.conv_w = conv.attr("weight").toTensor();
r.conv_b = at::zeros_like(r.bn_rm);
auto bias_opt = conv.attr("bias").toOptional<at::Tensor>();
if (bias_opt) {
r.conv_b = *bias_opt;
}
return true;
}
void FoldConvBatchNorm2dHelper::analyze(Module& module) {
const PatternInfo pattern = PatternInfo::parse_from_str(R"IR(
graph(%self, %x):
%conv_submodule = match::module[name="Conv2d"](%self)
%conv_out = prim::CallMethod[name="forward"](%conv_submodule, %x)
%bn_submodule = match::module[name="BatchNorm2d"](%self)
%bn_out = prim::CallMethod[name="forward"](%bn_submodule, %conv_out)
return (%bn_out))IR");
const Graph& pattern_graph = *pattern.pattern_graph;
const auto& vmap = pattern.vmap;
Value* pattern_conv_out = vmap.at("conv_out");
Value* pattern_bn_out = vmap.at("bn_out");
Value* pattern_conv_submodule = vmap.at("conv_submodule");
Value* pattern_bn_submodule = vmap.at("bn_submodule");
Node* pattern_conv = pattern_conv_out->node();
Node* pattern_bn = pattern_bn_out->node();
// We will put submodules into this worklist and keep processing items from it
// one by one. We start by just putting the top module there.
std::stack<Module> worklist({module});
while (!worklist.empty()) {
Module current = worklist.top();
worklist.pop();
// Queue submodules for processing
for (const Module& submodule : current.children()) {
worklist.push(submodule);
}
// Process all method of the current module
for (auto& method : current.get_methods()) {
GRAPH_DUMP(
current.type()->name()->name() + "::" + method.name() +
"() before Conv2d-BatchNorm2d folding",
method.graph());
const auto& matches = findPatternMatches(pattern_graph, *method.graph());
GRAPH_DEBUG("number of Conv2d-BatchNorm2d matches: ", matches.size());
Graph* g = method.graph().get();
if (!conv_bn_names_.count(g)) {
// This is to make sure we don't visit one graph multiple times
conv_bn_names_[g] = {};
for (const Match& match : matches) {
GRAPH_DEBUG("Checking next match...");
Node* matched_conv = match.nodes_map.at(pattern_conv);
Node* matched_bn = match.nodes_map.at(pattern_bn);
Node* matched_conv_submodule =
match.values_map.at(pattern_conv_submodule)->node();
Node* matched_bn_submodule =
match.values_map.at(pattern_bn_submodule)->node();
TORCH_INTERNAL_ASSERT(matched_conv_submodule->kind() == prim::GetAttr);
TORCH_INTERNAL_ASSERT(matched_bn_submodule->kind() == prim::GetAttr);
const auto& conv_module_name = matched_conv_submodule->s(Symbol::attr("name"));
const auto& bn_module_name = matched_bn_submodule->s(Symbol::attr("name"));
Module conv_submodule =
current.attr(conv_module_name).toModule();
Module bn_submodule =
current.attr(bn_module_name).toModule();
ConvBNParameters params;
if (!tryExtractingConvBNParameters(
conv_submodule, bn_submodule, params)) {
GRAPH_DEBUG(
"Conv and BN modules didn't have all required parameters or attributes...");
continue;
}
conv_bn_names_[g].push_back(
std::make_tuple(conv_module_name, bn_module_name));
// We are using a separate vector for saving Values we want to rewrite to
// make sure that the order in which we perform these transformations is
// deterministic. Iterating through keys of rewrite_map would result in
// non-determinism that might not manifest as a bug now, but can bite us
// later.
values_to_rewrite_.push_back(matched_bn->output());
rewrite_map_[matched_bn->output()] = matched_conv->output();
GRAPH_UPDATE(
"Rewriting %",
matched_bn->output()->debugName(),
" with %",
matched_conv->output()->debugName());
nodes_to_delete_.insert(matched_bn);
nodes_to_delete_.insert(matched_bn_submodule);
GRAPH_UPDATE("Deleting ", *matched_bn);
GRAPH_UPDATE("Deleting ", *matched_bn_submodule);
auto slot = conv_submodule.type()->getAttributeSlot("bias");
TORCH_CHECK(conv_submodule.type()->is_parameter(slot),
"Expected conv module to have a bias parameter");
} // matches
}
for (const auto& conv_bn : conv_bn_names_.at(g)) {
Module conv_submodule =
current.attr(std::get<0>(conv_bn))
.toModule();
Module bn_submodule =
current.attr(std::get<1>(conv_bn))
.toModule();
ConvBNParameters params;
TORCH_INTERNAL_ASSERT(tryExtractingConvBNParameters(
conv_submodule, bn_submodule, params));
auto new_w_b = computeUpdatedConvWeightAndBias(params);
conv_module_and_params_[conv_submodule._ivalue()] = new_w_b;
} // conv_bn module
} // methods
} // while
}
void FoldConvBatchNorm2dHelper::transform() {
for (const auto& item : conv_module_and_params_) {
Module conv(item.first);
auto w_b = item.second;
conv.setattr("weight", std::get<0>(w_b));
conv.setattr("bias", std::get<1>(w_b));
}
// Perform planned rewritings
for (auto v : values_to_rewrite_) {
v->replaceAllUsesWith(rewrite_map_.at(v));
}
// Perform planned deletions
for (auto n : nodes_to_delete_) {
n->removeAllInputs();
}
for (auto n : nodes_to_delete_) {
n->destroy();
}
}
} // namespace
TORCH_API Module InsertObservers(
Module& input_module,
const std::string& method_name,
const QConfigDict& qconfig_dict,
bool inplace) {
ModuleQConfigMap map_before_clone;
fillQConfigMap(input_module, qconfig_dict, map_before_clone);
ModuleCloneHelper mh;
Module module =
inplace ? input_module : mh.clone(input_module, map_before_clone);
ModuleQConfigMap module_qconfig_map;
// Since the types are changed after clone, we need to fill
// the qconfig map again
fillQConfigMap(module, qconfig_dict, module_qconfig_map);
InsertObserversHelper helper(module_qconfig_map);
helper.preprocess(module, method_name);
helper.insertObservers(module, method_name, true);
return module;
}
Module InsertQuantDeQuant(
Module& input_module,
const std::string& method_name,
bool inplace) {
Module module = inplace ? input_module : input_module.clone();
InsertQuantDeQuantHelper h;
h.run(module, method_name);
h.cleanup(module);
return module;
}
void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
throw std::runtime_error("Pass not implemented yet!");
}
void SwapFunctionalLinear(Module& module) {
for (auto& method : module.get_methods()) {
std::shared_ptr<Graph> g = method.graph();
SwapFunctionalLinear(g);
}
for (Module m : module.children()) {
SwapFunctionalLinear(m);
}
}
void SwapFunctionalLinear(std::shared_ptr<Graph>& graph) {
std::string functional_linear = R"(
graph(%linear, %input, %weight, %bias):
%r = prim::CallFunction(%linear, %input, %weight, %bias)
return (%r) )";
std::string aten_linear = R"(
graph(%linear, %input, %weight, %bias):
%r = aten::linear(%input, %weight, %bias)
return (%r) )";
auto filter = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto linear = getValue("linear", match_vmap, vmap);
auto func_name = getFuncName(linear);
return func_name == "linear";
};
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(functional_linear, aten_linear);
// TODO: runOnGraph takes const ref?
rewriter.runOnGraph(graph, filter);
}
void ReplicateDeQuant(std::shared_ptr<Graph>& graph) {
std::stack<Block*> blocks_to_visit;
std::vector<Node*> dequant_nodes_to_rewrite;
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
if (n->kind() == Symbol::aten("dequantize") &&
n->output()->uses().size() > 1) {
dequant_nodes_to_rewrite.push_back(n);
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
for (Node* n : dequant_nodes_to_rewrite) {
WithInsertPoint ins(n->next());
auto* quantized_val = n->inputs()[0];
auto* dequantized_val = n->output();
// copy uses to vector since value->uses() is a reference
// and changing the graph will also change the uses() list
std::vector<Use> uses = dequantized_val->uses();
insertDeQuantCall(graph.get(), quantized_val, dequantized_val, uses);
}
for (Node* n : dequant_nodes_to_rewrite) {
n->removeAllInputs();
}
for (Node* n : dequant_nodes_to_rewrite) {
n->destroy();
}
}
// This is the pass to handle ops that does not require observation
// for example: flatten, average_pool, upsample
// This is called after inline and before graph execution
void SwapDeQuant(std::shared_ptr<Graph>& graph) {
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
auto input_indexes = getGeneralOpTensorInputIndexes(n);
if (input_indexes.size() > 0) {
bool is_dequantized = true;
for (auto i : input_indexes) {
is_dequantized &= n->inputs()[i]->node()->kind() == Symbol::aten("dequantize");
}
if (!is_dequantized) {
continue;
}
// Delete dequantize node, we have one dequantize
// for each use of the value
for (auto i : input_indexes) {
auto* dequantized_val = n->inputs()[i];
auto* dequantize_node = dequantized_val->node();
TORCH_INTERNAL_ASSERT(dequantized_val->uses().size() == 1,
"Expect to have one dequantize node for each use");
// Replace useses of dequantized_val with the input of
// dequantize node
dequantized_val->replaceAllUsesWith(dequantize_node->inputs()[0]);
dequantize_node->removeAllInputs();
dequantize_node->destroy();
}
TORCH_CHECK(n->outputs().size() == 1, "We only support dequantize swapping for ops"
" with one output right now");
auto* output = n->output();
WithInsertPoint ins(n->next());
std::vector<Use> uses = output->uses();
// Insert new dequantize node for each use of the output
insertDeQuantCall(graph.get(), output, output, uses);
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
}
void QuantFusion(std::shared_ptr<Graph>& graph) {
for (const auto& item : quant_fusion_pattern_and_replacements()) {
SubgraphRewriter rewriter;
rewriter.RegisterRewritePattern(item.first, item.second);
rewriter.runOnGraph(graph);
}
}
Module FoldConvBatchNorm2d(const Module& module) {
FoldConvBatchNorm2dHelper h;
Module m = module.clone();
h.analyze(m);
h.transform();
return m;
}
void FoldQuantizeCallIntoBuffer(
Module& module,
const std::string& method_name) {
const PatternInfo& pattern = PatternInfo::parse_from_str(R"(
graph(%self, %scale, %zero_point, %dtype):
%weight = prim::GetAttr[name="weight"](%self)
%weight_quant = aten::quantize_per_tensor(%weight, %scale, %zero_point, %dtype)
return (%weight_quant) )");
const Graph& pattern_graph = *pattern.pattern_graph;
const auto& vmap = pattern.vmap;
auto method = module.get_method(method_name);
auto graph = method.graph();
const auto& matches = findPatternMatches(pattern_graph, *graph);
// Extra filter on scale/zero_point/dtype to make sure they are Constant
auto filter = [](const Match& match,
const std::unordered_map<std::string, Value*>& vmap) {
const auto& match_vmap = match.values_map;
auto scale_node = match_vmap.at(vmap.at("scale"))->node();
auto zero_point_node = match_vmap.at(vmap.at("zero_point"))->node();
auto dtype_node = match_vmap.at(vmap.at("dtype"))->node();
return scale_node->kind() == prim::Constant &&
zero_point_node->kind() == prim::Constant &&
dtype_node->kind() == prim::Constant;
};
std::unordered_set<Node*> nodes_to_delete;
for (const auto& match : matches) {
if (!filter(match, vmap)) {
continue;
}
auto match_vmap = match.values_map;
auto float_weight = module.attr("weight").toTensor().data();
auto scale = toIValue(match_vmap.at(vmap.at("scale"))).value().toDouble();
auto zero_point =
toIValue(match_vmap.at(vmap.at("zero_point"))).value().toInt();
auto dtype =
toIValue(match_vmap.at(vmap.at("dtype"))).value().toScalarType();
module.register_buffer(
"_quantized_weight",
at::quantize_per_tensor(float_weight, scale, zero_point, dtype));
// Replace the GetAttr[weight]->quantize_per_tensor sequence
// with a simple GetAttr[_quantized_weight] node.
Value* orig_weight = match_vmap.at(vmap.at("weight"));
Value* orig_weight_quant = match_vmap.at(vmap.at("weight_quant"));
orig_weight->node()->s_(attr::name, "_quantized_weight");
orig_weight_quant->replaceAllUsesWith(orig_weight);
nodes_to_delete.insert(orig_weight_quant->node());
}
for (Node* n : nodes_to_delete) {
n->destroy();
}
}
void InsertPrepackUnpack(std::shared_ptr<Graph>& graph) {
insertPrepackUnpackForLinear(graph);
insertPrepackUnpackForConv2d(graph);
}
void InsertPrepackUnpack(Module& module) {
for (auto& method : module.get_methods()) {
auto graph = method.graph();
InsertPrepackUnpack(graph);
}
for (Module m : module.children()) {
InsertPrepackUnpack(m);
}
}
struct FoldPrepackedWeightIntoModuleHelper {
void run(
Module& module,
const std::string& method_name,
const Module& linear_params_module,
const Module& conv_params_module) {
auto method = module.get_method(method_name);
auto graph = method.graph();
GRAPH_DUMP("Before FoldPrepackWeightIntoModule: ", graph);
// (is_conv, is_per_channel, pattern, packed_params_module)
std::vector<PatternsAndModules> pattern_and_modules = {
{false, false, linear_prepack_per_tensor, linear_params_module},
{false, true, linear_prepack_per_channel, linear_params_module},
{true, false, conv2d_prepack, conv_params_module},
{true, true, conv2d_prepack_per_channel, conv_params_module}};
for (const auto& pm : pattern_and_modules) {
const Graph& pattern_graph = *pm.pattern.pattern_graph;
const auto& vmap = pm.pattern.vmap;
const auto& matches = findPatternMatches(pattern_graph, *graph);
TORCH_INTERNAL_ASSERT(
matches.size() <= 1, "We only support at most one match right now");
for (const auto& match : matches) {
const auto& match_vmap = match.values_map;
auto w_dtype_opt = getIValue("w_dtype", match_vmap, vmap);
auto w_scale_opt = getIValue("w_scale", match_vmap, vmap);
auto w_zero_point_opt = getIValue("w_zero_point", match_vmap, vmap);
if (!w_dtype_opt || !w_scale_opt || !w_zero_point_opt) {
GRAPH_DEBUG(
"dtype, scale or zero_point for weight(",
getValue("w_dtype", match_vmap, vmap)->debugName(),
", ",
getValue("w_scale", match_vmap, vmap)->debugName(),
", ",
getValue("w_zero_point", match_vmap, vmap)->debugName(),
") is not constant, skipping the match.");
continue;
}
auto w_dtype = w_dtype_opt.value().toScalarType();
auto w = module.attr("weight").toTensor().data();
at::Tensor w_quant;
if (pm.is_per_channel) {
auto w_axis_opt = getIValue("w_axis", match_vmap, vmap);
if (!w_axis_opt) {
GRAPH_DEBUG(
"axis for weight ",
getValue("w_axis", match_vmap, vmap)->debugName(),
" is non-constant, skipping the match");
continue;
}
auto w_scale = w_scale_opt.value().toTensor().to(at::kFloat);
auto w_zero_point = w_zero_point_opt.value().toTensor().to(at::kInt);
int w_axis = w_axis_opt.value().toInt();
TORCH_CHECK(
w_scale.sizes() == w_zero_point.sizes(),
"scale and zero_point must have the same size");
w_quant = at::quantize_per_channel(
w, w_scale, w_zero_point, w_axis, w_dtype);
} else {
auto w_scale = w_scale_opt.value().toDouble();
auto w_zero_point = w_zero_point_opt.value().toInt();
w_quant = at::quantize_per_tensor(w, w_scale, w_zero_point, w_dtype);
}
c10::optional<at::Tensor> b = c10::nullopt;
if (hastensor(module, "bias")) {
b = module.attr("bias").toTensor().data();
}
Module wrapper_module = pm.packed_params_module.clone();
auto set_weight_bias = wrapper_module.get_method("set_weight_bias");
std::string module_name_prefix;
if (pm.is_conv) {
module_name_prefix = "_conv_packed_params_module_for_";
auto stride_opt =
toTwoElementIntList(getValue("stride", match_vmap, vmap));
auto padding_opt =
toTwoElementIntList(getValue("padding", match_vmap, vmap));
auto dilation_opt =
toTwoElementIntList(getValue("dilation", match_vmap, vmap));
auto groups_opt = getIValue("groups", match_vmap, vmap);
auto set_conv_params = wrapper_module.get_method("set_conv_params");
if (!stride_opt || !padding_opt || !dilation_opt) {
GRAPH_DEBUG(
"Failed to extract two element IntList for stride/padding/dilation, (",
getValue("stride", match_vmap, vmap)->debugName(),
", ",
getValue("padding", match_vmap, vmap)->debugName(),
", ",
getValue("dilation", match_vmap, vmap)->debugName(),
") skipping the match");
continue;
}
set_conv_params(std::vector<IValue>{stride_opt.value(),
padding_opt.value(),
dilation_opt.value(),
groups_opt.value()});
} else {
module_name_prefix = "_linear_packed_params_module_for_";
}
set_weight_bias(std::vector<IValue>{IValue(w_quant), IValue(b)});
auto w_quant_val = getValue("w_quant", match_vmap, vmap);
// unique name for the module based on %w_quant
int uid = 0;
auto module_name = module_name_prefix + c10::to_string(uid++);
while (module.hasattr(module_name)) {
module_name_prefix + c10::to_string(uid++);
}
GRAPH_UPDATE("Adding new module: ", module_name);
module.register_module(module_name, wrapper_module);
// Add GetAttr of the packed module
auto packed_params_val = getValue("packed_params", match_vmap, vmap);
WithInsertPoint ins(packed_params_val->node());
// wrapper_module =
// self.{_conv,_linear}_packed_params_module_for_{unique_id}
Value* packed_params_module =
graph->insertGetAttr(graph->inputs()[0], module_name)
->setType(wrapper_module.type());
GRAPH_UPDATE("Adding GetAttr node for the wrapper module");
// packed_params = wrapper_module._packed_params
Value* packed_params_from_attr =
graph->insertGetAttr(packed_params_module, "_packed_params");
GRAPH_UPDATE(
"Adding GetAttr node for _packed_params: ",
packed_params_from_attr->debugName());
packed_params_val->replaceAllUsesWith(packed_params_from_attr);
// Delete nodes
std::vector<Node*> nodes_to_delete = {w_quant_val->node(),
packed_params_val->node()};
for (auto n : nodes_to_delete) {
n->removeAllInputs();
}
for (auto n : nodes_to_delete) {
GRAPH_UPDATE("Deleting node: ", n);
n->destroy();
}
}
}
}
void run(
Module& module,
const Module& linear_params_module,
const Module& conv_params_module) {
for (auto& method : module.get_methods()) {
run(module, method.name(), linear_params_module, conv_params_module);
}
for (Module m : module.children()) {
run(m, linear_params_module, conv_params_module);
}
}
const PatternInfo linear_prepack_per_tensor = PatternInfo::parse_from_str(R"(
graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_dtype):
%w_quant = aten::quantize_per_tensor(%w, %w_scale, %w_zero_point, %w_dtype)
%packed_params = quantized::linear_prepack(%w_quant, %b)
return (%packed_params) )");
const PatternInfo linear_prepack_per_channel = PatternInfo::parse_from_str(R"(
graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_axis, %w_dtype):
%w_quant = aten::quantize_per_channel(%w, %w_scale, %w_zero_point, %w_axis, %w_dtype)
%packed_params = quantized::linear_prepack(%w_quant, %b)
return (%packed_params) )");
const PatternInfo conv2d_prepack = PatternInfo::parse_from_str(R"(
graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_dtype, %stride, %padding, %dilation, %groups):
%w_quant = aten::quantize_per_tensor(%w, %w_scale, %w_zero_point, %w_dtype)
%packed_params = quantized::conv2d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
return (%packed_params) )");
const PatternInfo conv2d_prepack_per_channel = PatternInfo::parse_from_str(R"(
graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_axis, %w_dtype, %stride, %padding, %dilation, %groups):
%w_quant = aten::quantize_per_channel(%w, %w_scale, %w_zero_point, %w_axis, %w_dtype)
%packed_params = quantized::conv2d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
return (%packed_params) )");
};
void FoldPrepackedWeightIntoModule(
Module& module,
const Module& linear_params_module,
const Module& conv_params_module) {
FoldPrepackedWeightIntoModuleHelper h;
h.run(module, linear_params_module, conv_params_module);
}
void DedupModuleUses(Module& module) {
ModuleUseDeduper d(module);
d.dedup();
}
script::Module Finalize(script::Module& module) {
SwapFunctionalLinear(module);
auto graph = module.get_method("forward").graph();
Inline(*graph);
ReplicateDeQuant(graph);
SwapDeQuant(graph);
InsertPrepackUnpack(graph);
ConstantPropagation(graph);
QuantFusion(graph);
return freeze_module(module);
}
} // namespace jit
} // namespace torch