Revert D20896697: [pytorch][PR] QuantizedCUDA implementation

Test Plan: revert-hammer

Differential Revision:
D20896697

Original commit changeset: 163554efa23d

fbshipit-source-id: e3e370ef7c8be68ea34368dfcc7a7efc9d1f8761
diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h
index 7c9316b..ecccb35 100644
--- a/aten/src/ATen/Dispatch.h
+++ b/aten/src/ATen/Dispatch.h
@@ -6,26 +6,18 @@
 #include <c10/util/Exception.h>
 #include <ATen/core/DeprecatedTypeProperties.h>
 
-// Workaround for C10_UNUSED because CUDA 9.2 fails to handle unused attribute in the type aliasing context.
-// Keep name long and verbose to avoid macro collisions.
-#if defined(__CUDACC__) && CUDA_VERSION <= 9200
-#define C10_UNUSED_DISPATCH_CUDA9_WORKAROUND
-#else
-#define C10_UNUSED_DISPATCH_CUDA9_WORKAROUND C10_UNUSED
-#endif // defined(__CUDACC__) && CUDA_VERSION <= 9200
-
-#define AT_PRIVATE_CASE_TYPE(enum_type, type, ...)              \
-  case enum_type: {                                             \
-    using scalar_t C10_UNUSED_DISPATCH_CUDA9_WORKAROUND = type; \
-    return __VA_ARGS__();                                       \
+#define AT_PRIVATE_CASE_TYPE(enum_type, type, ...) \
+  case enum_type: {                                \
+    using scalar_t = type;                         \
+    return __VA_ARGS__();                          \
   }
 
 #define AT_QINT_PRIVATE_CASE_TYPE(enum_type, type, underlying_enum, underlying_type, ...) \
-  case enum_type: {                                                                       \
-    const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA9_WORKAROUND = underlying_enum;   \
-    using scalar_t C10_UNUSED_DISPATCH_CUDA9_WORKAROUND = type;                           \
-    using underlying_t C10_UNUSED_DISPATCH_CUDA9_WORKAROUND = underlying_type;            \
-    return __VA_ARGS__();                                                                 \
+  case enum_type: {                                                     \
+    const auto& UNDERLYING_TYPE C10_UNUSED = underlying_enum;           \
+    using scalar_t C10_UNUSED = type;                                   \
+    using underlying_t C10_UNUSED = underlying_type;                    \
+    return __VA_ARGS__();                                               \
   }
 
 // This macro should be used to skip bfloat16 dispatch on non-ROCm platforms and
diff --git a/aten/src/ATen/cpu/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec256/vec256_qint.h
index 812642c..c39100f 100644
--- a/aten/src/ATen/cpu/vec256/vec256_qint.h
+++ b/aten/src/ATen/cpu/vec256/vec256_qint.h
@@ -2,7 +2,7 @@
 
 #include <ATen/cpu/vec256/intrinsics.h>
 #include <ATen/cpu/vec256/vec256_base.h>
-#include <ATen/native/quantized/affine_quantizer.h>
+#include <ATen/quantized/Quantizer.h>
 #include <c10/util/qint8.h>
 #include <c10/util/quint8.h>
 #include <c10/util/qint32.h>
@@ -212,7 +212,7 @@
     dst[i] = nearbyint(clipped);
   }
 #else
-  at::native::quantize_vec<T>(
+  at::quantize_vec<T>(
       1.0f / inverse_scale, zero_point, src, reinterpret_cast<T*>(dst), len);
 #endif
 }
@@ -278,7 +278,7 @@
         float inverse_scale) {
       Vec256<c10::qint32> retval;
       auto rhs_data = (__m256)rhs[0];
-      at::native::quantize_vec<c10::qint32, /*precision=*/32>(
+      at::quantize_vec<c10::qint32, /*precision=*/32>(
           scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 8);
       return retval;
     }
@@ -1094,7 +1094,7 @@
     for (int i = 0; i < float_num_vecs(); ++i) {
       for (int j = 0; j < 8; ++j) {
         rv[i][j] =
-            at::native::dequantize_val<T>(scale[j], zero_point[j], T(vals[8 * i + j]));
+            at::dequantize_val<T>(scale[j], zero_point[j], T(vals[8 * i + j]));
       }
     }
     return rv;
@@ -1152,7 +1152,7 @@
       rhs[i].store(float_vals + i * 8, 8);
     }
 
-    at::native::quantize_vec<c10::qint32, /*precision=*/32>(
+    at::quantize_vec<c10::qint32, /*precision=*/32>(
         scale,
         zero_point,
         float_vals,
@@ -1284,7 +1284,7 @@
       rhs[i].store(float_vals + i * 8, 8);
     }
 
-    at::native::quantize_vec<c10::qint8>(
+    at::quantize_vec<c10::qint8>(
         scale,
         zero_point,
         float_vals,
@@ -1404,7 +1404,7 @@
       rhs[i].store(float_vals + i * 8, 8);
     }
 
-    at::native::quantize_vec<c10::quint8>(
+    at::quantize_vec<c10::quint8>(
         scale,
         zero_point,
         float_vals,
diff --git a/aten/src/ATen/cuda/CUDAApplyUtils.cuh b/aten/src/ATen/cuda/CUDAApplyUtils.cuh
index 3e4ea5a..e89b1b1 100644
--- a/aten/src/ATen/cuda/CUDAApplyUtils.cuh
+++ b/aten/src/ATen/cuda/CUDAApplyUtils.cuh
@@ -406,7 +406,7 @@
                                const Op op,
                                TensorArgType aType = TensorArgType::ReadWrite,
                                TensorArgType bType = TensorArgType::ReadOnly) {
-  checkDeviceType("CUDA_tensor_apply2", {a, b}, DeviceType::CUDA);
+  checkBackend("CUDA_tensor_apply2", {a, b}, Backend::CUDA);
   int64_t totalElements = a.numel();
 
   if (totalElements != b.numel()) {
diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py
index 5bc1a58..b7f10ba 100644
--- a/aten/src/ATen/gen.py
+++ b/aten/src/ATen/gen.py
@@ -171,14 +171,12 @@
 def backend_to_devicetype(backend):
     if backend == 'QuantizedCPU':
         return 'CPU'
-    elif backend == 'QuantizedCUDA':
-        return 'CUDA'
     return backend
 
 backends = ['CPU', 'CUDA']
 densities = ['Dense', 'Sparse', 'Mkldnn']  # TODO: layout instead of densities?
 
-quantized_backends = ['QuantizedCPU', 'QuantizedCUDA']
+quantized_backends = ['QuantizedCPU']
 
 # scalar_name, c_type, accreal, is_floating_type
 quantized_scalar_types = [
@@ -214,8 +212,6 @@
 def is_whitelisted_backend(backend):
     return options.backend_whitelist is None or backend in options.backend_whitelist
 
-def is_cuda_backend(backend):
-    return backend in ("QuantizedCUDA", "CUDA")
 
 def dict_representer(dumper, data):
     return dumper.represent_dict(data.items())
@@ -299,7 +295,7 @@
     top_env['type_ids'].append(tag + ',')
 
     env['legacy_th_headers'] = []
-    if is_cuda_backend(backend):
+    if backend == 'CUDA':
         env['extra_cuda_headers'] = []
         env['extra_cuda_headers'].append('#include <ATen/DeviceGuard.h>')
         if options.rocm:
@@ -408,7 +404,7 @@
         if not is_whitelisted_backend(full_backend):
             continue
         fm = file_manager
-        if is_cuda_backend(backend):
+        if backend == 'CUDA':
             fm = cuda_file_manager
         for kind in ["Type"]:
             if kind != 'Type' and density == "Sparse":
diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp
index a99eed2..312805b 100644
--- a/aten/src/ATen/native/Copy.cpp
+++ b/aten/src/ATen/native/Copy.cpp
@@ -112,7 +112,7 @@
   }
 
   if (self.is_quantized() && !src.is_quantized()) {
-    return quantized_copy_from_float_cpu_(self, src);
+    return quantized_copy_from_float_(self, src);
   }
 
   if (self.is_quantized() && src.is_quantized()) {
diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu
index 3cda072..51d13e5 100644
--- a/aten/src/ATen/native/cuda/Copy.cu
+++ b/aten/src/ATen/native/cuda/Copy.cu
@@ -65,16 +65,9 @@
           copy_stream));
     }
   } else {
-    auto dtype = iter.dtype(0);
-    if (isQIntType(dtype)) {
-      AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
-        gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) { return x; });
-      });
-    } else {
-      AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, dtype, "copy_", [&] {
-        gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) { return x; });
-      });
-    }
+    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kHalf, kBool, kBFloat16, iter.dtype(0), "copy_", [&] {
+      gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) { return x; });
+    });
   }
 
   if (src_device != dst_device) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 22143b4..a37b0af 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -436,7 +436,6 @@
     CPU: as_strided_tensorimpl
     CUDA: as_strided_tensorimpl
     QuantizedCPU: as_strided_qtensorimpl
-    QuantizedCUDA: as_strided_qtensorimpl
   device_guard: False
   supports_named_tensor: True
 
@@ -1165,8 +1164,7 @@
 - func: _empty_affine_quantized(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor
   dispatch:
     CPU: empty_affine_quantized_other_backends_stub
-    QuantizedCPU: empty_affine_quantized
-    QuantizedCUDA: empty_affine_quantized
+    QuantizedCPU: empty_affine_quantized_cpu
 
 # it's a factory function receiving a tensor argument, thus overriding explicitly
 # other overrides are to provide a more helpful error message that dtype is required
@@ -3198,7 +3196,6 @@
     SparseCUDA: clone_sparse
     MkldnnCPU: mkldnn_clone
     QuantizedCPU: quantized_clone
-    QuantizedCUDA: quantized_clone
   supports_named_tensor: True
 
 - func: resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)
@@ -3666,8 +3663,7 @@
 - func: quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor
   variants: function
   dispatch:
-    CPU: quantize_per_tensor
-    CUDA: quantize_per_tensor
+    CPU: quantize_per_tensor_cpu
 
 - func: quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype) -> Tensor[]
   variants: function
@@ -3684,7 +3680,6 @@
   variants: function, method
   dispatch:
     QuantizedCPU: dequantize_quant
-    QuantizedCUDA: dequantize_quant
 
 - func: dequantize.tensors(Tensor[] tensors) -> Tensor[]
   variants: function
@@ -3696,14 +3691,12 @@
   variants: function, method
   dispatch:
     QuantizedCPU: q_scale_quant
-    QuantizedCUDA: q_scale_quant
 
 - func: q_zero_point(Tensor self) -> int
   use_c10_dispatcher: full
   variants: function, method
   dispatch:
     QuantizedCPU: q_zero_point_quant
-    QuantizedCUDA: q_zero_point_quant
 
 - func: q_per_channel_scales(Tensor self) -> Tensor
   variants: function, method
@@ -3724,14 +3717,12 @@
   use_c10_dispatcher: full
   variants: function, method
   dispatch:
-    QuantizedCPU: int_repr_quant_cpu
-    QuantizedCUDA: int_repr_quant_cuda
+    QuantizedCPU: int_repr_quant
 
 - func: _make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor
   use_c10_dispatcher: full
   dispatch:
     CPU: make_per_tensor_quantized_tensor_cpu
-    CUDA: make_per_tensor_quantized_tensor_cuda
 
 - func: _make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor
   dispatch:
@@ -3742,7 +3733,6 @@
   variants: method
   dispatch:
     QuantizedCPU: qscheme_quant
-    QuantizedCUDA: qscheme_quant
 
 - func: fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor
   use_c10_dispatcher: full
@@ -3917,8 +3907,7 @@
   dispatch:
     CPU: set_storage_cpu_
     CUDA: set_storage_cuda_
-    QuantizedCPU: set_storage_quantized_
-    QuantizedCUDA: set_storage_quantized_
+    QuantizedCPU: set_storage_quantized_cpu_
 
 - func: set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)
   variants: method
@@ -3937,7 +3926,6 @@
   variants: method
   dispatch:
     QuantizedCPU: set_quantizer_
-    QuantizedCUDA: set_quantizer_
 
 - func: is_set_to(Tensor self, Tensor tensor) -> bool
   use_c10_dispatcher: full
@@ -3989,7 +3977,6 @@
     CUDA: view
     MkldnnCPU: mkldnn_view
     QuantizedCPU: view
-    QuantizedCUDA: view
 
 - func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!)
   variants: method
@@ -5243,7 +5230,7 @@
   dispatch:
     CPU: legacy::cpu::_th_equal
     CUDA: legacy::cuda::_th_equal
-    QuantizedCPU: quantized_equal_cpu
+    QuantizedCPU: quantized_equal
   supports_named_tensor: True
 
 - func: pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
diff --git a/aten/src/ATen/native/quantized/Copy.cpp b/aten/src/ATen/native/quantized/Copy.cpp
index da95ff3..aa5b9b0 100644
--- a/aten/src/ATen/native/quantized/Copy.cpp
+++ b/aten/src/ATen/native/quantized/Copy.cpp
@@ -1,13 +1,13 @@
 #include <ATen/native/quantized/Copy.h>
 
 #include <ATen/ATen.h>
