[quant] Quantized flip dispatch (#46235)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46235
Test Plan: Imported from OSS
Reviewed By: vkuzo
Differential Revision: D24689161
Pulled By: z-a-f
fbshipit-source-id: 6833c2639b29ea5f6c81c880b8928c5a1951c7b8
diff --git a/aten/src/ATen/native/TensorTransformations.cpp b/aten/src/ATen/native/TensorTransformations.cpp
index 1b86b3f..fdee519 100644
--- a/aten/src/ATen/native/TensorTransformations.cpp
+++ b/aten/src/ATen/native/TensorTransformations.cpp
@@ -61,15 +61,30 @@
}
}
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, in_tensor.scalar_type(), "flip_cpu", [&] {
- flip_cpu_kernel<scalar_t>(
- total_dims,
- stride_contiguous_v,
- flip_dims_b,
- in_tensor,
- out_tensor
- );
- });
+ if (in_tensor.is_quantized()) {
+ AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(in_tensor.scalar_type(),
+ "flip_quantized_cpu", [&] {
+ flip_cpu_kernel<scalar_t>(
+ total_dims,
+ stride_contiguous_v,
+ flip_dims_b,
+ in_tensor,
+ out_tensor
+ );
+ });
+ } else {
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool,
+ in_tensor.scalar_type(),
+ "flip_cpu", [&] {
+ flip_cpu_kernel<scalar_t>(
+ total_dims,
+ stride_contiguous_v,
+ flip_dims_b,
+ in_tensor,
+ out_tensor
+ );
+ });
+ }
return out_tensor;
}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 398aa74..3492564 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3882,7 +3882,7 @@
use_c10_dispatcher: full
variants: function, method
dispatch:
- CPU: flip_cpu
+ CPU, QuantizedCPU: flip_cpu
CUDA: flip_cuda
- func: fliplr(Tensor self) -> Tensor