blob: c80989d291b0827061fd15778862d5309fec054c [file] [log] [blame]
import torch
from torch.quantization import (
propagate_qconfig_,
convert,
DEFAULT_QAT_MODULE_MAPPING,
DEFAULT_MODULE_MAPPING,
)
from torch.fx import (
GraphModule,
Proxy,
)
from torch.fx.graph import (
Graph,
Node,
map_arg,
)
from .pattern_utils import (
matches,
register_quant_pattern,
get_quant_patterns,
register_dynamic_pattern,
get_dynamic_quant_patterns,
)
from .utils import _parent_name
from abc import ABC, abstractmethod
import copy
import enum
import operator
# Quantization type (dynamic quantization, static quantization).
# Should match the c++ enum in quantization_type.h
class QuantType(enum.IntEnum):
DYNAMIC = 0
STATIC = 1
QAT = 2
# ------------------------
# Helper Functions
# ------------------------
def get_qparams(activation_post_process):
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
dtype = activation_post_process.dtype
return scale, zero_point, dtype
def quantize_node(node, activation_post_process):
scale, zero_point, dtype = get_qparams(activation_post_process)
return torch.quantize_per_tensor(node, scale, zero_point, dtype)
def quantize(quantizer, node):
quantize_node(node, quantizer.activation_post_process_map[node.name])
# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
torch.nn.functional.conv2d : [1],
torch.nn.functional.linear : [1],
}
# Pattern Registrations
# 1. Post Training Static Quantization and Quantization Aware Training Patterns
# Base Pattern Handler
class QuantizeHandler(ABC):
""" Base handler class for the quantizer patterns
"""
def __init__(self, quantizer, node):
""" Records pattern information in __init__, which will be used
in convert
"""
# this is an indicator of whether all the inputs are Node or not
# since some op might be quantized differently depending on whether
# all inputs are tensors or not, e.g. add/mul
self.all_nodes = True
@abstractmethod
def convert(self, quantizer, node, load_arg, debug=False):
""" Convert the given node to a quantized node and insert
it to the quantized graph
"""
return NotImplemented
@register_quant_pattern(operator.add)
@register_quant_pattern((torch.nn.ReLU, operator.add))
@register_quant_pattern((torch.nn.functional.relu, operator.add))
class Add(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_function' and node.target == operator.add
self.add_node = node
self.all_nodes = all([isinstance(a, Node) for a in self.add_node.args[:2]])
def convert(self, quantizer, node, load_arg, debug=False):
if not self.all_nodes:
# add scalar
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
return quantizer.quantized_graph.create_node(
'call_function', op,
load_arg(quantized=[0])(self.add_node.args), self.add_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
kwargs = self.add_node.kwargs
kwargs.update({'scale': scale, 'zero_point': zero_point})
return quantizer.quantized_graph.create_node(
'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs)
@register_quant_pattern(operator.mul)
@register_quant_pattern((torch.nn.ReLU, operator.mul))
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
class Mul(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_function' and node.target == operator.mul
self.mul_node = node
self.all_nodes = all([isinstance(a, Node) for a in self.mul_node.args[:2]])
def convert(self, quantizer, node, load_arg, debug=False):
if not self.all_nodes:
# mul scalar
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
else:
op = torch.ops.quantized.mul
return quantizer.quantized_graph.create_node(
'call_function', op, load_arg(quantized=[0])(self.mul_node.args), self.mul_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
else:
op = torch.ops.quantized.mul
kwargs = self.mul_node.kwargs
kwargs.update({'scale': scale, 'zero_point': zero_point})
return quantizer.quantized_graph.create_node('call_function', op, load_arg(quantized=True)(self.mul_node.args), kwargs)
@register_quant_pattern(torch.cat)
class Cat(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False):
if not self.all_nodes:
return NotImplemented
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
kwargs = load_arg(quantized=False)(node.kwargs)
kwargs.update({'scale': scale, 'zero_point': zero_point})
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.cat, load_arg(quantized=[0])(node.args), kwargs)
# handle conv, maybe followed by relu
# NB: matching order is reversed, that is we match from the bottom of this list to the beginning
@register_quant_pattern(torch.nn.Conv1d)
@register_quant_pattern(torch.nn.Conv2d)
@register_quant_pattern(torch.nn.Conv3d)
@register_quant_pattern(torch.nn.functional.conv2d)
@register_quant_pattern(torch.nn.qat.Conv2d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU1d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU2d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU3d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d)
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d))
# just for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
class ConvRelu(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
self.conv_node = node
if node.op == 'call_module':
self.conv = quantizer.modules[self.conv_node.target]
def convert(self, quantizer, node, load_arg, debug=False):
# TODO: debug option for conv module
if self.conv_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
# 1. attach activation post process to module
if type(self.conv) in [
torch.nn.intrinsic.ConvReLU1d,
torch.nn.intrinsic.ConvReLU2d,
torch.nn.intrinsic.ConvReLU3d
]:
self.conv[1].activation_post_process = quantizer.activation_post_process_map[node.name]
else:
self.conv.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
# TODO: make the mapping configurable?
assert type(self.conv) in DEFAULT_MODULE_MAPPING, \
'unhandled conv type:{}'.format(type(self.conv))
qconv_cls = DEFAULT_MODULE_MAPPING[type(self.conv)]
quantized = qconv_cls.from_float(self.conv)
parent_name, name = _parent_name(self.conv_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.conv_node.target,
(load_arg(quantized=True)(self.conv_node.args[0]),),
{})
elif self.conv_node.op == 'call_function':
if self.relu_node is not None:
raise Exception("functional conv + relu is not supported yet")
if debug:
args = load_arg(quantized=[0, 1])(self.conv_node.args)
args = load_arg(quantized=False)(self.conv_node.args)
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
conv_out = quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.conv2d, args, kwargs)
return quantize_node(
conv_out, quantizer.activation_post_process_map[self.conv_node.name])
else:
assert len(self.conv_node.args) == 7, \
'only conv2d calls with all arguments specified is support right now in debug=False option'
args = load_arg(quantized=[0, 1])(self.conv_node.args)
# pack weight
weight = load_arg(quantized=True)(self.conv_node.args[1])
other_args = load_arg(quantized=False)(self.conv_node.args[2:])
prepack_args = [weight] + list(other_args)
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.conv2d_prepack, prepack_args, {})
# construct conv input
conv_input = load_arg(quantized=True)(self.conv_node.args[0])
activation_post_process = quantizer.activation_post_process_map[self.conv_node.name]
scale, zero_point, _ = get_qparams(activation_post_process)
qconv_args = [conv_input, packed_weight, scale, zero_point]
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.conv2d, qconv_args, kwargs)
# handle linear, maybe followed by relu
@register_quant_pattern(torch.nn.Linear)
@register_quant_pattern(torch.nn.functional.linear)
@register_quant_pattern(torch.nn.qat.Linear)
@register_quant_pattern(torch.nn.intrinsic.LinearReLU)
@register_quant_pattern(torch.nn.intrinsic.qat.LinearReLU)
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.linear))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.linear))
# for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
class LinearReLU(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
self.linear_node = node
if node.op == 'call_module':
self.linear = quantizer.modules[self.linear_node.target]
def convert(self, quantizer, node, load_arg, debug=False):
# TODO: debug option for linear module
if self.linear_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
# 1. attach activation post process to module
if type(self.linear) == torch.nn.intrinsic.LinearReLU:
self.linear[1].activation_post_process = quantizer.activation_post_process_map[node.name]
else:
self.linear.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]:
qlinear = torch.nn.quantized.Linear
elif type(self.linear) in [torch.nn.intrinsic.LinearReLU, torch.nn.intrinsic.qat.LinearReLU]:
qlinear = torch.nn.intrinsic.quantized.LinearReLU
else:
raise Exception("unhandled linear type:", type(self.linear))
quantized = qlinear.from_float(self.linear)
parent_name, name = _parent_name(self.linear_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.linear_node.target, (load_arg(quantized=True)(self.linear_node.args[0]),), {})
elif self.linear_node.op == 'call_function':
if debug:
args = load_arg(quantized=[0, 1])(self.linear_node.args)
args = load_arg(quantized=False)(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
linear_out = quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.linear, args, kwargs)
return quantize_node(
linear_out,
quantizer.activation_post_process_map[self.linear_node.name])
else:
args = load_arg(quantized=[0, 1])(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
# pack weight
weight = load_arg(quantized=True)(self.linear_node.args[1])
bias = None
other_args = load_arg(quantized=False)(self.linear_node.args[1:])
if len(self.linear_node.args) > 2:
bias = load_arg(quantized=False)(self.linear_node.args[2])
other_args = other_args[1:] # remove the bias argument
else:
assert 'bias' in kwargs, \
'expect bias provided as a keyword argument when it is not a positional argument'
bias = kwargs['bias']
kwargs.pop('bias')
prepack_args = [weight, bias]
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
# construct linear input
linear_input = load_arg(quantized=True)(self.linear_node.args[0])
activation_post_process = \
quantizer.activation_post_process_map[self.linear_node.name]
scale, zero_point, _ = get_qparams(activation_post_process)
qlinear_args = [linear_input, packed_weight, scale, zero_point]
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear, qlinear_args, kwargs)
# these ops have quantized equivalents that do not need any extra information
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
@register_quant_pattern(torch.nn.AvgPool2d)
@register_quant_pattern(torch.nn.Dropout)
@register_quant_pattern(torch.nn.MaxPool2d)
@register_quant_pattern(torch.nn.ReLU)
@register_quant_pattern(torch.nn.ReLU6)
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
@register_quant_pattern(torch.nn.functional.dropout)
@register_quant_pattern(torch.nn.functional.max_pool2d)
@register_quant_pattern(torch._C._nn.avg_pool2d)
@register_quant_pattern(torch.flatten)
@register_quant_pattern(torch.transpose)
@register_quant_pattern(torch.mean)
@register_quant_pattern(torch.unsqueeze)
@register_quant_pattern(operator.getitem)
@register_quant_pattern(operator.floordiv)
@register_quant_pattern('chunk')
@register_quant_pattern('contiguous')
@register_quant_pattern('mean')
@register_quant_pattern('reshape')
@register_quant_pattern('shape')
@register_quant_pattern('size')
@register_quant_pattern('view')
class CopyNode(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False):
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
class DefaultQuant(QuantizeHandler):
def convert(self, quantizer, node):
assert self.all_nodes
return quantize(quantizer, node)
# 2. Post Training Dynamic Quantizatoin Patterns
@register_dynamic_pattern(torch.nn.Linear)
@register_dynamic_pattern(torch.nn.functional.linear)
class DynamicLinear(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.linear_node = node
if node.op == 'call_module':
assert isinstance(quantizer.modules[node.target], torch.nn.Linear)
self.linear = quantizer.modules[self.linear_node.target]
def convert(self, quantizer, node, load_arg, debug=False):
if self.linear_node.op == 'call_module':
quantized = torch.nn.quantized.dynamic.Linear.from_float(self.linear)
parent_name, name = _parent_name(self.linear_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.linear_node.target,
(load_arg(quantized=False)(self.linear_node.args[0]),),
{})
elif self.linear_node.op == 'call_function':
if debug:
# quantize and dequantize weight
args = load_arg(quantized=[1])(self.linear_node.args)
args = load_arg(quantized=False)(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
return quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.linear, args, kwargs)
else:
# quantize and dequantize weight
args = load_arg(quantized=[1])(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
# pack weight
weight = load_arg(quantized=True)(self.linear_node.args[1])
bias = None
other_args = load_arg(quantized=False)(self.linear_node.args[1:])
if len(self.linear_node.args) > 2:
bias = load_arg(quantized=False)(self.linear_node.args[2])
other_args = other_args[1:] # remove the bias argument
else:
assert 'bias' in kwargs, \
'expect bias provided as a keyword argument when it is not a positional argument'
bias = kwargs['bias']
kwargs.pop('bias')
prepack_args = [weight, bias]
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
# construct dynamic linear input
linear_input = load_arg(quantized=False)(self.linear_node.args[0])
qdynamic_linear_args = [linear_input, packed_weight]
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_dynamic, qdynamic_linear_args, kwargs)
class Quantizer:
def __init__(self):
# mapping from matched node to activation_post_process
# must be filled before convert
self.activation_post_process_map = None
def _qat_swap_modules(self, root):
convert(root, mapping=DEFAULT_QAT_MODULE_MAPPING, inplace=True, remove_qconfig=False)
def _generate_qconfig_map(self, root, input_graph):
def get_qconfig(module):
return module.qconfig if hasattr(module, 'qconfig') else None
self.qconfig_map = dict()
for node in input_graph.nodes:
if node.op == 'get_param':
parent, _ = _parent_name(node.target)
self.qconfig_map[node.name] = get_qconfig(self.modules[parent])
elif node.op == 'call_function':
self.qconfig_map[node.name] = get_qconfig(root)
elif node.op == 'call_method':
self_obj = node.args[0]
# qconfig for call_method should be the same as the `self` object for the call
self.qconfig_map[node.name] = self.qconfig_map[self_obj.name]
elif node.op == 'call_module':
self.qconfig_map[node.name] = get_qconfig(self.modules[node.target])
def _prepare(self, model, qconfig_dict, inplace, quant_type):
input_root = model.root
if not inplace:
input_root = copy.deepcopy(input_root)
input_graph = model.graph
self.quant_type = quant_type
# TODO: allow user specified patterns
if self.quant_type == QuantType.DYNAMIC:
self.patterns = get_dynamic_quant_patterns()
else:
self.patterns = get_quant_patterns()
propagate_qconfig_(input_root, qconfig_dict)
if input_root.training:
self._qat_swap_modules(input_root)
self.modules = dict(input_root.named_modules())
# map from node name to qconfig, used in _find_matches
self._generate_qconfig_map(input_root, input_graph)
# match the patterns that will get quantized
matches = self._find_matches(input_graph, self.modules, self.patterns)
# find _inputs_ to matched nodes that are not quantized, these
# have to be quantized, which requires measuring stats,
# initialize an DefaultQuant object for each
quants = self._find_quants(input_graph, matches)
self.activation_post_process_map = dict()
env = {}
observed_graph = Graph()
observed = set()
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
for node in input_graph.nodes:
if node.name in observed:
continue
def get_new_observer_name(parent_module):
i = 0
def get_observer_name(i):
return 'activation_post_process_' + str(i)
observer_name = get_observer_name(i)
while hasattr(parent_module, observer_name):
i += 1
observer_name = get_observer_name(i)
return observer_name
root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None))
if root_node is None:
env[node.name] = observed_graph.node_copy(node, load_arg)
elif root_node is node:
env[node.name] = observed_graph.node_copy(node, load_arg)
def insert_observer(node, observer):
observer_name = get_new_observer_name(input_root)
setattr(input_root, observer_name, observer)
self.activation_post_process_map[node.name] = observer
env[node.name] = observed_graph.create_node('call_module', observer_name, [load_arg(node)], {})
observed.add(node.name)
# don't need to insert observer for output in dynamic quantization
if self.quant_type == QuantType.DYNAMIC:
continue
if isinstance(obj, CopyNode):
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'CopyNode of type ' + node.op + ' is not handled'
# propagate observed property from input
if node.args[0].name in observed:
observed.add(node.name)
elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes:
if node.args[0].name in observed:
observed.add(node.name)
elif qconfig is not None and obj.all_nodes:
# observer for outputs
insert_observer(node, qconfig.activation())
else:
env[node.name] = observed_graph.node_copy(node, load_arg)
if node.name not in observed and node.name in quants:
observer_name = get_new_observer_name(input_root)
_, qconfig, is_weight = quants[node.name]
if qconfig is not None:
self.activation_post_process_map[node.name] = qconfig.weight() if is_weight else qconfig.activation()
setattr(input_root, observer_name, self.activation_post_process_map[node.name])
env[node.name] = observed_graph.create_node('call_module', observer_name, [load_arg(node)], {})
observed.add(node.name)
observed_graph.output(load_arg(input_graph.result))
return GraphModule(input_root, observed_graph)
def prepare(self, model, qconfig_dict, inplace=False):
return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.STATIC)
def prepare_dynamic(self, model, qconfig_dict, inplace=False):
return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC)
def convert(self, observed, inplace=False, debug=False):
assert self.activation_post_process_map is not None
# move to cpu since we only have quantized cpu kernels
observed.eval().cpu()
observed_root = observed.root
observed_graph = observed.graph
if not inplace:
observed_root = copy.deepcopy(observed_root)
self.modules = dict(observed_root.named_modules())
matches = self._find_matches(observed.graph, self.modules, self.patterns)
quants = self._find_quants(observed.graph, matches)
self.quantized_graph = Graph()
env = {}
quant_env = {}
def load_non_quantized(n):
if n.name not in env:
assert n.name in quant_env, \
'trying to load float node but did not find node:' + n.name + \
' in quantized environment:' + str(quant_env)
env[n.name] = Proxy(quant_env[n.name]).dequantize().node
return env[n.name]
def load_quantized(n):
if n.name not in quant_env:
assert n.name in env, \
'trying to load quantized node but did not find node:' + n.name + \
' in float environment:' + str(env)
assert n.name in quants, 'did not find quant object for node:' + n.name
quant = quants[n.name][0]
quant_env[n.name] = quant.convert(self, env[n.name])
return quant_env[n.name]
def load_x(n):
assert n.name in env or n.name in quant_env, \
'node ' + n.name + ' does not exist in either of the environment'
if n.name in quant_env:
return quant_env[n.name]
else:
return env[n.name]
def load_arg(quantized):
"""
if quantized is a list, then arg should be a list and the args with corresponding
indexes will be quantized
if quantized is a boolean, then all args will be quantized/not quantized
if quantized is None, then we'll load the node as long as it exists
"""
assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized)
def load_arg_impl(arg):
if quantized is None:
return map_arg(arg, load_x)
if isinstance(quantized, bool):
return map_arg(arg, load_quantized if quantized else load_non_quantized)
elif isinstance(quantized, (tuple, list)):
assert isinstance(arg, (tuple, list)), arg
loaded_arg = []
# for now, we only support quantizing positional arguments
for i, a in enumerate(arg):
if i in quantized:
loaded_arg.append(map_arg(a, load_quantized))
else:
loaded_arg.append(map_arg(a, load_non_quantized))
return type(arg)(loaded_arg)
return load_arg_impl
def is_quantized(node):
assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
# there might be nodes appearing in both environemnts, but quant_env will take
# precedence
if node.name in quant_env:
return True
elif node.name in env:
return False
for node in observed_graph.nodes:
root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None))
if root_node is node:
result = obj.convert(self, node, load_arg)
quantized = True
# Need to get correct quantized/non-quantized state for the output of CopyNode
if isinstance(obj, CopyNode):
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'CopyNode of type ' + node.op + ' is not handled'
quantized = is_quantized(node.args[0])
if self.quant_type == QuantType.DYNAMIC:
quantized = False
if quantized:
quant_env[node.name] = result
else:
env[node.name] = result
continue
elif root_node is not None:
continue
# handle activation post process calls
if node.op == 'call_module':
if node.target.split('.')[-1].startswith('activation_post_process_'):
observer_module = self.modules[node.target]
prev_node = node.args[0]
if prev_node.name in quant_env:
# if previous node is already quantized, we'll just remove the activation_post_process
quant_env[node.name] = quant_env[prev_node.name]
continue
# replace activation post process with quantization ops
parent_name = ''
scale, zero_point = observer_module.calculate_qparams()
# TODO: per channel
scale = float(scale)
zero_point = int(zero_point)
dtype = observer_module.dtype
qparams = {'_scale_': scale, '_zero_point_': zero_point, '_dtype_': dtype}
i = 0
def noattr(module, qparams, i):
for name in qparams.keys():
if hasattr(module, name + str(i)):
return False
return True
def get_next_i(module, qparams):
i = 0
while not noattr(module, qparams, i):
i += 1
return i
parent_module = self.modules[parent_name]
i = get_next_i(parent_module, qparams)
inputs = [load_non_quantized(node.args[0])]
for key, value in qparams.items():
setattr(parent_module, key + str(i), value)
qparam_full_path = key + str(i)
if parent_name:
qparam_full_path = parent_name + '.' + qparam_full_path
inputs.append(self.quantized_graph.get_param(qparam_full_path))
quant_env[node.name] = self.quantized_graph.create_node('call_function', torch.quantize_per_tensor, inputs, {})
continue
# dequantize inputs for the node that are not quantized
env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
self.quantized_graph.output(load_non_quantized(observed_graph.result))
to_be_removed = []
for name, _ in observed_root.named_modules():
if name.split('.')[-1].startswith('activation_post_process_'):
to_be_removed.append(name)
for n in to_be_removed:
delattr(observed_root, n)
return GraphModule(observed_root, self.quantized_graph)
def _find_matches(self, graph, modules, patterns):
match_map = {} # node name -> (root_node, match_value?)
all_matched = set()
def record_match(pattern, node, matched):
if isinstance(pattern, tuple):
s, *args = pattern
record_match(s, node, matched)
if pattern[0] is not getattr:
for subpattern, arg in zip(args, node.args):
record_match(subpattern, arg, matched)
else:
matched.append(node)
for node in reversed(graph.nodes):
if node.name not in match_map and node.name not in all_matched:
for pattern, value in patterns.items():
if matches(modules, node, pattern):
matched = []
record_match(pattern, node, matched)
for n in matched:
match_map[n.name] = (node, matched, value(self, node), self.qconfig_map[n.name])
all_matched.add(n.name)
# break after finding the first match
break
return match_map
def _find_quants(self, graph, matches):
quants = {}
def visit(node, qconfig):
def visit_arg(arg):
# note: we have to measure quantization information
# even for nodes where we might not use it because it is already
# quantized. This is because each match has the option to
# say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
is_weight = False
if isinstance(node, Node) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]:
is_weight = True
if self.quant_type != QuantType.DYNAMIC or is_weight:
# overwrite previous quant config
quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight)
return visit_arg
for node in graph.nodes:
if node.name in matches:
root_node, matched, obj, qconfig = matches[node.name]
# don't attach observer/fake_quant for CopyNode
if isinstance(obj, CopyNode):
qconfig = None
if root_node is node:
# matched[-1] is the first op in the sequence and
# matched[0] is the last op in the sequence
# inputs
map_arg(matched[-1].args, visit(matched[-1], qconfig))
map_arg(matched[-1].kwargs, visit(matched[-1], qconfig))
# output
map_arg(matched[0], visit(None, qconfig))
return quants