-#include <ATen/native/quantized/affine_quantizer.h>
+#include <ATen/quantized/Quantizer.h>
 
 namespace at {
 namespace native {
 
 // Copying from float to QInt, used for assigning float value to QTensor
-Tensor& quantized_copy_from_float_cpu_(Tensor& self, const Tensor& src) {
+Tensor& quantized_copy_from_float_(Tensor& self, const Tensor& src) {
   TORCH_CHECK(
       src.scalar_type() == at::kFloat,
       "Quantized copy only works with kFloat as source Tensor");
@@ -17,9 +17,6 @@
   TORCH_CHECK(
       self.sizes().equals(src.sizes()),
       "Quantized copy only works with Tensors with the same shape");
-  TORCH_CHECK(
-      self.device().type() == kCPU,
-      "Quantized copy only works with QuantizedCPU Tensors");
   AT_DISPATCH_QINT_TYPES(self.scalar_type(), "Copy", [&]() {
     float* src_data = src.data_ptr<float>();
     scalar_t* self_data = self.data_ptr<scalar_t>();
diff --git a/aten/src/ATen/native/quantized/Copy.h b/aten/src/ATen/native/quantized/Copy.h
index ef12a9a..a1bd290 100644
--- a/aten/src/ATen/native/quantized/Copy.h
+++ b/aten/src/ATen/native/quantized/Copy.h
@@ -5,7 +5,7 @@
 namespace at {
 namespace native {
 
-Tensor& quantized_copy_from_float_cpu_(Tensor& self, const Tensor& src);
+Tensor& quantized_copy_from_float_(Tensor& self, const Tensor& src);
 
 }
 } // namespace at
diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp
index 9f653ad..db4ae7c 100644
--- a/aten/src/ATen/native/quantized/QTensor.cpp
+++ b/aten/src/ATen/native/quantized/QTensor.cpp
@@ -5,12 +5,11 @@
 #include <ATen/native/quantized/cpu/quant_utils.h>
 #include <ATen/quantized/QTensorImpl.h>
 #include <ATen/quantized/Quantizer.h>
-#include <ATen/native/quantized/cpu/quant_utils.h>
 
 namespace at {
 namespace native {
 
-Tensor quantize_per_tensor(
+Tensor quantize_per_tensor_cpu(
     const Tensor& self,
     double scale,
     int64_t zero_point,
@@ -92,6 +91,49 @@
   return static_cast<PerChannelAffineQuantizer*>(quantizer.get())->axis();
 }
 
+// When input Tensor is non-dense, i.e. the allocated memory
+// is larger than the memory used by all the elements, we'll
+// convert it to dense tensor, otherwise we'll keep the memory
+// format of the output the same as input
+Tensor int_repr_quant(const Tensor& self) {
+  Tensor dst;
+  AT_DISPATCH_QINT_TYPES(self.scalar_type(), "int_repr", [&]() {
+    dst = at::empty(
+        self.sizes(),
+        self.options().dtype(UNDERLYING_TYPE),
+        self.suggest_memory_format());
+    auto iter = TensorIterator();
+    iter.add_output(dst);
+    iter.add_input(self);
+    iter.dont_compute_common_dtype();
+    iter.build();
+    cpu_kernel(iter, [](scalar_t value) -> underlying_t { return value.val_; });
+  });
+  return dst;
+}
+
+Tensor make_per_tensor_quantized_tensor_cpu(
+    const Tensor& self,
+    double scale,
+    int64_t zero_point) {
+  Tensor dst = at::_empty_affine_quantized(
+      self.sizes(),
+      self.options().dtype(toQIntType(self.scalar_type())),
+      scale,
+      zero_point);
+  Tensor self_contig = self.contiguous();
+  AT_DISPATCH_QINT_TYPES(
+      dst.scalar_type(), "make_per_tensor_quantized_tensor", [&]() {
+        underlying_t* self_data = self_contig.data_ptr<underlying_t>();
+        underlying_t* dst_data =
+            reinterpret_cast<underlying_t*>(dst.data_ptr<scalar_t>());
+        if (self.numel() > 0) {
+          memcpy(dst_data, self_data, self.nbytes());
+        }
+      });
+  return dst;
+}
+
 Tensor make_per_channel_quantized_tensor_cpu(
     const Tensor& self,
     const Tensor& scales,
@@ -116,7 +158,7 @@
   return dst;
 }
 
-Tensor& set_storage_quantized_(
+Tensor& set_storage_quantized_cpu_(
     Tensor& self,
     Storage storage,
     int64_t storage_offset,
@@ -171,9 +213,7 @@
   return dst;
 }
 
-bool quantized_equal_cpu(const Tensor& self, const Tensor& other) {
-  TORCH_CHECK(self.device().type() == kCPU && other.device().type() == kCPU,
-    "quantized_equal is implemented only for the QuantizedCPU backend");
+bool quantized_equal(const Tensor& self, const Tensor& other) {
   if (!other.is_quantized()) {
     return false;
   }
diff --git a/aten/src/ATen/native/quantized/TensorFactories.cpp b/aten/src/ATen/native/quantized/TensorFactories.cpp
index 09ed776..ce63a2f 100644
--- a/aten/src/ATen/native/quantized/TensorFactories.cpp
+++ b/aten/src/ATen/native/quantized/TensorFactories.cpp
@@ -10,7 +10,7 @@
 // We explicitly pass in scale and zero_point because we don't have the infra
 // ready to support quantizer in python frontend, once that is ready, we'll
 // change to use quantizer
-Tensor empty_affine_quantized(
+Tensor empty_affine_quantized_cpu(
     IntArrayRef size,
     const TensorOptions& options_,
     double scale,
@@ -24,7 +24,7 @@
   TORCH_CHECK(
       options.has_dtype(),
       "Must provide data type for Tensor creation functions.");
-  return new_qtensor(
+  return new_qtensor_cpu(
       size,
       options,
       make_per_tensor_affine_quantizer(
@@ -49,7 +49,7 @@
   TORCH_CHECK(
       options.dtype() == kQInt8 || options.dtype() == kQUInt8,
       "Supported data type for tensor creation is int8 or uint8");
-  return new_qtensor(
+  return new_qtensor_cpu(
       size,
       options,
       make_per_channel_affine_quantizer(
diff --git a/aten/src/ATen/native/quantized/affine_quantizer.cpp b/aten/src/ATen/native/quantized/affine_quantizer.cpp
deleted file mode 100644
index a5a2e68..0000000
--- a/aten/src/ATen/native/quantized/affine_quantizer.cpp
+++ /dev/null
@@ -1,311 +0,0 @@
-#include <ATen/native/quantized/affine_quantizer.h>
-
-#ifdef USE_FBGEMM
-#include <fbgemm/QuantUtils.h>
-#endif
-#ifdef __ARM_NEON__
-#include <arm_neon.h>
-#endif
-
-namespace at {
-namespace native {
-
-DEFINE_DISPATCH(quantize_tensor_per_tensor_affine_stub);
-DEFINE_DISPATCH(quantize_tensor_per_channel_affine_stub);
-DEFINE_DISPATCH(dequantize_tensor_per_tensor_affine_stub);
-DEFINE_DISPATCH(dequantize_tensor_per_channel_affine_stub);
-
-namespace {
-
-void checkCPUTensor(const std::string& fn_name, Tensor t) {
-  TORCH_CHECK(
-      t.device().type() == kCPU,
-      fn_name,
-      " only supports CPU device type.");
-}
-
-void checkFloatTensor(const std::string& fn_name, Tensor t) {
-  TORCH_CHECK(
-      t.scalar_type() == kFloat,
-      fn_name,
-      " expects a Float Tensor.");
-}
-
-void checkSameDevice(const std::string& fn_name, Tensor t1, Tensor t2) {
-  TORCH_CHECK(
-      t1.device() == t2.device(),
-      fn_name,
-      " expects a quantized and float tensors to be on the same device.");
-}
-
-template <typename T>
-void checkQuantizedTensor(const std::string& fn_name, Tensor t) {
-  TORCH_CHECK(t.is_quantized(),
-           fn_name,
-           " expects a quantized Tensor.");
-  TORCH_CHECK(t.scalar_type() == caffe2::TypeMeta::Make<T>(),
-           fn_name,
-           " expects a ",
-           caffe2::TypeMeta::Make<T>(),
-           " Tensor");
-}
-
-template <typename T>
-void checkZeroPoint(const std::string& fn_name, int64_t zero_point) {
-  TORCH_CHECK(zero_point <= std::numeric_limits<T>::max(),
-              fn_name,
-              " zero_point ",
-              zero_point,
-              " is out of range.");
-  TORCH_CHECK(zero_point >= std::numeric_limits<T>::min(),
-              fn_name,
-              " zero_point ",
-              zero_point,
-              " is out of range.");
-}
-
-template <typename T>
-void checkZeroPoints(const std::string& fn_name, Tensor zero_points) {
-  auto zero_points_data = zero_points.data_ptr<int64_t>();
-  for (size_t i = 0; i < zero_points.numel(); ++i) {
-    checkZeroPoint<T>(fn_name, zero_points_data[i]);
-  }
-}
-
-void checkSameSize(const std::string& fn_name, Tensor qt, Tensor rt) {
-  TORCH_CHECK(
-      qt.sizes().equals(rt.sizes()),
-      fn_name,
-      " only works with Tensors with the same shape");
-}
-
-} // anonymous namespace
-
-Tensor quantize_tensor_per_tensor_affine(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point) {
-  static const auto fn_name = "quantize_tensor_per_tensor_affine";
-  checkFloatTensor(fn_name, rtensor);
-  checkSameDevice(fn_name, rtensor, qtensor);
-  checkSameSize(fn_name, qtensor, rtensor);
-
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
-    checkQuantizedTensor<scalar_t>(fn_name, qtensor);
-    checkZeroPoint<underlying_t>(fn_name, zero_point);
-  });
-
-  quantize_tensor_per_tensor_affine_stub(rtensor.device().type(), rtensor, qtensor, scale, zero_point);
-  return qtensor;
-}
-
-Tensor quantize_tensor_per_channel_affine(Tensor rtensor,
-                                          Tensor qtensor,
-                                          Tensor scales,
-                                          Tensor zero_points,
-                                          int64_t axis) {
-  static const auto fn_name = "quantize_tensor_per_channel_affine";
-
-  checkFloatTensor(fn_name, rtensor);
-  checkCPUTensor(fn_name, rtensor);
-  checkSameDevice(fn_name, rtensor, qtensor);
-  checkSameSize(fn_name, qtensor, rtensor);
-
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
-    checkQuantizedTensor<scalar_t>(fn_name, qtensor);
-    checkZeroPoints<underlying_t>(fn_name, zero_points);
-  });
-
-  TORCH_CHECK(0 <= axis && axis < rtensor.dim(), "Channel axis out of range in per channel affine quantization.");
-  int64_t channel = rtensor.size(axis);
-  TORCH_CHECK(channel == int64_t(scales.numel()), "length of scales must equal to channel");
-  TORCH_CHECK(channel == int64_t(zero_points.numel()), "length of zero_points must equal to channel");
-
-  quantize_tensor_per_channel_affine_stub(rtensor.device().type(), rtensor, qtensor, scales, zero_points, axis);
-  return qtensor;
-}
-
-Tensor dequantize_tensor_per_tensor_affine(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point) {
-  static const auto fn_name = "dequantize_tensor_per_tensor_affine";
-  checkFloatTensor(fn_name, rtensor);
-  checkSameDevice(fn_name, rtensor, qtensor);
-  checkSameSize(fn_name, qtensor, rtensor);
-
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
-    checkQuantizedTensor<scalar_t>(fn_name, qtensor);
-    checkZeroPoint<underlying_t>(fn_name, zero_point);
-  });
-
-  dequantize_tensor_per_tensor_affine_stub(qtensor.device().type(), qtensor, rtensor, scale, zero_point);
-  return rtensor;
-}
-
-Tensor dequantize_tensor_per_channel_affine(Tensor qtensor,
-                                            Tensor rtensor,
-                                            Tensor scales,
-                                            Tensor zero_points,
-                                            int64_t axis) {
-  static const auto fn_name = "dequantize_tensor_per_channel_affine";
-
-  checkFloatTensor(fn_name, rtensor);
-  checkCPUTensor(fn_name, rtensor);
-  checkSameDevice(fn_name, rtensor, qtensor);
-  checkSameSize(fn_name, qtensor, rtensor);
-
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
-    checkQuantizedTensor<scalar_t>(fn_name, qtensor);
-    checkZeroPoints<underlying_t>(fn_name, zero_points);
-  });
-
-  TORCH_CHECK(0 <= axis && axis < qtensor.dim(), "Channel axis out of range in per channel affine dequantization.");
-  int64_t channel = qtensor.size(axis);
-  TORCH_CHECK(channel == int64_t(scales.numel()), "length of scales must equal to channel");
-  TORCH_CHECK(channel == int64_t(zero_points.numel()), "length of zero_points must equal to channel");
-
-  dequantize_tensor_per_channel_affine_stub(qtensor.device().type(), qtensor, rtensor, scales, zero_points, axis);
-  return rtensor;
-
-}
-
-#ifdef USE_FBGEMM
-// Note: quantize_val is only explicitly used in test outside of this file
-template <typename T>
-T quantize_val(double scale, int64_t zero_point, float value) {
-  // Internally, fbgemm::Quantize uses std::nearbyint.
-  // std::nearbyint results in nearest integer value according to the current
-  // rounding mode and the default rounding mode is rounds to even in half-way
-  // cases in most popular processor architectures like x86 and ARM. This is
-  // typically faster than an alternatives like std::round that rounds half-way
-  // cases away from zero, and can be consistent with SIMD implementations for
-  // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
-  // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
-  int32_t qvalue;
-  qvalue = fbgemm::Quantize<typename T::underlying>(
-      value,
-      static_cast<int32_t>(zero_point),
-      static_cast<double>(scale),
-      /*result_precision=*/CHAR_BIT * sizeof(typename T::underlying));
-  return static_cast<T>(qvalue);
-}
-
-template <typename T, int precision>
-void quantize_vec(double scale, int64_t zero_point, const float *src, T *dst, size_t count) {
-  fbgemm::Quantize<typename T::underlying>(
-    src,
-    (typename T::underlying*)dst,
-    count,
-    fbgemm::TensorQuantizationParams{(float)scale, (int32_t)zero_point, precision}
-  );
-}
-
-template <typename T>
-inline float dequantize_val(double scale, int64_t zero_point, T value) {
-  fbgemm::TensorQuantizationParams qparams;
-  qparams.scale = static_cast<float>(scale);
-  qparams.zero_point = static_cast<int32_t>(zero_point);
-  return fbgemm::Dequantize<typename T::underlying>(value.val_, qparams);
-}
-#else  // USE_FBGEMM
-
-#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
-template <class T>
-inline float Round(const float x) {
-  return ::nearbyintf(x);
-}
-inline double Round(const double x) {
-  return ::nearbyint(x);
-}
-#else
-template <class T>
-inline T Round(const T x) {
-  return std::nearbyint(x);
-}
-#endif
-
-template <typename T>
-T quantize_val(double scale, int64_t zero_point, float value) {
-  // std::nearbyint results in nearest integer value according to the current
-  // rounding mode and the default rounding mode is rounds to even in half-way
-  // cases in most popular processor architectures like x86 and ARM. This is
-  // typically faster than an alternatives like std::round that rounds half-way
-  // cases away from zero, and can be consistent with SIMD implementations for
-  // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
-  // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
-  int64_t qvalue;
-  constexpr int64_t qmin = std::numeric_limits<typename T::underlying>::min();
-  constexpr int64_t qmax = std::numeric_limits<typename T::underlying>::max();
-  qvalue = static_cast<int64_t>(Round(value / scale + zero_point));
-  qvalue = std::max<int64_t>(qvalue, qmin);
-  qvalue = std::min<int64_t>(qvalue, qmax);
-  return static_cast<T>(qvalue);
-}
-
-uint8_t quantize_val_arm(const float scale, const int32_t zero_point, const float value) {
-  const int32_t qmin = std::numeric_limits<uint8_t>::min();
-  const int32_t qmax = std::numeric_limits<uint8_t>::max();
-  auto r = zero_point + static_cast<int32_t>(Round(value / scale));
-  r = std::max(r, qmin);
-  r = std::min(r, qmax);
-  return static_cast<uint8_t>(r);
-}
-
-template <typename T, int precision>
-void quantize_vec(double scale, int64_t zero_point, const float *src, T *dst, size_t count) {
-  checkZeroPoint<typename T::underlying>("quantize_vec", zero_point);
-  for (int64_t i = 0; i < count; ++i) {
-    dst[i] = quantize_val<T>(scale, zero_point, src[i]);
-  }
-}
-
-template <typename T>
-CAFFE2_API float dequantize_val(double scale, int64_t zero_point, T value) {
-  // We need to convert the qint8 value to float to ensure the subtraction
-  // subexpression returns a float
-  return (static_cast<float>(value.val_) - zero_point) * scale;
-}
-#endif  // USE_FBGEMM
-
-template <typename SRC_T, typename DST_T>
-DST_T requantize_val(double src_scale, int64_t src_zero_point,
-                     double dst_scale, int64_t dst_zero_point,
-                     SRC_T src) {
-  const auto dq = dequantize_val<SRC_T>(src_scale, src_zero_point, src);
-  return quantize_val<DST_T>(dst_scale, dst_zero_point, dq);
-}
-
-template <typename DST_T>
-DST_T requantize_from_int(double multiplier, int64_t zero_point, int64_t src) {
-  int64_t quantize_down =
-      zero_point + lrintf(src * static_cast<float>(multiplier));
-  int32_t min = std::numeric_limits<typename DST_T::underlying>::min();
-  int32_t max = std::numeric_limits<typename DST_T::underlying>::max();
-  return static_cast<DST_T>(
-      std::min<int64_t>(std::max<int64_t>(quantize_down, min), max));
-}
-
-template CAFFE2_API qint8 quantize_val<qint8>(double scale, int64_t zero_point, float value);
-template CAFFE2_API quint8 quantize_val<quint8>(double scale, int64_t zero_point, float value);
-template CAFFE2_API qint32 quantize_val<qint32>(double scale, int64_t zero_point, float value);
-template CAFFE2_API void quantize_vec<c10::qint8>(double scale, int64_t zero_point, const float *src, c10::qint8 *dst, size_t count);
-template CAFFE2_API void quantize_vec<c10::quint8>(double scale, int64_t zero_point, const float *src, c10::quint8 *dst, size_t count);
-template CAFFE2_API void quantize_vec<c10::qint32, 32>(double scale, int64_t zero_point, const float *src, c10::qint32 *dst, size_t count);
-
-template CAFFE2_API float dequantize_val<qint8>(double scale, int64_t zero_point, qint8 value);
-template CAFFE2_API float dequantize_val<quint8>(double scale, int64_t zero_point, quint8 value);
-template CAFFE2_API float dequantize_val<qint32>(double scale, int64_t zero_point, qint32 value);
-
-template CAFFE2_API qint8 requantize_val<qint8, qint8>(double, int64_t, double, int64_t, qint8);
-template CAFFE2_API quint8 requantize_val<qint8, quint8>(double, int64_t, double, int64_t, qint8);
-template CAFFE2_API qint32 requantize_val<qint8, qint32>(double, int64_t, double, int64_t, qint8);
-template CAFFE2_API qint8 requantize_val<quint8, qint8>(double, int64_t, double, int64_t, quint8);
-template CAFFE2_API quint8 requantize_val<quint8, quint8>(double, int64_t, double, int64_t, quint8);
-template CAFFE2_API qint32 requantize_val<quint8, qint32>(double, int64_t, double, int64_t, quint8);
-template CAFFE2_API qint8 requantize_val<qint32, qint8>(double, int64_t, double, int64_t, qint32);
-template CAFFE2_API quint8 requantize_val<qint32, quint8>(double, int64_t, double, int64_t, qint32);
-template CAFFE2_API qint32 requantize_val<qint32, qint32>(double, int64_t, double, int64_t, qint32);
-
-template CAFFE2_API qint8 requantize_from_int<qint8>(double, int64_t, int64_t);
-template CAFFE2_API quint8
-requantize_from_int<quint8>(double, int64_t, int64_t);
-template CAFFE2_API qint32
-requantize_from_int<qint32>(double, int64_t, int64_t);
-
-} // native
-} // at
diff --git a/aten/src/ATen/native/quantized/affine_quantizer.h b/aten/src/ATen/native/quantized/affine_quantizer.h
deleted file mode 100644
index ce805a9..0000000
--- a/aten/src/ATen/native/quantized/affine_quantizer.h
+++ /dev/null
@@ -1,81 +0,0 @@
-#pragma once
-
-#include <ATen/ATen.h>
-#include <ATen/native/DispatchStub.h>
-
-namespace at {
-namespace native {
-
-Tensor quantize_tensor_per_tensor_affine(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
-Tensor quantize_tensor_per_channel_affine(Tensor qtensor,
-                                          Tensor rtensor,
-                                          Tensor scales,
-                                          Tensor zero_points,
-                                          int64_t axis);
-
-Tensor dequantize_tensor_per_tensor_affine(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point);
-Tensor dequantize_tensor_per_channel_affine(Tensor qtensor,
-                                            Tensor rtensor,
-                                            Tensor scales,
-                                            Tensor zero_points,
-                                            int64_t axis);
-
-using quantize_tensor_per_tensor_affine_fn = void (*)(
-    Tensor rtensor,
-    Tensor qtensor,
-    double scale,
-    int64_t zero_point);
-
-using quantize_tensor_per_channel_affine_fn = void (*)(
-    Tensor qtensor,
-    Tensor rtensor,
-    Tensor scales,
-    Tensor zero_points,
-    int64_t axis);
-
-using dequantize_tensor_per_tensor_affine_fn = void (*)(
-    Tensor qtensor,
-    Tensor rtensor,
-    double scale,
-    int64_t zero_point);
-
-using dequantize_tensor_per_channel_affine_fn = void (*)(
-    Tensor qtensor,
-    Tensor rtensor,
-    Tensor scales,
-    Tensor zero_points,
-    int64_t axis);
-
-DECLARE_DISPATCH(quantize_tensor_per_tensor_affine_fn, quantize_tensor_per_tensor_affine_stub);
-DECLARE_DISPATCH(quantize_tensor_per_channel_affine_fn, quantize_tensor_per_channel_affine_stub);
-
-DECLARE_DISPATCH(dequantize_tensor_per_tensor_affine_fn, dequantize_tensor_per_tensor_affine_stub);
-DECLARE_DISPATCH(dequantize_tensor_per_channel_affine_fn, dequantize_tensor_per_channel_affine_stub);
-
-// Quantize a float value into a uint value given scale and zero_point
-template <typename T>
-CAFFE2_API T quantize_val(double scale, int64_t zero_point, float value);
-// TODO combine this with quantize_val once the numerics for ARM are aligned with it
-uint8_t quantize_val_arm(const float scale, const int32_t zero_point, const float value);
-template <typename T, int precision=8>
-void quantize_vec(double scale, int64_t zero_point, const float *src, T *dst, size_t count=8);
-template <typename T>
-CAFFE2_API Tensor quantize_tensor(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
-template <typename T>
-CAFFE2_API float dequantize_val(double scale, int64_t zero_point, T value);
-template <typename T>
-CAFFE2_API float dequantize_vec(double scale, int64_t zero_point, const T* src, float* dst, size_t count=8);
-template <typename T>
-CAFFE2_API Tensor dequantize_tensor(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point);
-template <typename SRC_T, typename DST_T>
-CAFFE2_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src);
-
-// Given a multiplier and a zero_point, requantize int32_t computed values back
-// to quantized values. See comment above
-// make_per_tensor_affine_quantizer function for the usage of int64_t
-template <typename DST_T>
-CAFFE2_API DST_T
-requantize_from_int(double multiplier, int64_t zero_point, int64_t src);
-
-} // namespace native
-} // namespace at
diff --git a/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp b/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp
deleted file mode 100644
index c0b0d2a..0000000
--- a/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp
+++ /dev/null
@@ -1,29 +0,0 @@
-#include <ATen/native/TensorIterator.h>
-#include <ATen/native/cpu/Loops.h>
-
-namespace at {
-namespace native {
-
-// When input Tensor is non-dense, i.e. the allocated memory
-// is larger than the memory used by all the elements, we'll
-// convert it to dense tensor, otherwise we'll keep the memory
-// format of the output the same as input
-Tensor int_repr_quant_cpu(const Tensor& self) {
-  Tensor dst;
-  AT_DISPATCH_QINT_TYPES(self.scalar_type(), "int_repr", [&]() {
-    dst = at::empty(
-        self.sizes(),
-        self.options().dtype(UNDERLYING_TYPE),
-        self.suggest_memory_format());
-    auto iter = TensorIterator();
-    iter.add_output(dst);
-    iter.add_input(self);
-    iter.dont_compute_common_dtype();
-    iter.build();
-    cpu_kernel(iter, [](scalar_t value) -> underlying_t { return value.val_; });
-  });
-  return dst;
-}
-
-} // namespace native
-} // namespace at
diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
index 53d7a88..7bbf573 100644
--- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
+++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
@@ -1,11 +1,10 @@
 #include <ATen/ATen.h>
 #include <ATen/Dispatch.h>
-#include <ATen/Parallel.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/UpSample.h>
 #include <ATen/native/cpu/Loops.h>
 #include <ATen/native/quantized/cpu/quantized_ops.h>
-#include <ATen/native/quantized/affine_quantizer.h>
+#include <ATen/quantized/Quantizer.h>
 #include <ATen/native/SortingUtils.h>
 
 #include <cmath>
@@ -15,10 +14,6 @@
 #ifdef _OPENMP
 #include <omp.h>
 #endif
-#ifdef __ARM_NEON__
-#include <arm_neon.h>
-#include <ATen/quantized/Quantizer.h>
-#endif
 
 namespace at {
 namespace native {
@@ -137,7 +132,7 @@
 
             // Scalar loop
             for (; c < curr_C; ++c) {
-              auto float_val = at::native::dequantize_val(
+              auto float_val = at::dequantize_val(
                   curr_scale,
                   curr_zero_pt,
                   reinterpret_cast<scalar_t*>(iptr)[c]);
@@ -145,7 +140,7 @@
                 float_val = std::max(0.0f, float_val);
               }
               optr[c] =
-                  at::native::quantize_val<scalar_t>(scale, zero_point, float_val).val_;
+                  at::quantize_val<scalar_t>(scale, zero_point, float_val).val_;
             } // for c
 
           } // for tidx
@@ -410,7 +405,7 @@
     using Vec = Vec256<scalar_t>;
     auto iter = TensorIterator::unary_op(qy, qx);
     scalar_t six =
-        at::native::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(), 6.0);
+        at::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(), 6.0);
     auto zero_point_vec = Vec(scalar_t(zero_point));
     auto six_vec = Vec(six);
     cpu_kernel_vec(
@@ -452,9 +447,9 @@
     cpu_kernel_vec(
         iter,
         [&](scalar_t value_qx) -> scalar_t {
-          auto value_dx = at::native::dequantize_val(i_scale, i_zp, value_qx);
+          auto value_dx = at::dequantize_val(i_scale, i_zp, value_qx);
           auto value_dy = value_dx > 0 ? value_dx : value_dx * negval;
-          return at::native::quantize_val<scalar_t>(o_scale, o_zp, value_dy);
+          return at::quantize_val<scalar_t>(o_scale, o_zp, value_dy);
         },
         [&](qVec qx_vec) -> qVec {
           /* Vectorized implementation creates a multiplicand vector, which has
@@ -510,10 +505,10 @@
     cpu_kernel_vec(
       iter,
       [&](scalar_t value_qx) -> scalar_t {
-        const auto value_dx = at::native::dequantize_val(scale, zero_point, value_qx);
+        const auto value_dx = at::dequantize_val(scale, zero_point, value_qx);
         const auto value_dy = 1.0f / (1.0 + std::exp((-value_dx)));
-        return at::native::quantize_val<scalar_t>(output_scale, output_zero_point,
-                                                  value_dy);
+        return at::quantize_val<scalar_t>(output_scale, output_zero_point,
+                                          value_dy);
       },
       [&](Vec value_qx) -> Vec {
         auto value_dx = value_qx.dequantize(scale_vec, zero_point_vec,
@@ -573,9 +568,9 @@
     cpu_kernel_vec(
       iter,
       [&](scalar_t qx) -> scalar_t {
-        auto x = at::native::dequantize_val(scale, zero_point, qx);
+        auto x = at::dequantize_val(scale, zero_point, qx);
         const auto y = std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
-        return at::native::quantize_val<scalar_t>(output_scale, output_zero_point, y);
+        return at::quantize_val<scalar_t>(output_scale, output_zero_point, y);
       },
       [&](qVec value_qx) -> qVec {
         auto value_dx = value_qx.dequantize(scale_vec, zero_point_vec,
@@ -610,9 +605,9 @@
     auto min = min_scalar.to<float>();
     auto max = max_scalar.to<float>();
     scalar_t min_q =
-        at::native::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(), min);
+        at::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(), min);
     scalar_t max_q =
-        at::native::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(), max);
+        at::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(), max);
     auto min_vec = Vec(min_q);
     auto max_vec = Vec(max_q);
     cpu_kernel_vec(
@@ -651,9 +646,9 @@
     cpu_kernel_vec(
         iter,
         [&](scalar_t value) -> scalar_t {
-          const auto x = at::native::dequantize_val(i_scale, i_zero_point, value);
+          const auto x = at::dequantize_val(i_scale, i_zero_point, value);
           const auto y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
-          return at::native::quantize_val<scalar_t>(o_scale, o_zero_point, y);
+          return at::quantize_val<scalar_t>(o_scale, o_zero_point, y);
         },
         [&](qVec value) -> qVec {
           auto value_dx = value.dequantize(i_scale_vec, i_zero_point_vec,
@@ -703,9 +698,9 @@
     cpu_kernel_vec(
       iter,
       [&](scalar_t value_qx) -> scalar_t {
-        const auto value_dx = at::native::dequantize_val(scale, zero_point, value_qx);
-        return at::native::quantize_val<scalar_t>(output_scale, output_zero_point,
-                                                  std::tanh(value_dx));
+        const auto value_dx = at::dequantize_val(scale, zero_point, value_qx);
+        return at::quantize_val<scalar_t>(output_scale, output_zero_point,
+                                          std::tanh(value_dx));
       },
       [&](Vec value_qx) -> Vec {
         const auto value_dx = value_qx.dequantize(scale_vec, zero_point_vec,
@@ -753,13 +748,13 @@
       iter,
       [&](scalar_t value_qx) -> scalar_t {
         // dequantize
-        const auto x = at::native::dequantize_val(i_scale, i_zp, value_qx);
+        const auto x = at::dequantize_val(i_scale, i_zp, value_qx);
         // ELU
         const auto y = x >= 0
           ? x
           : (alpha_float * (std::exp(x) - 1));
         // quantize
-        return at::native::quantize_val<scalar_t>(o_scale, o_zp, y);
+        return at::quantize_val<scalar_t>(o_scale, o_zp, y);
       },
       [&](qVec value_qx) -> qVec {
         // dequantize
@@ -816,7 +811,7 @@
               static_cast<int32_t>(self_zero_point);
           int32_t c = a_sub_z + other_val;
           scalar_t res =
-              at::native::requantize_from_int<scalar_t>(multiplier, zero_point, c);
+              at::requantize_from_int<scalar_t>(multiplier, zero_point, c);
           if (ReLUFused) {
             res.val_ = std::max<scalar_t::underlying>(res.val_, zero_point);
           }
@@ -868,13 +863,13 @@
     cpu_kernel_vec(
         iter,
         [&](scalar_t a, scalar_t b) -> scalar_t {
-          const auto da = at::native::dequantize_val(self_scale, self_zero_point, a);
-          const auto db = at::native::dequantize_val(other_scale, other_zero_point, b);
+          const auto da = at::dequantize_val(self_scale, self_zero_point, a);
+          const auto db = at::dequantize_val(other_scale, other_zero_point, b);
           float c = da + db;
           if (ReLUFused) {
             c = std::max<float>(c, 0.0);
           }
-          return at::native::quantize_val<scalar_t>(scale, zero_point, c);
+          return at::quantize_val<scalar_t>(scale, zero_point, c);
         },
         [&](Vec a, Vec b) -> Vec {
           const auto da = a.dequantize(
@@ -929,7 +924,7 @@
               static_cast<int32_t>(other_zero_point);
           int32_t c = a_sub_z * b_sub_z;
           scalar_t res =
-              at::native::requantize_from_int<scalar_t>(multiplier, zero_point, c);
+              at::requantize_from_int<scalar_t>(multiplier, zero_point, c);
           if (ReLUFused) {
             res.val_ = std::max<scalar_t::underlying>(res.val_, zero_point);
           }
@@ -1165,7 +1160,7 @@
       float acc_fp[vec_width];
       acc.store(acc_int);
       vec256::convert(acc_int, acc_fp, vec_width);
-      at::native::quantize_vec<T>(
+      at::quantize_vec<T>(
           1.0f / multiplier,
           output_zero_point,
           acc_fp,
@@ -1254,7 +1249,7 @@
             }
           }
           // clamp
-          o_p[c] = at::native::quantize_val<scalar_t>(
+          o_p[c] = at::quantize_val<scalar_t>(
                        1.0f / multiplier, output_zero_point, acc_int32)
                        .val_;
         } // c
@@ -1354,7 +1349,7 @@
           }
           double acc_fp = acc_int32 * 1.0;
           // clamp
-          o_p[c] = at::native::quantize_val<scalar_t>(
+          o_p[c] = at::quantize_val<scalar_t>(
                        1.0f / multiplier, output_zero_point, acc_fp)
                        .val_;
         } // c
@@ -1470,7 +1465,7 @@
             }
             double acc_fp = acc_int32 * 1.0;
             // clamp
-            o_p[c] = at::native::quantize_val<scalar_t>(
+            o_p[c] = at::quantize_val<scalar_t>(
                          1.0f / multiplier, output_zero_point, acc_fp)
                          .val_;
           } // c
@@ -1530,7 +1525,7 @@
           input_zero_point_v;
       float result_fp[vec_width];
       result.store(result_fp);
-      at::native::quantize_vec<T>(
+      at::quantize_vec<T>(
           inverse_scale,
           output_zero_point,
           result_fp,
@@ -1623,11 +1618,11 @@
                     h1lambda *
                         (w0lambda * pos1[h1p * input_width * channels] +
                          w1lambda * pos1[(h1p * input_width + w1p) * channels]);
-                pos2[0] = at::native::quantize_val<scalar_t>(
-                                      inverse_scale,
-                                      output.q_zero_point(),
-                                      result - input.q_zero_point())
-                                      .val_;
+                pos2[0] = at::quantize_val<scalar_t>(
+                              inverse_scale,
+                              output.q_zero_point(),
+                              result - input.q_zero_point())
+                              .val_;
                 pos1 += 1;
                 pos2 += 1;
               } // c
@@ -1997,10 +1992,10 @@
           const float gamma_v = gamma_null ? 1.0f : gamma_data[remIdx];
           const float beta_v = beta_null ? 0.0f : beta_data[remIdx];
           auto qXVal = X_ptr[remIdx];
-          float dqXVal = at::native::dequantize_val(x_fake_scale, x_zp, qXVal);
+          float dqXVal = at::dequantize_val(x_fake_scale, x_zp, qXVal);
           float dqY =
             ((dqXVal - layer_mean_div_scale_x) * scale_x_div_layer_std) * gamma_v + beta_v;
-          Y_ptr[remIdx] = at::native::quantize_val<scalar_t>(y_scale, y_zp, dqY);
+          Y_ptr[remIdx] = at::quantize_val<scalar_t>(y_scale, y_zp, dqY);
         }
       }
     }); // parallel_for
@@ -2008,207 +2003,6 @@
   });
 }
 
-#ifdef USE_FBGEMM
-void quantize_tensor_per_tensor_affine_cpu(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point) {
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
-    const float* rd = rtensor.data_ptr<float>();
-    auto qd = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
-    fbgemm::TensorQuantizationParams qparams;
-    qparams.scale = scale;
-    qparams.zero_point = zero_point;
-    qparams.precision = CHAR_BIT * sizeof(underlying_t);
-    int num_tasks = at::get_num_threads();
-    at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
-      for (int task_id = begin; task_id < end; ++task_id) {
-        fbgemm::Quantize<underlying_t>(
-            rd, /*src=*/
-            qd, /*dst=*/
-            rtensor.numel(), /*len*/
-            qparams, /*qparams=*/
-            task_id, /*thread_id*/
-            num_tasks /*num_threads*/);
-      }
-    });
-  });
-}
-
-void dequantize_tensor_per_tensor_affine_cpu(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point) {
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
-    const auto* qd = reinterpret_cast<const underlying_t*>(qtensor.data_ptr<scalar_t>());
-    fbgemm::TensorQuantizationParams qparams;
-    qparams.scale = scale;
-    qparams.zero_point = zero_point;
-    qparams.precision = CHAR_BIT * sizeof(underlying_t);
-    float* rd = rtensor.data_ptr<float>();
-    int num_tasks = at::get_num_threads();
-    at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
-      for (int task_id = begin; task_id < end; ++task_id) {
-        fbgemm::Dequantize<underlying_t>(
-            qd, /*src=*/
-            rd, /*dst=*/
-            qtensor.numel(), /*len=*/
-            qparams, /*qparams=*/
-            task_id, /*thread_id*/
-            num_tasks /*num_threads*/);
-      }
-    });
-  });
-}
-#else  // USE_FBGEMM
-
-#ifdef __ARM_NEON__
-// Generic template defaults to naive quantize implementation
-template <typename T>
-void quantize_tensor_arm(
-    const float* in,
-    Tensor qtensor,
-    const int64_t N,
-    const float scale,
-    const int32_t zero_point) {
-  auto out = qtensor.data_ptr<T>();
-  for (int i = 0; i < N; ++i) {
-    out[i] = at::native::quantize_val<T>(scale, zero_point, in[i]);
-  }
-}
-
-// Specialized implementation from caffe2::Int8Quantize.
-// There may be slight accuracy difference between this and implementation of quantize_val
-// TODO Update quantize_tensor_arm implementation to follow quantize_val,
-// i.e. f = Round(value/scale + zero_point)
-// TODO Make quantize_tensor_arm work for other datatypes too (int8, int32).
-template <>
-void quantize_tensor_arm<c10::quint8>(
-    const float* in,
-    Tensor qtensor,
-    const int64_t N,
-    const float scale,
-    const int32_t zero_point) {
-  const float inv_scale = 1.0f / scale;
-  uint32_t i = 0;
-  auto out = (uint8_t*)qtensor.data_ptr<c10::quint8>();
-  const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);
-  // magic float and magic int to take care of rounding
-  // int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
-  // Some detail:
-  // 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
-  // add a small number to a large number, the result rounds to the precision of
-  // the least significant bit of the large number. For IEEE-754
-  // single-precision number mantissa has 23 bits, and adding 2**23 would cause
-  // rounding to the nearest even integer. The we cast to int and subtract the
-  // same number (0x4B400000 is the integer representation of 12582912.0f) to
-  // get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
-  // sign for negative numbers.
-  const int32x4_t voffset = vdupq_n_s32(zero_point - 0x4B400000);
-  const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
-  for (i = 0; i + 8 < N; i += 8) {
-    const float32x4_t vin0123 = vld1q_f32(in);
-    in += 4;
-    const float32x4_t vin4567 = vld1q_f32(in);
-    in += 4;
-    const int32x4_t vraw0123 = vaddq_s32(
-        voffset,
-        vreinterpretq_s32_f32(
-            vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
-    const int32x4_t vraw4567 = vaddq_s32(
-        voffset,
-        vreinterpretq_s32_f32(
-            vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
-    const int16x8_t vraw01234567 =
-        vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
-    const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
-    vst1_u8(out, vout01234567);
-    out += 8;
-  }
-  for (; i < N; ++i) {
-    (*out++) = at::native::quantize_val_arm(scale, zero_point, (*in++));
-  }
-}
-
-#endif // __ARM_NEON__
-
-void quantize_tensor_per_tensor_affine_cpu(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point) {
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
-    TORCH_CHECK(rtensor.is_contiguous(), "Float tensor should be contiguous");
-    const float* const rdata = rtensor.data_ptr<float>();
-    // If QEngine is set to QNNPACK, use caffe2 specialized Int8Quantize implementation on ARM
-    #if defined(__ARM_NEON__)
-      if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
-        quantize_tensor_arm<scalar_t>(rdata, qtensor, rtensor.numel(), scale, zero_point);
-      }
-    #endif
-    auto qdata = qtensor.data_ptr<scalar_t>();
-    auto numel = rtensor.numel();
-    for (int i = 0; i < numel; ++i) {
-      qdata[i] = quantize_val<scalar_t>(scale, zero_point, rdata[i]);
-    }
-  });
-}
-
-void dequantize_tensor_per_tensor_affine_cpu(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point) {
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
-    const auto* qd = qtensor.data_ptr<scalar_t>();
-    float* rd = rtensor.data_ptr<float>();
-    auto numel = qtensor.numel();
-    for (auto i = 0; i < numel; ++i) {
-      rd[i] = dequantize_val<scalar_t>(scale, zero_point, qd[i]);
-    }
-  });
-}
-#endif  // USE_FBGEMM
-
-// TODO: add fbgemm for per channel
-void quantize_tensor_per_channel_affine_cpu(
-    Tensor rtensor,
-    Tensor qtensor,
-    Tensor scales,
-    Tensor zero_points,
-    int64_t axis) {
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() {
-    int64_t batches = size_to_dim_(axis, rtensor.sizes());
-    int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
-    int64_t channel = rtensor.size(axis);
-    auto scales_data = scales.data_ptr<double>();
-    auto zero_points_data = zero_points.data_ptr<int64_t>();
-    const float* rdata = rtensor.data_ptr<float>();
-    auto qdata = qtensor.data_ptr<scalar_t>();
-    for (auto b = 0; b < batches; ++b) {
-      for (auto c = 0; c < channel; ++c) {
-        for (auto e = 0; e < elements_per_channel; ++e) {
-          auto i = b * channel * elements_per_channel + c * elements_per_channel + e;
-          qdata[i] = quantize_val<scalar_t>(scales_data[c], zero_points_data[c], rdata[i]);
-        }
-      }
-    }
-  });
-}
-
-void dequantize_tensor_per_channel_affine_cpu(
-    Tensor qtensor,
-    Tensor rtensor,
-    Tensor scales,
-    Tensor zero_points,
-    int64_t axis) {
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "dequantize_tensor_per_channel_affine_cpu", [&]() {
-    int64_t batches = size_to_dim_(axis, rtensor.sizes());
-    int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
-    int64_t channel = rtensor.size(axis);
-    auto scales_data = scales.data_ptr<double>();
-    auto zero_points_data = zero_points.data_ptr<int64_t>();
-    const auto* qd = qtensor.data_ptr<scalar_t>();
-    float* rd = rtensor.data_ptr<float>();
-    for (auto b = 0; b < batches; ++b) {
-      for (auto c = 0; c < channel; ++c) {
-        for (auto e = 0; e < elements_per_channel; ++e) {
-          auto i = b * channel * elements_per_channel + c * elements_per_channel + e;
-          // We need to convert the qint8 value to float to ensure the subtraction
-          // subexpression returns a float
-          rd[i] = (static_cast<float>(qd[i].val_) - zero_points_data[c]) * scales_data[c];
-        }
-      }
-    }
-  });
-}
-
 } // namespace
 
 REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel);
