| import torch |
| import copy |
| |
| |
| class QuantizedLinear(torch.jit.ScriptModule): |
| __constants__ = ['scale', 'zero_point'] |
| |
| def __init__(self, other): |
| super(QuantizedLinear, self).__init__() |
| self.in_features = other.in_features |
| self.out_features = other.out_features |
| # Quantize weight and discard the original |
| self.weight, self.col_offsets, self.scale, self.zero_point = torch.fbgemm_linear_quantize_weight( |
| other.weight.clone().float()) |
| self.weight = torch.nn.Parameter(self.weight, requires_grad=False) |
| self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False) |
| assert other.bias is not None, 'QuantizedLinear requires a bias' |
| self.bias = torch.nn.Parameter(other.bias.clone().float()) |
| |
| self.register_buffer( |
| 'packed_tensor_ptr', |
| torch.fbgemm_pack_quantized_matrix(self.weight.clone(), self.weight.size(1), self.weight.size(0))) |
| |
| @torch.jit.script_method |
| def _unpack(self): |
| self.packed_tensor_ptr.set_( |
| torch.fbgemm_pack_quantized_matrix( |
| self.weight, self.weight.size(1), self.weight.size(0))) |
| |
| @torch.jit.script_method |
| def _pack(self): |
| self.packed_tensor_ptr.set_( |
| torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()) |
| |
| @torch.jit.script_method |
| def forward(self, input): |
| out = torch.fbgemm_linear_int8_weight( |
| input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets, |
| self.scale, self.zero_point, self.bias) |
| return out.type_as(input) |
| |
| def extra_repr(self): |
| repr = 'in_features={in_features}, out_features={out_features}, ' \ |
| 'scale={scale}, zero_point={zero_point}'.format(**self.__dict__) |
| return repr |
| |
| |
| def quantize_linear_modules(module): |
| for name, mod in module.named_modules(): |
| if mod is module: |
| continue |
| if isinstance(mod, torch.nn.Linear): |
| setattr(module, name, QuantizedLinear(mod)) |
| quantize_linear_modules(mod) |