blob: 6f2799ba617cad80b35465d69a2be4a8d0c99475 [file] [log] [blame]
import torch
def _quantize_weight(float_wt, observer):
wt_scale, wt_zp = observer.calculate_qparams()
if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
qweight = torch.quantize_per_tensor(
float_wt,
float(wt_scale), int(wt_zp), torch.qint8)
else:
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.double), wt_zp.to(torch.int64), 0, torch.qint8)
return qweight