@@ -2245,10 +2039,6 @@
 REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu);
 REGISTER_DISPATCH(fake_quant_grad_per_channel_stub, &fake_quant_grad_per_channel_cpu);
 REGISTER_DISPATCH(quantized_layer_norm_stub, &quantized_layer_norm_kernel);
-REGISTER_DISPATCH(quantize_tensor_per_tensor_affine_stub,  &quantize_tensor_per_tensor_affine_cpu);
-REGISTER_DISPATCH(quantize_tensor_per_channel_affine_stub, &quantize_tensor_per_channel_affine_cpu);
-REGISTER_DISPATCH(dequantize_tensor_per_tensor_affine_stub, &dequantize_tensor_per_tensor_affine_cpu);
-REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub, &dequantize_tensor_per_channel_affine_cpu);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/quantized/cpu/make_per_tensor_quantized_tensor.cpp b/aten/src/ATen/native/quantized/cpu/make_per_tensor_quantized_tensor.cpp
deleted file mode 100644
index 6d75c21..0000000
--- a/aten/src/ATen/native/quantized/cpu/make_per_tensor_quantized_tensor.cpp
+++ /dev/null
@@ -1,29 +0,0 @@
-#include <ATen/native/TensorIterator.h>
-#include <ATen/native/cpu/Loops.h>
-
-namespace at {
-namespace native {
-
-Tensor make_per_tensor_quantized_tensor_cpu(
-    const Tensor& self,
-    double scale,
-    int64_t zero_point) {
-  Tensor dst = at::_empty_affine_quantized(
-      self.sizes(),
-      self.options().dtype(toQIntType(self.scalar_type())),
-      scale,
-      zero_point);
-  Tensor self_contig = self.contiguous();
-  AT_DISPATCH_QINT_TYPES(dst.scalar_type(), "make_per_tensor_quantized_tensor", [&]() {
-    underlying_t* self_data = self_contig.data_ptr<underlying_t>();
-    underlying_t* dst_data =
-        reinterpret_cast<underlying_t*>(dst.data_ptr<scalar_t>());
-    if (self.numel() > 0) {
-      memcpy(dst_data, self_data, self.nbytes());
-    }
-  });
-  return dst;
-}
-
-} // namespace native
-} // namespace at
diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
index 15fd25b..3f5a3c2 100644
--- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
@@ -3,7 +3,7 @@
 #include <ATen/core/op_registration/op_registration.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/cpu/Loops.h>
