| import torch |
| import torch.nn.quantized.functional as F |
| |
| import numpy as np |
| from common_utils import TestCase, run_tests |
| |
| def _quantize(x, scale, zero_point, qmin=0, qmax=255): |
| """Quantizes a numpy array.""" |
| qx = np.round(x / scale + zero_point) |
| qx = np.clip(qx, qmin, qmax).astype(np.uint8) |
| return qx |
| |
| class FunctionalAPITest(TestCase): |
| def test_functional_api(self): |
| X = torch.arange(-5, 5, dtype=torch.float) |
| scale = 2.0 |
| zero_point = 1 |
| Y = X.numpy().copy() |
| Y[Y < 0] = 0 |
| qY = _quantize(Y, scale, zero_point) |
| qX = X.quantize_linear(scale=scale, zero_point=zero_point, dtype=torch.quint8) |
| qY_hat = F.relu(qX) |
| np.testing.assert_equal(qY, qY_hat.int_repr()) |
| |
| if __name__ == '__main__': |
| run_tests() |