blob: 154023b161836d2ab1e7fe149533fb927df0ded2 [file] [log] [blame]
import torch
import numpy as np
from torch.ao.nn.quantized.modules.utils import WeightedQuantizedModule
from torch.ao.quantization.experimental.observer import APoTObserver
from torch.ao.quantization.experimental.quantizer import quantize_APoT
class LinearAPoT(WeightedQuantizedModule):
r"""
A quantized linear module with quantized tensor as inputs and outputs
to support APoT quantization.
We adopt the same interface as `torch.nn.Linear`, see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
Similar to :class:`~torch.nn.Linear`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
alpha: `alpha` qparam of output Quantized Tensor, type: Tensor
gamma: `gamma` qparam of output Quantized Tensor, type: Tensor
quantization_levels: `quantization_levels` qparam of output Quantized Tensor, type: Tensor
level_indices: `level_indices` qparam of output Quantized Tensor, type: Tensor
weight: APoT quantized tensor from weight2quantize
weight_transposed: transposed weight tensor, used in linear transformation calculation (y = x * A^T + b)
"""
def __init__(self, weight2quantize: torch.Tensor, b: int, k: int):
assert weight2quantize.dim() == 2
assert b % k == 0
super().__init__()
self.b = b
self.k = k
self.n = self.b // self.k
observer = APoTObserver(b=self.b, k=self.k)
observer(weight2quantize)
self.alpha, self.gamma, self.quantization_levels, self.level_indices = observer.calculate_qparams(signed=False)
quantized_weight = quantize_APoT(weight2quantize, self.alpha, self.gamma, self.quantization_levels, self.level_indices)
self.weight = quantized_weight.data
self.weight_transposed = torch.transpose(self.weight, 0, 1)
def decompose_APoT(self, x):
r"""
Decompose binary representation of APoT values into list of k-sized blocks
Args:
x (Tensor): binary representation of APoT quantized tensor
"""
# remove "0b" prefix from binary representation
x = x[2:]
# initialize list of blocks
blocks = []
while x:
blocks.append(x[0:self.k])
x = x[self.k:]
return blocks
def bitshift_mul(self, weight_val, r):
r"""
Compute multiplication of weight_val * r using bitshifting
method discussed in APoT paper: https://arxiv.org/pdf/1909.13144.pdf
Args:
weight_val: list of binary digits representing APoT quantized weight value
r: int representing uniformly quantized activation value
"""
product = 0
idx = len(weight_val) - 1
place = 0
while idx >= 0:
block = weight_val[idx]
# reverse digits in block
block = block[::-1]
curr_block_result = 0
for ele in block:
if int(ele):
curr_block_result += r << place
place += 1
idx -= 1
product += curr_block_result
return product
def matmul(self, decomposed_weight, activation):
r"""
Perform matrix multiplication between decomposed_weight and
activation by calling bitshift_mul function for each value
Args:
decomposed_weight (Tensor): APoT quantized weight decomposed into binary
activation (Tensor): uniformly quantized activation
"""
rows1 = activation.size(dim=0)
cols1 = activation.size(dim=1)
rows2 = decomposed_weight.shape[0]
cols2 = decomposed_weight.shape[1]
result = torch.zeros(rows1, cols2)
# compute matrix multiplication with bitshifts
for i in range(rows1):
for j in range(cols2):
for k in range(rows2):
weight_val = decomposed_weight[k][j]
r = int(activation[i][k])
product = self.bitshift_mul(weight_val, r)
result[i][j] += product
return result
def forward(self, activation: torch.Tensor) -> torch.FloatTensor:
r"""
Multiply APoT quantized weight and uniformly quantized activation (dtype: quint8)
with bitshifting instead of matrix multiplication.
Result has dtype torch.float32
Args:
activation (Tensor): uniformly quantized activation tensor
"""
assert activation.dim() == 2
weight_rows = self.weight_transposed.size()[0]
weight_cols = self.weight_transposed.size()[1]
decomposed_weight: np.ndarray = np.empty(shape=(weight_rows, weight_cols), dtype=object)
for row in range(weight_rows):
for col in range(weight_cols):
decomposed_weight[row][col] = self.decompose_APoT(bin(self.weight_transposed[row][col]))
result = self.matmul(decomposed_weight, activation).type(torch.FloatTensor)
return result
@classmethod
def from_reference(cls, # type: ignore[override]
ref_qlinear,
alpha: torch.Tensor,
gamma: torch.Tensor,
quantization_levels: torch.Tensor,
level_indices: torch.Tensor):
raise NotImplementedError