-#include <ATen/native/quantized/affine_quantizer.h>
+#include <ATen/quantized/Quantizer.h>
 #include <ATen/native/quantized/cpu/quantized_ops.h>
 #include <ATen/native/quantized/cpu/init_qnnpack.h>
 #include <ATen/native/quantized/cpu/qnnpack_utils.h>
@@ -142,8 +142,8 @@
     using Vec = Vec256<scalar_t>;
     auto iter = TensorIterator::unary_op(qx, qx);
     auto zero_point_vec = Vec(scalar_t(zero_point));
-    scalar_t six = at::native::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(),
-                                                      /*value=*/6.0);
+    scalar_t six = at::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(),
+                                              /*value=*/6.0);
     auto six_vec = Vec(six);
     cpu_kernel_vec(
         iter,
diff --git a/aten/src/ATen/native/quantized/cpu/qupsample_bilinear2d.cpp b/aten/src/ATen/native/quantized/cpu/qupsample_bilinear2d.cpp
index 332c3c9..e3140ca 100644
--- a/aten/src/ATen/native/quantized/cpu/qupsample_bilinear2d.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qupsample_bilinear2d.cpp
@@ -1,7 +1,7 @@
 #include <ATen/ATen.h>
 #include <ATen/native/UpSample.h>
 #include <ATen/native/quantized/cpu/quantized_ops.h>
-#include <ATen/native/quantized/affine_quantizer.h>
+#include <ATen/quantized/Quantizer.h>
 
 #include <algorithm>
 #include <cmath>
@@ -78,9 +78,9 @@
                 (w0lambda * pos1[h1p * input_width] +
                  w1lambda * pos1[h1p * input_width + w1p]) - input.q_zero_point();
         // requantization
-        pos2[0] = at::native::quantize_val<scalar_t>(
-                              output_scale, output.q_zero_point(), result)
-                              .val_;
+        pos2[0] = at::quantize_val<scalar_t>(
+                      output_scale, output.q_zero_point(), result)
+                      .val_;
         pos1 += input_width * input_height;
         pos2 += output_width * output_height;
       }
diff --git a/aten/src/ATen/native/quantized/cuda/affine_quantizer.cu b/aten/src/ATen/native/quantized/cuda/affine_quantizer.cu
deleted file mode 100644
index 66ba70f..0000000
--- a/aten/src/ATen/native/quantized/cuda/affine_quantizer.cu
+++ /dev/null
@@ -1,49 +0,0 @@
-#include <math.h>
-#include <ATen/cuda/CUDAApplyUtils.cuh>
-#include <ATen/native/cuda/Loops.cuh>
-#include <ATen/native/quantized/affine_quantizer.h>
-#include <ATen/native/TensorIterator.h>
-
-namespace at {
-namespace native {
-namespace {
-
-void quantize_tensor_per_tensor_affine_cuda(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point){
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cuda", [&]() {
-    constexpr int64_t qmin = std::numeric_limits<underlying_t>::min();
-    constexpr int64_t qmax = std::numeric_limits<underlying_t>::max();
-    at::cuda::CUDA_tensor_apply2<float, scalar_t>(
-      /*a=*/rtensor,
-      /*b=*/qtensor,
-      [=] __device__ (
-        float& rtensor_val,
-        scalar_t& qtensor_val) {
-          int64_t qvalue;
-          qvalue = static_cast<int64_t>(nearbyint(rtensor_val / scale + zero_point));
-          qvalue = std::max<int64_t>(qvalue, qmin);
-          qvalue = std::min<int64_t>(qvalue, qmax);
-          qtensor_val.val_ = qvalue;
-    },
-    /*aType=*/at::cuda::TensorArgType::ReadOnly,
-    /*bType=*/at::cuda::TensorArgType::ReadWrite);
-  });
-}
-
-void dequantize_tensor_per_tensor_affine_cuda(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point){
-  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cuda", [&]() {
-    auto iter = TensorIterator();
-    iter.add_output(rtensor);
-    iter.add_input(qtensor);
-    iter.dont_compute_common_dtype();
-    iter.build();
-    gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t value) -> float { return (static_cast<float>(value.val_) - zero_point) * scale; });
-  });
-}
-
-} // anonymous namespace
-
-REGISTER_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_cuda);
-REGISTER_DISPATCH(dequantize_tensor_per_tensor_affine_stub, &dequantize_tensor_per_tensor_affine_cuda);
-
-} // namespace native
-} // namespace at
diff --git a/aten/src/ATen/native/quantized/cuda/int_repr_quant.cu b/aten/src/ATen/native/quantized/cuda/int_repr_quant.cu
deleted file mode 100644
index fa710df..0000000
--- a/aten/src/ATen/native/quantized/cuda/int_repr_quant.cu
+++ /dev/null
@@ -1,25 +0,0 @@
-#include <ATen/native/TensorIterator.h>
-#include <ATen/native/cuda/Loops.cuh>
-
-namespace at {                                                            
-namespace native {
-
-Tensor int_repr_quant_cuda(const Tensor& self) {
-  Tensor dst;
-  AT_DISPATCH_QINT_TYPES(self.scalar_type(), "int_repr_quant_cuda", [&]() {
-    dst = at::empty(
-        self.sizes(),
-        self.options().dtype(UNDERLYING_TYPE),
-        self.suggest_memory_format());
-    auto iter = TensorIterator();
-    iter.add_output(dst);
-    iter.add_input(self);
-    iter.dont_compute_common_dtype();
-    iter.build();
-    gpu_kernel(iter, []GPU_LAMBDA(scalar_t value) -> underlying_t { return value.val_; });
-  });
-  return dst;
-}
-
-} // namespace native
-} // namespace at
diff --git a/aten/src/ATen/native/quantized/cuda/make_per_tensor_quantized_tensor.cu b/aten/src/ATen/native/quantized/cuda/make_per_tensor_quantized_tensor.cu
deleted file mode 100644
index e3584e1..0000000
--- a/aten/src/ATen/native/quantized/cuda/make_per_tensor_quantized_tensor.cu
+++ /dev/null
@@ -1,28 +0,0 @@
-#include <ATen/native/TensorIterator.h>
-#include <ATen/native/cuda/Loops.cuh>
-
-namespace at {
-namespace native {
-
-Tensor make_per_tensor_quantized_tensor_cuda(
-    const Tensor& self,
-    double scale,
-    int64_t zero_point) {
-  Tensor dst = at::_empty_affine_quantized(
-    self.sizes(),
-    self.options().dtype(toQIntType(self.scalar_type())),
-    scale,
-    zero_point);
-  AT_DISPATCH_QINT_TYPES(dst.scalar_type(), "make_per_tensor_quantized_tensor_cuda", [&]() {
-    auto iter = TensorIterator();
-    iter.add_output(dst);
-    iter.add_input(self);
-    iter.dont_compute_common_dtype();
-    iter.build();
-    gpu_kernel(iter, []GPU_LAMBDA(underlying_t value) -> scalar_t { return scalar_t(value); });
-  });
-  return dst;
-}
-
-} // native
-} // at
diff --git a/aten/src/ATen/preprocess_declarations.py b/aten/src/ATen/preprocess_declarations.py
index c6dd6ee..b7d5612 100644
--- a/aten/src/ATen/preprocess_declarations.py
+++ b/aten/src/ATen/preprocess_declarations.py
@@ -28,7 +28,7 @@
 all_types = type_map['floating_point'] + type_map['integral'] + type_map['quantized']
 type_map['all'] = all_types
 
