blob: db3c95ef48bb11143990b6c6d4929c6e9a402f17 [file] [log] [blame]
import torch
import torch.nn as nn
from torch.fx import GraphModule
from torch.fx.graph import Node
from .utils import get_new_attr_name_with_prefix, maybe_get_next_module
from ..observer import (
PerChannelMinMaxObserver,
_with_args,
ObserverBase,
)
from ..utils import check_min_max_valid
from collections import namedtuple
from typing import Dict, Any, Tuple, Optional
import warnings
class _InputEqualizationObserver(nn.Module):
r"""Observer for tracking the running min/max values of input columns, and
computing the quantization parameters for the overall min/max input values.
Args:
dtype: Quantized data type
qscheme: Quantization scheme
quant_min: Minimum quantization value. If unspecified, it will
follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will
follow the 8-bit setup.
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`,
with the difference that the running min/max values are stored per column.
This observer is intended to be used along with a WeightEqualizationObserver
to calculate the equalization scale.
"""
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
quant_min=None, quant_max=None, factory_kwargs=None) -> None:
super(_InputEqualizationObserver, self).__init__()
if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
raise TypeError("Input qscheme must be per-tensor")
self.dtype = dtype
self.qscheme = qscheme
self.input_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
qscheme=qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
self.equalization_scale = torch.empty(0)
def forward(self, x_orig):
# TODO: Allow for convoluational layers
if not (x_orig.ndim == 2):
raise ValueError("InputEqualizationObserver only supports Linear layers")
return self.input_obs(x_orig)
def get_input_minmax(self):
return (self.input_obs.min_vals, self.input_obs.max_vals)
def set_equalization_scale(self, equalization_scale):
self.equalization_scale = equalization_scale
def calculate_scaled_minmax(self):
r""" Returns the scaled min/max inputs
"""
if self.equalization_scale.nelement() == 0:
warnings.warn(
"Must call calculate_scale before calling calculate_qparams.\
Returning default min and max input."
)
return torch.tensor([0]), torch.tensor([0])
# Calculate qparams for the scaled min/max inputs
# Scale the input by the equalization scale located at the same column
# index
(min_inputs, max_inputs) = self.get_input_minmax()
min_input_scaled = torch.min(torch.mul(min_inputs, self.equalization_scale))
max_input_scaled = torch.max(torch.mul(max_inputs, self.equalization_scale))
return min_input_scaled, max_input_scaled
with_args = classmethod(_with_args)
class _WeightEqualizationObserver(nn.Module):
r"""Observer for tracking the running min/max values of weight columns and
rows, and computing the quantization parameters for the weight rows.
Args:
dtype: Quantized data type
qscheme: Quantization scheme
quant_min: Minimum quantization value. If unspecified, it will
follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will
follow the 8-bit setup.
This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used
to record the running minimum and maximum of columns of incoming weight
tensors. This observer is intended to be used along with an
InputEqualizationObserver to calculate the equalization scale.
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`.
"""
def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
quant_max=None, factory_kwargs=None) -> None:
super(_WeightEqualizationObserver, self).__init__()
self.dtype = dtype
self.qscheme = qscheme
self.ch_axis = 1
self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
qscheme=qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
self.equalization_scale = torch.empty(0)
def forward(self, w_orig):
# TODO: Allow for convoluational layers
if not (w_orig.ndim == 2):
raise ValueError("WeightEqualizationObserver only supports Linear layers")
return self.weight_col_obs(w_orig)
def get_weight_col_minmax(self):
return (self.weight_col_obs.min_vals, self.weight_col_obs.max_vals)
def set_equalization_scale(self, equalization_scale):
self.equalization_scale = equalization_scale
with_args = classmethod(_with_args)
def calculate_equalization_scale(input_obs: _InputEqualizationObserver,
weight_obs: _WeightEqualizationObserver) -> torch.Tensor:
r""" Calculates the equalization scale and sets the equalization_scale value
in the observers.
Args:
input_obs: Observer that tracks the ranges for the input columns
weight_obs: Observer that tracks the ranges for the weight columns
"""
(min_inputs, max_inputs) = input_obs.get_input_minmax()
(min_weights, max_weights) = weight_obs.get_weight_col_minmax()
if not (check_min_max_valid(min_inputs, max_inputs) and check_min_max_valid(min_weights, max_weights)):
return torch.tensor(1)
if not (min_inputs.shape == min_weights.shape):
raise ValueError(
"Input and Weight must have the same column dimension. " +
f"Found {min_inputs.shape} and {max_inputs.shape} instead."
)
equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs))
return equalization_scale
class EqualizationQConfig(namedtuple('EqualizationQConfig', ['input_activation', 'weight'])):
"""
Describes how to quantize a layer or a part of the network specifically for
input-weight equalization by providing settings (observer classes) for
inputs, outputs, and weights.
Note that EqualizationQConfig needs to contain observer **classes** (like
MinMaxObserver) or a callable that returns instances on invocation, not the
concrete observer instances themselves.
Quantization function will instantiate observers multiple times for each of
the layers.
Observer classes have usually reasonable default arguments, but they can be
overwritten with `with_args` method (that behaves like functools.partial):
my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8),
weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
"""
def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
raise ValueError("EqualizationQConfig received observer instance, please pass observer class instead. " +
"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
self = super(EqualizationQConfig, cls).__new__(cls, input_activation, weight)
return self
input_equalization_observer = _InputEqualizationObserver.with_args(
dtype=torch.quint8, qscheme=torch.per_tensor_symmetric)
weight_equalization_observer = _WeightEqualizationObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
default_equalization_qconfig = EqualizationQConfig(input_activation=input_equalization_observer,
weight=weight_equalization_observer)
def node_supports_equalization(node: Node, modules) -> bool:
""" Checks if the current node supports equalization
Currently we only support nn.Linear and F.Linear layers
"""
if node.op == 'call_module':
return isinstance(modules[node.target], nn.Linear)
elif node.op == 'call_function':
return node.target == nn.functional.linear
return False
def is_equalization_observer(observer: nn.Module) -> bool:
return (isinstance(observer, _InputEqualizationObserver) or
isinstance(observer, _WeightEqualizationObserver))
def get_op_node_and_weight_eq_obs(
input_eq_obs_node: Node,
model: GraphModule,
modules: Dict[str, nn.Module]
) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]:
""" Gets the following weight equalization observer. There should always
exist a weight equalization observer after an input equalization observer.
Returns the node containing the weight equalization observer, and the weight
equalization observer if it has been newly created
"""
# Find the op node that comes directly after the input equaliation observer
op_node = None
for user in input_eq_obs_node.users.keys():
if node_supports_equalization(user, modules):
op_node = user
break
assert(op_node is not None)
if op_node.op == 'call_module':
# If the op_node is a nn.Linear layer, then it must have a
# WeightEqualizationObserver configuration
equalization_qconfig_map: Dict[str, Any] = model._equalization_qconfig_map # type: ignore[assignment]
assert(equalization_qconfig_map.get(op_node.name, None) is not None)
weight_eq_obs = equalization_qconfig_map.get(op_node.name, None).weight()
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
return op_node, weight_eq_obs
elif op_node.op == 'call_function':
# TODO
return None, None
return None, None
def maybe_get_next_input_eq_obs(node: Node, modules: Dict[str, nn.Module]) -> Optional[_InputEqualizationObserver]:
""" Gets the following input equalization observer if it exists.
For example, in the case of connecting linear layers:
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
If the node being passed in is the linear1 node, then we want to return eq_obs2,
the following equalization observer for linear2.
However, if there are no connecting layers:
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add
Then we want to return None.
"""
assert(node_supports_equalization(node, modules))
# Locate the following output observer if it exists
maybe_obs_node = maybe_get_next_module(node, modules, ObserverBase)
if maybe_obs_node is None:
return None
maybe_eq_obs_node = maybe_get_next_module(maybe_obs_node, modules, _InputEqualizationObserver)
if maybe_eq_obs_node is None:
return None
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
assert(isinstance(maybe_eq_obs, _InputEqualizationObserver))
return maybe_eq_obs
def maybe_get_next_equalization_scale(node: Node, modules: Dict[str, nn.Module]) -> Optional[torch.Tensor]:
""" If the next next node is an InputEqualizationObserver then we want to
return its equalization scale, else we return 1
This is used in the case where there are two connecting linear layers:
linear1 -> LinearOutObs -> InputEqObs -> linear2
In this case, the node given is linear1 and we want to locate the InputEqObs.
"""
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
if next_inp_eq_obs:
return next_inp_eq_obs.equalization_scale
return None
def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
""" Scales the following input quantization observer's min/max values by
updating the values with the scaled min/max values calculated by the input
equalization observer
"""
input_eq_obs = modules[str(node.target)]
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
input_quant_obs_node = node.args[0]
assert(isinstance(input_quant_obs_node, Node))
input_quant_obs = modules[str(input_quant_obs_node.target)]
if not isinstance(input_quant_obs, ObserverBase):
return
min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
input_quant_obs.min_val = min_input_scaled
input_quant_obs.max_val = max_input_scaled
def scale_weight_node(
node: Node,
modules: Dict[str, nn.Module],
equalization_scale: torch.Tensor,
next_equalization_scale: Optional[torch.Tensor],
) -> None:
""" Scale the weights for input-weight equalization by multiplying the
weight by 1/equalization_scale and next_equalization_scale
Args:
node: Current node whose weights we want to scale
equalization_scale: Current node's calculated equalization scale
next_equalization_scale: Next node's calculated equalization scale if
the following node needs to be equalized, 1 otherwise
"""
assert(isinstance(node.target, str))
# Scale the weights for input-weight equalization
# If the following layer needs to be equalized then we will multiply its scale
weight = modules[node.target].weight
assert(isinstance(weight, torch.Tensor))
scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale))
if next_equalization_scale is None:
modules[node.target].weight = nn.Parameter(scaled_weight)
return
scaled_weight = torch.mul(scaled_weight, next_equalization_scale)
modules[node.target].weight = nn.Parameter(scaled_weight)
# TODO: The bias may need to be scaled for connecting linear layers
bias = modules[node.target].bias
assert(isinstance(bias, torch.Tensor))
scaled_bias = torch.mul(bias, next_equalization_scale)
modules[node.target].bias = nn.Parameter(scaled_bias)
def update_obs_for_equalization(model: GraphModule, modules: Dict[str, nn.Module]) -> Dict[str, _WeightEqualizationObserver]:
""" Update all of the observer's equalization scale. For each
InputEqualizationObserver, we will find the location of the next
WeightEqualizationObserver, create it, and calculate the equalization scale
based on the two observers.
We will then return a dictionary mapping operation node names to
the corresponding WeightEqualizationObservers for that operation.
"""
weight_eq_obs_dict = {}
for node in model.graph.nodes:
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
input_eq_obs = modules[node.target]
assert(isinstance(input_eq_obs, _InputEqualizationObserver))
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
if op_node is None or weight_eq_obs is None:
continue
weight_eq_obs(modules[str(op_node.target)].weight)
# Calculate and set the equalization scale values
equalization_scale = calculate_equalization_scale(input_eq_obs, weight_eq_obs)
input_eq_obs.set_equalization_scale(equalization_scale)
weight_eq_obs.set_equalization_scale(equalization_scale)
weight_eq_obs_dict[op_node.name] = weight_eq_obs
return weight_eq_obs_dict
def convert_eq_obs(
model: GraphModule,
modules: Dict[str, nn.Module],
weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver],
) -> None:
""" Converts the equalization operations and updates the other nodes in the
following way:
- Removes the input equalization observers and inserts a mul operator
along with an equalization scale node wherever applicable (we do not
want to insert a mul operator between connecting linear layers).
- Updates the input quantization observers with the scaled input min/max
values.
- Scales the weights by the current and next equalization scales.
- Removes the weight equalization observer node if it exists.
Before (after prepare):
weight values
|
WeightQuantObs
|
WeightEqObs
|
x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs
After this function:
scaled weight values
|
equalization scale WeightQuantObs
| |
x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs
After convert:
equalization scale scaled weight values
| |
x -> mul -> quantize_per_tensor -> quantized::linear
Note that although the equalization observer appeared after the quantization
observer after prepare_fx, the mul node appears before the quantization node
after convert_fx. This is because placing the equalization observer after
the quantization observer in prepare_fx would allow us to keep the invariant
that the graph before the current node inserts its observers is not
modified.
Having the equalization observer before the quantization observer would also
cause some inconsistences between the ordering of the quantization and
equalization observers.
For example, a single linear layer would look like:
x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1
But between two connected linear layers, it would look like:
linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2
"""
for node in model.graph.nodes:
if node.op == 'call_module' and isinstance(modules[node.target], _InputEqualizationObserver):
inp_quant_obs_node = node.args[0]
prev_node = inp_quant_obs_node.args[0]
# TODO: Possible special handling for connected linear layers
# Update the following input quantization observer's min/max values
scale_input_observer(node, modules)
# Remove the InputEqualization node and add a mul operator before
# the quantization observer node that appears before the equalization node
# Before: x -> input_quant_obs -> input_eq_obs -> linear
# After: x -> mul -> input_quant_obs -> linear
# Create a node containing the equalization scale
with model.graph.inserting_before(inp_quant_obs_node):
get_new_eq_scale_name = get_new_attr_name_with_prefix(prev_node.name + '_equalization_scale')
name = get_new_eq_scale_name(modules)
setattr(model, name, modules[node.target].equalization_scale)
eq_scale_node = model.graph.create_node('get_attr', name)
# Create a node multiplying the input with the equalization scale
with model.graph.inserting_after(eq_scale_node):
inputs = (prev_node, eq_scale_node)
mul_node = model.graph.create_node("call_function", torch.mul, inputs)
# Set the mul nod to be the input_quant_obs_node's input instead of
# the previous node
inp_quant_obs_node.replace_input_with(prev_node, mul_node)
# For all of the current node's users, replace the current node with
# the input quantization observer node
orig_users = list(node.users.keys())
for user_node in orig_users:
user_node.replace_input_with(node, inp_quant_obs_node)
# Erase the InputEqualizationObserver node
model.graph.erase_node(node)
elif weight_eq_obs_dict.get(node.name, None) is not None:
weight_eq_obs = weight_eq_obs_dict.get(node.name)
assert(isinstance(weight_eq_obs, _WeightEqualizationObserver))
equalization_scale = weight_eq_obs.equalization_scale
# Scales the weights and runs the weight quantization observers
maybe_next_equalization_scale = maybe_get_next_equalization_scale(node, modules)
scale_weight_node(node, modules, equalization_scale, maybe_next_equalization_scale)
def _convert_equalization_ref(model: GraphModule):
""" Reference function which applies changes needed for equalization, but
does not quantize the nodes
"""
modules = dict(model.named_modules(remove_duplicate=False))
# Calculate the equalization scale, update the observers with the scaled
# inputs, and scale the weight
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
convert_eq_obs(model, modules, weight_eq_obs_dict)
return GraphModule(model, model.graph)