blob: fa5a8733bbf7ee21b4f1a3a002334b5e3f3b9450 [file] [log] [blame]
import torch
from torch.fx.graph import (
Node,
)
from ..quantization_mappings import (
get_static_quant_module_class,
get_quantized_operator,
)
from .pattern_utils import (
register_quant_pattern,
register_dynamic_quant_pattern,
)
from .utils import (
_parent_name,
quantize_node,
get_per_tensor_qparams,
)
from abc import ABC, abstractmethod
import operator
# -------------------------
# Pattern Registrations
# -------------------------
# 1. Post Training Static Quantization and Quantization Aware Training Patterns
# Base Pattern Handler
class QuantizeHandler(ABC):
""" Base handler class for the quantizer patterns
"""
def __init__(self, quantizer, node):
""" Records pattern information in __init__, which will be used
in convert
"""
# this is an indicator of whether all the inputs are Node or not
# since some op might be quantized differently depending on whether
# all inputs are tensors or not, e.g. add/mul
self.all_nodes = True
@abstractmethod
def convert(self, quantizer, node, load_arg, debug=False):
""" Convert the given node to a quantized node and insert
it to the quantized graph
"""
return NotImplemented
@register_quant_pattern(operator.add)
@register_quant_pattern((torch.nn.ReLU, operator.add))
@register_quant_pattern((torch.nn.functional.relu, operator.add))
class Add(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_function' and node.target == operator.add
self.add_node = node
self.all_nodes = all([isinstance(a, Node) for a in self.add_node.args[:2]])
def convert(self, quantizer, node, load_arg, debug=False):
if not self.all_nodes:
# add scalar
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
return quantizer.quantized_graph.create_node(
'call_function', op,
load_arg(quantized=[0])(self.add_node.args), self.add_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
kwargs = self.add_node.kwargs
kwargs.update({'scale': scale, 'zero_point': zero_point})
return quantizer.quantized_graph.create_node(
'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs)
@register_quant_pattern(operator.mul)
@register_quant_pattern((torch.nn.ReLU, operator.mul))
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
class Mul(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_function' and node.target == operator.mul
self.mul_node = node
self.all_nodes = all([isinstance(a, Node) for a in self.mul_node.args[:2]])
def convert(self, quantizer, node, load_arg, debug=False):
if not self.all_nodes:
# mul scalar
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
else:
op = torch.ops.quantized.mul
return quantizer.quantized_graph.create_node(
'call_function', op, load_arg(quantized=[0])(self.mul_node.args), self.mul_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
else:
op = torch.ops.quantized.mul
kwargs = self.mul_node.kwargs
kwargs.update({'scale': scale, 'zero_point': zero_point})
return quantizer.quantized_graph.create_node('call_function', op, load_arg(quantized=True)(self.mul_node.args), kwargs)
@register_quant_pattern(torch.cat)
class Cat(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False):
if not self.all_nodes:
return NotImplemented
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
kwargs = load_arg(quantized=False)(node.kwargs)
kwargs.update({'scale': scale, 'zero_point': zero_point})
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.cat, load_arg(quantized=[0])(node.args), kwargs)
# handle conv, maybe followed by relu
# NB: matching order is reversed, that is we match from the bottom of this list to the beginning
@register_quant_pattern(torch.nn.Conv1d)
@register_quant_pattern(torch.nn.Conv2d)
@register_quant_pattern(torch.nn.Conv3d)
@register_quant_pattern(torch.nn.functional.conv2d)
@register_quant_pattern(torch.nn.qat.Conv2d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU1d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU2d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU3d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d)
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d))
# just for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
class ConvRelu(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
self.conv_node = node
if node.op == 'call_module':
self.conv = quantizer.modules[self.conv_node.target]
def convert(self, quantizer, node, load_arg, debug=False):
# TODO: debug option for conv module
if self.conv_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
# 1. attach activation post process to module
if type(self.conv) in [
torch.nn.intrinsic.ConvReLU1d,
torch.nn.intrinsic.ConvReLU2d,
torch.nn.intrinsic.ConvReLU3d
]:
self.conv[1].activation_post_process = quantizer.activation_post_process_map[node.name]
else:
self.conv.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
qconv_cls = get_static_quant_module_class(type(self.conv))
quantized = qconv_cls.from_float(self.conv)
parent_name, name = _parent_name(self.conv_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.conv_node.target,
(load_arg(quantized=True)(self.conv_node.args[0]),),
{})
elif self.conv_node.op == 'call_function':
if self.relu_node is not None:
raise Exception("functional conv + relu is not supported yet")
if debug:
args = load_arg(quantized=[0, 1])(self.conv_node.args)
args = load_arg(quantized=False)(self.conv_node.args)
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
conv_out = quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.conv2d, args, kwargs)
root_module = quantizer.modules['']
return quantize_node(
root_module, quantizer.quantized_graph, conv_out, quantizer.activation_post_process_map[self.conv_node.name])
else:
assert len(self.conv_node.args) == 7, \
'only conv2d calls with all arguments specified is support right now in debug=False option'
args = load_arg(quantized=[0, 1])(self.conv_node.args)
# pack weight
weight = load_arg(quantized=True)(self.conv_node.args[1])
other_args = load_arg(quantized=False)(self.conv_node.args[2:])
prepack_args = tuple([weight] + list(other_args))
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.conv2d_prepack, prepack_args, {})
# construct conv input
conv_input = load_arg(quantized=True)(self.conv_node.args[0])
activation_post_process = quantizer.activation_post_process_map[self.conv_node.name]
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
qconv_args = (conv_input, packed_weight, scale, zero_point)
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.conv2d, qconv_args, kwargs)
# handle linear, maybe followed by relu
@register_quant_pattern(torch.nn.Linear)
@register_quant_pattern(torch.nn.functional.linear)
@register_quant_pattern(torch.nn.qat.Linear)
@register_quant_pattern(torch.nn.intrinsic.LinearReLU)
@register_quant_pattern(torch.nn.intrinsic.qat.LinearReLU)
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.linear))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.linear))
# for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
class LinearReLU(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
self.linear_node = node
if node.op == 'call_module':
self.linear = quantizer.modules[self.linear_node.target]
def convert(self, quantizer, node, load_arg, debug=False):
# TODO: debug option for linear module
if self.linear_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
# 1. attach activation post process to module
if type(self.linear) == torch.nn.intrinsic.LinearReLU:
self.linear[1].activation_post_process = quantizer.activation_post_process_map[node.name]
else:
self.linear.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]:
qlinear = torch.nn.quantized.Linear
elif type(self.linear) in [torch.nn.intrinsic.LinearReLU, torch.nn.intrinsic.qat.LinearReLU]:
qlinear = torch.nn.intrinsic.quantized.LinearReLU
else:
raise Exception("unhandled linear type:", type(self.linear))
quantized = qlinear.from_float(self.linear)
parent_name, name = _parent_name(self.linear_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.linear_node.target, (load_arg(quantized=True)(self.linear_node.args[0]),), {})
elif self.linear_node.op == 'call_function':
if debug:
args = load_arg(quantized=[0, 1])(self.linear_node.args)
args = load_arg(quantized=False)(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
linear_out = quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.linear, args, kwargs)
root_module = quantizer.modules['']
return quantize_node(
root_module,
quantizer.quantized_graph,
linear_out,
quantizer.activation_post_process_map[self.linear_node.name])
else:
# TODO: this code can be merged with dynamic linear code
# linear args
# (x, weight, bias, ...)
args = load_arg(quantized=[0, 1])(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
# pack weight
weight = load_arg(quantized=True)(self.linear_node.args[1])
bias = None
# all args after bias, including bias
other_args = load_arg(quantized=False)(self.linear_node.args[2:])
if len(self.linear_node.args) > 2:
bias = load_arg(quantized=False)(self.linear_node.args[2])
other_args = other_args[1:] # remove the bias argument
else:
assert 'bias' in kwargs, \
'expect bias provided as a keyword argument when it is not a positional argument'
bias = kwargs['bias']
kwargs.pop('bias')
prepack_args = (weight, bias)
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
# construct linear input
linear_input = load_arg(quantized=True)(self.linear_node.args[0])
activation_post_process = \
quantizer.activation_post_process_map[self.linear_node.name]
scale, zero_point, _ = get_per_tensor_qparams(activation_post_process)
qlinear_args = (linear_input, packed_weight, scale, zero_point)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear, qlinear_args, kwargs)
@register_quant_pattern(torch.nn.BatchNorm2d)
@register_quant_pattern(torch.nn.BatchNorm3d)
@register_quant_pattern(torch.nn.intrinsic.BNReLU2d)
@register_quant_pattern(torch.nn.intrinsic.BNReLU3d)
class BatchNorm(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
assert node.op == 'call_module'
self.bn_node = node
self.bn = quantizer.modules[self.bn_node.target]
def convert(self, quantizer, node, load_arg, debug=False):
# 1. attach activation post process to module
activation_post_process = quantizer.activation_post_process_map[node.name]
if type(self.bn) in \
[torch.nn.intrinsic.BNReLU2d,
torch.nn.intrinsic.BNReLU3d]:
self.bn[1].activation_post_process = activation_post_process
else:
self.bn.activation_post_process = activation_post_process
qbn_cls = get_static_quant_module_class(type(self.bn))
quantized = qbn_cls.from_float(self.bn)
parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.bn_node.target,
load_arg(quantized=[0])(self.bn_node.args),
load_arg(quantized=False)(self.bn_node.kwargs))
ARGS_TO_SKIP = {
torch._ops.ops.quantized.hardswish: ['inplace'],
torch._ops.ops.quantized.instance_norm:
['running_mean', 'running_var', 'use_input_stats', 'momentum'],
}
@register_quant_pattern(torch.nn.ELU)
@register_quant_pattern(torch.nn.Hardswish)
@register_quant_pattern(torch.nn.InstanceNorm1d)
@register_quant_pattern(torch.nn.InstanceNorm2d)
@register_quant_pattern(torch.nn.InstanceNorm3d)
@register_quant_pattern(torch.nn.LayerNorm)
@register_quant_pattern(torch.nn.functional.hardswish)
@register_quant_pattern(torch.nn.functional.instance_norm)
@register_quant_pattern(torch.nn.functional.layer_norm)
class DefaultNode(QuantizeHandler):
''' Common quantized op, first input and first output will be quantized
'''
def convert(self, quantizer, node, load_arg, debug=False):
if not self.all_nodes:
return NotImplemented
assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \
'call_function are handled in DefaultNode'
activation_post_process = quantizer.activation_post_process_map[node.name]
if node.op == 'call_module':
module = quantizer.modules[node.target]
module.activation_post_process = activation_post_process
quantized_module_cls = get_static_quant_module_class(type(module))
quantized_module = quantized_module_cls.from_float(module)
parent_name, name = _parent_name(node.target)
setattr(quantizer.modules[parent_name], name, quantized_module)
return quantizer.quantized_graph.create_node(
'call_module',
node.target,
load_arg(quantized=[0])(node.args),
load_arg(quantized=False)(node.kwargs))
else:
# call_function
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args)
kwargs = load_arg(quantized=False)(node.kwargs)
kwargs.update({'output_scale': scale, 'output_zero_point': zero_point})
if quantized_op in ARGS_TO_SKIP:
args_to_skip = ARGS_TO_SKIP[quantized_op]
for arg in args_to_skip:
if arg in kwargs:
kwargs.pop(arg)
return quantizer.quantized_graph.create_node(
'call_function', quantized_op, args, kwargs)
# TODO: elu is using scale/zero_point instead of output_scale, output_zero_point
@register_quant_pattern(torch.nn.functional.elu)
class ELU(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False):
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
quantized_op = get_quantized_operator(node.target)
args = load_arg(quantized=[0])(node.args)
kwargs = load_arg(quantized=False)(node.kwargs)
kwargs.update({'output_scale': scale, 'output_zero_point': zero_point})
kwargs.pop('inplace')
return quantizer.quantized_graph.create_node(
'call_function', quantized_op, args, kwargs)
# these ops have quantized equivalents that do not need any extra information
@register_quant_pattern(torch.nn.AdaptiveAvgPool1d)
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
@register_quant_pattern(torch.nn.AdaptiveAvgPool3d)
@register_quant_pattern(torch.nn.AvgPool1d)
@register_quant_pattern(torch.nn.AvgPool2d)
@register_quant_pattern(torch.nn.AvgPool3d)
@register_quant_pattern(torch.nn.Dropout)
@register_quant_pattern(torch.nn.Hardsigmoid)
@register_quant_pattern(torch.nn.Hardtanh)
@register_quant_pattern(torch.nn.LeakyReLU)
@register_quant_pattern(torch.nn.MaxPool1d)
@register_quant_pattern(torch.nn.MaxPool2d)
@register_quant_pattern(torch.nn.MaxPool3d)
@register_quant_pattern(torch.nn.ReLU)
@register_quant_pattern(torch.nn.ReLU6)
@register_quant_pattern(torch.nn.Sigmoid)
@register_quant_pattern(torch.nn.Tanh)
@register_quant_pattern(torch.adaptive_avg_pool1d)
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d)
@register_quant_pattern(torch.nn.functional.dropout)
@register_quant_pattern(torch.nn.functional.hardsigmoid)
@register_quant_pattern(torch.nn.functional.hardtanh)
@register_quant_pattern(torch.nn.functional.hardtanh_)
@register_quant_pattern(torch.nn.functional.interpolate)
@register_quant_pattern(torch.nn.functional.leaky_relu)
@register_quant_pattern(torch.nn.functional.max_pool1d)
@register_quant_pattern(torch.nn.functional.max_pool2d)
@register_quant_pattern(torch.nn.functional.max_pool3d)
@register_quant_pattern(torch.nn.functional.relu)
@register_quant_pattern(torch.nn.functional.relu6)
@register_quant_pattern(torch.avg_pool1d)
@register_quant_pattern(torch._C._nn.avg_pool2d)
@register_quant_pattern(torch._C._nn.avg_pool3d)
@register_quant_pattern(torch.chunk)
@register_quant_pattern(torch.clamp)
@register_quant_pattern(torch.flatten)
@register_quant_pattern(torch.transpose)
@register_quant_pattern(torch.max)
@register_quant_pattern(torch.mean)
@register_quant_pattern(torch.min)
@register_quant_pattern(torch.repeat_interleave)
@register_quant_pattern(torch.sigmoid)
@register_quant_pattern(torch.sort)
@register_quant_pattern(torch.squeeze)
@register_quant_pattern(torch.stack)
@register_quant_pattern(torch.tanh)
@register_quant_pattern(torch.unsqueeze)
@register_quant_pattern(operator.getitem)
@register_quant_pattern(operator.floordiv)
@register_quant_pattern('chunk')
@register_quant_pattern('clamp')
@register_quant_pattern('contiguous')
@register_quant_pattern('detach')
@register_quant_pattern('detach_')
@register_quant_pattern('hardsigmoid')
@register_quant_pattern('hardsigmoid_')
@register_quant_pattern('leaky_relu')
@register_quant_pattern('leaky_relu_')
@register_quant_pattern('mean')
@register_quant_pattern('numel')
@register_quant_pattern('permute')
@register_quant_pattern('relu')
@register_quant_pattern('relu_')
@register_quant_pattern('repeat')
@register_quant_pattern('repeat_interleave')
@register_quant_pattern('reshape')
@register_quant_pattern('resize_')
@register_quant_pattern('shape')
@register_quant_pattern('sigmoid')
@register_quant_pattern('sigmoid_')
@register_quant_pattern('size')
@register_quant_pattern('squeeze')
@register_quant_pattern('squeeze_')
@register_quant_pattern('tanh')
@register_quant_pattern('tanh_')
@register_quant_pattern('transpose')
@register_quant_pattern('unsqueeze')
@register_quant_pattern('unsqueeze_')
@register_quant_pattern('view')
class CopyNode(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False):
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
# Default quantization handler, used for quantization of input and output
# of quantizable objects (e.g. modules and functionals)
class DefaultQuant(QuantizeHandler):
def convert(self, quantizer, node):
assert self.all_nodes
root_module = quantizer.modules['']
return quantize_node(
root_module,
quantizer.quantized_graph,
node, quantizer.activation_post_process_map[node.name])
# 2. Post Training Dynamic Quantizatoin Patterns
@register_dynamic_quant_pattern(torch.nn.Linear)
@register_dynamic_quant_pattern(torch.nn.functional.linear)
class DynamicLinear(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.linear_node = node
if node.op == 'call_module':
assert isinstance(quantizer.modules[node.target], torch.nn.Linear)
self.linear = quantizer.modules[self.linear_node.target]
def convert(self, quantizer, node, load_arg, debug=False):
if self.linear_node.op == 'call_module':
quantized = torch.nn.quantized.dynamic.Linear.from_float(self.linear)
parent_name, name = _parent_name(self.linear_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.linear_node.target,
(load_arg(quantized=False)(self.linear_node.args[0]),),
{})
elif self.linear_node.op == 'call_function':
if debug:
# quantize and dequantize weight
args = load_arg(quantized=[1])(self.linear_node.args)
args = load_arg(quantized=False)(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
return quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.linear, args, kwargs)
else:
# linear args:
# (x, observed_weight, bias)
# get observer for the weight
weight_observer = quantizer.activation_post_process_map[self.linear_node.args[1].args[0].name]
if weight_observer.dtype == torch.float16:
linear_weight = load_arg(quantized=False)(self.linear_node.args[1])
prepack_op = torch.ops.quantized.linear_prepack_fp16
else:
linear_weight = load_arg(quantized=True)(self.linear_node.args[1])
prepack_op = torch.ops.quantized.linear_prepack
bias = None
# all args after bias, including bias
other_args = load_arg(quantized=False)(self.linear_node.args[2:])
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
if len(self.linear_node.args) > 2:
bias = load_arg(quantized=False)(self.linear_node.args[2])
other_args = other_args[1:] # remove the bias argument
else:
assert 'bias' in kwargs, \
'expect bias provided as a keyword argument when it is not a positional argument'
bias = kwargs['bias']
kwargs.pop('bias')
prepack_args = (linear_weight, bias)
# pack weight
packed_weight = quantizer.quantized_graph.create_node(
'call_function', prepack_op, prepack_args, {})
# construct dynamic linear input
non_quantized_input = load_arg(quantized=False)(self.linear_node.args[0])
qdynamic_linear_args = (non_quantized_input, packed_weight)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_dynamic, qdynamic_linear_args, kwargs)