-all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA', 'MkldnnCPU', 'QuantizedCPU', 'QuantizedCUDA']
+all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA', 'MkldnnCPU', 'QuantizedCPU']
 default_backends = ['CPU', 'CUDA']
 
 
@@ -44,7 +44,7 @@
 
         backend_types = {}
         for backend in backends:
-            if backend in ('QuantizedCPU', 'QuantizedCUDA'):
+            if backend == 'QuantizedCPU':
                 backend_types[backend] = type_map['quantized']
             else:
                 backend_types[backend] = option.get('types', all_types)
diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp
index 07507d5..0b02097 100644
--- a/aten/src/ATen/quantized/Quantizer.cpp
+++ b/aten/src/ATen/quantized/Quantizer.cpp
@@ -9,10 +9,16 @@
 #include <ATen/native/utils/Allocator.h>
 #include <ATen/quantized/QTensorImpl.h>
 #include <ATen/core/Tensor.h>
-#include <ATen/native/quantized/affine_quantizer.h>
 #include <typeinfo>
 #include <cmath>
 
+#ifdef USE_FBGEMM
+#include <fbgemm/QuantUtils.h>
+#endif
+#ifdef __ARM_NEON__
+#include <arm_neon.h>
+#endif
+
 namespace at {
 
 // Note: this is not a native function as Quantizer is not exposed to python yet
@@ -22,6 +28,444 @@
   return get_qtensorimpl(*this)->quantizer();
 }
 
+void checkFloatCPUTensor(std::string fn_name, Tensor t) {
+  TORCH_CHECK(
+      t.scalar_type() == kFloat,
+      fn_name,
+      " expects a Float Tensor.");
+  TORCH_CHECK(
+      t.device() == kCPU,
+      fn_name,
+      " expects a CPU Tensor.");
+}
+
+template <typename T>
+void checkQuantizedCPUTensor(std::string fn_name, Tensor t) {
+  TORCH_CHECK(t.is_quantized(),
+           fn_name,
+           " expects a quantized Tensor.");
+  TORCH_CHECK(t.scalar_type() == caffe2::TypeMeta::Make<T>(),
+           fn_name,
+           " expects a ",
+           caffe2::TypeMeta::Make<T>(),
+           " Tensor");
+  TORCH_CHECK(t.device() == kCPU,
+           fn_name,
+           " expects a CPU quantized Tensor");
+}
+
+template <typename T>
+void checkZeroPoint(std::string fn_name, int64_t zero_point) {
+  TORCH_CHECK(zero_point <= std::numeric_limits<T>::max(),
+              fn_name,
+              " zero_point ",
+              zero_point,
+              " is out of range.");
+  TORCH_CHECK(zero_point >= std::numeric_limits<T>::min(),
+              fn_name,
+              " zero_point ",
+              zero_point,
+              " is out of range.");
+}
+
+template <typename T>
+void checkZeroPoints(std::string fn_name, Tensor zero_points) {
+  auto zero_points_data = zero_points.data_ptr<int64_t>();
+  for (size_t i = 0; i < zero_points.numel(); ++i) {
+    TORCH_CHECK(zero_points_data[i] <= std::numeric_limits<T>::max(),
+                fn_name,
+                "zero_point",
+                i,
+                "is out of range.");
+    TORCH_CHECK(zero_points_data[i] >= std::numeric_limits<T>::min(),
+                fn_name,
+                "zero_point",
+                i,
+                "is out of range.");
+  }
+}
+
+#ifdef USE_FBGEMM
+// Note: quantize_val is only explicitly used in test outside of this file
+template <typename T>
+T quantize_val(double scale, int64_t zero_point, float value) {
+  // Internally, fbgemm::Quantize uses std::nearbyint.
+  // std::nearbyint results in nearest integer value according to the current
+  // rounding mode and the default rounding mode is rounds to even in half-way
+  // cases in most popular processor architectures like x86 and ARM. This is
+  // typically faster than an alternatives like std::round that rounds half-way
+  // cases away from zero, and can be consistent with SIMD implementations for
+  // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
+  // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
+  int32_t qvalue;
+  qvalue = fbgemm::Quantize<typename T::underlying>(
+      value,
+      static_cast<int32_t>(zero_point),
+      static_cast<double>(scale),
+      /*result_precision=*/CHAR_BIT * sizeof(typename T::underlying));
+  return static_cast<T>(qvalue);
+}
+
+template <typename T, int precision>
+void quantize_vec(double scale, int64_t zero_point, const float *src, T *dst, size_t count) {
+  fbgemm::Quantize<typename T::underlying>(
+    src,
+    (typename T::underlying*)dst,
+    count,
+    fbgemm::TensorQuantizationParams{(float)scale, (int32_t)zero_point, precision}
+  );
+}
+
+template <typename T>
+Tensor quantize_tensor(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point) {
+  auto fn_name = "quantize_tensor";
+  checkFloatCPUTensor(fn_name, rtensor);
+  checkQuantizedCPUTensor<T>(fn_name, qtensor);
+  checkZeroPoint<typename T::underlying>(fn_name, zero_point);
+  const float* rd = rtensor.data_ptr<float>();
+  auto qd = reinterpret_cast<typename T::underlying*>(qtensor.data_ptr<T>());
+  fbgemm::TensorQuantizationParams qparams;
+  qparams.scale = scale;
+  qparams.zero_point = zero_point;
+  qparams.precision = CHAR_BIT * sizeof(typename T::underlying);
+  int num_tasks = at::get_num_threads();
+  at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
+    for (int task_id = begin; task_id < end; ++task_id) {
+      fbgemm::Quantize<typename T::underlying>(
+          rd, /*src=*/
+          qd, /*dst=*/
+          rtensor.numel(), /*len*/
+          qparams, /*qparams=*/
+          task_id, /*thread_id*/
+          num_tasks /*num_threads*/);
+    }
+  });
+  return qtensor;
+}
+
+template <typename T>
+inline float dequantize_val(double scale, int64_t zero_point, T value) {
+  fbgemm::TensorQuantizationParams qparams;
+  qparams.scale = static_cast<float>(scale);
+  qparams.zero_point = static_cast<int32_t>(zero_point);
+  return fbgemm::Dequantize<typename T::underlying>(value.val_, qparams);
+}
+
+template <typename T>
+Tensor dequantize_tensor(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point) {
+  auto fn_name = "dequantize_tensor";
+  checkFloatCPUTensor(fn_name, rtensor);
+  checkQuantizedCPUTensor<T>(fn_name, qtensor);
+  checkZeroPoint<typename T::underlying>(fn_name, zero_point);
+  const auto* qd = reinterpret_cast<const typename T::underlying*>(qtensor.data_ptr<T>());
+  fbgemm::TensorQuantizationParams qparams;
+  qparams.scale = scale;
+  qparams.zero_point = zero_point;
+  qparams.precision = CHAR_BIT * sizeof(typename T::underlying);
+  float* rd = rtensor.data_ptr<float>();
+  int num_tasks = at::get_num_threads();
+  at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
+    for (int task_id = begin; task_id < end; ++task_id) {
+      fbgemm::Dequantize<typename T::underlying>(
+          qd, /*src=*/
+          rd, /*dst=*/
+          qtensor.numel(), /*len=*/
+          qparams, /*qparams=*/
+          task_id, /*thread_id*/
+          num_tasks /*num_threads*/);
+    }
+  });
+  return rtensor;
+}
+#else  // USE_FBGEMM
+
+#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
+template <class T>
+inline float Round(const float x) {
+  return ::nearbyintf(x);
+}
+inline double Round(const double x) {
+  return ::nearbyint(x);
+}
+#else
+template <class T>
+inline T Round(const T x) {
+  return std::nearbyint(x);
+}
+#endif
+
+template <typename T>
+T quantize_val(double scale, int64_t zero_point, float value) {
+  // std::nearbyint results in nearest integer value according to the current
+  // rounding mode and the default rounding mode is rounds to even in half-way
+  // cases in most popular processor architectures like x86 and ARM. This is
+  // typically faster than an alternatives like std::round that rounds half-way
+  // cases away from zero, and can be consistent with SIMD implementations for
+  // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
+  // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
+  int64_t qvalue;
+  constexpr int64_t qmin = std::numeric_limits<typename T::underlying>::min();
+  constexpr int64_t qmax = std::numeric_limits<typename T::underlying>::max();
+  qvalue = static_cast<int64_t>(Round(value / scale + zero_point));
+  qvalue = std::max<int64_t>(qvalue, qmin);
+  qvalue = std::min<int64_t>(qvalue, qmax);
+  return static_cast<T>(qvalue);
+}
+
+template <typename T, int precision>
+void quantize_vec(double scale, int64_t zero_point, const float *src, T *dst, size_t count) {
+  checkZeroPoint<typename T::underlying>("quantize_val", zero_point);
+  for (int64_t i = 0; i < count; ++i) {
+    dst[i] = quantize_val<T>(scale, zero_point, src[i]);
+  }
+}
+
+// TODO combine this with quantize_val once the numerics for ARM are aligned with it
+inline uint8_t quantize_val_arm(const float scale, const int32_t zero_point, const float value) {
+  const int32_t qmin = std::numeric_limits<uint8_t>::min();
+  const int32_t qmax = std::numeric_limits<uint8_t>::max();
+  auto r = zero_point + static_cast<int32_t>(Round(value / scale));
+  r = std::max(r, qmin);
+  r = std::min(r, qmax);
+  return static_cast<uint8_t>(r);
+}
+
+#ifdef __ARM_NEON__
+// Generic template defaults to naive quantize implementation
+template <typename T>
+void quantize_tensor_arm(
+    const float* in,
+    Tensor qtensor,
+    const int64_t N,
+    const float scale,
+    const int32_t zero_point) {
+  auto out = qtensor.data_ptr<T>();
+  for (int i = 0; i < N; ++i) {
+    out[i] = quantize_val<T>(scale, zero_point, in[i]);
+  }
+}
+
+// Specialized implementation from caffe2::Int8Quantize.
+// There may be slight accuracy difference between this and implementation of quantize_val
+// TODO Update quantize_tensor_arm implementation to follow quantize_val,
+// i.e. f = Round(value/scale + zero_point)
+// TODO Make quantize_tensor_arm work for other datatypes too (int8, int32).
+template <>
+void quantize_tensor_arm<c10::quint8>(
+    const float* in,
+    Tensor qtensor,
+    const int64_t N,
+    const float scale,
+    const int32_t zero_point) {
+  const float inv_scale = 1.0f / scale;
+  uint32_t i = 0;
+  auto out = (uint8_t*)qtensor.data_ptr<c10::quint8>();
+  const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);
+  // magic float and magic int to take care of rounding
+  // int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
+  // Some detail:
+  // 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
+  // add a small number to a large number, the result rounds to the precision of
+  // the least significant bit of the large number. For IEEE-754
+  // single-precision number mantissa has 23 bits, and adding 2**23 would cause
+  // rounding to the nearest even integer. The we cast to int and subtract the
+  // same number (0x4B400000 is the integer representation of 12582912.0f) to
+  // get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
+  // sign for negative numbers.
+  const int32x4_t voffset = vdupq_n_s32(zero_point - 0x4B400000);
+  const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
+  for (i = 0; i + 8 < N; i += 8) {
+    const float32x4_t vin0123 = vld1q_f32(in);
+    in += 4;
+    const float32x4_t vin4567 = vld1q_f32(in);
+    in += 4;
+    const int32x4_t vraw0123 = vaddq_s32(
+        voffset,
+        vreinterpretq_s32_f32(
+            vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
+    const int32x4_t vraw4567 = vaddq_s32(
+        voffset,
+        vreinterpretq_s32_f32(
+            vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
+    const int16x8_t vraw01234567 =
+        vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
+    const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
+    vst1_u8(out, vout01234567);
+    out += 8;
+  }
+  for (; i < N; ++i) {
+    (*out++) = quantize_val_arm(scale, zero_point, (*in++));
+  }
+}
+#endif // __ARM_NEON__
+
+template <typename T>
+Tensor quantize_tensor(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point) {
+  auto fn_name = "quantize_tensor";
+  checkFloatCPUTensor(fn_name, rtensor);
+  checkQuantizedCPUTensor<T>(fn_name, qtensor);
+  checkZeroPoint<typename T::underlying>(fn_name, zero_point);
+  TORCH_CHECK(rtensor.is_contiguous(), "Float tensor should be contiguous");
+  const float* const rdata = rtensor.data_ptr<float>();
+  // If QEngine is set to QNNPACK, use caffe2 specialized Int8Quantize implementation on ARM
+#if defined(__ARM_NEON__)
+  if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
+    quantize_tensor_arm<T>(rdata, qtensor, rtensor.numel(), scale, zero_point);
+    return qtensor;
+  }
+#endif
+  auto qdata = qtensor.data_ptr<T>();
+  auto numel = rtensor.numel();
+  for (int i = 0; i < numel; ++i) {
+    qdata[i] = quantize_val<T>(scale, zero_point, rdata[i]);
+  }
+  return qtensor;
+}
+
+template <typename T>
+CAFFE2_API float dequantize_val(double scale, int64_t zero_point, T value) {
+  // We need to convert the qint8 value to float to ensure the subtraction
+  // subexpression returns a float
+  return (static_cast<float>(value.val_) - zero_point) * scale;
+}
+
+template <typename T>
+Tensor dequantize_tensor(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point) {
+  auto fn_name = "dequantize_tensor";
+  checkFloatCPUTensor(fn_name, rtensor);
+  checkQuantizedCPUTensor<T>(fn_name, qtensor);
+  checkZeroPoint<typename T::underlying>(fn_name, zero_point);
+  const auto* qd = qtensor.data_ptr<T>();
+  float* rd = rtensor.data_ptr<float>();
+  auto numel = qtensor.numel();
+  for (auto i = 0; i < numel; ++i) {
+    rd[i] = dequantize_val<T>(scale, zero_point, qd[i]);
+  }
+  return rtensor;
+}
+#endif  // USE_FBGEMM
+
+template <typename SRC_T, typename DST_T>
+DST_T requantize_val(double src_scale, int64_t src_zero_point,
+                     double dst_scale, int64_t dst_zero_point,
+                     SRC_T src) {
+  const auto dq = dequantize_val<SRC_T>(src_scale, src_zero_point, src);
+  return quantize_val<DST_T>(dst_scale, dst_zero_point, dq);
+}
+
+template <typename DST_T>
+DST_T requantize_from_int(double multiplier, int64_t zero_point, int64_t src) {
+  int64_t quantize_down =
+      zero_point + lrintf(src * static_cast<float>(multiplier));
+  int32_t min = std::numeric_limits<typename DST_T::underlying>::min();
+  int32_t max = std::numeric_limits<typename DST_T::underlying>::max();
+  return static_cast<DST_T>(
+      std::min<int64_t>(std::max<int64_t>(quantize_down, min), max));
+}
+
+template CAFFE2_API qint8 quantize_val<qint8>(double scale, int64_t zero_point, float value);
+template CAFFE2_API quint8 quantize_val<quint8>(double scale, int64_t zero_point, float value);
+template CAFFE2_API qint32 quantize_val<qint32>(double scale, int64_t zero_point, float value);
+template CAFFE2_API void quantize_vec<c10::qint8>(double scale, int64_t zero_point, const float *src, c10::qint8 *dst, size_t count);
+template CAFFE2_API void quantize_vec<c10::quint8>(double scale, int64_t zero_point, const float *src, c10::quint8 *dst, size_t count);
+template CAFFE2_API void quantize_vec<c10::qint32, 32>(double scale, int64_t zero_point, const float *src, c10::qint32 *dst, size_t count);
+template CAFFE2_API Tensor quantize_tensor<qint8>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
+template CAFFE2_API Tensor quantize_tensor<quint8>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
+template CAFFE2_API Tensor quantize_tensor<qint32>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
+
+template CAFFE2_API float dequantize_val<qint8>(double scale, int64_t zero_point, qint8 value);
+template CAFFE2_API float dequantize_val<quint8>(double scale, int64_t zero_point, quint8 value);
+template CAFFE2_API float dequantize_val<qint32>(double scale, int64_t zero_point, qint32 value);
+template CAFFE2_API Tensor dequantize_tensor<qint8>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
+template CAFFE2_API Tensor dequantize_tensor<quint8>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
+template CAFFE2_API Tensor dequantize_tensor<qint32>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
+
+template CAFFE2_API qint8 requantize_val<qint8, qint8>(double, int64_t, double, int64_t, qint8);
+template CAFFE2_API quint8 requantize_val<qint8, quint8>(double, int64_t, double, int64_t, qint8);
+template CAFFE2_API qint32 requantize_val<qint8, qint32>(double, int64_t, double, int64_t, qint8);
+template CAFFE2_API qint8 requantize_val<quint8, qint8>(double, int64_t, double, int64_t, quint8);
+template CAFFE2_API quint8 requantize_val<quint8, quint8>(double, int64_t, double, int64_t, quint8);
+template CAFFE2_API qint32 requantize_val<quint8, qint32>(double, int64_t, double, int64_t, quint8);
+template CAFFE2_API qint8 requantize_val<qint32, qint8>(double, int64_t, double, int64_t, qint32);
+template CAFFE2_API quint8 requantize_val<qint32, quint8>(double, int64_t, double, int64_t, qint32);
+template CAFFE2_API qint32 requantize_val<qint32, qint32>(double, int64_t, double, int64_t, qint32);
+
+template CAFFE2_API qint8 requantize_from_int<qint8>(double, int64_t, int64_t);
+template CAFFE2_API quint8
+requantize_from_int<quint8>(double, int64_t, int64_t);
+template CAFFE2_API qint32
+requantize_from_int<qint32>(double, int64_t, int64_t);
+
+// TODO: add fbgemm for per channel
+template <typename T>
+Tensor quantize_tensor_per_channel_affine(Tensor rtensor,
+                                          Tensor qtensor,
+                                          Tensor scales,
+                                          Tensor zero_points,
+                                          int64_t axis) {
+  auto fn_name = "quantize_tensor_per_channel_affine";
+  checkFloatCPUTensor(fn_name, rtensor);
+  checkQuantizedCPUTensor<T>(fn_name, qtensor);
+  checkZeroPoints<typename T::underlying>(fn_name, zero_points);
+  TORCH_CHECK(0 <= axis && axis < rtensor.dim(), "Channel axis out of range in per channel affine quantization.");
+  int64_t batches = size_to_dim_(axis, rtensor.sizes());
+  int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
+  int64_t channel = rtensor.size(axis);
+  auto scales_data = scales.data_ptr<double>();
+  auto zero_points_data = zero_points.data_ptr<int64_t>();
+  TORCH_CHECK(channel == int64_t(scales.numel()),
+              "length of scales must equal to channel");
+  TORCH_CHECK(channel == int64_t(zero_points.numel()),
+              "length of zero_points must equal to channel");
+  const float* rdata = rtensor.data_ptr<float>();
+  auto qdata = qtensor.data_ptr<T>();
+  for (auto b = 0; b < batches; ++b) {
+    for (auto c = 0; c < channel; ++c) {
+      for (auto e = 0; e < elements_per_channel; ++e) {
+        auto i = b * channel * elements_per_channel + c * elements_per_channel + e;
+        qdata[i] = quantize_val<T>(scales_data[c], zero_points_data[c], rdata[i]);
+      }
+    }
+  }
+  return qtensor;
+}
+
+template <typename T>
+Tensor dequantize_tensor_per_channel_affine(Tensor qtensor,
+                                            Tensor rtensor,
+                                            Tensor scales,
+                                            Tensor zero_points,
+                                            int64_t axis) {
+  auto fn_name = "dequantize_tensor_per_channel_affine";
+  checkFloatCPUTensor(fn_name, rtensor);
+  checkQuantizedCPUTensor<T>(fn_name, qtensor);
+  checkZeroPoints<typename T::underlying>(fn_name, zero_points);
+  TORCH_CHECK(0 <= axis && axis < qtensor.dim(),
+              "Channel axis out of range in per channel affine dequantization.");
+  int64_t batches = size_to_dim_(axis, rtensor.sizes());
+  int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
+  int64_t channel = rtensor.size(axis);
+  auto scales_data = scales.data_ptr<double>();
+  auto zero_points_data = zero_points.data_ptr<int64_t>();
+  TORCH_CHECK(channel == int64_t(scales.numel()),
+              "length of scales must equal to channel");
+  TORCH_CHECK(channel == int64_t(zero_points.numel()),
+              "length of zero_points must equal to channel");
+  const auto* qd = qtensor.data_ptr<T>();
+  float* rd = rtensor.data_ptr<float>();
+  for (auto b = 0; b < batches; ++b) {
+    for (auto c = 0; c < channel; ++c) {
+      for (auto e = 0; e < elements_per_channel; ++e) {
+        auto i = b * channel * elements_per_channel + c * elements_per_channel + e;
+        // We need to convert the qint8 value to float to ensure the subtraction
+        // subexpression returns a float
+        rd[i] = (static_cast<float>(qd[i].val_) - zero_points_data[c]) * scales_data[c];
+      }
+    }
+  }
+  return rtensor;
+}
+
 QuantizerPtr make_per_tensor_affine_quantizer(
     double scale,
     int64_t zero_point,
@@ -74,33 +518,36 @@
 
 #endif
 
-inline Tensor new_qtensor(
+inline Tensor new_qtensor_cpu(
     IntArrayRef sizes,
     const TensorOptions& options,
     QuantizerPtr quantizer) {
+  AT_ASSERT(options.device().is_cpu());
+
   auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Contiguous);
-  at::Allocator* allocator = GetAllocator(options.device().type());
 
-  #ifdef USE_PYTORCH_QNNPACK
-    if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
-      static QAllocator qallocator;
-      allocator = &qallocator;
-    }
-  #endif
+  at::Allocator* allocator = at::getCPUAllocator();
 
-  at::DispatchKey tensorDispatchKey = options.computeDispatchKey();
+#ifdef USE_PYTORCH_QNNPACK
+  if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
+    static QAllocator qallocator;
+    allocator = &qallocator;
+  }
+#endif
+
   native::check_size_nonnegative(sizes);
   int64_t nelements = at::prod_intlist(sizes);
   auto dtype = options.dtype();
   TORCH_CHECK(isQIntType(typeMetaToScalarType(dtype)),
-           "ScalarType is not supported in new_qtensor.");
+           "ScalarType is not supported in new_qtensor_cpu.");
   auto storage = c10::make_intrusive<StorageImpl>(
       dtype,
       nelements,
       allocator->allocate(nelements * dtype.itemsize()),
       allocator,
       /*resizable=*/true);
-  auto tensor = detail::make_tensor<QTensorImpl>(storage, at::DispatchKeySet(tensorDispatchKey), quantizer);
+  auto tensor = detail::make_tensor<QTensorImpl>(
+      storage, at::DispatchKeySet(at::DispatchKey::QuantizedCPU), quantizer);
   get_qtensorimpl(tensor)->set_sizes_contiguous(sizes);
   get_qtensorimpl(tensor)->empty_tensor_restride(memory_format);
   return tensor;
@@ -108,43 +555,81 @@
 
 Tensor PerTensorAffineQuantizer::quantize(Tensor rtensor) {
   TORCH_CHECK(
-    rtensor.scalar_type() == kFloat,
-    "quantize only works on Float Tensor.");
+      rtensor.scalar_type() == kFloat,
+      "quantize only works on Float Tensor.");
+  TORCH_CHECK(
+      rtensor.device() == kCPU,
+      "quantize only works for CPU backend right now.");
   // Here we need a std::intrusive_ptr<Quantizer>.. but actually "this" is the
   // quantizer that can be reused, so I'm using intrusive_from_this here
-  Tensor qtensor = new_qtensor(
+  Tensor qtensor = new_qtensor_cpu(
       rtensor.sizes(),
       rtensor.options().dtype(scalar_type_),
       intrusive_from_this());
 
   rtensor = rtensor.contiguous();
-  native::quantize_tensor_per_tensor_affine(rtensor, qtensor, scale_, zero_point_);
+  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "quantize_tensor", [&]() {
+    qtensor = quantize_tensor<scalar_t>(rtensor, qtensor, scale_, zero_point_);
+  });
   return qtensor;
 }
 
 Tensor PerTensorAffineQuantizer::dequantize(Tensor qtensor) {
+  TORCH_CHECK(qtensor.is_quantized(),
+           "dequantize is only supported in quantized Tensor.");
+  TORCH_CHECK(
+      qtensor.device() == kCPU,
+      "dequantize only works for CPU backend right now.");
   Tensor rtensor = at::empty(qtensor.sizes(), qtensor.options().dtype(at::kFloat));
   qtensor = qtensor.contiguous();
-  native::dequantize_tensor_per_tensor_affine(qtensor, rtensor, scale_, zero_point_);
+
+  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "dequantize_tensor", [&]() {
+    rtensor = dequantize_tensor<scalar_t>(qtensor, rtensor, scale_, zero_point_);
+  });
+
   return rtensor;
 }
 
 Tensor PerChannelAffineQuantizer::quantize(Tensor rtensor) {
+  TORCH_CHECK(
+      rtensor.scalar_type() == kFloat,
+      "quantize only works on Float Tensor.");
+  TORCH_CHECK(
+      rtensor.device() == kCPU,
+      "quantize only works for CPU backend right now.");
   // Here we need a std::intrusive_ptr<Quantizer>.. but actually "this" is the
   // quantizer that can be reused, so I'm using intrusive_from_this here
-  Tensor qtensor = new_qtensor(
+  Tensor qtensor = new_qtensor_cpu(
       rtensor.sizes(),
       rtensor.options().dtype(scalar_type_),
       intrusive_from_this());
+
   rtensor = rtensor.contiguous();
-  native::quantize_tensor_per_channel_affine(rtensor, qtensor, scales_, zero_points_, axis_);
+  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(),
+                         "quantize_tensor_per_channel_affine",
+                         [&]() {
+    qtensor = quantize_tensor_per_channel_affine<scalar_t>(
+        rtensor, qtensor, scales_, zero_points_, axis_);
+  });
   return qtensor;
 }
 
 Tensor PerChannelAffineQuantizer::dequantize(Tensor qtensor) {
+  TORCH_CHECK(qtensor.is_quantized(),
+           "dequantize is only supported in quantized Tensor.");
+  TORCH_CHECK(
+      qtensor.device() == kCPU,
+      "dequantize only works for CPU backend right now.");
   Tensor rtensor = at::empty(qtensor.sizes(), qtensor.options().dtype(at::kFloat));
   qtensor = qtensor.contiguous();
-  native::dequantize_tensor_per_channel_affine(qtensor, rtensor, scales_, zero_points_, axis_);
+
+  AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(),
+                         "dequantize_tensor_per_channel_affine",
+                         [&]() {
+    rtensor = dequantize_tensor_per_channel_affine<scalar_t>(
+        qtensor, rtensor, scales_, zero_points_, axis_);
+  });
+
   return rtensor;
 }
 
