[quant][refactor tests] Move qtensor serialization tests from test_deprecated_jit (#59089)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59089
Move these tests into test_quantized_tensor
Test Plan:
python test/test_quantization.py
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D28750065
fbshipit-source-id: 5c4350d49dd07710b86ba330de80369403c6013c
diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py
index 4a89332..70f309f 100644
--- a/test/quantization/core/test_quantized_tensor.py
+++ b/test/quantization/core/test_quantized_tensor.py
@@ -6,7 +6,7 @@
from copy import deepcopy
from hypothesis import given
from hypothesis import strategies as st
-
+from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM
import torch.testing._internal.hypothesis_utils as hu
@@ -903,3 +903,64 @@
ref = param_search_greedy(x.numpy(), bit_rate=bit_width)
self.assertEqual(y[0].numpy(), ref[0])
self.assertEqual(y[1].numpy(), ref[1])
+
+ def _test_pickle_checkpoint_qtensor(self, device):
+ with TemporaryFileName() as fname:
+ class M(torch.jit.ScriptModule):
+ __constants__ = ['fname']
+
+ def __init__(self):
+ super(M, self).__init__()
+ self.fname = fname
+
+ @torch.jit.script_method
+ def forward(self, x, y):
+ torch.save((x, y), self.fname)
+ return y
+
+ q = torch.quantize_per_tensor(
+ torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device)
+ qc = torch.quantize_per_channel(
+ torch.rand(2, 3, dtype=torch.float),
+ scales=torch.tensor([0.1, 0.5, 0.01]),
+ zero_points=torch.tensor([10, 0, 20]),
+ axis=1, dtype=torch.quint8).to(device)
+ m = M()
+ m(q, qc)
+ with open(fname, "rb") as handle:
+ loaded_q, loaded_qc = torch.load(fname)
+ self.assertEqual(loaded_q, q)
+ self.assertEqual(loaded_qc, qc)
+
+ def test_pickle_checkpoint_qtensor(self):
+ self._test_pickle_checkpoint_qtensor('cpu')
+
+ def test_jit_serialization(self):
+ class SimpleQTensor(torch.jit.ScriptModule):
+ def __init__(self, per_channel):
+ super(SimpleQTensor, self).__init__()
+ x = torch.rand(5, 5).float()
+ if not per_channel:
+ x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8)
+ else:
+ s = torch.rand(5, dtype=torch.float64) + 0.1
+ zp = torch.randint(5, 15, (5,))
+ x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8)
+ self.register_buffer('x', x_q)
+
+ @torch.jit.script_method
+ def forward(self):
+ return self.x
+
+ for per_channel in [False, True]:
+ model = SimpleQTensor(per_channel)
+ buffer = io.BytesIO()
+ torch.jit.save(model, buffer)
+ buffer.seek(0)
+ model_loaded = torch.jit.load(buffer)
+ self.assertEqual(model_loaded(), model())
+
+if __name__ == '__main__':
+ raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
+ "\tpython test/test_quantization.py TESTNAME\n\n"
+ "instead.")
diff --git a/test/quantization/jit/test_deprecated_jit_quant.py b/test/quantization/jit/test_deprecated_jit_quant.py
index d98778f..662ead3 100644
--- a/test/quantization/jit/test_deprecated_jit_quant.py
+++ b/test/quantization/jit/test_deprecated_jit_quant.py
@@ -2,15 +2,12 @@
from torch.testing._internal.common_quantization import (
skipIfNoFBGEMM
)
-from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.common_utils import suppress_warnings
from torch.testing._internal.jit_utils import JitTestCase
from typing import Tuple
import copy
-import io
-# TODO: Move some tensor tests here like test_serialize_qtensor to test_quantize_tensor.py
class TestDeprecatedJitQuantized(JitTestCase):
@skipIfNoFBGEMM
def test_rnn_cell_quantized(self):
@@ -258,62 +255,6 @@
torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3)
torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3)
- def _test_pickle_checkpoint_qtensor(self, device):
- with TemporaryFileName() as fname:
- class M(torch.jit.ScriptModule):
- __constants__ = ['fname']
-
- def __init__(self):
- super(M, self).__init__()
- self.fname = fname
-
- @torch.jit.script_method
- def forward(self, x, y):
- torch.save((x, y), self.fname)
- return y
-
- q = torch.quantize_per_tensor(
- torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device)
- qc = torch.quantize_per_channel(
- torch.rand(2, 3, dtype=torch.float),
- scales=torch.tensor([0.1, 0.5, 0.01]),
- zero_points=torch.tensor([10, 0, 20]),
- axis=1, dtype=torch.quint8).to(device)
- m = M()
- m(q, qc)
- with open(fname, "rb") as handle:
- loaded_q, loaded_qc = torch.load(fname)
- self.assertEqual(loaded_q, q)
- self.assertEqual(loaded_qc, qc)
-
- def test_pickle_checkpoint_qtensor(self):
- self._test_pickle_checkpoint_qtensor('cpu')
-
- def test_serialize_qtensor(self):
- class SimpleQTensor(torch.jit.ScriptModule):
- def __init__(self, per_channel):
- super(SimpleQTensor, self).__init__()
- x = torch.rand(5, 5).float()
- if not per_channel:
- x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8)
- else:
- s = torch.rand(5, dtype=torch.float64) + 0.1
- zp = torch.randint(5, 15, (5,))
- x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8)
- self.register_buffer('x', x_q)
-
- @torch.jit.script_method
- def forward(self):
- return self.x
-
- for per_channel in [False, True]:
- model = SimpleQTensor(per_channel)
- buffer = io.BytesIO()
- torch.jit.save(model, buffer)
- buffer.seek(0)
- model_loaded = torch.jit.load(buffer)
- self.assertEqual(model_loaded(), model())
-
@skipIfNoFBGEMM
def test_erase_class_tensor_shapes(self):
class Linear(torch.nn.Module):
diff --git a/test/test_quantization.py b/test/test_quantization.py
index 6627a04..0821fa1 100644
--- a/test/test_quantization.py
+++ b/test/test_quantization.py
@@ -10,6 +10,7 @@
# - quantized tensor
# 1. Quantized Kernels
+# TODO: merge the different quantized op tests into one test class
from quantization.core.test_quantized_op import TestQuantizedOps # noqa: F401
from quantization.core.test_quantized_op import TestQNNPackOps # noqa: F401
from quantization.core.test_quantized_op import TestQuantizedLinear # noqa: F401