blob: fdbf6c977bec1637df778acf070658147f5e6eef [file] [log] [blame]
import torch
import torch.nn as nn
from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
import warnings
class _InputEqualizationObserver(nn.Module):
r"""Observer for tracking the running min/max values of input columns, and
computing the quantization parameters for the overall min/max input values.
Args:
dtype: Quantized data type
qscheme: Quantization scheme
quant_min: Minimum quantization value. If unspecified, it will
follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will
follow the 8-bit setup.
output_obs: For the user to specify what kind of output observer they
would like to use
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`,
with the difference that the running min/max values are stored per column.
The qparams are calculated by multiplying the min/max input column values
with the equalization scale, reducing to find the global min/max input
values, and then calculating in the same way as in
:class:`~torch.quantization.observer.MinMaxObserver`
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
"""
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
quant_min=None, quant_max=None, output_obs=None,
factory_kwargs=None) -> None:
super(_InputEqualizationObserver, self).__init__()
if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
raise TypeError("Input qscheme must be per-tensor")
self.input_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
qscheme=qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
if output_obs is None:
self.output_obs = MinMaxObserver(dtype=dtype,
qscheme=qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
else:
self.output_obs = output_obs
self.equalization_scale = torch.empty(0)
def forward(self, x_orig):
# TODO: Allow for convoluational layers
if not (x_orig.ndim == 2):
raise ValueError("InputEqualizationObserver only supports Linear layers")
return self.input_obs(x_orig)
def get_input_minmax(self):
return (self.input_obs.min_vals, self.input_obs.max_vals)
def set_equalization_scale(self, equalization_scale):
self.equalization_scale = equalization_scale
def calculate_qparams(self):
r"""
Returns the scale/zero_point for the input and weight rows
"""
if self.equalization_scale.nelement() == 0:
warnings.warn(
"Must call calculate_scale before calling calculate_qparams.\
Returning default scale and zero point. "
)
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
# Calculate qparams for the scaled min/max inputs
# Scale the input by the equalization scale located at the same column
# index
(min_inputs, max_inputs) = self.get_input_minmax()
min_input_scaled = torch.min(torch.mul(min_inputs, self.equalization_scale))
max_input_scaled = torch.max(torch.mul(max_inputs, self.equalization_scale))
(scale_input, zero_point_input) = self.input_obs._calculate_qparams(min_input_scaled, max_input_scaled)
return scale_input, zero_point_input
class _WeightEqualizationObserver(nn.Module):
r"""Observer for tracking the running min/max values of weight columns and
rows, and computing the quantization parameters for the weight rows.
Args:
dtype: Quantized data type
qscheme: Quantization scheme
quant_min: Minimum quantization value. If unspecified, it will
follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will
follow the 8-bit setup.
This observer is made up of 2 PerChannelMinMaxObservers
- weight_col_obs: Used to record the running minimum and maximum of
columns of incoming weight tensors
- weight_row_obs: Used to record the running minimum and maximum of
rows of incoming weight tensors
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`.
The qparams are calculated by multiplying the min/max weight row values
with the inverse of the equalization scale, and then calculating in the same
way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver`
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
"""
def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
quant_max=None, factory_kwargs=None) -> None:
super(_WeightEqualizationObserver, self).__init__()
self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
qscheme=qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
self.weight_row_obs = PerChannelMinMaxObserver(ch_axis=0, dtype=dtype,
qscheme=qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs)
self.equalization_scale = torch.empty(0)
def forward(self, w_orig):
# TODO: Allow for convoluational layers
if not (w_orig.ndim == 2):
raise ValueError("WeightEqualizationObserver only supports Linear layers")
return self._forward(w_orig)
def _forward(self, w_orig):
r"""
Calculates the min/max values of each weight column and weight row.
"""
w_orig = self.weight_col_obs(w_orig)
w_orig = self.weight_row_obs(w_orig)
# Calculate the column indices of the min/max weight in each row
num_row, _ = w_orig.shape
min_weights_ind = []
max_weights_ind = []
for i in range(num_row):
min_weights_ind.append(torch.nonzero(w_orig[i] == self.weight_row_obs.min_vals[i])[0][0])
max_weights_ind.append(torch.nonzero(w_orig[i] == self.weight_row_obs.max_vals[i])[0][0])
self.min_weights_ind = torch.tensor(min_weights_ind)
self.max_weights_ind = torch.tensor(max_weights_ind)
return w_orig
def get_weight_col_minmax(self):
return (self.weight_col_obs.min_vals, self.weight_col_obs.max_vals)
def get_weight_row_minmax(self):
return (self.weight_row_obs.min_vals, self.weight_row_obs.max_vals)
def set_equalization_scale(self, equalization_scale):
self.equalization_scale = equalization_scale
def calculate_qparams(self):
r"""
Returns the scale/zero_point for the input and weight rows
"""
if self.equalization_scale.nelement() == 0:
warnings.warn(
"Must call calculate_scale before calling calculate_qparams.\
Returning default scale and zero point. "
)
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
if self.min_weights_ind is None or self.max_weights_ind is None:
warnings.warn(
"Must find the column indicies of the minimum of each row in the \
weights in order to calculate the qparams calculate the \
qparams. Returning default scale and zero point. "
)
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
# Calculate the qparams for weights by using the rows
# Scale the weight rows by the reciprocal of the equalization scale
# located at the same column index
(min_weights, max_weights) = self.get_weight_row_minmax()
min_weights_scaled = torch.mul(min_weights, torch.reciprocal(self.equalization_scale[self.min_weights_ind]))
max_weights_scaled = torch.mul(max_weights, torch.reciprocal(self.equalization_scale[self.max_weights_ind]))
(scale_weight, zero_point_weight) = self.weight_row_obs._calculate_qparams(min_weights_scaled, max_weights_scaled)
return scale_weight, zero_point_weight
def calculate_equalization_scale(input_obs: _InputEqualizationObserver,
weight_obs: _WeightEqualizationObserver) -> torch.Tensor:
r""" Calculates the equalization scale and sets the equalization_scale value
in the observers.
Args:
input_obs: Observer that tracks the ranges for the input columns
weight_obs: Observer that tracks the ranges for the weight columns
"""
(min_inputs, max_inputs) = input_obs.get_input_minmax()
(min_weights, max_weights) = weight_obs.get_weight_col_minmax()
if not (min_inputs.shape == min_weights.shape):
raise ValueError(
"Input and Weight must have the same column dimension. " +
f"Found {min_inputs.shape} and {max_inputs.shape} instead."
)
equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs))
return equalization_scale