diff --git a/aten/src/ATen/quantized/Quantizer.h b/aten/src/ATen/quantized/Quantizer.h
index 298fa0e..547cf4d 100644
--- a/aten/src/ATen/quantized/Quantizer.h
+++ b/aten/src/ATen/quantized/Quantizer.h
@@ -8,12 +8,14 @@
 #include <c10/core/ScalarType.h>
 #include <c10/core/TensorOptions.h>
 
-#include <ATen/Tensor.h>
 #include <ATen/TensorUtils.h>
 
 #include <cmath>
 #include <memory>
 
+// TODO: move to c10 namespace after we
+// unified caffe2::Tensor and at::Tensor
+
 namespace at {
 
 class Tensor;
@@ -229,6 +231,29 @@
 // This may be called repeatedly, so make sure it's pretty cheap.
 CAFFE2_API QTensorImpl* get_qtensorimpl(const Tensor& self);
 
+// Quantize a float value into a uint value given scale and zero_point
+template <typename T>
+CAFFE2_API T quantize_val(double scale, int64_t zero_point, float value);
+template <typename T, int precision=8>
+void quantize_vec(double scale, int64_t zero_point, const float *src, T *dst, size_t count=8);
+template <typename T>
+CAFFE2_API Tensor quantize_tensor(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
+template <typename T>
+CAFFE2_API float dequantize_val(double scale, int64_t zero_point, T value);
+template <typename T>
+CAFFE2_API float dequantize_vec(double scale, int64_t zero_point, const T* src, float* dst, size_t count=8);
+template <typename T>
+CAFFE2_API Tensor dequantize_tensor(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point);
+template <typename SRC_T, typename DST_T>
+CAFFE2_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src);
+
+// Given a multiplier and a zero_point, requantize int32_t computed values back
+// to quantized values. See comment above
+// make_per_tensor_affine_quantizer function for the usage of int64_t
+template <typename DST_T>
+CAFFE2_API DST_T
+requantize_from_int(double multiplier, int64_t zero_point, int64_t src);
+
 // double and int64_t are because of the native function API, we only have these
 // argument types right now in native functions
 CAFFE2_API QuantizerPtr
@@ -242,7 +267,7 @@
     ScalarType scalar_type);
 
 // Create a Quantized Tensor given arguments for normal Tensor and a quantizer
-CAFFE2_API Tensor new_qtensor(
+CAFFE2_API Tensor new_qtensor_cpu(
     IntArrayRef sizes,
     const TensorOptions& options,
     QuantizerPtr quantizer);
diff --git a/aten/src/ATen/test/quantized_test.cpp b/aten/src/ATen/test/quantized_test.cpp
index 555d94c..87bf735 100644
--- a/aten/src/ATen/test/quantized_test.cpp
+++ b/aten/src/ATen/test/quantized_test.cpp
@@ -8,7 +8,7 @@
 #include <sstream>
 #include <type_traits>
 // For quantize_val
-#include <ATen/native/quantized/affine_quantizer.h>
+#include <ATen/quantized/Quantizer.h>
 #include <c10/core/ScalarType.h>
 
 using namespace at;
@@ -36,7 +36,7 @@
   auto qr_data = qr.data_ptr<quint8>();
   for (auto i = 0; i < num_elements; ++i) {
     ASSERT_EQ(
-      native::quantize_val<quint8>(scale, zero_point, r_data[i]).val_,
+      quantize_val<quint8>(scale, zero_point, r_data[i]).val_,
       qr_data[i].val_);
   }
 
@@ -48,7 +48,7 @@
   }
   for (auto i = 0; i < num_elements; ++i) {
     ASSERT_EQ(r_data[i],
-              native::dequantize_val(qr.q_scale(), qr.q_zero_point(), qr_data[i]));
+              dequantize_val(qr.q_scale(), qr.q_zero_point(), qr_data[i]));
   }
 
   // Check for correct requantization
@@ -57,11 +57,11 @@
   Tensor reqr = at::quantize_per_tensor(r, new_scale, new_zero_point, kQInt8);
   auto reqr_data = reqr.data_ptr<qint8>();
   for (auto i = 0; i < num_elements; ++i) {
-    reqr_data[i].val_ = native::requantize_val<quint8, qint8>(scale, zero_point,
-                                                              new_scale, new_zero_point,
-                                                              qr_data[i]).val_;
-    const qint8 expected = native::quantize_val<qint8>(new_scale, new_zero_point,
-                                                       rqr_data[i]);
+    reqr_data[i].val_ = requantize_val<quint8, qint8>(scale, zero_point,
+                                                      new_scale, new_zero_point,
+                                                      qr_data[i]).val_;
+    const qint8 expected = quantize_val<qint8>(new_scale, new_zero_point,
+                                               rqr_data[i]);
     ASSERT_EQ(expected.val_, reqr_data[i].val_);
   }
 }
diff --git a/c10/core/Backend.h b/c10/core/Backend.h
index 8f9ba79..2512af0 100644
--- a/c10/core/Backend.h
+++ b/c10/core/Backend.h
@@ -25,7 +25,7 @@
  * or "SparseCUDA"; backend in torch.backends is something like "MKL" or
  * "CUDNN".
  */
-enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, QuantizedCPU, QuantizedCUDA, Undefined, MkldnnCPU, NumOptions };
+enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, QuantizedCPU, Undefined, MkldnnCPU, NumOptions };
 
 static inline Backend toSparse(Backend b) {
   switch (b) {
@@ -66,8 +66,6 @@
       return Backend::HIP;
     case Backend::QuantizedCPU:
       return Backend::QuantizedCPU;
-    case Backend::QuantizedCUDA:
-      return Backend::QuantizedCUDA;
     default:
       throw std::runtime_error("Unknown backend");
   }
@@ -94,8 +92,6 @@
     return Backend::MkldnnCPU;
   } else if (t == DispatchKey::QuantizedCPU) {
     return Backend::QuantizedCPU;
-  } else if (t == DispatchKey::QuantizedCUDA) {
-    return Backend::QuantizedCUDA;
   } else if (t == DispatchKey::Undefined) {
     return Backend::Undefined;
   } else {
@@ -125,8 +121,6 @@
       return DispatchKey::MkldnnCPU;
     case Backend::QuantizedCPU:
       return DispatchKey::QuantizedCPU;
-    case Backend::QuantizedCUDA:
-      return DispatchKey::QuantizedCUDA;
     case Backend::Undefined:
       return DispatchKey::Undefined;
     default:
@@ -155,8 +149,6 @@
     case Backend::MkldnnCPU:
     case Backend::QuantizedCPU:
       return DeviceType::CPU;
-    case Backend::QuantizedCUDA:
-      return DeviceType::CUDA;
     case Backend::Undefined:
       AT_ERROR("Undefined backend is not a valid device type");
     default:
@@ -185,8 +177,6 @@
       return Backend::MkldnnCPU;
     case Backend::QuantizedCPU:
       return Backend::QuantizedCPU;
-    case Backend::QuantizedCUDA:
-      return Backend::QuantizedCPU;
     case Backend::Undefined:
       return Backend::Undefined;
     default:
@@ -255,8 +245,6 @@
       return "MkldnnCPU";
     case Backend::QuantizedCPU:
       return "QuantizedCPU";
-    case Backend::QuantizedCUDA:
-      return "QuantizedCUDA";
     default:
       return "UNKNOWN_BACKEND";
   }
diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h
index 1760e59..fdf3472 100644
--- a/c10/core/DispatchKey.h
+++ b/c10/core/DispatchKey.h
@@ -67,10 +67,9 @@
 
   // Here are backends which specify more specialized operators
   // based on the dtype of the tensor.
-  QuantizedCPU,  // registered at build/aten/src/ATen/QuantizedCPUType.cpp
-  QuantizedCUDA, // registered at build/aten/src/ATen/QuantizedCUDAType.cpp
-  ComplexCPU,    // lives out of tree at https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex
-  ComplexCUDA,   // and https://gitlab.com/pytorch-complex/pytorch-cuda-strided-complex
+  QuantizedCPU, // registered at build/aten/src/ATen/QuantizedCPUType.cpp
+  ComplexCPU,   // lives out of tree at https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex
+  ComplexCUDA,  // and https://gitlab.com/pytorch-complex/pytorch-cuda-strided-complex
                         // tested at test/cpp_extensions/complex_registration_extension.cpp
                         // TODO: Remove Complex dispatch keys when Complex is moved in tree
 
diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h
index 515d195..eba122c 100644
--- a/c10/core/TensorImpl.h
+++ b/c10/core/TensorImpl.h
@@ -430,15 +430,13 @@
 
   bool is_quantized() const {
     // NB: This method is not virtual and avoid dispatches for performance reasons.
-    return key_set_.has(DispatchKey::QuantizedCPU) || 
-           key_set_.has(DispatchKey::QuantizedCUDA);
+    return key_set_.has(DispatchKey::QuantizedCPU);
   }
 
   bool is_cuda() const {
     // NB: This method is not virtual and avoid dispatches for performance reasons.
     return key_set_.has(DispatchKey::CUDA) ||
-           key_set_.has(DispatchKey::SparseCUDA) ||
-           key_set_.has(DispatchKey::QuantizedCUDA);
+           key_set_.has(DispatchKey::SparseCUDA);
   }
 
   bool is_hip() const {
diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h
index 670fd34..7b7ab17 100644
--- a/c10/core/TensorOptions.h
+++ b/c10/core/TensorOptions.h
@@ -380,13 +380,8 @@
             }
             return DispatchKey::CPU;
             }
-          case DeviceType::CUDA: {
-            auto dtype_tmp = typeMetaToScalarType(dtype());
-            if (isQIntType(dtype_tmp)) {
-              return DispatchKey::QuantizedCUDA;
-            }
+          case DeviceType::CUDA:
             return DispatchKey::CUDA;
-            }
           case DeviceType::MKLDNN:
             return DispatchKey::MKLDNN;
           case DeviceType::OPENGL:
diff --git a/c10/util/qint32.h b/c10/util/qint32.h
index 1217b68..0aa744e 100644
--- a/c10/util/qint32.h
+++ b/c10/util/qint32.h
@@ -9,7 +9,6 @@
 struct alignas(4) qint32 {
   using underlying = int32_t;
   int32_t val_;
-  qint32() = default;
   explicit qint32(int32_t val) : val_(val) {}
 };
 
diff --git a/c10/util/qint8.h b/c10/util/qint8.h
index 0fed7fb..27dd7b3 100644
--- a/c10/util/qint8.h
+++ b/c10/util/qint8.h
@@ -11,7 +11,6 @@
 struct alignas(1) qint8 {
   using underlying = int8_t;
   int8_t val_;
-  qint8() = default;
   explicit qint8(int8_t val) : val_(val) {}
 };
 
diff --git a/c10/util/quint8.h b/c10/util/quint8.h
index 93aa2cd..0dbef37 100644
--- a/c10/util/quint8.h
+++ b/c10/util/quint8.h
@@ -9,7 +9,6 @@
 struct alignas(1) quint8 {
   using underlying = uint8_t;
   uint8_t val_;
-  quint8() = default;
   explicit quint8(uint8_t val) : val_(val) {}
 };
 
diff --git a/test/quantization/test_quantized_tensor.py b/test/quantization/test_quantized_tensor.py
index 0017814..dc39706 100644
--- a/test/quantization/test_quantized_tensor.py
+++ b/test/quantization/test_quantized_tensor.py
@@ -2,12 +2,11 @@
 import math
 import torch
 import io
-import unittest
 from copy import deepcopy
 from hypothesis import given
 from hypothesis import strategies as st
 
-from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM
+from torch.testing._internal.common_utils import TestCase, run_tests
 import torch.testing._internal.hypothesis_utils as hu
 
 hu.assert_deadline_disabled()
@@ -63,79 +62,62 @@
 
     return [scale.astype(np.float32), int(nudged_zero_point)]
 
-def get_supported_device_types():
-    return ['cpu', 'cuda'] if torch.cuda.is_available() and not TEST_WITH_ROCM else ['cpu']
-
 class TestQuantizedTensor(TestCase):
     def test_qtensor(self):
         num_elements = 10
+        r = torch.ones(num_elements, dtype=torch.float)
         scale = 1.0
         zero_point = 2
-        for device in get_supported_device_types():
-            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
-                r = torch.ones(num_elements, dtype=torch.float, device=device)
-                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
-                self.assertEqual(qr.q_scale(), scale)
-                self.assertEqual(qr.q_zero_point(), zero_point)
-                self.assertTrue(qr.is_quantized)
-                self.assertFalse(r.is_quantized)
-                self.assertEqual(qr.qscheme(), torch.per_tensor_affine)
-                self.assertTrue(isinstance(qr.qscheme(), torch.qscheme))
-                # slicing and int_repr
-                int_repr = qr.int_repr()
-                for num in int_repr:
-                    self.assertEqual(num, 3)
-                for num in qr[2:].int_repr():
-                    self.assertEqual(num, 3)
-                # dequantize
-                rqr = qr.dequantize()
-                for i in range(num_elements):
-                    self.assertEqual(r[i], rqr[i])
-                # we can also print a qtensor
-                empty_r = torch.ones((0, 1), dtype=torch.float, device=device)
-                empty_qr = torch.quantize_per_tensor(empty_r, scale, zero_point, dtype)
-
-                device_msg = "" if device == 'cpu' else "device='" + device + ":0', "
-                dtype_msg = str(dtype) + ", "
-                self.assertEqual(' '.join(str(empty_qr).split()),
-                                 "tensor([], " + device_msg + "size=(0, 1), dtype=" + dtype_msg +
-                                 "quantization_scheme=torch.per_tensor_affine, " +
-                                 "scale=1.0, zero_point=2)")
-
-    def test_qtensor_float_assignment(self):
+        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.quint8)
+        self.assertEqual(qr.q_scale(), scale)
+        self.assertEqual(qr.q_zero_point(), zero_point)
+        self.assertTrue(qr.is_quantized)
+        self.assertFalse(r.is_quantized)
+        self.assertEqual(qr.qscheme(), torch.per_tensor_affine)
+        self.assertTrue(isinstance(qr.qscheme(), torch.qscheme))
+        # slicing and int_repr
+        int_repr = qr.int_repr()
+        for num in int_repr:
+            self.assertEqual(num, 3)
+        for num in qr[2:].int_repr():
+            self.assertEqual(num, 3)
+        # dequantize
+        rqr = qr.dequantize()
+        for i in range(num_elements):
+            self.assertEqual(r[i], rqr[i])
         # Scalar Tensor
         # item
-        scale = 1.0
-        zero_point = 2
         r = torch.ones(1, dtype=torch.float)
-        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
-            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
-            self.assertEqual(qr.item(), 1)
-            self.assertEqual(qr[0].item(), 1)
-            # assignment
-            self.assertTrue(qr[0].is_quantized)
-            qr[0] = 11.3  # float assignment
-            self.assertEqual(qr.item(), 11)
-            x = torch.ones(1, dtype=torch.float) * 15.3
-            # Copying from a float Tensor
-            qr[:] = x
-            self.assertEqual(qr.item(), 15)
-
-            dtype_msg = str(dtype) + ", "
-            self.assertEqual(' '.join(str(qr).split()),
-                             "tensor([15.], size=(1,), dtype=" + dtype_msg +
-                             "quantization_scheme=torch.per_tensor_affine, " +
-                             "scale=1.0, zero_point=2)")
+        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.quint8)
+        self.assertEqual(qr.item(), 1)
+        self.assertEqual(qr[0].item(), 1)
+        # assignment
+        self.assertTrue(qr[0].is_quantized)
+        qr[0] = 11.3  # float asignment
+        self.assertEqual(qr.item(), 11)
+        x = torch.ones(1, dtype=torch.float) * 15.3
+        # Copying from a float Tensor
+        qr[:] = x
+        self.assertEqual(qr.item(), 15)
+        # we can also print a qtensor
+        self.assertEqual(' '.join(str(qr).split()),
+                         "tensor([15.], size=(1,), dtype=torch.quint8, " +
+                         "quantization_scheme=torch.per_tensor_affine, " +
+                         "scale=1.0, zero_point=2)")
+        empty_r = torch.ones((0, 1), dtype=torch.float)
+        empty_qr = torch.quantize_per_tensor(empty_r, scale, zero_point, torch.quint8)
+        self.assertEqual(' '.join(str(empty_qr).split()),
+                         "tensor([], size=(0, 1), dtype=torch.quint8, " +
+                         "quantization_scheme=torch.per_tensor_affine, " +
+                         "scale=1.0, zero_point=2)")
 
     def test_qtensor_quant_dequant(self):
+        r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
         scale = 0.02
         zero_point = 2
-        for device in get_supported_device_types():
-            r = torch.rand(3, 2, dtype=torch.float, device=device) * 4 - 2
-            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
-                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
-                rqr = qr.dequantize()
-                self.assertTrue(np.allclose(r.cpu().numpy(), rqr.cpu().numpy(), atol=2 / scale))
+        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.quint8)
+        rqr = qr.dequantize()
+        self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
 
     # legacy constructor/new doesn't support qtensors
     def test_qtensor_legacy_new_failure(self):
@@ -154,12 +136,11 @@
         ch_axis = 0
         scales = torch.rand(numel)
         zero_points = torch.randint(0, 10, size=(numel,))
-        for dtype in [torch.qint8, torch.quint8]:
-            q = torch._empty_per_channel_affine_quantized(
-                [numel], scales=scales, zero_points=zero_points, axis=ch_axis, dtype=dtype)
-            self.assertEqual(scales, q.q_per_channel_scales())
-            self.assertEqual(zero_points, q.q_per_channel_zero_points())
-            self.assertEqual(ch_axis, q.q_per_channel_axis())
+        q = torch._empty_per_channel_affine_quantized(
+            [numel], scales=scales, zero_points=zero_points, axis=ch_axis, dtype=torch.quint8)
+        self.assertEqual(scales, q.q_per_channel_scales())
+        self.assertEqual(zero_points, q.q_per_channel_zero_points())
+        self.assertEqual(ch_axis, q.q_per_channel_axis())
 
         # create Tensor from uint8_t Tensor, scales and zero_points
         int_tensor = torch.randint(0, 100, size=(numel,), dtype=torch.uint8)
@@ -172,31 +153,29 @@
     def test_qtensor_creation(self):
         scale = 0.5
         zero_point = 10
+        val = 100
         numel = 10
-        for device in get_supported_device_types():
-            q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point,
-                                              device=device, dtype=torch.quint8)
-            self.assertEqual(scale, q.q_scale())
-            self.assertEqual(zero_point, q.q_zero_point())
+        q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=torch.quint8)
+        self.assertEqual(scale, q.q_scale())
+        self.assertEqual(zero_point, q.q_zero_point())
 
-            # create Tensor from uint8_t Tensor, scale and zero_point
-            int_tensor = torch.randint(0, 100, size=(10,), device=device, dtype=torch.uint8)
-            q = torch._make_per_tensor_quantized_tensor(int_tensor, scale, zero_point)
-            self.assertEqual(int_tensor, q.int_repr())
-            self.assertEqual(scale, q.q_scale())
-            self.assertEqual(zero_point, q.q_zero_point())
+        # create Tensor from uint8_t Tensor, scale and zero_point
+        int_tensor = torch.randint(0, 100, size=(10,), dtype=torch.uint8)
+        q = torch._make_per_tensor_quantized_tensor(int_tensor, scale, zero_point)
+        self.assertEqual(int_tensor, q.int_repr())
+        self.assertEqual(scale, q.q_scale())
+        self.assertEqual(zero_point, q.q_zero_point())
 
-            # create via empty_like
-            q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point,
-                                              device=device, dtype=torch.quint8)
-            q_el = torch.empty_like(q)
-            self.assertEqual(q.q_scale(), q_el.q_scale())
-            self.assertEqual(q.q_zero_point(), q_el.q_zero_point())
-            self.assertEqual(q.dtype, q_el.dtype)
+        # create via empty_like
+        q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=torch.quint8)
+        q_el = torch.empty_like(q)
+        self.assertEqual(q.q_scale(), q_el.q_scale())
+        self.assertEqual(q.q_zero_point(), q_el.q_zero_point())
+        self.assertEqual(q.dtype, q_el.dtype)
 
-            # create via empty_like but change the dtype (currently not supported)
-            with self.assertRaises(RuntimeError):
-                torch.empty_like(q, dtype=torch.qint8)
+        # create via empty_like but change the dtype (currently not supported)
+        with self.assertRaises(RuntimeError):
+            torch.empty_like(q, dtype=torch.qint8)
 
     def test_qtensor_dtypes(self):
         r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
@@ -231,53 +210,49 @@
         self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / np.min(scales.numpy())))
 
     def test_qtensor_permute(self):
+        r = torch.rand(10, 30, 2, 2, dtype=torch.float) * 4 - 2
         scale = 0.02
         zero_point = 1
-        for device in get_supported_device_types():
-            r = torch.rand(10, 30, 2, 2, device=device, dtype=torch.float) * 4 - 2
-            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
-                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
-                qr = qr.transpose(0, 1)
-                rqr = qr.dequantize()
-                # compare transpose + dequantized result with orignal transposed result
-                self.assertTrue(np.allclose(r.cpu().numpy().transpose([1, 0, 2, 3]), rqr.cpu().numpy(), atol=2 / scale))
+        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.qint8)
+        qr = qr.transpose(0, 1)
+        rqr = qr.dequantize()
+        # compare transpose + dequantized result with original transposed result
+        self.assertTrue(np.allclose(r.numpy().transpose([1, 0, 2, 3]), rqr.numpy(), atol=2 / scale))
 
