[quant] aten::repeat work for quantized tensor (#40644)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40644

Test Plan: Imported from OSS

Differential Revision: D22268558

fbshipit-source-id: 3bc9a129bece1b547c519772ecc6b980780fb904
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 015df4f..17377fb 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -762,7 +762,12 @@
 
   Tensor xtensor = self.expand(padded_size);
 
-  Tensor result = at::empty(target_size, self.options());
+  Tensor result;
+  if (self.is_quantized()) {
+    result = at::empty_quantized(target_size, self);
+  } else {
+    result = at::empty(target_size, self.options());
+  }
 
   // return an empty tensor if one of the repeat dimensions is zero
   if (zero_tensor) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 403260d..eeb6163 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1116,6 +1116,11 @@
   variants: method
   device_guard: False
 
+- func: empty_quantized(int[] size, Tensor qtensor) -> Tensor
+  variants: function
+  dispatch:
+    QuantizedCPU, QuantizedCUDA: empty_quantized
+
 - func: empty.out(int[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
   device_guard: False
 
@@ -4970,6 +4975,7 @@
   device_guard: False
   dispatch:
     CPU, CUDA: unfold
+    QuantizedCPU, QuantizedCUDA: unfold
 
 - func: unfold_backward(Tensor grad_in, int[] input_sizes, int dim, int size, int step) -> Tensor
   variants: function
diff --git a/aten/src/ATen/native/quantized/TensorFactories.cpp b/aten/src/ATen/native/quantized/TensorFactories.cpp
index 0e5c7e8..d3e368a 100644
--- a/aten/src/ATen/native/quantized/TensorFactories.cpp
+++ b/aten/src/ATen/native/quantized/TensorFactories.cpp
@@ -76,5 +76,28 @@
   TORCH_CHECK(false, "Creation of quantized tensor requires quantized dtype like torch.quint8");
 }
 
+// Create an empty quantized Tensor with size, based on the options
+// and quantization parameters of the input quantized Tensor
+Tensor empty_quantized(IntArrayRef size, const Tensor& qtensor) {
+  Tensor output;
+  if (qtensor.qscheme() == kPerTensorAffine) {
+    output = at::_empty_affine_quantized(size, qtensor.options(),
+                                         qtensor.q_scale(),
+                                         qtensor.q_zero_point());
+  } else if (qtensor.qscheme() == kPerChannelAffine) {
+    output = at::_empty_per_channel_affine_quantized(
+        size,
+        qtensor.q_per_channel_scales(),
+        qtensor.q_per_channel_zero_points(),
+        qtensor.q_per_channel_axis(),
+        qtensor.options());
+  } else {
+    TORCH_CHECK(false,
+                "QScheme not supported by empty_quantized:",
+                toString(qtensor.qscheme()));
+  }
+  return output;
+}
+
 } // namespace native
 } // namespace at
diff --git a/test/quantization/test_quantized_tensor.py b/test/quantization/test_quantized_tensor.py
index 018f1f5..3043aa1 100644
--- a/test/quantization/test_quantized_tensor.py
+++ b/test/quantization/test_quantized_tensor.py
@@ -526,6 +526,17 @@
         with self.assertRaisesRegex(RuntimeError, "Squeeze is only possible on non-axis dimension for Per-Channel"):
             qz = qy.squeeze()
 
+    def test_repeat(self):
+        scale, zero_point, dtype = 1.0, 2, torch.uint8
+        for device in get_supported_device_types():
+            q_int = torch.randint(0, 100, [3], dtype=dtype, device=device)
+            q_int_repeat = q_int.repeat(4, 2)
+            q_ref = torch._make_per_tensor_quantized_tensor(q_int_repeat, scale=scale, zero_point=zero_point)
+
+            q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
+            q_repeat = q.repeat(4, 2)
+            self.assertEqual(q_ref, q_repeat)
+
     def test_qscheme_pickle(self):
         f = Foo()
         buf = io.BytesIO()
diff --git a/torch/_overrides.py b/torch/_overrides.py
index d320eb5..c5eb8e9 100644
--- a/torch/_overrides.py
+++ b/torch/_overrides.py
@@ -102,6 +102,7 @@
         torch.empty,
         torch.empty_meta,
         torch.empty_strided,
+        torch.empty_quantized,
         torch.eye,
         torch.from_file,
         torch.full,