blob: 8c02700e4c77913f5a0149d62f0054b6bc148035 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from functools import reduce
from math import gcd
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .ops.quantized_ops import * # noqa
from torch.ao.quantization.fx._decomposed import (
_quant_min_max_bounds_check,
quantized_decomposed_lib,
)
from torch.library import impl
try:
# pyre-ignore[21]: Undefined import.
from fairseq2.nn.embedding import (
Embedding as fsEmbedding,
StandardEmbedding as fsStandardEmbedding,
)
# pyre-ignore[21]: Undefined import.
from fairseq2.nn.projection import Linear as fsLinear
except:
print("Could not import fairseq2 modules.")
fsEmbedding = nn.Embedding
fsStandardEmbedding = nn.Embedding
fsLinear = nn.Linear
def dynamically_quantize_per_channel(
x,
quant_min,
quant_max,
target_dtype,
group_size: Optional[int] = None,
*,
scales_dtype=torch.float16,
enable_non_multiple_groups=True,
):
"""
Dynamically quantize per channel. This function is used for quantizing weights,
for linear and embedding layers.
Arguments:
x: input tensor,
quant_min: minimum value after quantization,
quant_max: maximum value after quantization,
target_dtype: target data type for weights after quantization,
group_size: number of elements of the channel to quantize together
Keyword arguments:
scales_dtype: data type of scale,
enable_non_multiple_groups: if True, allow the rowsize to not be a multiple of group size,
with a final group of a size less than group size.
Assumptions:
This function assumes symmetric quantization, axis ==0 and a dense memory format.
"""
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed
x_shape_1 = x.shape[1]
if group_size is None or group_size == 0:
items = x_shape_1
elif ((x_shape_1 % group_size) == 0) or not enable_non_multiple_groups:
assert group_size > 0, "group size must be positive"
assert (
x_shape_1 % group_size
) == 0, f"weights dimension 1 = {x_shape_1} must be a multiple of group size {group_size}"
items = group_size
else:
assert group_size > 0, "group size must be positive"
print(
f"row-size of weight matrix {x_shape_1} is not divisible by group size {group_size}, using nearest neighbor rounding"
)
assert (
x_shape_1 % group_size != 0
), f"expected x.shape[1] to not be a multiple of group size {group_size}, but got {x_shape_1}"
padding = group_size - (x_shape_1 % group_size)
x = F.pad(x, (0, padding))
items = group_size
# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps
x = x.view(x.shape[0], x.shape[1] // items, items)
# get min and max
min_val, max_val = torch.aminmax(x, dim=2)
# print(f"min_val {min_val}")
# print(f"max_val {max_val}")
# calculate scales and zero_points based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scales = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scales is the same dtype as the original tensor
scales = torch.clamp(scales, min=eps).to(x.dtype)
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
# quantize based on qmin/qmax/scales/zp
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x / scales.unsqueeze(-1)
x_round = torch.round(x_div)
x_zp = x_round + zero_points.unsqueeze(-1)
quant = (
torch.clamp(x_zp, quant_min, quant_max).to(target_dtype).view(x.shape[0], -1)
)
scales = scales.to(dtype=scales_dtype)
quant = quant[:, :x_shape_1]
return quant, scales, zero_points
# TODO: move this to https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py
quantized_decomposed_lib.define(
"choose_qparams_per_token(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
)
@impl(
quantized_decomposed_lib,
"choose_qparams_per_token",
"CompositeExplicitAutograd",
)
def choose_qparams_per_token(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
every N elements with the same quantization parameter. The dimension for scales/zero_points
will be (M1 * M2 ... * Mn)
Args:
input (torch.Tensor): original float32/float16 Tensor
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
Returns:
scales and zero_points, both float32 Tensors
"""
scales = input.abs().amax(dim=-1, keepdim=True)
if scales.dtype == torch.float16:
scales = (
scales.float()
) # want float scales to avoid overflows for fp16, (bf16 has wide enough range)
if dtype == torch.int8:
n_bits = 8
quant_max = 2 ** (n_bits - 1) - 1
else:
raise Exception(f"unsupported dtype in choose_qparams_per_token: {dtype}")
scales = scales.clamp(min=1e-5).div(quant_max)
zero_points = torch.zeros_like(scales)
return scales, zero_points
@impl(
quantized_decomposed_lib,
"choose_qparams_per_token",
"Meta",
)
def choose_qparams_per_token_meta(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
size = (1, input.size(-1))
return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
size, dtype=torch.int64, device=input.device
)
# TODO: move this to https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py
quantized_decomposed_lib.define(
"choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
)
@impl(
quantized_decomposed_lib,
"choose_qparams_per_token_asymmetric",
"CompositeExplicitAutograd",
)
def choose_qparams_per_token_asymmetric(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
every N elements with the same quantization parameter. The dimension for scales/zero_points
will be (M1 * M2 ... * Mn)
Args:
input (torch.Tensor): original float32/float16 Tensor
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
Returns:
scales and zero_points, both float32 Tensors
"""
# Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
qmin, qmax = -128, 127
min_val, max_val = torch.aminmax(input, dim=-1, keepdim=True)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
eps = torch.finfo(torch.float32).eps # use xnnpack eps?
# scale
scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
scale = scale.clamp(min=eps)
# zero point
descaled_min = min_val_neg / scale
descaled_max = max_val_pos / scale
zero_point_from_min_error = qmin + descaled_min
zero_point_from_max_error = qmax + descaled_max
zero_point = torch.where(
zero_point_from_min_error + zero_point_from_max_error > 0,
qmin - descaled_min,
qmax - descaled_max,
)
zero_point = torch.clamp(zero_point, qmin, qmax).round()
return scale.to(torch.float32), zero_point.to(torch.float32)
@impl(
quantized_decomposed_lib,
"choose_qparams_per_token_asymmetric",
"Meta",
)
def choose_qparams_per_token_asymmetric_meta(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
size = (1, input.size(-1))
return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
size, dtype=torch.int64, device=input.device
)
def _per_token_quant_qparam_dim_check(input, scales, zero_points):
num_tokens = math.prod(list(input.size())[:-1])
assert (
num_tokens == scales.numel()
), f"num_tokens: {num_tokens} scales: {scales.size()}"
assert (
num_tokens == zero_points.numel()
), f"num_tokens: {num_tokens} zero_points: {zero_points.size()}"
quantized_decomposed_lib.define(
"quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor"
)
@impl(quantized_decomposed_lib, "quantize_per_token", "CompositeExplicitAutograd")
def quantize_per_token(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
):
"""Per token quantization for the Tensor using the quantization parameters to map
from floating point to quantized values. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
every N elements with the same quantization parameter. The dimension for scales/zero_points
will be (M1 * M2 ... * Mn)
Args:
input (torch.Tensor): original float32 or bfloat16 Tensor
scales (float32 torch.Tensor): quantization parameter for per token affine quantization
zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization
quant_min (int): minimum quantized value for output Tensor
quant_max (int): maximum quantized value for output Tensor
dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
Returns:
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
are not stored in the Tensor, we are storing them in function arguments instead
"""
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
_per_token_quant_qparam_dim_check(input, scales, zero_points)
input = (
torch.round(input / scales + zero_points).clamp(quant_min, quant_max).to(dtype)
)
return input
@impl(quantized_decomposed_lib, "quantize_per_token", "Meta")
def quantize_per_token_meta(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
):
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=dtype)
quantized_decomposed_lib.define(
"dequantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
"int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype) -> Tensor"
)
@impl(quantized_decomposed_lib, "dequantize_per_token", "CompositeExplicitAutograd")
def dequantize_per_token(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
output_dtype: torch.dtype = torch.float32,
):
"""Per token dequantization for the Tensor using the quantization parameters to map
from floating point to quantized values. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
every N elements with the same quantization parameter. The dimension for scales/zero_points
will be (M1 * M2 ... * Mn)
Args:
input (torch.Tensor): quantized Tensor (uint8, int8 etc.)
scales (float32 torch.Tensor): quantization parameter for per token affine quantization
zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization
quant_min (int): minimum quantized value for input Tensor
quant_max (int): maximum quantized value for input Tensor
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor
Returns:
dequantized Tensor with dtype `output_dtype`
"""
input = input - zero_points
input = input.to(output_dtype) * scales
return input
@impl(quantized_decomposed_lib, "dequantize_per_token", "Meta")
def dequantize_per_token_meta(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
output_dtype: torch.dtype = torch.float32,
):
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
# TODO: support fp16
return torch.empty_like(input, dtype=output_dtype)
def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float32):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
max_val_abs = torch.max(-min_val_neg, max_val_pos)
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))
scales = max_val_abs / (float(max_int - min_int) / 2)
scales = torch.max(scales, torch.full_like(scales, torch.finfo(torch.float32).eps))
# TODO: make sure abs(scales) is not too small?
zeros = torch.full_like(scales, 0)
return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape(
w.shape[0], -1
)
def pack_scales_and_zeros(scales, zeros, precision=torch.float16):
assert scales.shape == zeros.shape
assert scales.dtype == precision
assert zeros.dtype == precision
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)
quantized_decomposed_lib.define(
"quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, "
"int quant_max, ScalarType dtype, int group_size) -> Tensor"
)
# TODO: dtype is ignored for now
@impl(
quantized_decomposed_lib, "quantize_per_channel_group", "CompositeExplicitAutograd"
)
def quantize_per_channel_group(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
group_size=128,
):
assert group_size > 1
# needed for GPTQ single column quantize
if group_size > input.shape[-1] and scales.shape[-1] == 1:
group_size = input.shape[-1]
assert input.shape[-1] % group_size == 0
assert input.dim() == 2
# TODO: check for dtype, currently we can't express torch.int4 so it's omitted
to_quant = input.reshape(-1, group_size)
assert torch.isnan(to_quant).sum() == 0
scales = scales.reshape(-1, 1)
zero_points = zero_points.reshape(-1, 1)
input_int8 = (
to_quant.div(scales)
.add(zero_points)
.round()
.clamp_(quant_min, quant_max)
.to(dtype)
.reshape_as(input)
)
return input_int8
@impl(quantized_decomposed_lib, "quantize_per_channel_group", "Meta")
def quantize_per_channel_group_meta(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
group_size=128,
):
"""Groupwise quantization within each channel for an 2-d Tensor using the quantization parameters
to map from floating point to quantized values. This means for each row of a 2-d Tensor
(M, N), we calculate scales/zero_points for each `group_size` elements
and quantize every `group_size` elements with the same quantization parameter.
The dimension for scales/zero_points will be (M * ceil(N, group_size),)
Args:
input (torch.Tensor): original float32 or bfloat16 Tensor
scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization
zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization
quant_min (int): minimum quantized value for output Tensor
quant_max (int): maximum quantized value for output Tensor
dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
Returns:
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
are not stored in the Tensor, we are storing them in function arguments instead
"""
assert group_size > 1
# needed for GPTQ single column quantize
if group_size > input.shape[-1] and scales.shape[-1] == 1:
group_size = input.shape[-1]
assert input.shape[-1] % group_size == 0
assert input.dim() == 2
return torch.empty_like(input, dtype=dtype)
def group_quantize_tensor_symmetric(
w, n_bit=4, group_size=128, precision=torch.float32
):
scales, zeros = get_group_qparams_symmetric(w, n_bit, group_size, precision)
n_bit = 4
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))
# TODO: currently we don't know how to express torch.int4, we'll
# add torch.int4 to core later
w_int8 = torch.ops.quantized_decomposed.quantize_per_channel_group(
w, scales, zeros, min_int, max_int, torch.int8, group_size
)
return w_int8, scales, zeros
quantized_decomposed_lib.define(
"dequantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, "
"int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor"
)
@impl(
quantized_decomposed_lib,
"dequantize_per_channel_group",
"CompositeExplicitAutograd",
)
def dequantize_per_channel_group(
w_int8: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
group_size: int = 128,
output_dtype: torch.dtype = torch.float32,
):
"""Groupwise dequantization within each channel for an 2-d Tensor using the quantization parameters
to map from floating point to quantized values. This means for each row of a 2-d Tensor
(M, N), we calculate scales/zero_points for each `group_size` elements
and quantize every `group_size` elements with the same quantization parameter.
The dimension for scales/zero_points will be (M * ceil(N, group_size),)
Args:
input (torch.Tensor): quantized Tensor (uint8/int8 etc.)
scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization
zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization
quant_min (int): minimum quantized value for input Tensor
quant_max (int): maximum quantized value for input Tensor
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor
Returns:
dequantized Tensor with dtype `output_dtype`
"""
assert group_size > 1
# needed for GPTQ single column dequantize
if group_size > w_int8.shape[-1] and scales.shape[-1] == 1:
group_size = w_int8.shape[-1]
assert w_int8.shape[-1] % group_size == 0
assert w_int8.dim() == 2
w_int8_grouped = w_int8.reshape(-1, group_size)
scales = scales.reshape(-1, 1)
zero_points = zero_points.reshape(-1, 1)
w_dq = (
w_int8_grouped.sub(zero_points).mul(scales).reshape_as(w_int8).to(output_dtype)
)
return w_dq
def down_size(size):
assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
return (*size[:-1], size[-1] // 2)
def up_size(size):
return (*size[:-1], size[-1] * 2)
quantized_decomposed_lib.define("pack_int4_from_int8(Tensor int8_data) -> Tensor")
@impl(quantized_decomposed_lib, "pack_int4_from_int8", "CompositeExplicitAutograd")
def pack_int4_from_int8(int8_data: torch.Tensor) -> torch.Tensor:
# converting to uint8 for operations
shape = int8_data.shape
assert shape[-1] % 2 == 0
int8_data = int8_data.contiguous().view(-1)
return (int8_data[::2] << 4 | int8_data[1::2]).view(down_size(shape))
quantized_decomposed_lib.define("unpack_int4_to_int8(Tensor int8_data) -> Tensor")
@impl(quantized_decomposed_lib, "unpack_int4_to_int8", "CompositeExplicitAutograd")
def unpack_int4_to_int8(int8_data: torch.Tensor) -> torch.Tensor:
"""Get the original weight from the normalized float weight format"""
# since we are using int8 we will decode 2 entries per byte
# Shift elements down 4 and select out the bottom 4 bits
shape = int8_data.shape
first_elements = (int8_data >> 4).to(torch.int8)
second_elements = (int8_data & 0b1111).to(torch.int8)
return torch.stack([first_elements, second_elements], dim=-1).view(up_size(shape))
def per_token_dynamic_quant(input: torch.Tensor) -> torch.Tensor:
orig_dtype = input.dtype
(
scales,
zero_points,
) = torch.ops.quantized_decomposed.choose_qparams_per_token(input, torch.int8)
# TODO: get these from torch.int8
quant_min = -128
quant_max = 127
input = torch.ops.quantized_decomposed.quantize_per_token(
input, scales, zero_points, quant_min, quant_max, torch.int8
)
input = torch.ops.quantized_decomposed.dequantize_per_token(
input, scales, zero_points, quant_min, quant_max, torch.int8, orig_dtype
)
return input
class QuantHandler:
def __init__(self, mod):
self.mod = mod
def create_quantized_state_dict(self) -> Dict: # "StateDict"
pass
def convert_for_runtime(self) -> nn.Module:
pass
##### Weight-only int8 per-channel quantized code ######
def replace_linear_weight_only_int8_per_channel(module, node_type):
for name, child in module.named_children():
print(f"name: {name}")
if isinstance(child, nn.Linear):
if (
(node_type == "*")
or (node_type == "output" and name == "output")
or (node_type == "!output" and name != "output")
):
print(f"{name, child}")
print(f"in_features: {child.in_features}")
print(f"out_features: {child.out_features}")
setattr(
module,
name,
WeightOnlyInt8Linear(child.in_features, child.out_features),
)
else:
replace_linear_weight_only_int8_per_channel(child, node_type)
class WeightOnlyInt8QuantHandler:
def __init__(
self,
mod,
*,
node_type: str = "*",
bitwidth: Optional[int] = None,
group_size: Optional[int] = None,
):
self.mod = mod
self.group_size = group_size
self.node_type = node_type
if bitwidth is None:
self.bitwidth = 8
else:
self.bitwidth = bitwidth
@torch.no_grad()
def create_quantized_state_dict(self) -> Dict:
cur_state_dict = self.mod.state_dict()
if self.bitwidth == 4:
range_min = -8
range_max = 7
elif self.bitwidth == 8:
range_min = -128
range_max = 127
else:
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
for fqn, mod in self.mod.named_modules():
# print(f"maybe? quantize {fqn}...{type(mod)}")
if isinstance(mod, torch.nn.Linear) or isinstance(mod, fsLinear):
# print(f"candidate {fqn}, nodetype {self.node_type}")
if (
(self.node_type == "*")
or (self.node_type == "output" and fqn in ["output", "final_proj"])
or (
self.node_type == "!output"
and fqn not in ["output", "final_proj"]
)
):
print(
f"quantize {self.node_type} {fqn, mod} with groupsize {self.group_size}, bitwidth {self.bitwidth}"
)
# print(f"initial weight shape {mod.weight.shape}")
input_weight = mod.weight.float()
# print(f"expanded weight shape {input_weight.shape}")
weight, scales, _ = dynamically_quantize_per_channel(
input_weight,
range_min,
range_max,
torch.int8,
self.group_size,
scales_dtype=mod.weight.dtype,
)
cur_state_dict[f"{fqn}.weight"] = weight
# squeeze makes groupsize=rowsize unidimensional
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
return cur_state_dict
def convert_for_runtime(self) -> nn.Module:
replace_linear_weight_only_int8_per_channel(self.mod, self.node_type)
return self.mod
def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod
class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
)
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
# return F.linear(input, self.weight.to(dtype=input.dtype)) * se...
##### embedding table quantization ######
def replace_embedding_weight_only_grouped_int8_per_channel(
module, bitwidth: int = 8, group_size: Optional[int] = None
):
for name, child in module.named_children():
print(f"name: {name}")
if isinstance(child, nn.Embedding):
print(f"{name, child}")
print(f"weights size: {child.weight.size()}")
setattr(
module,
name,
QuantizedGroupEmbedding(
vocab_size=child.weight.shape[0],
embedding_dim=child.weight.shape[1],
group_size=group_size,
),
)
else:
replace_embedding_weight_only_grouped_int8_per_channel(
child, bitwidth, group_size
)
class EmbeddingOnlyInt8QuantHandler:
def __init__(self, mod, *, bitwidth: int = 8, group_size: Optional[int] = None):
self.mod = mod
self.group_size = group_size
self.bitwidth = bitwidth
@torch.no_grad()
def create_quantized_state_dict(self) -> Dict:
cur_state_dict = self.mod.state_dict()
if self.bitwidth == 4:
range_min = -8
range_max = 7
elif self.bitwidth == 8:
range_min = -128
range_max = 127
else:
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
for fqn, mod in self.mod.named_modules():
if (
isinstance(mod, nn.Embedding)
or isinstance(mod, fsEmbedding)
or isinstance(mod, fsStandardEmbedding)
):
print("****")
print(f"Embedding identified: {fqn, mod}")
print(f"weights size: {mod.weight.size()}")
# print(f"quantize {fqn}...")
print(
f"quantize {fqn, mod} with groupsize {self.group_size}, bitwidth {self.bitwidth}"
)
weight, scales, _ = dynamically_quantize_per_channel(
mod.weight.float(),
range_min,
range_max,
torch.int8,
self.group_size,
scales_dtype=mod.weight.dtype,
)
# Update state dict
cur_state_dict[f"{fqn}.weight"] = weight
# squeeze makes groupsize=rowsize unidimensional
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
return cur_state_dict
def convert_for_runtime(self) -> nn.Module:
replace_embedding_weight_only_grouped_int8_per_channel(
self.mod, self.bitwidth, self.group_size
)
return self.mod
def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod
class QuantizedGroupEmbedding(torch.nn.Module):
def __init__(
self,
vocab_size: int,
embedding_dim: int,
group_size: Optional[int] = None,
device=None,
dtype=torch.half,
) -> None:
super().__init__()
if group_size is None:
group_size = embedding_dim
self.group_size = group_size
self.dtype = dtype
self.register_buffer(
"weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8)
)
groups_per_row = (embedding_dim + group_size - 1) // group_size
if groups_per_row > 1:
self.register_buffer(
"scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16)
)
else:
self.register_buffer(
"scales", torch.ones((vocab_size,), dtype=torch.float16)
)
@torch.no_grad()
def forward(self, indices: torch.Tensor) -> torch.Tensor:
return torch.ops.llama_quantized.embedding_byte.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)
# result_weights = self.weight.index_select(0, indices.view(-1))
# result_scales = self.scales.index_select(0, indices.view(-1))
#
# r = result_weights.to(dtype=result_scales.dtype) * result_scales
# return r.view(indices.size() + (-1,))
##### weight only int4 per channel groupwise quantized code ######
def prepare_int4_weight_and_scales_and_zeros(weight, group_size, precision):
weight_int8, scales, zeros = group_quantize_tensor_symmetric(
weight,
n_bit=4,
group_size=group_size,
precision=precision,
)
# TODO: better API
# weight_int4packed = torch.ops.quantized_decomposed.pack_int4_from_int8(weight_int8)
return weight_int8, scales, zeros
def linear_forward_8da4w(
x, weight_int8, scales, zeros, out_features, group_size, precision
):
x = per_token_dynamic_quant(x)
# TODO: verify and remove following reshape code
# origin_x_size = x.size()
# x = x.reshape(-1, origin_x_size[-1])
# TODO: better API
# weight_int8 = torch.ops.quantized_decomposed.unpack_int4_to_int8(weight_int4packed)
n_bit = 4
quant_min = -(2 ** (n_bit - 1))
quant_max = 2 ** (n_bit - 1) - 1
w_dq = torch.ops.quantized_decomposed.dequantize_per_channel_group(
weight_int8,
scales,
zeros,
quant_min,
quant_max,
torch.int8,
group_size,
precision,
)
# x = x.to(torch.float16)
# w_dq = w_dq.to(torch.float16)
c = torch.nn.functional.linear(x, w_dq)
# new_shape = origin_x_size[:-1] + (out_features,)
# c = c.reshape(new_shape)
return c
def find_multiple(n: int, *args: Tuple[int]) -> int:
k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9]
if n % k == 0:
return n
return n + k - (n % k)
def _check_linear_int4_k(k, group_size=1):
return k % group_size == 0
def _calc_padded_size_linear_int4(k, groupsize=1):
return find_multiple(k, groupsize)
def replace_linear_8da4w(
module,
group_size,
padding_allowed,
precision,
scales_precision,
):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if _check_linear_int4_k(child.in_features, group_size) or padding_allowed:
setattr(
module,
name,
Int8DynActInt4WeightLinear(
child.in_features,
child.out_features,
bias=False,
group_size=group_size,
precision=precision,
scales_precision=scales_precision,
),
)
else:
replace_linear_8da4w(
child,
group_size,
padding_allowed,
precision,
scales_precision,
)
class Int8DynActInt4WeightQuantHandler:
def __init__(
self,
mod,
group_size=256,
padding_allowed=False,
precision=torch.float32,
scales_precision=torch.float32,
):
self.mod = mod
self.group_size = group_size
self.padding_allowed = padding_allowed
self.precision = precision
self.scales_precision = scales_precision
# assert group_size in [32, 64, 128, 256]
@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
print("in features:", in_features, " out features:", out_features)
# assert out_features % 8 == 0, "require out_features % 8 == 0"
print(f"linear: {fqn}, in={in_features}, out={out_features}")
assert (
in_features % self.group_size == 0
), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0"
weight = mod.weight.data
"""
if not _check_linear_int4_k(
in_features, self.group_size
):
if self.padding_allowed:
print(
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
)
padded_in_features = _calc_padded_size_linear_int4(
in_features, self.group_size
)
weight = F.pad(
weight, pad=(0, padded_in_features - in_features)
)
else:
raise RuntimeError(
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ "and that group_size"
)
"""
(
weight_int4pack,
scales,
zeros,
) = prepare_int4_weight_and_scales_and_zeros(
weight.to(self.precision),
self.group_size,
self.scales_precision,
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
cur_state_dict[f"{fqn}.scales"] = scales.to("cpu")
cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu")
return cur_state_dict
def convert_for_runtime(self):
replace_linear_8da4w(
self.mod,
self.group_size,
self.padding_allowed,
self.precision,
self.scales_precision,
)
return self.mod
def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod
class Int8DynActInt4WeightLinear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
"""
This module implements a dynamic quantized linear layer with int4 weight.
Weights are per channel groupwise quantized. Parameters of importance
group_size: the number of elements in each quantized group
precision: precision of input and output. e.g. torch.float32 means input
activation is float32 and output is float32.
scales_precision: precision of per group scale.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias=True,
device=None,
dtype=None,
group_size: int = 256,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
) -> None:
super().__init__()
# always pad if needed since it becomes a noop at runtime if not needed
# self.origin_in_features = in_features
assert (
in_features % group_size == 0
), f"require in_features:{in_features} % group_size:{group_size} == 0"
# in_features = _calc_padded_size_linear_int4(
# in_features, group_size
# )
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.group_size = group_size
# Precision of the activation which also indicates
# output precision of the dynamically quantized linear layer
# that his module represents.
self.precision = precision
# currently storing unpacked int8 weights
self.register_buffer(
"weight",
torch.empty((out_features, in_features), dtype=torch.int8),
)
self.register_buffer(
"scales",
torch.empty(
(out_features, in_features // group_size),
dtype=scales_precision,
),
)
self.register_buffer(
"zeros",
torch.empty(
(out_features, in_features // group_size),
dtype=scales_precision,
),
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(self.precision)
# padding is removed for perf
# input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_8da4w(
input,
self.weight,
self.scales,
self.zeros,
self.out_features,
self.groupsize,
self.precision,
)
#### GPTQ ########
try:
from GPTQ import ( # pyre-ignore[21]
evaluate,
GenericGPTQRunner,
get_task_dict,
InputRecorder,
lm_eval,
MultiInput,
)
except:
pass
class GPTQQuantHandler(QuantHandler):
"""
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
__init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
create_quantized_state_dict. Here is a description of each function.
get_qparams_func:
A function that calculates the quantization qparams for an input tensor.
Args:
weight: A 2d weight tensor with non-integer dtype.
Returns:
qparams: it can have any format but will need to be handled by the other defined functions below.
quantize_func:
A function that applies quantization to an input tensor. It should be noted
that this function needs to be able to handle quantizing the entire weight tensor, a single group,
or a single column.
Args:
weight: A 2d weight tensor with non-integer dtype.
qparams: the output from get_qparams_func
Returns:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
dequantize_func:
A function that dequantizes an input quantized weight tensor. It should be noted
that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
or a single column.
Args:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
qparams: the output from get_qparams_func
Returns:
weight: A 2d weight tensor with non-integer dtype.
combine_qparams_list_func:
A function that combines several qparams into one qparam.
Args:
qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
on a single group from a weight tensor
Returns:
qparams: an object of the same format as the qparams above.
skip_layer_func:
A function that determines which linear layers should be skipped during GPTQ
Args:
weight: A 2d weight tensor with non-integer dtype.
Returns:
skip: boolean indicating whether layer should be skipped
make_names_and_values_dict_func:
A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
Args:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
qparams: the output from get_qparams_func
Returns:
names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
corresponding quantized weights and qparams.
"""
def __init__(self):
assert self.mod is not None
assert self.get_qparams_func is not None
assert self.quantize_func is not None
assert self.dequantize_func is not None
assert self.combine_qparams_list_func is not None
assert self.make_names_and_values_dict_func is not None
@staticmethod
def get_inputs(
model,
tokenizer,
calibration_tasks,
calibration_limit,
calibration_seq_length,
pad_calibration_inputs,
) -> "MultiInput": # pyre-ignore[11]
input_recorder = InputRecorder(
model,
tokenizer,
calibration_seq_length,
pad_calibration_inputs,
)
try:
lm_eval.tasks.initialize_tasks()
except:
pass
task_dict = get_task_dict(calibration_tasks)
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
evaluate(
input_recorder,
task_dict,
limit=calibration_limit,
)
inputs = input_recorder.get_recorded_inputs()
assert inputs is not None, (
f"No inputs were collected, use a task other than {calibration_tasks}, "
+ "use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "
+ f"{calibration_seq_length})"
)
print(f"Obtained {len(inputs[0].values)} calibration samples")
return inputs
@torch.no_grad()
def create_quantized_state_dict(
self,
tokenizer,
blocksize,
percdamp,
groupsize,
calibration_tasks,
calibration_limit,
calibration_seq_length,
pad_calibration_inputs,
) -> Dict:
inputs = GPTQQuantHandler.get_inputs(
self.mod,
tokenizer,
calibration_tasks,
calibration_limit,
calibration_seq_length,
pad_calibration_inputs,
)
print("Tracing model for GPTQ")
GPTQ_runner = GenericGPTQRunner(
self.mod,
inputs,
blocksize,
percdamp,
groupsize,
).configure_quantization_mode(
self.get_qparams_func, # pyre-ignore[16]
self.quantize_func, # pyre-ignore[16]
self.dequantize_func, # pyre-ignore[16]
self.combine_qparams_list_func, # pyre-ignore[16]
self.make_names_and_values_dict_func, # pyre-ignore[16]
self.skip_layer_func, # pyre-ignore[16]
)
print("Applying GPTQ to weights")
GPTQ_runner.run()
return GPTQ_runner.get_quantized_state_dict()
def convert_for_runtime(self) -> "nn.Module":
pass
class Int8DynActInt4WeightGPTQQuantHandler(GPTQQuantHandler):
def __init__(
self,
groupsize=128,
inner_k_tiles=8,
padding_allowed=True,
precision=torch.float32,
):
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding_allowed = padding_allowed
self.precision = precision
self.dyn_quant_func = lambda x: per_token_dynamic_quant(x)
n_bit = 4
self.get_qparams_func = lambda w: get_group_qparams_symmetric(
w, n_bit, groupsize, self.precision
)
quant_min = -(2 ** (n_bit - 1))
quant_max = 2 ** (n_bit - 1) - 1
self.quantize_func = lambda w, qparams: torch.ops.quantized_decomposed.quantize_per_channel_group(
w, qparams[0], qparams[1], quant_min, quant_max, torch.int8, groupsize
)
self.dequantize_func = lambda q, qparams: torch.ops.quantized_decomposed.dequantize_per_channel_group(
q,
qparams[0],
qparams[1],
quant_min,
quant_max,
torch.int8,
groupsize,
self.precision,
)
self.combine_qparams_list_func = lambda qparams_list: [
torch.cat(x, dim=1) for x in zip(*qparams_list)
]
# skip unless padding_allowed=True or its correctly sized
self.skip_layer_func = lambda linear_weight: not (
_check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles)
or padding_allowed
)
# we need to do the padding here, both for q and the qparams if necessary
def make_names_and_values_dict_func(q, qparams):
k = q.shape[1]
new_k = _calc_padded_size_linear_int4(k, groupsize, inner_k_tiles)
# how much we need to pad the weight
delta_k = new_k - q.shape[1]
final_q = F.pad(q, pad=(0, delta_k))
scales_and_zeros = pack_scales_and_zeros(*qparams, precision=self.precision)
# how many new groups we need for padded weight
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
# TODO: split scales and zero_points
final_s_and_z = F.pad(
scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1
)
return {"weight": final_q, "scales_and_zeros": final_s_and_z}
self.make_names_and_values_dict_func = make_names_and_values_dict_func
super().__init__()
def convert_for_runtime(self, model):
replace_linear_8da4w(
model,
self.groupsize,
self.padding_allowed,
torch.int8,
self.precision,
)
return model