-                qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
-                qr1 = qr.permute([1, 0, 2, 3])
-                qr2 = qr.transpose(0, 1)
-                # compare int representation after transformations
-                self.assertEqual(qr1.int_repr(), qr2.int_repr())
-                self.assertEqual(qr1.q_scale(), qr2.q_scale())
-                self.assertEqual(qr1.q_zero_point(), qr2.q_zero_point())
-                # compare dequantized result
-                self.assertEqual(qr1.dequantize(), qr2.dequantize())
-                # compare permuted + dequantized result with original transposed result
-                self.assertTrue(np.allclose(qr2.dequantize().cpu().numpy(),
-                                            r.cpu().numpy().transpose([1, 0, 2, 3]), atol=2 / scale))
-                # make permuted result contiguous
-                self.assertEqual(qr2.contiguous().int_repr(), qr2.int_repr())
+        qr = torch.quantize_per_tensor(r, scale, zero_point, torch.qint8)
+        qr1 = qr.permute([1, 0, 2, 3])
+        qr2 = qr.transpose(0, 1)
+        # compare int representation after transformations
+        self.assertEqual(qr1.int_repr(), qr2.int_repr())
+        self.assertEqual(qr1.q_scale(), qr2.q_scale())
+        self.assertEqual(qr1.q_zero_point(), qr2.q_zero_point())
+        # compare dequantized result
+        self.assertEqual(qr1.dequantize(), qr2.dequantize())
+        # compare permuted + dequantized result with original transposed result
+        self.assertTrue(np.allclose(qr2.dequantize().numpy(), r.numpy().transpose([1, 0, 2, 3]), atol=2 / scale))
+        # make permuted result contiguous
+        self.assertEqual(qr2.contiguous().int_repr(), qr2.int_repr())
 
-                # change memory format
-                qlast = qr.contiguous(memory_format=torch.channels_last)
-                self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride()))))
-                self.assertNotEqual(qlast.stride(), list(reversed(sorted(qlast.stride()))))
-                self.assertEqual(qr.int_repr(), qlast.int_repr())
-                self.assertEqual(qr.q_scale(), qlast.q_scale())
-                self.assertEqual(qr.q_zero_point(), qlast.q_zero_point())
-                self.assertEqual(qlast.dequantize(), qr.dequantize())
+        # change memory format
+        qlast = qr.contiguous(memory_format=torch.channels_last)
+        self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride()))))
+        self.assertNotEqual(qlast.stride(), list(reversed(sorted(qlast.stride()))))
+        self.assertEqual(qr.int_repr(), qlast.int_repr())
+        self.assertEqual(qr.q_scale(), qlast.q_scale())
+        self.assertEqual(qr.q_zero_point(), qlast.q_zero_point())
+        self.assertEqual(qlast.dequantize(), qr.dequantize())
 
-                # permuting larger tensors
-                x = torch.randn(64, 64, device=device)
-                qx = torch.quantize_per_tensor(x, 1.0, 0, dtype)
-                # should work
-                qx.permute([1, 0])
+        # permuting larger tensors
+        x = torch.randn(64, 64)
+        qx = torch.quantize_per_tensor(x, 1.0, 0, torch.qint32)
+        # should work
+        qx.permute([1, 0])
 
     def test_qtensor_per_channel_permute(self):
         r = torch.rand(20, 10, 2, 2, dtype=torch.float) * 4 - 2
-        dtype = torch.qint8
         scales = torch.rand(10) * 0.02 + 0.01
         zero_points = torch.round(torch.rand(10) * 2 - 1).to(torch.long)
-        qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype)
+        qr = torch.quantize_per_channel(r, scales, zero_points, 1, torch.qint8)
 
         # we can't reorder the axis
         with self.assertRaises(RuntimeError):
@@ -296,11 +271,9 @@
     def test_qtensor_load_save(self):
         scale = 0.2
         zero_point = 10
-        # storage is not accessible on the cuda right now
-        device = "cpu"
-        r = torch.rand(15, 2, dtype=torch.float32, device=device) * 2
-        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
-            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
+        r = torch.rand(15, 2, dtype=torch.float32) * 2
+        for dtype in [torch.quint8, torch.qint8, torch.qint32]:
+            qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
             qrv = qr[:, 1]
             with tempfile.NamedTemporaryFile() as f:
                 # Serializing and Deserializing Tensor
@@ -315,7 +288,7 @@
         r = torch.rand(20, 10, dtype=torch.float) * 4 - 2
         scales = torch.rand(10, dtype=torch.double) * 0.02 + 0.01
         zero_points = torch.round(torch.rand(10) * 20 + 1).to(torch.long)
-        # quint32, cuda is not supported yet
+        # quint32 is not supported yet
         for dtype in [torch.quint8, torch.qint8]:
             qr = torch.quantize_per_channel(r, scales, zero_points, 1, dtype)
             with tempfile.NamedTemporaryFile() as f:
@@ -328,123 +301,106 @@
     def test_qtensor_copy(self):
         scale = 0.5
         zero_point = 10
+        val = 100
         numel = 10
-        for device in get_supported_device_types():
-            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
-                # copy from same scale and zero_point
-                q = torch._empty_affine_quantized([numel], scale=scale,
-                                                  zero_point=zero_point, device=device, dtype=dtype)
-                q2 = torch._empty_affine_quantized([numel], scale=scale,
-                                                   zero_point=zero_point, device=device, dtype=dtype)
-                q.copy_(q2)
-                self.assertEqual(q.int_repr(), q2.int_repr())
-                self.assertEqual(q.q_scale(), q2.q_scale())
-                self.assertEqual(q.q_zero_point(), q2.q_zero_point())
-                # copying from different scale and zero_point
-                scale = 3.2
-                zero_point = 5
-                q = torch._empty_affine_quantized([numel], scale=scale,
-                                                  zero_point=zero_point, device=device, dtype=dtype)
-                # check original scale and zero_points are set correctly
-                self.assertEqual(q.q_scale(), scale)
-                self.assertEqual(q.q_zero_point(), zero_point)
-                q.copy_(q2)
-                # check scale and zero_points has been copied
-                self.assertEqual(q, q2)
-                # can't copy from quantized tensor to non-quantized tensor
-                r = torch.empty([numel], dtype=torch.float)
-                q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=torch.quint8)
-                with self.assertRaisesRegex(RuntimeError, "please use dequantize"):
-                    r.copy_(q)
-
-    def test_torch_qtensor_deepcopy(self):
-        # cuda is not supported yet
-        device = "cpu"
-        q_int = torch.randint(0, 100, [3, 5], device=device, dtype=torch.uint8)
+        # copy from same scale and zero_point
+        q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=torch.quint8)
+        q2 = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=torch.quint8)
+        q.copy_(q2)
+        self.assertEqual(q.int_repr(), q2.int_repr())
+        self.assertEqual(q.q_scale(), q2.q_scale())
+        self.assertEqual(q.q_zero_point(), q2.q_zero_point())
+        # copying from different scale and zero_point
+        scale = 3.2
+        zero_point = 5
+        q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=torch.quint8)
+        # check original scale and zero_points are set correctly
+        self.assertEqual(q.q_scale(), scale)
+        self.assertEqual(q.q_zero_point(), zero_point)
+        q.copy_(q2)
+        # check scale and zero_points has been copied
+        self.assertEqual(q, q2)
+        # deep copy
+        scale, zero_point, dtype = 1.0, 2, torch.uint8
+        q_int = torch.randint(0, 100, [3, 5], dtype=dtype)
         scale, zero_point = 2.0, 3
         q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
         qc = deepcopy(q)
         self.assertEqual(qc, q)
 
+        # can't copy from quantized tensor to non-quantized tensor
+        r = torch.empty([numel], dtype=torch.float)
+        q = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=torch.quint8)
+        with self.assertRaisesRegex(RuntimeError, "please use dequantize"):
+            r.copy_(q)
+
     def test_qtensor_clone(self):
         numel = 10
         scale = 0.5
         zero_point = 10
-        for device in get_supported_device_types():
-            for dtype in [torch.qint8, torch.quint8, torch.qint32]:
-                q2 = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point,
-                                                   device=device, dtype=dtype)
-                q = q2.clone()
-                # Check to make sure the scale and zero_point has been copied.
-                self.assertEqual(q, q2)
+        q2 = torch._empty_affine_quantized([numel], scale=scale, zero_point=zero_point, dtype=torch.quint8)
+        q = q2.clone()
+        # Check to make sure the scale and zero_point has been copied.
+        self.assertEqual(q, q2)
 
     def test_qtensor_view(self):
         scale, zero_point, dtype = 1.0, 2, torch.uint8
-        for device in get_supported_device_types():
-            q_int = torch.randint(0, 100, [1, 2, 3], device=device, dtype=dtype)
-            q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
-            q2 = q.view(1, 3, 2)
-            self.assertEqual(q.numel(), q2.numel())
-            # testing -1
-            self.assertEqual(q, q2.view(1, -1, 3))
+        q_int = torch.randint(0, 100, [1, 2, 3], dtype=dtype)
+        q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
+        q2 = q.view(1, 3, 2)
+        self.assertEqual(q.numel(), q2.numel())
+        # testing -1
+        self.assertEqual(q, q2.view(1, -1, 3))
 
-            a_int = torch.randint(0, 100, [1, 2, 3, 4], device=device, dtype=dtype)
-            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
-            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
-            c = a.view(1, 3, 2, 4)  # does not change tensor layout in memory
-            self.assertEqual(b.size(), c.size())
-            self.assertEqual(b.q_scale(), c.q_scale())
-            self.assertEqual(b.q_zero_point(), c.q_zero_point())
-            self.assertNotEqual(b.stride(), c.stride())
-            # size is the same but the underlying data is different
-            self.assertNotEqual(b.int_repr(), c.int_repr())
-            # torch.equal is not supported for the cuda backend
-            if device == 'cpu':
-                self.assertFalse(torch.equal(b, c))
-            else:
-                self.assertRaises(RuntimeError, lambda: torch.equal(b, c))
+        a_int = torch.randint(0, 100, [1, 2, 3, 4], dtype=dtype)
+        a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
+        b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
+        c = a.view(1, 3, 2, 4)  # does not change tensor layout in memory
+        self.assertEqual(b.size(), c.size())
+        self.assertEqual(b.q_scale(), c.q_scale())
+        self.assertEqual(b.q_zero_point(), c.q_zero_point())
+        self.assertNotEqual(b.stride(), c.stride())
+        # size is the same but the underlying data is different
+        self.assertNotEqual(b.int_repr(), c.int_repr())
+        self.assertFalse(torch.equal(b, c))
 
-            # a case can't view non-contiguos Tensor
-            a_int = torch.randint(0, 100, [1, 2, 3, 4], device=device, dtype=dtype)
-            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
-            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
-            err_str = "view size is not compatible with input tensor's size and stride*"
-            with self.assertRaisesRegex(RuntimeError, err_str):
-                b.view(1, 4, 2, 3)
-            # view on contiguous tensor is fine
-            b.contiguous().view(1, 4, 2, 3)
+        # a case can't view non-contiguos Tensor
+        a_int = torch.randint(0, 100, [1, 2, 3, 4], dtype=dtype)
+        a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
+        b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
+        err_str = "view size is not compatible with input tensor's size and stride*"
+        with self.assertRaisesRegex(RuntimeError, err_str):
+            b.view(1, 4, 2, 3)
+        # view on contiguous tensor is fine
+        b.contiguous().view(1, 4, 2, 3)
+
 
     def test_qtensor_reshape(self):
         scale, zero_point, dtype = 1.0, 2, torch.uint8
-        for device in get_supported_device_types():
-            q_int = torch.randint(0, 100, [3, 5], dtype=dtype, device=device)
-            q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
-            q2 = q.reshape([15])
-            self.assertEqual(q.numel(), q2.numel())
-            self.assertEqual(q2.size(), [15])
-            # testing -1
-            self.assertEqual(q, q2.reshape([3, -1]))
+        q_int = torch.randint(0, 100, [3, 5], dtype=dtype)
+        q = torch._make_per_tensor_quantized_tensor(q_int, scale=scale, zero_point=zero_point)
+        q2 = q.reshape([15])
+        self.assertEqual(q.numel(), q2.numel())
+        self.assertEqual(q2.size(), [15])
+        # testing -1
+        self.assertEqual(q, q2.reshape([3, -1]))
 
-            a_int = torch.randint(0, 100, [1, 2, 3, 4], dtype=dtype, device=device)
-            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
-            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
-            c = a.reshape(1, 3, 2, 4)  # does not change tensor layout
-            self.assertEqual(b.size(), c.size())
-            self.assertEqual(b.q_scale(), c.q_scale())
-            self.assertEqual(b.q_zero_point(), c.q_zero_point())
-            self.assertNotEqual(b.stride(), c.stride())
-            self.assertNotEqual(b.int_repr(), c.int_repr())
-            # torch.equal is not supported for the cuda backend
-            if device == 'cpu':
-                self.assertFalse(torch.equal(b, c))
-            else:
-                self.assertRaises(RuntimeError, lambda: torch.equal(b, c))
+        a_int = torch.randint(0, 100, [1, 2, 3, 4], dtype=dtype)
+        a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
+        b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
+        c = a.reshape(1, 3, 2, 4)  # does not change tensor layout
+        self.assertEqual(b.size(), c.size())
+        self.assertEqual(b.q_scale(), c.q_scale())
+        self.assertEqual(b.q_zero_point(), c.q_zero_point())
+        self.assertNotEqual(b.stride(), c.stride())
+        self.assertNotEqual(b.int_repr(), c.int_repr())
+        self.assertFalse(torch.equal(b, c))
 
-            # we can use reshape for non-contiguous Tensor
-            a_int = torch.randint(0, 100, [1, 2, 3, 4], dtype=dtype, device=device)
-            a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
-            b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
-            c = b.reshape(1, 4, 2, 3)
+        # we can use reshape for non-contiguous Tensor
+        a_int = torch.randint(0, 100, [1, 2, 3, 4], dtype=dtype)
+        a = torch._make_per_tensor_quantized_tensor(a_int, scale=scale, zero_point=zero_point)
+        b = a.transpose(1, 2)  # swaps 2nd and 3rd dimension
+        c = b.reshape(1, 4, 2, 3)
 
     def test_qscheme_pickle(self):
         f = Foo()
@@ -469,18 +425,5 @@
         np.testing.assert_array_almost_equal(X_scale, qparams[0], decimal=3)
         self.assertEqual(X_zp, qparams[1])
 
-    @unittest.skipIf(not torch.cuda.is_available() or TEST_WITH_ROCM, 'CUDA is not available')
-    def test_cuda_cpu_implementation_consistency(self):
-        numel, zero_point, scale = 100, 2, 0.02
-        r = torch.rand(numel, dtype=torch.float32, device='cpu') * 25 - 4
-        for dtype in [torch.qint8, torch.quint8, torch.qint32]:
-            qr_cpu = torch.quantize_per_tensor(r, scale, zero_point, dtype=dtype)
-            qr_cuda = torch.quantize_per_tensor(r.cuda(), scale, zero_point, dtype=dtype)
-            # intr repr must be the same
-            np.testing.assert_equal(qr_cpu.int_repr().numpy(), qr_cuda.int_repr().cpu().numpy())
-            # dequantized values must be the same
-            r_cpu, r_cuda = qr_cpu.dequantize().numpy(), qr_cuda.dequantize().cpu().numpy()
-            np.testing.assert_almost_equal(r_cuda, r_cpu, decimal=5)
-
 if __name__ == "__main__":
     run_tests()
diff --git a/torch/csrc/utils/tensor_layouts.cpp b/torch/csrc/utils/tensor_layouts.cpp
index b506eb7..6fcd84f 100644
--- a/torch/csrc/utils/tensor_layouts.cpp
+++ b/torch/csrc/utils/tensor_layouts.cpp
@@ -24,7 +24,6 @@
   registerLayoutObject((THPLayout*)strided_layout, at::Backend::MSNPU);
   registerLayoutObject((THPLayout*)strided_layout, at::Backend::XLA);
   registerLayoutObject((THPLayout*)strided_layout, at::Backend::QuantizedCPU);
-  registerLayoutObject((THPLayout*)strided_layout, at::Backend::QuantizedCUDA);
 
   PyObject *sparse_coo_layout = THPLayout_New(at::Layout::Sparse, "torch.sparse_coo");
   Py_INCREF(sparse_coo_layout);