[quant][refactor tests] Move qtensor serialization tests from test_deprecated_jit (#59089)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59089

Move these tests into test_quantized_tensor

Test Plan:
python test/test_quantization.py

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D28750065

fbshipit-source-id: 5c4350d49dd07710b86ba330de80369403c6013c
diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py
index 4a89332..70f309f 100644
--- a/test/quantization/core/test_quantized_tensor.py
+++ b/test/quantization/core/test_quantized_tensor.py
@@ -6,7 +6,7 @@
 from copy import deepcopy
 from hypothesis import given
 from hypothesis import strategies as st
-
+from torch.testing._internal.common_utils import TemporaryFileName
 from torch.testing._internal.common_cuda import TEST_CUDA
 from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM
 import torch.testing._internal.hypothesis_utils as hu
@@ -903,3 +903,64 @@
             ref = param_search_greedy(x.numpy(), bit_rate=bit_width)
             self.assertEqual(y[0].numpy(), ref[0])
             self.assertEqual(y[1].numpy(), ref[1])
+
+    def _test_pickle_checkpoint_qtensor(self, device):
+        with TemporaryFileName() as fname:
+            class M(torch.jit.ScriptModule):
+                __constants__ = ['fname']
+
+                def __init__(self):
+                    super(M, self).__init__()
+                    self.fname = fname
+
+                @torch.jit.script_method
+                def forward(self, x, y):
+                    torch.save((x, y), self.fname)
+                    return y
+
+            q = torch.quantize_per_tensor(
+                torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device)
+            qc = torch.quantize_per_channel(
+                torch.rand(2, 3, dtype=torch.float),
+                scales=torch.tensor([0.1, 0.5, 0.01]),
+                zero_points=torch.tensor([10, 0, 20]),
+                axis=1, dtype=torch.quint8).to(device)
+            m = M()
+            m(q, qc)
+            with open(fname, "rb") as handle:
+                loaded_q, loaded_qc = torch.load(fname)
+                self.assertEqual(loaded_q, q)
+                self.assertEqual(loaded_qc, qc)
+
+    def test_pickle_checkpoint_qtensor(self):
+        self._test_pickle_checkpoint_qtensor('cpu')
+
+    def test_jit_serialization(self):
+        class SimpleQTensor(torch.jit.ScriptModule):
+            def __init__(self, per_channel):
+                super(SimpleQTensor, self).__init__()
+                x = torch.rand(5, 5).float()
+                if not per_channel:
+                    x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8)
+                else:
+                    s = torch.rand(5, dtype=torch.float64) + 0.1
+                    zp = torch.randint(5, 15, (5,))
+                    x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8)
+                self.register_buffer('x', x_q)
+
+            @torch.jit.script_method
+            def forward(self):
+                return self.x
+
+        for per_channel in [False, True]:
+            model = SimpleQTensor(per_channel)
+            buffer = io.BytesIO()
+            torch.jit.save(model, buffer)
+            buffer.seek(0)
+            model_loaded = torch.jit.load(buffer)
+            self.assertEqual(model_loaded(), model())
+
+if __name__ == '__main__':
+    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
+                       "\tpython test/test_quantization.py TESTNAME\n\n"
+                       "instead.")
diff --git a/test/quantization/jit/test_deprecated_jit_quant.py b/test/quantization/jit/test_deprecated_jit_quant.py
index d98778f..662ead3 100644
--- a/test/quantization/jit/test_deprecated_jit_quant.py
+++ b/test/quantization/jit/test_deprecated_jit_quant.py
@@ -2,15 +2,12 @@
 from torch.testing._internal.common_quantization import (
     skipIfNoFBGEMM
 )
-from torch.testing._internal.common_utils import TemporaryFileName
 from torch.testing._internal.common_utils import suppress_warnings
 from torch.testing._internal.jit_utils import JitTestCase
 
 from typing import Tuple
 import copy
-import io
 
-# TODO: Move some tensor tests here like test_serialize_qtensor to test_quantize_tensor.py
 class TestDeprecatedJitQuantized(JitTestCase):
     @skipIfNoFBGEMM
     def test_rnn_cell_quantized(self):
@@ -258,62 +255,6 @@
             torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3)
             torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3)
 
-    def _test_pickle_checkpoint_qtensor(self, device):
-        with TemporaryFileName() as fname:
-            class M(torch.jit.ScriptModule):
-                __constants__ = ['fname']
-
-                def __init__(self):
-                    super(M, self).__init__()
-                    self.fname = fname
-
-                @torch.jit.script_method
-                def forward(self, x, y):
-                    torch.save((x, y), self.fname)
-                    return y
-
-            q = torch.quantize_per_tensor(
-                torch.rand(2, 3, dtype=torch.float), scale=0.1, zero_point=10, dtype=torch.quint8).to(device)
-            qc = torch.quantize_per_channel(
-                torch.rand(2, 3, dtype=torch.float),
-                scales=torch.tensor([0.1, 0.5, 0.01]),
-                zero_points=torch.tensor([10, 0, 20]),
-                axis=1, dtype=torch.quint8).to(device)
-            m = M()
-            m(q, qc)
-            with open(fname, "rb") as handle:
-                loaded_q, loaded_qc = torch.load(fname)
-                self.assertEqual(loaded_q, q)
-                self.assertEqual(loaded_qc, qc)
-
-    def test_pickle_checkpoint_qtensor(self):
-        self._test_pickle_checkpoint_qtensor('cpu')
-
-    def test_serialize_qtensor(self):
-        class SimpleQTensor(torch.jit.ScriptModule):
-            def __init__(self, per_channel):
-                super(SimpleQTensor, self).__init__()
-                x = torch.rand(5, 5).float()
-                if not per_channel:
-                    x_q = torch.quantize_per_tensor(x, 0.2, 10, torch.quint8)
-                else:
-                    s = torch.rand(5, dtype=torch.float64) + 0.1
-                    zp = torch.randint(5, 15, (5,))
-                    x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8)
-                self.register_buffer('x', x_q)
-
-            @torch.jit.script_method
-            def forward(self):
-                return self.x
-
-        for per_channel in [False, True]:
-            model = SimpleQTensor(per_channel)
-            buffer = io.BytesIO()
-            torch.jit.save(model, buffer)
-            buffer.seek(0)
-            model_loaded = torch.jit.load(buffer)
-            self.assertEqual(model_loaded(), model())
-
     @skipIfNoFBGEMM
     def test_erase_class_tensor_shapes(self):
         class Linear(torch.nn.Module):
diff --git a/test/test_quantization.py b/test/test_quantization.py
index 6627a04..0821fa1 100644
--- a/test/test_quantization.py
+++ b/test/test_quantization.py
@@ -10,6 +10,7 @@
 # - quantized tensor
 
 # 1. Quantized Kernels
+# TODO: merge the different quantized op tests into one test class
 from quantization.core.test_quantized_op import TestQuantizedOps  # noqa: F401
 from quantization.core.test_quantized_op import TestQNNPackOps  # noqa: F401
 from quantization.core.test_quantized_op import TestQuantizedLinear  # noqa: F401