blob: 1ddc24bd5138f28bb59bc9a3c4b8c7ba2ee73d72 [file] [log] [blame]
import torch
import torch.jit
import numpy as np
import unittest
from caffe2.python import core
from common_utils import TestCase, run_tests
def canonical(graph):
return str(torch._C._jit_pass_canonicalize(graph))
@unittest.skipIf("Relu_ENGINE_DNNLOWP" not in core._REGISTERED_OPERATORS, "fbgemm-based Caffe2 ops are not linked")
class TestQuantized(TestCase):
def test_relu(self):
a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
r = torch.ops.c10.quantized_relu(a)
np.testing.assert_equal(r[0].numpy(), torch.tensor([5, 6, 5, 10], dtype=torch.uint8).numpy())
np.testing.assert_almost_equal(0.01, r[1])
self.assertEqual(5, r[2])
def test_quantize(self):
a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
r = torch.ops.c10.dequantize(a)
np.testing.assert_almost_equal(r.numpy(), [-0.01, 0.01, -0.04, 0.05])
# default args
q_def = torch.ops.c10.quantize(r)
# specified
q = torch.ops.c10.quantize(r, scale=0.01, zero_point=5)
np.testing.assert_equal(q[0].numpy(), a[0].numpy())
np.testing.assert_almost_equal(q[1], a[1])
self.assertEqual(q[2], a[2])
def test_script(self):
@torch.jit.script
def foo(x):
# type: (Tuple[Tensor, float, int]) -> Tuple[Tensor, float, int]
return torch.ops.c10.quantized_relu(x)
self.assertExpectedInline(canonical(foo.graph), '''\
graph(%x : (Tensor, float, int)):
%1 : (Tensor, float, int) = c10::quantized_relu(%x)
return (%1)
''')
if __name__ == '__main__':
run_tests()