blob: b435f4283327de86499c2fdc4238a648f72ed151 [file] [log] [blame]
import torch
from torch._six import container_abcs
from itertools import repeat
def _quantize_weight(float_wt, observer):
wt_scale, wt_zp = observer.calculate_qparams()
if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
qweight = torch.quantize_per_tensor(
float_wt,
float(wt_scale), int(wt_zp), torch.qint8)
elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]:
wt_axis = observer.ch_axis
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.double), wt_zp.to(torch.int64), wt_axis, torch.qint8)
else:
raise ValueError("Unexpected qscheme " + observer.qscheme)
return qweight
def _ntuple_from_first(n):
"""Converts the argument to a tuple of size n
with the first element repeated."""
def parse(x):
while isinstance(x, container_abcs.Iterable):
if len(x) == n:
break
x = x[0]
return tuple(repeat(x, n))
return parse
_pair_from_first = _ntuple_from_first(2)