[quant] aten::repeat work for quantized tensor (#40644)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40644
Test Plan: Imported from OSS
Differential Revision: D22268558
fbshipit-source-id: 3bc9a129bece1b547c519772ecc6b980780fb904
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 015df4f..17377fb 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -762,7 +762,12 @@
Tensor xtensor = self.expand(padded_size);
- Tensor result = at::empty(target_size, self.options());
+ Tensor result;
+ if (self.is_quantized()) {
+ result = at::empty_quantized(target_size, self);
+ } else {
+ result = at::empty(target_size, self.options());
+ }
// return an empty tensor if one of the repeat dimensions is zero
if (zero_tensor) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 403260d..eeb6163 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1116,6 +1116,11 @@
variants: method
device_guard: False
+- func: empty_quantized(int[] size, Tensor qtensor) -> Tensor
+ variants: function
+ dispatch:
+ QuantizedCPU, QuantizedCUDA: empty_quantized
+
- func: empty.out(int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
device_guard: False
@@ -4970,6 +4975,7 @@
device_guard: False
dispatch:
CPU, CUDA: unfold
+ QuantizedCPU, QuantizedCUDA: unfold
- func: unfold_backward(Tensor grad_in, int[] input_sizes, int dim, int size, int step) -> Tensor
variants: function
diff --git a/aten/src/ATen/native/quantized/TensorFactories.cpp b/aten/src/ATen/native/quantized/TensorFactories.cpp
index 0e5c7e8..d3e368a 100644
--- a/aten/src/ATen/native/quantized/TensorFactories.cpp
+++ b/aten/src/ATen/native/quantized/TensorFactories.cpp
@@ -76,5 +76,28 @@
TORCH_CHECK(false, "Creation of quantized tensor requires quantized dtype like torch.quint8");
}
+// Create an empty quantized Tensor with size, based on the options
+// and quantization parameters of the input quantized Tensor
+Tensor empty_quantized(IntArrayRef size, const Tensor& qtensor) {
+ Tensor output;
+ if (qtensor.qscheme() == kPerTensorAffine) {
+ output = at::_empty_affine_quantized(size, qtensor.options(),
+ qtensor.q_scale(),
+ qtensor.q_zero_point());
+ } else if (qtensor.qscheme() == kPerChannelAffine) {
+ output = at::_empty_per_channel_affine_quantized(
+ size,
+ qtensor.q_per_channel_scales(),
+ qtensor.q_per_channel_zero_points(),
+ qtensor.q_per_channel_axis(),
+ qtensor.options());
+ } else {
+ TORCH_CHECK(false,
+ "QScheme not supported by empty_quantized:",
+ toString(qtensor.qscheme()));
+ }
+ return output;
+}
+
} // namespace native
} // namespace at
diff --git a/test/quantization/test_quantized_tensor.py b/test/quantization/test_quantized_tensor.py
index 018f1f5..3043aa1 100644
--- a/test/quantization/test_quantized_tensor.py
+++ b/test/quantization/test_quantized_tensor.py
@@ -526,6 +526,17 @@
with self.assertRaisesRegex(RuntimeError, "Squeeze is only possible on non-axis dimension for Per-Channel"):
qz = qy.squeeze()
+ def test_repeat(self):
+ scale, zero_point, dtype = 1.0, 2, torch.uint8
+ for device in get_supported_device_types():
+ q_int = torch.randint(0, 100, [3], dtype=dtype, device=device)
+ q_int_repeat = q_int.repeat(4, 2)
+ q_ref = torch._make_per_tensor_quantized_tensor(q_int_repeat, scale=scale, zero_point=zero_point)
+
+ q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
+ q_repeat = q.repeat(4, 2)
+ self.assertEqual(q_ref, q_repeat)
+
def test_qscheme_pickle(self):
f = Foo()
buf = io.BytesIO()
diff --git a/torch/_overrides.py b/torch/_overrides.py
index d320eb5..c5eb8e9 100644
--- a/torch/_overrides.py
+++ b/torch/_overrides.py
@@ -102,6 +102,7 @@
torch.empty,
torch.empty_meta,
torch.empty_strided,
+ torch.empty_quantized,
torch.eye,
torch.from_file,
torch.full,