| import torch |
| import numpy as np |
| |
| from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule |
| from torch.ao.quantization.experimental.observer import APoTObserver |
| from torch.ao.quantization.experimental.quantizer import quantize_APoT |
| |
| class LinearAPoT(WeightedQuantizedModule): |
| r""" |
| A quantized linear module with quantized tensor as inputs and outputs |
| to support APoT quantization. |
| We adopt the same interface as `torch.nn.Linear`, see |
| https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation. |
| |
| Similar to :class:`~torch.nn.Linear`, attributes will be randomly |
| initialized at module creation time and will be overwritten later |
| |
| Attributes: |
| alpha: `alpha` qparam of output Quantized Tensor, type: Tensor |
| gamma: `gamma` qparam of output Quantized Tensor, type: Tensor |
| quantization_levels: `quantization_levels` qparam of output Quantized Tensor, type: Tensor |
| level_indices: `level_indices` qparam of output Quantized Tensor, type: Tensor |
| weight: APoT quantized tensor from weight2quantize |
| weight_transposed: transposed weight tensor, used in linear transformation calculation (y = x * A^T + b) |
| """ |
| |
| def __init__(self, weight2quantize: torch.Tensor, b: int, k: int): |
| assert weight2quantize.dim() == 2 |
| assert b % k == 0 |
| |
| super().__init__() |
| |
| self.b = b |
| self.k = k |
| self.n = self.b // self.k |
| |
| observer = APoTObserver(b=self.b, k=self.k) |
| |
| observer(weight2quantize) |
| |
| self.alpha, self.gamma, self.quantization_levels, self.level_indices = observer.calculate_qparams(signed=False) |
| |
| quantized_weight = quantize_APoT(weight2quantize, self.alpha, self.gamma, self.quantization_levels, self.level_indices) |
| self.weight = quantized_weight.data |
| self.weight_transposed = torch.transpose(self.weight, 0, 1) |
| |
| def decompose_APoT(self, x): |
| r""" |
| Decompose binary representation of APoT values into list of k-sized blocks |
| Args: |
| x (Tensor): binary representation of APoT quantized tensor |
| """ |
| # remove "0b" prefix from binary representation |
| x = x[2:] |
| |
| # initialize list of blocks |
| blocks = [] |
| |
| while x: |
| blocks.append(x[0:self.k]) |
| x = x[self.k:] |
| |
| return blocks |
| |
| def bitshift_mul(self, weight_val, r): |
| r""" |
| Compute multiplication of weight_val * r using bitshifting |
| method discussed in APoT paper: https://arxiv.org/pdf/1909.13144.pdf |
| Args: |
| weight_val: list of binary digits representing APoT quantized weight value |
| r: int representing uniformly quantized activation value |
| """ |
| product = 0 |
| |
| idx = len(weight_val) - 1 |
| place = 0 |
| |
| while idx >= 0: |
| block = weight_val[idx] |
| |
| # reverse digits in block |
| block = block[::-1] |
| |
| curr_block_result = 0 |
| |
| for ele in block: |
| if int(ele): |
| curr_block_result += r << place |
| place += 1 |
| |
| idx -= 1 |
| product += curr_block_result |
| |
| return product |
| |
| |
| def matmul(self, decomposed_weight, activation): |
| r""" |
| Perform matrix multiplication between decomposed_weight and |
| activation by calling bitshift_mul function for each value |
| Args: |
| decomposed_weight (Tensor): APoT quantized weight decomposed into binary |
| activation (Tensor): uniformly quantized activation |
| """ |
| rows1 = activation.size(dim=0) |
| cols1 = activation.size(dim=1) |
| |
| rows2 = decomposed_weight.shape[0] |
| cols2 = decomposed_weight.shape[1] |
| |
| result = torch.zeros(rows1, cols2) |
| |
| # compute matrix multiplication with bitshifts |
| for i in range(rows1): |
| for j in range(cols2): |
| for k in range(rows2): |
| weight_val = decomposed_weight[k][j] |
| r = int(activation[i][k]) |
| |
| product = self.bitshift_mul(weight_val, r) |
| |
| result[i][j] += product |
| |
| return result |
| |
| def forward(self, activation: torch.Tensor) -> torch.FloatTensor: |
| r""" |
| Multiply APoT quantized weight and uniformly quantized activation (dtype: quint8) |
| with bitshifting instead of matrix multiplication. |
| Result has dtype torch.float32 |
| Args: |
| activation (Tensor): uniformly quantized activation tensor |
| """ |
| assert activation.dim() == 2 |
| |
| weight_rows = self.weight_transposed.size()[0] |
| weight_cols = self.weight_transposed.size()[1] |
| |
| decomposed_weight: np.ndarray = np.empty(shape=(weight_rows, weight_cols), dtype=object) |
| for row in range(weight_rows): |
| for col in range(weight_cols): |
| decomposed_weight[row][col] = self.decompose_APoT(bin(self.weight_transposed[row][col])) |
| |
| result = self.matmul(decomposed_weight, activation).type(torch.FloatTensor) |
| |
| return result |
| |
| @classmethod |
| def from_reference(cls, # type: ignore[override] |
| ref_qlinear, |
| alpha: torch.Tensor, |
| gamma: torch.Tensor, |
| quantization_levels: torch.Tensor, |
| level_indices: torch.Tensor): |
| raise NotImplementedError |