Guard against copying from quantized Tensor to non-quantized Tensor (#29660)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29660
att
Test Plan:
python test/test_quantized_tensor.py
Imported from OSS
Differential Revision: D18799897
fbshipit-source-id: 5d1b4ef84f5ae8eba830784b74485d78fa1e6fcf
diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp
index c849e0b..7cd2684 100644
--- a/aten/src/ATen/native/Copy.cpp
+++ b/aten/src/ATen/native/Copy.cpp
@@ -122,6 +122,10 @@
self.set_quantizer_(src.quantizer());
}
+ if (!self.is_quantized() && src.is_quantized()) {
+ TORCH_CHECK(false, "Copying from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor");
+ }
+
auto iter = TensorIterator();
iter.set_check_mem_overlap(true);
iter.add_output(self);
diff --git a/test/test_quantized_tensor.py b/test/test_quantized_tensor.py
index d8c2943..915f946 100644
--- a/test/test_quantized_tensor.py
+++ b/test/test_quantized_tensor.py
@@ -267,6 +267,12 @@
qc = deepcopy(q)
self.assertEqual(qc, q)
+ # can't copy from quantized tensor to non-quantized tensor
+ r = torch.empty([numel], dtype=torch.float)
+ q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=torch.quint8)
+ with self.assertRaisesRegex(RuntimeError, "please use dequantize"):
+ r.copy_(q)
+
def test_qtensor_clone(self):
numel = 10
scale = 0.5