blob: 817e54460e07505e574431978a2dd16bebeb8698 [file] [log] [blame]
# Torch
import torch
from torch.quantization import (
MinMaxObserver,
PerChannelMinMaxObserver,
MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
MinMaxDynamicQuantObserver,
HistogramObserver,
RecordingObserver,
PlaceholderObserver,
NoopObserver,
FakeQuantize,
default_debug_qconfig,
default_observer,
default_per_channel_weight_observer,
get_observer_dict,
prepare,
)
from torch.quantization._learnable_fake_quantize import (
_LearnableFakeQuantizePerTensorOp,
_LearnableFakeQuantizePerChannelOp
)
import torch.nn as nn
# Standard library
import copy
import io
import itertools
import unittest
import math
import numpy as np
# Testing utils
from hypothesis import given, settings
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
AnnotatedSingleLayerLinearModel,
test_only_eval_fn,
)
from torch.testing._internal.common_quantized import (
override_quantized_engine,
supported_qengines,
override_qengines,
)
# Reference method for fake quantize
def _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, quant_min, quant_max):
res = (torch.clamp(torch.round(X * (1.0 / scale) + zero_point), quant_min, quant_max) - zero_point) * scale
return res
# Reference method for the gradient of the fake quantize operator
def _fake_quantize_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max):
Xq = torch.round(X * (1.0 / scale) + zero_point)
mask = (Xq >= quant_min) * (Xq <= quant_max)
res = torch.zeros_like(dY)
res[mask] = dY[mask]
return res
# Reference method for the gradients of the fake quantize operator
def _fake_quantize_learnable_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max, device):
r"""This method references the following literatures for back propagation on scale and zero point.
- https://arxiv.org/pdf/1902.08153.pdf
- https://arxiv.org/pdf/1903.08066.pdf
"""
zero_point_rounded = int((zero_point + 0.5).clamp(quant_min, quant_max).item())
Xq = torch.round(X * (1.0 / scale) + zero_point_rounded).clamp(quant_min, quant_max)
Xfq = (Xq - zero_point_rounded) * scale
indicate_small_scale = (Xq == quant_min).float().to(device)
indicate_big_scale = (Xq == quant_max).float().to(device)
indicate_middle_scale = torch.ones(indicate_small_scale.shape).to(device) - \
indicate_small_scale - indicate_big_scale
indicate_saturate_zp = ((Xq == quant_min).float() + (Xq == quant_max).float()).to(device)
indicate_unsaturate_zp = torch.ones(indicate_saturate_zp.shape).to(device) - indicate_saturate_zp
grad_small_scale = quant_min - zero_point_rounded
grad_big_scale = quant_max - zero_point_rounded
grad_middle_scale = ((Xfq - X) / scale).to(device)
grad_saturate_zp = -scale.to(device)
grad_unsaturate_zp = 0
grad_scale = indicate_small_scale * grad_small_scale + \
indicate_big_scale * grad_big_scale + \
indicate_middle_scale * grad_middle_scale
grad_zp = indicate_saturate_zp * grad_saturate_zp + \
indicate_unsaturate_zp * grad_unsaturate_zp
grad_X = _fake_quantize_per_tensor_affine_grad_reference(
dY, X, scale, zero_point, quant_min, quant_max).to(device)
grad_scale = (grad_scale * dY).sum().unsqueeze(dim=0)
grad_zp = (grad_zp * dY).sum().unsqueeze(dim=0)
return grad_X, grad_scale, grad_zp
# Helper function used to simulate per-channel fake-quant against any axis
def _permute_to_axis_zero(X, axis):
new_axis_list = list(range(X.dim()))
new_axis_list[axis] = 0
new_axis_list[0] = axis
y = X.permute(tuple(new_axis_list))
return y, new_axis_list
# Reference method for fake quantize
def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
X, permute_axis_list = _permute_to_axis_zero(X, axis)
res = torch.zeros_like(X)
for i in range(X.size()[0]):
res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
out = res.permute(tuple(permute_axis_list))
return out
# Reference method for the gradient of the fake quantize operator
def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
X, permute_axis_list = _permute_to_axis_zero(X, axis)
Xq = torch.zeros_like(X)
for i in range(X.size()[0]):
Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
Xq = Xq.permute(tuple(permute_axis_list))
mask = (Xq >= quant_min) * (Xq <= quant_max)
res = torch.zeros_like(dY)
res[mask] = dY[mask]
return res
# Reference method for quantization.
def _quantize_per_tensor(x, scale, zero_point, quant_min, quant_max):
return ((x / scale) + zero_point).round().clamp(quant_min, quant_max)
# Reference method for the per channel gradients of the learnable fake quantize operator
def _fake_quantize_learnable_per_channel_affine_grad_reference(
dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max, device):
r"""This method references the following literatures for back propagation on scale and zero point.
- https://arxiv.org/pdf/1902.08153.pdf
- https://arxiv.org/pdf/1903.08066.pdf
"""
grad_X = _fake_quantize_per_channel_affine_grad_reference(
dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max).to(device)
per_channel_scale = per_channel_scale.detach().type(torch.float)
per_channel_zero_point = ((per_channel_zero_point.detach() + 0.5).clamp(quant_min, quant_max)).type(torch.int64)
grad_scale = torch.zeros([per_channel_scale.size(0)]).to(device)
grad_zero_point = torch.zeros([per_channel_zero_point.size(0)]).to(device)
X_flattened = torch.unbind(X, dim=axis)
dY_flattened = torch.unbind(dY, dim=axis)
for i, X_i in enumerate(torch.unbind(X, dim=axis), 0):
scale_i = per_channel_scale[i]
zero_point_i = per_channel_zero_point[i]
X_i = X_flattened[i]
dY_i = dY_flattened[i]
Xq_i = _quantize_per_tensor(
X_i, scale_i, zero_point_i, quant_min, quant_max).to(device)
Xfq_i = (Xq_i - zero_point_i) * scale_i
indicate_small_scale_i = (Xq_i == quant_min).float().to(device)
indicate_big_scale_i = (Xq_i == quant_max).float().to(device)
indicate_middle_scale_i = torch.ones(indicate_small_scale_i.shape).to(device) - \
indicate_small_scale_i - indicate_big_scale_i
indicate_saturate_zp_i = ((Xq_i == quant_min).float() +
(Xq_i == quant_max).float()).to(device)
indicate_unsaturate_zp_i = torch.ones(indicate_saturate_zp_i.shape).to(device) - \
indicate_saturate_zp_i
grad_small_scale_i = quant_min - zero_point_i
grad_big_scale_i = quant_max - zero_point_i
grad_middle_scale_i = ((Xfq_i - X_i) / scale_i).to(device)
grad_saturate_zp_i = -scale_i.to(device)
grad_unsaturate_zp_i = 0
grad_scale_i = indicate_small_scale_i * grad_small_scale_i + \
indicate_middle_scale_i * grad_middle_scale_i + \
indicate_big_scale_i * grad_big_scale_i
grad_zp_i = indicate_saturate_zp_i * grad_saturate_zp_i + \
indicate_unsaturate_zp_i * grad_unsaturate_zp_i
grad_scale_i = (grad_scale_i * dY_i).sum().unsqueeze(dim=0)
grad_zp_i = (grad_zp_i * dY_i).sum().unsqueeze(dim=0)
grad_scale[i] = grad_scale_i
grad_zero_point[i] = grad_zp_i
return grad_X, grad_scale, grad_zero_point
def to_tensor(X, device):
return torch.tensor(X).to(device=torch.device(device), dtype=torch.float32)
NP_RANDOM_SEED = 19
tolerance = 1e-6
class TestObserver(QuantizationTestCase):
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
reduce_range=st.booleans())
def test_per_tensor_observers(self, qdtype, qscheme, reduce_range):
# reduce_range cannot be true for symmetric quantization with uint8
if qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric:
reduce_range = False
ObserverList = [MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range),
MovingAverageMinMaxObserver(averaging_constant=0.5,
dtype=qdtype,
qscheme=qscheme,
reduce_range=reduce_range)]
for myobs in ObserverList:
# Calculate Qparams should return with a warning for observers with no data
qparams = myobs.calculate_qparams()
if type(myobs) == MinMaxObserver:
x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0])
else:
# Moving average of min/max for x and y matches that of
# extreme values for x/y used for minmax observer
x = torch.tensor([0.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
y = torch.tensor([2.0, 5.0, 5.0, 6.0, 7.0, 10.0])
result = myobs(x)
result = myobs(y)
self.assertEqual(result, y)
self.assertEqual(myobs.min_val, 1.0)
self.assertEqual(myobs.max_val, 8.0)
qparams = myobs.calculate_qparams()
if reduce_range:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.062745 * 255 / 127
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0313725 * 255 / 127
ref_zero_point = -64 if qdtype is torch.qint8 else 0
else:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.062745
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0313725
ref_zero_point = -128 if qdtype is torch.qint8 else 0
self.assertEqual(qparams[1].item(), ref_zero_point)
self.assertEqual(qparams[0].item(), ref_scale, atol=1e-5, rtol=0)
state_dict = myobs.state_dict()
b = io.BytesIO()
torch.save(state_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
for key in state_dict:
self.assertEqual(state_dict[key], loaded_dict[key])
loaded_obs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
loaded_obs.load_state_dict(loaded_dict)
loaded_qparams = loaded_obs.calculate_qparams()
self.assertEqual(myobs.min_val, loaded_obs.min_val)
self.assertEqual(myobs.max_val, loaded_obs.max_val)
self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=2, max_dims=4,
min_side=1, max_side=10),
qparams=hu.qparams()),
reduce_range=st.booleans())
def test_per_tensor_dynamic_quant_observers(self, X, reduce_range):
X, (scale, zero_point, torch_type) = X
x = torch.from_numpy(X)
obs = MinMaxDynamicQuantObserver(dtype=torch.quint8, reduce_range=reduce_range)
result = obs(x)
qparams = obs.calculate_qparams()
ref = torch._choose_qparams_per_tensor(x, reduce_range)
self.assertEqual(ref[0], qparams[0])
self.assertEqual(ref[1], qparams[1])
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric, torch.per_channel_affine_float_qparams)),
ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans())
def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range):
# reduce_range cannot be true for symmetric quantization with uint8
if qscheme == torch.per_channel_affine_float_qparams:
reduce_range = False
if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric:
reduce_range = False
ObserverList = [PerChannelMinMaxObserver(reduce_range=reduce_range,
ch_axis=ch_axis,
dtype=qdtype,
qscheme=qscheme),
MovingAveragePerChannelMinMaxObserver(averaging_constant=0.5,
reduce_range=reduce_range,
ch_axis=ch_axis,
dtype=qdtype,
qscheme=qscheme)]
for myobs in ObserverList:
# Calculate qparams should work for empty observers
qparams = myobs.calculate_qparams()
x = torch.tensor(
[
[[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]],
[[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]],
]
)
if type(myobs) == MovingAveragePerChannelMinMaxObserver:
# Scaling the input tensor to model change in min/max values
# across batches
result = myobs(0.5 * x)
result = myobs(1.5 * x)
self.assertEqual(result, 1.5 * x)
else:
result = myobs(x)
self.assertEqual(result, x)
qparams = myobs.calculate_qparams()
ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]]
ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]]
per_channel_symmetric_ref_scales = [
[0.04705882, 0.06274509],
[0.03921569, 0.0627451],
[0.04705882, 0.0627451],
[0.05490196, 0.0627451],
]
per_channel_affine_ref_scales = [
[0.02352941, 0.04705882],
[0.03529412, 0.03137255],
[0.03921569, 0.03137255],
[0.04313726, 0.04313726],
]
per_channel_affine_qint8_zp = [
[-128, -43],
[-15, -128],
[-26, -128],
[-35, -58],
]
per_channel_affine_float_qparams_ref_scales = [
[0.0196, 0.0471],
[0.0353, 0.0196],
[0.0392, 0.0235],
[0.0431, 0.0431],
]
per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]]
self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis])
self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis])
if qscheme == torch.per_channel_symmetric:
ref_scales = per_channel_symmetric_ref_scales[ch_axis]
ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128]
elif qscheme == torch.per_channel_affine_float_qparams:
ref_scales = per_channel_affine_float_qparams_ref_scales[ch_axis]
ref_zero_points = [-1 * ref_min_vals[ch_axis][i] / ref_scales[i] for i in range(len(ref_scales))]
else:
ref_scales = per_channel_affine_ref_scales[ch_axis]
ref_zero_points = (
per_channel_affine_qint8_zp[ch_axis]
if qdtype is torch.qint8
else per_channel_affine_quint8_zp[ch_axis]
)
if reduce_range:
ref_scales = [s * 255 / 127 for s in ref_scales]
ref_zero_points = [math.floor(z / 2) for z in ref_zero_points]
self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype), atol=0.0001))
if qscheme == torch.per_channel_affine_float_qparams:
self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype), atol=1))
else:
self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))
# Test for serializability
state_dict = myobs.state_dict()
b = io.BytesIO()
torch.save(state_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
for key in state_dict:
self.assertEqual(state_dict[key], loaded_dict[key])
loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme)
loaded_obs.load_state_dict(loaded_dict)
loaded_qparams = loaded_obs.calculate_qparams()
self.assertEqual(myobs.min_vals, loaded_obs.min_vals)
self.assertEqual(myobs.max_vals, loaded_obs.max_vals)
self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
def test_observer_scriptable(self):
obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver(), MinMaxDynamicQuantObserver()]
for obs in obs_list:
scripted = torch.jit.script(obs)
x = torch.rand(3, 4)
obs(x)
scripted(x)
self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())
buf = io.BytesIO()
torch.jit.save(scripted, buf)
buf.seek(0)
loaded = torch.jit.load(buf)
self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@override_qengines
def test_state_dict_respects_device_affinity(self):
"""
Tests that loading from a state dict loads buffers to the correct
device.
"""
device_cpu = torch.device('cpu')
device_cuda = torch.device('cuda:0')
test_cases = itertools.product(
[device_cpu, device_cuda],
[device_cpu, device_cuda],
[MinMaxObserver, MovingAverageMinMaxObserver,
MinMaxDynamicQuantObserver, PerChannelMinMaxObserver,
MovingAveragePerChannelMinMaxObserver,
# TODO: enable this (separate PR)
# HistogramObserver,
PlaceholderObserver, RecordingObserver, NoopObserver])
for device_source, device_target, obs_cls in test_cases:
# calibrated source model
model = obs_cls()
model.to(device_source)
model(torch.randn(4, 1, 4, 4, device=device_source))
# target model
model2 = obs_cls()
model2.to(device_target)
model2.load_state_dict(model.state_dict())
# verify that buffers stayed on model2's device
model_devices = {p.device for p in model2.parameters()} | \
{p.device for p in model2.buffers()}
# some observers do not have any buffers, so lessEqual instead of
# Equal
self.assertLessEqual(len(model_devices), 1)
if len(model_devices) == 1:
model_device = next(iter(model_devices))
self.assertEqual(model_device, device_target)
def test_histogram_observer_consistent_buffer_shape(self):
"""
Ensures that the buffer shapes do not change from uninitialized to
initialized states for HistogramObserver.
"""
obs = HistogramObserver()
min_shape_before = obs.min_val.shape
max_shape_before = obs.max_val.shape
for _ in range(2):
obs(torch.randn(4, 4, 4, 4))
self.assertEqual(min_shape_before, obs.min_val.shape)
self.assertEqual(max_shape_before, obs.max_val.shape)
def test_histogram_observer_save_load_state_dict(self):
"""
Smoke test on saving/loading state_dict
"""
obs1 = HistogramObserver()
obs1(torch.randn(4, 4, 4, 4))
obs2 = HistogramObserver()
obs2.load_state_dict(obs1.state_dict())
self.assertEqual(obs2.min_val.shape, torch.Size([]))
self.assertEqual(obs2.max_val.shape, torch.Size([]))
# HistogramObserver that works like it does on master
class _ReferenceHistogramObserver(HistogramObserver):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@torch.jit.ignore
def _non_linear_param_search(self):
r"""Non-linear parameter search.
An approximation for L2 error minimization for selecting min/max.
By selecting new min/max, we filter out outliers in input distribution.
This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
caffe2/quantization/server/norm_minimization.cc
"""
def _get_norm(delta_begin, delta_end, density, norm_type):
r"""
Compute the norm of the values uniformaly distributed between
delta_begin and delta_end.
norm = density * (integral_{begin, end} x^2)
= density * (end^3 - begin^3) / 3
"""
assert norm_type == "L2", "Only L2 norms are currently supported"
norm = 0.0
if norm_type == "L2":
norm = (
delta_end * delta_end * delta_end
- delta_begin * delta_begin * delta_begin
) / 3
return density * norm
def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
r"""
Compute the quantization error if we use start_bin to end_bin as the
min and max to do the quantization.
"""
bin_width = (self.max_val.item() - self.min_val.item()) / self.bins
norm = 0.0
dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
if dst_bin_width == 0.0:
return 0.0
for src_bin in range(self.bins):
# distances from the beginning of first dst_bin to the beginning and
# end of src_bin
src_bin_begin = (src_bin - next_start_bin) * bin_width
src_bin_end = src_bin_begin + bin_width
# which dst_bins the beginning and end of src_bin belong to?
dst_bin_of_begin = min(
self.dst_nbins - 1, max(0.0, math.floor(src_bin_begin / dst_bin_width))
)
dst_bin_of_end = min(
self.dst_nbins - 1, max(0.0, math.floor(src_bin_end / dst_bin_width))
)
dst_bin_of_begin_center = (
dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
)
density = self.histogram[src_bin] / bin_width
if dst_bin_of_begin == dst_bin_of_end:
# if src_bin is entirely within 1 dst_bin
delta_begin = src_bin_begin - dst_bin_of_begin_center
delta_end = src_bin_end - dst_bin_of_begin_center
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
else:
delta_begin = src_bin_begin - dst_bin_of_begin_center
delta_end = dst_bin_width / 2
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
-dst_bin_width / 2, dst_bin_width / 2, density, norm_type
)
dst_bin_of_end_center = (
dst_bin_of_end * dst_bin_width + dst_bin_width / 2
)
delta_begin = -dst_bin_width / 2
delta_end = src_bin_end - dst_bin_of_end_center
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
return norm
assert self.histogram.size()[0] == self.bins, "bins mistmatch"
bin_width = (self.max_val - self.min_val) / self.bins
# cumulative sum
total = sum(self.histogram)
cSum = torch.cumsum(self.histogram, dim=0)
stepsize = 1e-5 # granularity
alpha = 0.0 # lower bound
beta = 1.0 # upper bound
start_bin = 0
end_bin = self.bins - 1
norm_min = float("inf")
while alpha < beta:
# Find the next step
next_alpha = alpha + stepsize
next_beta = beta - stepsize
# find the left and right bins between the quantile bounds
l = start_bin
r = end_bin
while l < end_bin and cSum[l] < next_alpha * total:
l = l + 1
while r > start_bin and cSum[r] > next_beta * total:
r = r - 1
# decide the next move
next_start_bin = start_bin
next_end_bin = end_bin
if (l - start_bin) > (end_bin - r):
# move the start bin
next_start_bin = l
alpha = next_alpha
else:
# move the end bin
next_end_bin = r
beta = next_beta
if next_start_bin == start_bin and next_end_bin == end_bin:
continue
# calculate the quantization error using next_start_bin and next_end_bin
norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
if norm > norm_min:
break
norm_min = norm
start_bin = next_start_bin
end_bin = next_end_bin
new_min = self.min_val + bin_width * start_bin
new_max = self.min_val + bin_width * (end_bin + 1)
return new_min, new_max
class TestRecordHistogramObserver(QuantizationTestCase):
# TODO: move this to quantize.py
def test_record_observer(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = AnnotatedSingleLayerLinearModel()
model.qconfig = default_debug_qconfig
model = prepare(model)
# run the evaluation and dump all tensors
test_only_eval_fn(model, self.calib_data)
test_only_eval_fn(model, self.calib_data)
observer_dict = {}
get_observer_dict(model, observer_dict)
self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(),
'observer is not recorded in the dict')
self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()),
2 * len(self.calib_data))
self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0],
model(self.calib_data[0][0]))
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)))
def test_observer_scriptable(self, qdtype, qscheme):
obs = RecordingObserver(dtype=qdtype, qscheme=qscheme)
scripted = torch.jit.script(obs)
x = torch.rand(3, 4)
obs(x)
scripted(x)
self.assertTrue(torch.equal(obs.get_tensor_value()[0], scripted.get_tensor_value()[0]))
buf = io.BytesIO()
torch.jit.save(scripted, buf)
buf.seek(0)
loaded = torch.jit.load(buf)
self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0]))
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
reduce_range=st.booleans())
@settings(max_examples=10)
def test_histogram_observer(self, qdtype, qscheme, reduce_range):
myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
# Calculate qparams should work for empty observers
qparams = myobs.calculate_qparams()
x = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
y = torch.tensor([5.0, 6.0, 7.0, 8.0])
out_x = myobs(x)
self.assertTrue(out_x.requires_grad)
myobs(y)
self.assertEqual(myobs.min_val, 2.0)
self.assertEqual(myobs.max_val, 8.0)
self.assertEqual(myobs.histogram, [2., 3., 3.])
qparams = myobs.calculate_qparams()
if reduce_range:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.0470588 * 255 / 127
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0235294 * 255 / 127
ref_zero_point = -64 if qdtype is torch.qint8 else 0
else:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.0470588
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0235294
ref_zero_point = -128 if qdtype is torch.qint8 else 0
self.assertEqual(qparams[1].item(), ref_zero_point)
self.assertEqual(qparams[0].item(), ref_scale, atol=1e-5, rtol=0)
# Test for serializability
state_dict = myobs.state_dict()
b = io.BytesIO()
torch.save(state_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
for key in state_dict:
self.assertEqual(state_dict[key], loaded_dict[key])
loaded_obs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
loaded_obs.load_state_dict(loaded_dict)
loaded_qparams = loaded_obs.calculate_qparams()
self.assertEqual(myobs.min_val, loaded_obs.min_val)
self.assertEqual(myobs.max_val, loaded_obs.max_val)
self.assertEqual(myobs.histogram, loaded_obs.histogram)
self.assertEqual(myobs.bins, loaded_obs.bins)
self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
def test_histogram_observer_one_sided(self):
myobs = HistogramObserver(bins=8, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
x = torch.tensor([0.0, 0.3, 1.2, 1.7])
y = torch.tensor([0.1, 1.3, 2.0, 2.7])
myobs(x)
myobs(y)
self.assertEqual(myobs.min_val, 0)
qparams = myobs.calculate_qparams()
self.assertEqual(qparams[1].item(), 0)
def test_histogram_observer_same_inputs(self):
myobs = HistogramObserver(bins=3, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
w = torch.ones(4, requires_grad=True)
x = torch.zeros(4, requires_grad=True)
y = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
z = torch.tensor([5.0, 6.0, 7.0, 8.0])
myobs(w)
myobs(x)
myobs(x)
myobs(y)
myobs(z)
qparams = myobs.calculate_qparams()
self.assertEqual(myobs.min_val, 2.0)
self.assertEqual(myobs.max_val, 8.0)
self.assertEqual(myobs.histogram, [2., 3., 3.])
@given(N=st.sampled_from([10, 1000, 10**6]),
bins=st.sampled_from([256, 512, 1024, 2048]),
dtype=st.sampled_from([torch.qint8, torch.quint8]),
qscheme=st.sampled_from([torch.per_tensor_affine, torch.per_tensor_symmetric]),
reduce_range=st.booleans())
def test_histogram_observer_against_reference(self, N, bins, dtype, qscheme, reduce_range):
ref_obs = _ReferenceHistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)
my_obs = HistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)
for _ in range(10):
X = torch.randn(N)
my_obs(X)
ref_obs(X)
ref_qparams = ref_obs.calculate_qparams()
my_qparams = my_obs.calculate_qparams()
self.assertEqual(ref_qparams, my_qparams)
class TestFakeQuantizePerTensor(TestCase):
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_forward_per_tensor(self, device, X):
r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max)
Y_prime = torch.fake_quantize_per_tensor_affine(
X, scale, zero_point, quant_min, quant_max)
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
@unittest.skip("temporarily disable the test")
def test_backward_per_tensor(self, device, X):
r"""Tests the backward method.
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
X.requires_grad_()
Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max)
Y_prime = torch.fake_quantize_per_tensor_affine(
X, scale, zero_point, quant_min, quant_max)
dout = torch.rand(X.shape, dtype=torch.float).to(device)
dX = _fake_quantize_per_tensor_affine_grad_reference(
dout, X, scale, zero_point, quant_min, quant_max)
Y_prime.backward(dout)
np.testing.assert_allclose(dX.cpu(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_learnable_py_module_forward_per_tensor(self, device, X):
r"""Tests the forward path of the _LearnableFakeQuantize module per tensor op.
"""
X, (scale, zero_point, torch_type) = X
scale = torch.tensor([scale]).to(device)
zero_point = torch.tensor([zero_point]).to(device)
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
Y = _fake_quantize_per_tensor_affine_reference(
X, scale, zero_point, quant_min, quant_max).to(device)
Y_prime = _LearnableFakeQuantizePerTensorOp.apply(
X, scale, zero_point, quant_min, quant_max, 1.).to(device)
tolerance = 1e-2
self.assertTrue(
torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance),
"Expected _LearnableFakeQuantizePerTensorOp to have results match the reference forward function")
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_learnable_py_module_backward_per_tensor(self, device, X):
X, (scale, zero_point, torch_type) = X
scale = torch.tensor([scale]).float().to(device)
zero_point = torch.tensor([zero_point]).float().to(device)
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
X.requires_grad_()
scale.requires_grad_()
zero_point.requires_grad_()
Y_prime = _LearnableFakeQuantizePerTensorOp.apply(
X, scale, zero_point, quant_min, quant_max, 1.)
dout = torch.rand(X.shape, dtype=torch.float).to(device)
dX, dScale, dZeroPoint = _fake_quantize_learnable_per_tensor_affine_grad_reference(
dout, X, scale, zero_point, quant_min, quant_max, device)
Y_prime.backward(dout)
expected_dX = dX.to(device).detach()
actual_dX = X.grad.to(device).detach()
expected_dScale = dScale.to(device).detach()
actual_dScale = scale.grad.to(device).detach()
expected_dZeroPoint = dZeroPoint.to(device).detach()
actual_dZeroPoint = zero_point.grad.to(device).detach()
tolerance = 1e-2
self.assertTrue(
torch.allclose(
expected_dX, actual_dX, rtol=tolerance, atol=tolerance),
"Expected dX to match X.grad")
self.assertTrue(
torch.allclose(
expected_dScale, actual_dScale, rtol=tolerance, atol=tolerance),
"Expected dScale to match scale.grad")
self.assertTrue(
torch.allclose(
expected_dZeroPoint, actual_dZeroPoint, rtol=tolerance, atol=tolerance),
"Expected dZeroPoint to match zero_point.grad")
def _test_learnable_forward_per_tensor(self, X, device, scale_base, zero_point_base):
X_base = torch.tensor(X).to(device)
for n_bits in (4, 8):
quant_min, quant_max = 0, 2 ** n_bits - 1
X = X_base.clone().float()
scale_base = scale_base.to(device).float()
zero_point_base = zero_point_base.to(dtype=torch.int64, device=device)
scale = scale_base.clone()
zero_point = zero_point_base.clamp(quant_min, quant_max)
Y = _fake_quantize_per_tensor_affine_reference(
X, scale, zero_point, quant_min, quant_max).to(device)
Y_prime = torch._fake_quantize_learnable_per_tensor_affine(
X, scale, zero_point, quant_min, quant_max).to(device)
self.assertTrue(
torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance),
"Expected kernel forward function to have results match the reference forward function")
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_learnable_forward_per_tensor_cpu(self, X):
X, (_, _, _) = X
scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
zero_point_base = torch.normal(mean=0, std=128, size=(1,))
self._test_learnable_forward_per_tensor(
X, 'cpu', scale_base, zero_point_base)
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
qparams=hu.qparams(dtypes=torch.quint8)))
@unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
def test_learnable_forward_per_tensor_cuda(self, X):
X, (_, _, _) = X
scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
zero_point_base = torch.normal(mean=0, std=128, size=(1,))
self._test_learnable_forward_per_tensor(
X, 'cuda', scale_base, zero_point_base)
def _test_learnable_backward_per_tensor(self, X, device, scale_base, zero_point_base):
r"""Tests the backward method with additional backprop support for scale and zero point.
"""
X_base = torch.tensor(X).to(device)
for n_bits in (4, 8):
quant_min, quant_max = 0, 2 ** n_bits - 1
X = X_base.clone().float().to(device)
X.requires_grad_()
scale_base = scale_base.to(device)
zero_point_base = zero_point_base.to(device)
scale = scale_base.clone()
scale.requires_grad_()
zero_point = zero_point_base.clone().clamp(quant_min, quant_max)
zero_point.requires_grad_()
Y_prime = torch._fake_quantize_learnable_per_tensor_affine(
X, scale, zero_point, quant_min, quant_max).to(device)
dout = torch.rand(X.shape, dtype=torch.float).to(device)
dX, dScale, dZeroPoint = _fake_quantize_learnable_per_tensor_affine_grad_reference(
dout, X, scale, zero_point, quant_min, quant_max, device)
Y_prime.backward(dout)
expected_dX = dX.to(device).detach()
actual_dX = X.grad.to(device).detach()
expected_dScale = dScale.to(device).detach()
actual_dScale = scale.grad.to(device).detach()
expected_dZeroPoint = dZeroPoint.to(device).detach()
actual_dZeroPoint = zero_point.grad.to(device).detach()
self.assertTrue(
torch.allclose(
expected_dX, actual_dX, rtol=tolerance, atol=tolerance),
"Expected dX to match X.grad")
self.assertTrue(
torch.allclose(
expected_dScale, actual_dScale, rtol=tolerance, atol=tolerance),
"Expected dScale to match scale.grad")
self.assertTrue(
torch.allclose(
expected_dZeroPoint, actual_dZeroPoint, rtol=tolerance, atol=tolerance),
"Expected dZeroPoint to match zero_point.grad")
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_learnable_backward_per_tensor_cpu(self, X):
torch.random.manual_seed(NP_RANDOM_SEED)
X, (_, _, _) = X
scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
zero_point_base = torch.normal(mean=0, std=128, size=(1,))
self._test_learnable_backward_per_tensor(
X, 'cpu', scale_base, zero_point_base)
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
qparams=hu.qparams(dtypes=torch.quint8)))
@unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
def test_learnable_backward_per_tensor_cuda(self, X):
torch.random.manual_seed(NP_RANDOM_SEED)
X, (_, _, _) = X
scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
zero_point_base = torch.normal(mean=0, std=128, size=(1,))
self._test_learnable_backward_per_tensor(
X, 'cuda', scale_base, zero_point_base)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
# https://github.com/pytorch/pytorch/issues/30604
@unittest.skip("temporarily disable the test")
def test_numerical_consistency_per_tensor(self, device, X):
r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
# quantize_per_tensor and dequantize are only implemented in CPU
Y = torch.dequantize(torch.quantize_per_tensor(X.cpu(), scale, zero_point, torch_type))
Y_prime = torch.fake_quantize_per_tensor_affine(
X, scale, zero_point, quant_min, quant_max)
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=[torch.quint8])),
)
def test_fq_module(self, device, X):
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
X.requires_grad_()
fq_module = torch.quantization.default_fake_quant().to(device)
Y_prime = fq_module(X)
assert fq_module.scale is not None
assert fq_module.zero_point is not None
Y = _fake_quantize_per_tensor_affine_reference(X, fq_module.scale, fq_module.zero_point, quant_min, quant_max)
np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
# Test backward
dout = torch.rand(X.shape, dtype=torch.float, device=device)
Y_prime.backward(dout)
dX = _fake_quantize_per_tensor_affine_grad_reference(dout, X, fq_module.scale, fq_module.zero_point, quant_min, quant_max)
np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
def test_fq_serializable(self):
observer = default_observer
quant_min = 0
quant_max = 255
fq_module = FakeQuantize(observer, quant_min, quant_max)
X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
y_ref = fq_module(X)
state_dict = fq_module.state_dict()
self.assertEqual(state_dict['scale'], 0.094488)
self.assertEqual(state_dict['zero_point'], 53)
b = io.BytesIO()
torch.save(state_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
loaded_fq_module = FakeQuantize(observer, quant_min, quant_max)
loaded_fq_module.load_state_dict(loaded_dict)
for key in state_dict:
self.assertEqual(state_dict[key], loaded_fq_module.state_dict()[key])
self.assertEqual(loaded_fq_module.calculate_qparams(), fq_module.calculate_qparams())
def test_fake_quant_control(self):
torch.manual_seed(42)
X = torch.rand(20, 10, dtype=torch.float32)
fq_module = torch.quantization.default_fake_quant()
# Output of fake quant is not identical to input
Y = fq_module(X)
self.assertNotEqual(Y, X)
torch.quantization.disable_fake_quant(fq_module)
X = torch.rand(20, 10, dtype=torch.float32)
Y = fq_module(X)
# Fake quant is disabled,output is identical to input
self.assertEqual(Y, X)
# Explicit copy at this point in time, because FakeQuant keeps internal
# state in mutable buffers.
scale = fq_module.scale.clone().detach()
zero_point = fq_module.zero_point.clone().detach()
torch.quantization.disable_observer(fq_module)
torch.quantization.enable_fake_quant(fq_module)
X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0
Y = fq_module(X)
self.assertNotEqual(Y, X)
# Observer is disabled, scale and zero-point do not change
self.assertEqual(fq_module.scale, scale)
self.assertEqual(fq_module.zero_point, zero_point)
torch.quantization.enable_observer(fq_module)
Y = fq_module(X)
self.assertNotEqual(Y, X)
# Observer is enabled, scale and zero-point are different
self.assertNotEqual(fq_module.scale, scale)
self.assertNotEqual(fq_module.zero_point, zero_point)
def test_fake_quant_preserves_qparam_shapes_for_activations(self):
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(4, 4)
def forward(self, x):
x = self.linear(x)
return x
m = Model()
m.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(m, inplace=True)
scale_shape_before = m.linear.activation_post_process.scale.shape
zero_point_shape_before = m.linear.activation_post_process.zero_point.shape
x = torch.rand(4, 4, 4, 4)
m(x)
scale_shape_after = m.linear.activation_post_process.scale.shape
zero_point_shape_after = m.linear.activation_post_process.zero_point.shape
self.assertEqual(
scale_shape_before, scale_shape_after,
msg="FakeQuant scale shape must stay consistent")
self.assertEqual(
zero_point_shape_before, zero_point_shape_after,
msg="FakeQuant zero_point shape must stay consistent")
def fake_quant_scriptable(self):
observer = default_observer
quant_min = 0
quant_max = 255
fq_module = FakeQuantize(observer, quant_min, quant_max)
scripted_module = torch.jit.script(fq_module)
X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
fq_module(X)
scripted_module(X)
self.assertEqual(fq_module.calculate_qparams(),
scripted_module.calculate_qparams())
buf = io.BytesIO()
torch.jit.save(scripted_module, buf)
buf.seek(0)
loaded_module = torch.jit.load(buf)
self.assertEqual(fq_module.calculate_qparams(),
loaded_module.calculate_qparams())
class TestFakeQuantizePerChannel(TestCase):
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_forward_per_channel(self, device, X):
r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, axis, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
scale = to_tensor(scale, device)
zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device)
Y = _fake_quantize_per_channel_affine_reference(X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
Y_prime = torch.fake_quantize_per_channel_affine(
X, scale, zero_point, axis, quant_min, quant_max)
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
def _test_learnable_forward_per_channel(self, X_base, device, scale_base, zero_point_base, axis):
r"""Tests the forward path of the learnable FakeQuantizePerTensorAffine op.
"""
for n_bits in (4, 8):
quant_min, quant_max = 0, 2 ** (n_bits) - 1
scale_base = scale_base.to(device)
zero_point_base = zero_point_base.clamp(quant_min, quant_max)
X_curr = X_base.clone()
scale_curr = scale_base.clone()
zero_point_curr = zero_point_base.to(dtype=torch.int64, device=device)
Y = _fake_quantize_per_channel_affine_reference(
X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max).to(device)
Y_prime = torch._fake_quantize_learnable_per_channel_affine(
X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max).to(device)
self.assertTrue(
torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance),
"Expected kernel forward function to have results match the reference forward function")
@given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_learnable_forward_per_channel_cpu(self, X):
torch.random.manual_seed(NP_RANDOM_SEED)
X, (_, _, axis, _) = X
X_base = torch.tensor(X).to('cpu')
channel_size = X_base.size(axis)
scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100)
zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,))
self._test_learnable_forward_per_channel(
X_base, 'cpu', scale_base, zero_point_base, axis)
@given(X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
@unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
def test_learnable_forward_per_channel_cuda(self, X):
torch.random.manual_seed(NP_RANDOM_SEED)
X, (_, _, axis, _) = X
X_base = torch.tensor(X).to('cuda')
channel_size = X_base.size(axis)
scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100)
zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,))
self._test_learnable_forward_per_channel(
X_base, 'cuda', scale_base, zero_point_base, axis)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_backward_per_channel(self, device, X):
r"""Tests the backward method.
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, axis, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
scale = to_tensor(scale, device)
zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device)
X.requires_grad_()
Y_prime = torch.fake_quantize_per_channel_affine(
X, scale, zero_point, axis, quant_min, quant_max)
dout = torch.rand(X.shape, dtype=torch.float).to(device)
dX = _fake_quantize_per_channel_affine_grad_reference(
dout, X, scale, zero_point, axis, quant_min, quant_max)
Y_prime.backward(dout)
np.testing.assert_allclose(dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
def _test_learnable_backward_per_channel(self, X_base, device, scale_base, zero_point_base, axis):
r"""Tests the backward path of the learnable FakeQuantizePerTensorAffine op.
"""
for n_bits in (4, 8):
quant_min, quant_max = 0, 2 ** n_bits - 1
scale_base = scale_base.to(device)
zero_point_base = zero_point_base.to(device=device)
X_curr = X_base.clone()
X_curr.requires_grad_()
scale_curr = scale_base.clone()
scale_curr.requires_grad_()
zero_point_curr = zero_point_base.clamp(quant_min, quant_max)
zero_point_curr.requires_grad_()
Y_prime = torch._fake_quantize_learnable_per_channel_affine(
X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max).to(device)
dout = torch.rand(X_curr.shape, dtype=torch.float).to(device)
dX, dScale, dZeroPoint = _fake_quantize_learnable_per_channel_affine_grad_reference(
dout, X_curr, scale_curr, zero_point_curr, axis, quant_min, quant_max, device)
Y_prime.backward(dout)
dX_expected = dX.to(device).detach()
dX_actual = X_curr.to(device).grad.detach()
dScale_expected = dScale.to(device).detach()
dScale_actual = scale_curr.to(device).grad.detach()
dZeroPoint_expected = dZeroPoint.to(device).detach()
dZeroPoint_actual = zero_point_curr.to(device).grad.detach()
tolerance = 1e-4
self.assertTrue(
torch.allclose(dX_expected, dX_actual, rtol=tolerance, atol=tolerance),
"Expected dX to match X.grad")
self.assertTrue(
torch.allclose(dScale_expected, dScale_actual, rtol=tolerance, atol=tolerance),
"Expected dScale to match scale.grad")
self.assertTrue(
torch.allclose(dZeroPoint_expected, dZeroPoint_actual, rtol=tolerance, atol=tolerance),
"Expected dZeroPoint to match zero_point.grad")
@given(X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_learnable_backward_per_channel_cpu(self, X):
torch.random.manual_seed(NP_RANDOM_SEED)
X, (_, _, axis, _) = X
X_base = torch.tensor(X).to('cpu')
channel_size = X_base.size(axis)
scale_base = torch.normal(mean=0, std=1, size=(channel_size,)).clamp(1e-4, 100)
zero_point_base = torch.normal(mean=0, std=128, size=(channel_size,))
self._test_learnable_backward_per_channel(
X_base, 'cpu', scale_base, zero_point_base, axis)
@given(X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
@unittest.skip("temporarily disable the test")
def test_learnable_backward_per_channel_cuda(self, X):
torch.random.manual_seed(NP_RANDOM_SEED)
X, (scale, zero_point, axis, torch_type) = X
X_base = torch.tensor(X).to('cuda')
scale_base = to_tensor(scale, 'cuda')
zero_point_base = to_tensor(zero_point, 'cuda')
self._test_learnable_backward_per_channel(
X_base, 'cuda', scale_base, zero_point_base, axis)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,),
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_learnable_py_module_forward_per_channel(self, device, X):
r"""Tests the forward path of the _LearnableFakeQuantizePerChannel op.
"""
X, (scale, zero_point, axis, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
scale = to_tensor(scale, device)
zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device)
Y = _fake_quantize_per_channel_affine_reference(
X, scale, zero_point, axis, quant_min, quant_max).to(device)
Y_prime = _LearnableFakeQuantizePerChannelOp.apply(
X, scale, zero_point, axis, quant_min, quant_max, 1.).to(device)
tolerance = 1e-2
self.assertTrue(
torch.allclose(Y, Y_prime, rtol=tolerance, atol=tolerance),
"Expected _LearnableFakeQuantizePerChannelOp to have results match the reference forward function")
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,),
elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_learnable_py_module_backward_per_channel(self, device, X):
r"""Tests the forward path of the _LearnableFakeQuantizePerChannel op.
"""
X, (scale, zero_point, axis, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device).float()
X.requires_grad_()
scale = to_tensor(scale, device).float()
scale.requires_grad_()
zero_point = torch.tensor(zero_point).to(device).float()
zero_point.requires_grad_()
Y_prime = _LearnableFakeQuantizePerChannelOp.apply(
X, scale, zero_point, axis, quant_min, quant_max, 1.).to(device)
dout = torch.rand(X.shape, dtype=torch.float).to(device)
dX, dScale, dZeroPoint = _fake_quantize_learnable_per_channel_affine_grad_reference(
dout, X, scale, zero_point, axis, quant_min, quant_max, device)
Y_prime.backward(dout)
dX_expected = dX.to(device).detach()
dX_actual = X.to(device).grad.detach()
dScale_expected = dScale.to(device).detach()
dScale_actual = scale.to(device).grad.detach()
dZeroPoint_expected = dZeroPoint.to(device).detach()
dZeroPoint_actual = zero_point.to(device).grad.detach()
tolerance = 1e-2
self.assertTrue(
torch.allclose(dX_expected, dX_actual, rtol=tolerance, atol=tolerance),
"Expected dX to match X.grad")
self.assertTrue(
torch.allclose(dScale_expected, dScale_actual, rtol=tolerance, atol=tolerance),
"Expected dScale to match scale.grad")
self.assertTrue(
torch.allclose(dZeroPoint_expected, dZeroPoint_actual, rtol=tolerance, atol=tolerance),
"Expected dZeroPoint to match zero_point.grad")
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
@unittest.skip("temporarily disable the test")
def test_numerical_consistency_per_channel(self, device, X):
r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, axis, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
scale = to_tensor(scale, device)
zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device)
# quantize_linear and dequantize are only implemented in CPU
Y = torch.dequantize(torch.quantize_per_channel(X.cpu(), scale.cpu(), zero_point.cpu(), axis, torch_type))
Y_prime = torch.fake_quantize_per_channel_affine(
X, scale, zero_point, axis, quant_min, quant_max)
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,),
qparams=hu.qparams(dtypes=torch.qint8)))
def test_fq_module(self, device, X):
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, axis, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
X.requires_grad_()
fq_module = FakeQuantize(default_per_channel_weight_observer, quant_min, quant_max, ch_axis=axis).to(device)
Y_prime = fq_module(X)
assert fq_module.scale is not None
assert fq_module.zero_point is not None
Y = _fake_quantize_per_channel_affine_reference(X, fq_module.scale,
fq_module.zero_point, axis, quant_min, quant_max)
np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
# Test backward
dout = torch.rand(X.shape, dtype=torch.float, device=device)
Y_prime.backward(dout)
dX = _fake_quantize_per_channel_affine_grad_reference(dout, X, fq_module.scale,
fq_module.zero_point, axis, quant_min, quant_max)
np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
def test_fq_serializable(self):
observer = default_per_channel_weight_observer
quant_min = -128
quant_max = 127
fq_module = FakeQuantize(observer, quant_min, quant_max)
X = torch.tensor([[-5, -3.5, -2, 0, 3, 5, 7], [1, 3, 2, 5, 6.5, 8, 10]], dtype=torch.float32)
y_ref = fq_module(X)
state_dict = fq_module.state_dict()
self.assertEqual(state_dict['scale'], [0.054902, 0.078431])
self.assertEqual(state_dict['zero_point'], [0, 0])
b = io.BytesIO()
torch.save(state_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
for key in state_dict:
self.assertEqual(state_dict[key], loaded_dict[key])
def _get_buffer_ids(module):
"""
Object addresses stay constant if and only if all modifications are in-place
"""
return [id(v) for k, v in module._buffers.items()]
class TestDistributed(QuantizationTestCase):
def test_observers_preserve_buffers(self):
"""
Tests that observers only modify buffers in place. Note: this is important
because nn.DataParallel depends on this assumption to work correctly.
However, DataParallel does not expose IDs of the replicas, so we test it
without DataParallel in order to easily access the object IDs.
"""
observer_types = [
torch.quantization.MinMaxObserver.with_args(dtype=torch.qint8),
torch.quantization.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8),
torch.quantization.MinMaxDynamicQuantObserver.with_args(dtype=torch.qint8),
torch.quantization.PerChannelMinMaxObserver.with_args(dtype=torch.qint8),
torch.quantization.MovingAveragePerChannelMinMaxObserver.with_args(dtype=torch.qint8),
torch.quantization.HistogramObserver.with_args(dtype=torch.qint8),
torch.quantization.RecordingObserver.with_args(dtype=torch.qint8),
torch.quantization.PlaceholderObserver.with_args(dtype=torch.float16),
]
for observer_type in observer_types:
observer = observer_type()
buffer_ids_before = _get_buffer_ids(observer)
for _i in range(5):
inputs = torch.rand((4, 4, 4))
observer(inputs)
buffer_ids_after = _get_buffer_ids(observer)
self.assertEqual(
buffer_ids_before,
buffer_ids_after,
msg="{}: Buffers must be modified in place".format(str(observer)))
def test_fake_quant_preserves_buffers(self):
"""
Tests that fake quant only modifies buffers in place. Note: this is important
because nn.DataParallel depends on this assumption to work correctly.
However, DataParallel does not expose IDs of the replicas, so we test it
without DataParallel in order to easily access the object IDs.
"""
model = torch.quantization.FakeQuantize()
buffer_ids_before = _get_buffer_ids(model)
for _i in range(5):
inputs = torch.rand((4, 4, 4))
model(inputs)
model.apply(torch.quantization.enable_fake_quant)
model.apply(torch.quantization.disable_fake_quant)
model.apply(torch.quantization.enable_observer)
model.apply(torch.quantization.disable_observer)
buffer_ids_after = _get_buffer_ids(model)
self.assertEqual(
buffer_ids_before,
buffer_ids_after,
msg="FakeQuant: Buffers must be modified in place")
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_qat_data_parallel(self):
"""
Tests that doing QAT in nn.DataParallel does not crash.
"""
if 'fbgemm' not in torch.backends.quantized.supported_engines:
return
with override_quantized_engine('fbgemm'):
device = torch.device('cuda')
model = nn.Sequential(
torch.quantization.QuantStub(),
nn.Conv2d(3, 1, 1, bias=False),
nn.BatchNorm2d(1),
nn.ReLU(),
nn.Conv2d(1, 2, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(2),
nn.AvgPool2d(14),
nn.Sigmoid(),
torch.quantization.DeQuantStub(),
)
torch.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)
model = nn.DataParallel(model, device_ids=[0, 1])
model.to(device)
model.train()
for epoch in range(3):
inputs = torch.rand(2, 3, 28, 28).to(device)
model(inputs)
if epoch >= 1:
model.apply(torch.quantization.disable_observer)
if epoch >= 2:
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
quant_model = copy.deepcopy(model.module)
quant_model = torch.quantization.convert(quant_model.eval().cpu(), inplace=False)
with torch.no_grad():
out = quant_model(torch.rand(1, 3, 28, 28))
def test_qat_convbn_fused_syncbn_replacement(self):
"""
Tests that SyncBatchNorm replacement works for fused ConvBN.
"""
if 'fbgemm' not in torch.backends.quantized.supported_engines:
return
with override_quantized_engine('fbgemm'):
# create conv-bn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Conv2d(4, 1, 3, padding=1)
self.bn = nn.BatchNorm2d(1)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
model = Model()
# fuse it
fused_model = torch.quantization.fuse_modules(
model,
[['conv', 'bn']],
)
# convert to QAT
fused_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare_qat(fused_model, inplace=True)
# replace with DDP
fused_model = nn.SyncBatchNorm.convert_sync_batchnorm(fused_model)
self.assertTrue(
isinstance(fused_model.conv.bn, nn.SyncBatchNorm),
"Expected BN to be converted to SyncBN")
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@override_qengines
def test_device_affinity(self):
"""
Tests that converting a model to QAT respects device affinity
"""
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Conv2d(1, 1, 1)
self.bn = nn.BatchNorm2d(1)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
model = Model()
model.qconfig = torch.quantization.get_default_qat_qconfig(torch.backends.quantized.engine)
device = torch.device('cuda:0')
model.to(device)
torch.quantization.prepare_qat(model, inplace=True)
model_devices = {p.device for p in model.parameters()} | \
{p.device for p in model.buffers()}
self.assertEqual(len(model_devices), 1)
model_device = next(iter(model_devices))
self.assertEqual(model_device, device)
# ensure that running an input on CUDA works without any needed changes
input = torch.randn(4, 1, 4, 4, device=device)
model(input)