blob: 69956aaba0729d0f76b35a4de1f8ddbe3b43968b [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/quantized/Quantizer.h>
#include <c10/core/Allocator.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/quantized/QTensorImpl.h>
#include <ATen/core/Tensor.h>
#include <typeinfo>
#ifdef USE_FBGEMM
#include <fbgemm/QuantUtils.h>
#endif
#ifdef __ARM_NEON__
#include <arm_neon.h>
#endif
namespace at {
// Note: this is not a native function as Quantizer is not exposed to python yet
QuantizerPtr Tensor::quantizer() const {
// This is a terrible hack to emulate what VariableType is doing
at::AutoNonVariableTypeMode non_var_type_mode(true);
return get_qtensorimpl(*this)->quantizer();
}
void checkFloatCPUTensor(std::string fn_name, Tensor t) {
TORCH_CHECK(
t.scalar_type() == kFloat,
fn_name,
" expects a Float Tensor.");
TORCH_CHECK(
t.device() == kCPU,
fn_name,
" expects a CPU Tensor.");
}
template <typename T>
void checkQuantizedCPUTensor(std::string fn_name, Tensor t) {
TORCH_CHECK(t.is_quantized(),
fn_name,
" expects a quantized Tensor.");
TORCH_CHECK(t.scalar_type() == caffe2::TypeMeta::Make<T>(),
fn_name,
" expects a ",
caffe2::TypeMeta::Make<T>(),
" Tensor");
TORCH_CHECK(t.device() == kCPU,
fn_name,
" expects a CPU quantized Tensor");
}
template <typename T>
void checkZeroPoint(std::string fn_name, int64_t zero_point) {
TORCH_CHECK(zero_point <= std::numeric_limits<T>::max(),
fn_name,
" zero_point ",
zero_point,
" is out of range.");
TORCH_CHECK(zero_point >= std::numeric_limits<T>::min(),
fn_name,
" zero_point ",
zero_point,
" is out of range.");
}
template <typename T>
void checkZeroPoints(std::string fn_name, std::vector<int64_t> zero_points) {
for (size_t i = 0; i < zero_points.size(); ++i) {
TORCH_CHECK(zero_points[i] <= std::numeric_limits<T>::max(),
fn_name,
"zero_point",
i,
"is out of range.");
TORCH_CHECK(zero_points[i] >= std::numeric_limits<T>::min(),
fn_name,
"zero_point",
i,
"is out of range.");
}
}
#ifdef USE_FBGEMM
// Note: quantize_val is only explicitly used in test outside of this file
template <typename T>
T quantize_val(double scale, int64_t zero_point, float value) {
// Internally, fbgemm::Quantize uses std::nearbyint.
// std::nearbyint results in nearest integer value according to the current
// rounding mode and the default rounding mode is rounds to even in half-way
// cases in most popular processor architectures like x86 and ARM. This is
// typically faster than an alternatives like std::round that rounds half-way
// cases away from zero, and can be consistent with SIMD implementations for
// example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
// _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
int32_t qvalue;
qvalue = fbgemm::Quantize<typename T::underlying>(
value,
static_cast<int32_t>(zero_point),
static_cast<double>(scale),
/*result_precision=*/CHAR_BIT * sizeof(typename T::underlying));
return static_cast<T>(qvalue);
}
template <typename T, int precision>
void quantize_vec(double scale, int64_t zero_point, const float *src, T *dst, size_t count) {
fbgemm::Quantize<typename T::underlying>(
src,
(typename T::underlying*)dst,
count,
fbgemm::TensorQuantizationParams{(float)scale, (int32_t)zero_point, precision}
);
}
// TODO: dequantize_val?
template <typename T>
Tensor quantize_tensor(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point) {
auto fn_name = "quantize_tensor";
checkFloatCPUTensor(fn_name, rtensor);
checkQuantizedCPUTensor<T>(fn_name, qtensor);
checkZeroPoint<typename T::underlying>(fn_name, zero_point);
const float* rd = rtensor.data_ptr<float>();
auto qd = reinterpret_cast<typename T::underlying*>(qtensor.data_ptr<T>());
fbgemm::TensorQuantizationParams qparams;
qparams.scale = scale;
qparams.zero_point = zero_point;
qparams.precision = CHAR_BIT * sizeof(typename T::underlying);
fbgemm::Quantize<typename T::underlying>(/*src=*/rd,
/*dst=*/qd,
/*len=*/rtensor.numel(),
/*qparams=*/qparams);
return qtensor;
}
template <typename T>
inline float dequantize_val(double scale, int64_t zero_point, T value) {
fbgemm::TensorQuantizationParams qparams = {
.scale = static_cast<float>(scale),
.zero_point = static_cast<int32_t>(zero_point)
};
return fbgemm::Dequantize<typename T::underlying>(value.val_, qparams);
}
template <typename T>
Tensor dequantize_tensor(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point) {
auto fn_name = "dequantize_tensor";
checkFloatCPUTensor(fn_name, rtensor);
checkQuantizedCPUTensor<T>(fn_name, qtensor);
checkZeroPoint<typename T::underlying>(fn_name, zero_point);
const auto* qd = reinterpret_cast<const typename T::underlying*>(qtensor.data_ptr<T>());
fbgemm::TensorQuantizationParams qparams;
qparams.scale = scale;
qparams.zero_point = zero_point;
qparams.precision = CHAR_BIT * sizeof(typename T::underlying);
float* rd = rtensor.data_ptr<float>();
fbgemm::Dequantize<typename T::underlying>(/*src=*/qd,
/*dst=*/rd,
/*len=*/qtensor.numel(),
/*qparams=*/qparams);
return rtensor;
}
#else // USE_FBGEMM
#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
template <class T>
inline float Round(const float x) {
return ::nearbyintf(x);
}
inline double Round(const double x) {
return ::nearbyint(x);
}
#else
template <class T>
inline T Round(const T x) {
return std::nearbyint(x);
}
#endif
template <typename T>
T quantize_val(double scale, int64_t zero_point, float value) {
// std::nearbyint results in nearest integer value according to the current
// rounding mode and the default rounding mode is rounds to even in half-way
// cases in most popular processor architectures like x86 and ARM. This is
// typically faster than an alternatives like std::round that rounds half-way
// cases away from zero, and can be consistent with SIMD implementations for
// example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
// _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
int64_t qvalue;
constexpr int64_t qmin = std::numeric_limits<typename T::underlying>::min();
constexpr int64_t qmax = std::numeric_limits<typename T::underlying>::max();
qvalue = static_cast<int64_t>(Round(value / scale + zero_point));
qvalue = std::max<int64_t>(qvalue, qmin);
qvalue = std::min<int64_t>(qvalue, qmax);
return static_cast<T>(qvalue);
}
template <typename T, int precision>
void quantize_vec(double scale, int64_t zero_point, const float *src, T *dst, size_t count) {
checkZeroPoint<typename T::underlying>("quantize_val", zero_point);
for (int64_t i = 0; i < count; ++i) {
dst[i] = quantize_val<T>(scale, zero_point, src[i]);
}
}
// TODO combine this with quantize_val once the numerics for ARM are aligned with it
inline uint8_t quantize_val_arm(const float scale, const int32_t zero_point, const float value) {
const int32_t qmin = std::numeric_limits<uint8_t>::min();
const int32_t qmax = std::numeric_limits<uint8_t>::max();
auto r = zero_point + static_cast<int32_t>(Round(value / scale));
r = std::max(r, qmin);
r = std::min(r, qmax);
return static_cast<uint8_t>(r);
}
#ifdef __ARM_NEON__
// Generic template defaults to naive quantize implementation
template <typename T>
void quantize_tensor_arm(
const float* in,
Tensor qtensor,
const int64_t N,
const float scale,
const int32_t zero_point) {
auto out = qtensor.data_ptr<T>();
for (int i = 0; i < N; ++i) {
out[i] = quantize_val<T>(scale, zero_point, in[i]);
}
}
// Specialized implementation from caffe2::Int8Quantize.
// There may be slight accuracy difference between this and implementation of quantize_val
// TODO Update quantize_tensor_arm implementation to follow quantize_val,
// i.e. f = Round(value/scale + zero_point)
// TODO Make quantize_tensor_arm work for other datatypes too (int8, int32).
template <>
void quantize_tensor_arm<c10::quint8>(
const float* in,
Tensor qtensor,
const int64_t N,
const float scale,
const int32_t zero_point) {
const float inv_scale = 1.0f / scale;
uint32_t i = 0;
auto out = (uint8_t*)qtensor.data_ptr<c10::quint8>();
const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);
// magic float and magic int to take care of rounding
// int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
// Some detail:
// 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
// add a small number to a large number, the result rounds to the precision of
// the least significant bit of the large number. For IEEE-754
// single-precision number mantissa has 23 bits, and adding 2**23 would cause
// rounding to the nearest even integer. The we cast to int and subtract the
// same number (0x4B400000 is the integer representation of 12582912.0f) to
// get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
// sign for negative numbers.
const int32x4_t voffset = vdupq_n_s32(zero_point - 0x4B400000);
const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
for (i = 0; i + 8 < N; i += 8) {
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t vraw0123 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
const int32x4_t vraw4567 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
const int16x8_t vraw01234567 =
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
vst1_u8(out, vout01234567);
out += 8;
}
for (; i < N; ++i) {
(*out++) = quantize_val_arm(scale, zero_point, (*in++));
}
}
#endif // __ARM_NEON__
template <typename T>
Tensor quantize_tensor(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point) {
auto fn_name = "quantize_tensor";
checkFloatCPUTensor(fn_name, rtensor);
checkQuantizedCPUTensor<T>(fn_name, qtensor);
checkZeroPoint<typename T::underlying>(fn_name, zero_point);
TORCH_CHECK(rtensor.is_contiguous(), "Float tensor should be contiguous");
const float* const rdata = rtensor.data_ptr<float>();
// If QEngine is set to QNNPACK, use caffe2 specialized Int8Quantize implementation on ARM
#if defined(__ARM_NEON__)
if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
quantize_tensor_arm<T>(rdata, qtensor, rtensor.numel(), scale, zero_point);
return qtensor;
}
#endif
auto qdata = qtensor.data_ptr<T>();
for (int i = 0; i < rtensor.numel(); ++i) {
qdata[i] = quantize_val<T>(scale, zero_point, rdata[i]);
}
return qtensor;
}
template <typename T>
CAFFE2_API float dequantize_val(double scale, int64_t zero_point, T value) {
// We need to convert the qint8 value to float to ensure the subtraction
// subexpression returns a float
return (static_cast<float>(value.val_) - zero_point) * scale;
}
template <typename T>
Tensor dequantize_tensor(Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point) {
auto fn_name = "dequantize_tensor";
checkFloatCPUTensor(fn_name, rtensor);
checkQuantizedCPUTensor<T>(fn_name, qtensor);
checkZeroPoint<typename T::underlying>(fn_name, zero_point);
const auto* qd = qtensor.data_ptr<T>();
float* rd = rtensor.data_ptr<float>();
for (auto i = 0; i < qtensor.numel(); ++i) {
rd[i] = dequantize_val<T>(scale, zero_point, qd[i]);
}
return rtensor;
}
#endif // USE_FBGEMM
template <typename SRC_T, typename DST_T>
DST_T requantize_val(double src_scale, int64_t src_zero_point,
double dst_scale, int64_t dst_zero_point,
SRC_T src) {
const auto dq = dequantize_val<SRC_T>(src_scale, src_zero_point, src);
return quantize_val<DST_T>(dst_scale, dst_zero_point, dq);
}
template CAFFE2_API qint8 quantize_val<qint8>(double scale, int64_t zero_point, float value);
template CAFFE2_API quint8 quantize_val<quint8>(double scale, int64_t zero_point, float value);
template CAFFE2_API qint32 quantize_val<qint32>(double scale, int64_t zero_point, float value);
template CAFFE2_API void quantize_vec<c10::qint8>(double scale, int64_t zero_point, const float *src, c10::qint8 *dst, size_t count);
template CAFFE2_API void quantize_vec<c10::quint8>(double scale, int64_t zero_point, const float *src, c10::quint8 *dst, size_t count);
template CAFFE2_API void quantize_vec<c10::qint32, 32>(double scale, int64_t zero_point, const float *src, c10::qint32 *dst, size_t count);
template CAFFE2_API Tensor quantize_tensor<qint8>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
template CAFFE2_API Tensor quantize_tensor<quint8>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
template CAFFE2_API Tensor quantize_tensor<qint32>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
template CAFFE2_API float dequantize_val<qint8>(double scale, int64_t zero_point, qint8 value);
template CAFFE2_API float dequantize_val<quint8>(double scale, int64_t zero_point, quint8 value);
template CAFFE2_API float dequantize_val<qint32>(double scale, int64_t zero_point, qint32 value);
template CAFFE2_API Tensor dequantize_tensor<qint8>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
template CAFFE2_API Tensor dequantize_tensor<quint8>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
template CAFFE2_API Tensor dequantize_tensor<qint32>(Tensor rtensor, Tensor qtensor, double scale, int64_t zero_point);
template CAFFE2_API qint8 requantize_val<qint8, qint8>(double, int64_t, double, int64_t, qint8);
template CAFFE2_API quint8 requantize_val<qint8, quint8>(double, int64_t, double, int64_t, qint8);
template CAFFE2_API qint32 requantize_val<qint8, qint32>(double, int64_t, double, int64_t, qint8);
template CAFFE2_API qint8 requantize_val<quint8, qint8>(double, int64_t, double, int64_t, quint8);
template CAFFE2_API quint8 requantize_val<quint8, quint8>(double, int64_t, double, int64_t, quint8);
template CAFFE2_API qint32 requantize_val<quint8, qint32>(double, int64_t, double, int64_t, quint8);
template CAFFE2_API qint8 requantize_val<qint32, qint8>(double, int64_t, double, int64_t, qint32);
template CAFFE2_API quint8 requantize_val<qint32, quint8>(double, int64_t, double, int64_t, qint32);
template CAFFE2_API qint32 requantize_val<qint32, qint32>(double, int64_t, double, int64_t, qint32);
// TODO: add fbgemm for per channel
template <typename T>
Tensor quantize_tensor_per_channel_affine(Tensor rtensor,
Tensor qtensor,
const std::vector<double>& scales,
const std::vector<int64_t>& zero_points,
int64_t axis) {
auto fn_name = "quantize_tensor_per_channel_affine";
checkFloatCPUTensor(fn_name, rtensor);
checkQuantizedCPUTensor<T>(fn_name, qtensor);
checkZeroPoints<typename T::underlying>(fn_name, zero_points);
TORCH_CHECK(0 <= axis && axis < rtensor.dim(), "Channel axis out of range in per channel affine quantization.");
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(axis);
TORCH_CHECK(channel == int64_t(scales.size()),
"length of scales must equal to channel");
TORCH_CHECK(channel == int64_t(zero_points.size()),
"length of zero_points must equal to channel");
const float* rdata = rtensor.data_ptr<float>();
auto qdata = qtensor.data_ptr<T>();
for (auto b = 0; b < batches; ++b) {
for (auto c = 0; c < channel; ++c) {
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channel * elements_per_channel + c * elements_per_channel + e;
qdata[i] = quantize_val<T>(scales[c], zero_points[c], rdata[i]);
}
}
}
return qtensor;
}
template <typename T>
Tensor dequantize_tensor_per_channel_affine(Tensor qtensor,
Tensor rtensor,
const std::vector<double>& scales,
const std::vector<int64_t>& zero_points,
int64_t axis) {
auto fn_name = "dequantize_tensor_per_channel_affine";
checkFloatCPUTensor(fn_name, rtensor);
checkQuantizedCPUTensor<T>(fn_name, qtensor);
checkZeroPoints<typename T::underlying>(fn_name, zero_points);
TORCH_CHECK(0 <= axis && axis < qtensor.dim(),
"Channel axis out of range in per channel affine dequantization.");
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(axis);
TORCH_CHECK(channel == int64_t(scales.size()),
"length of scales must equal to channel");
TORCH_CHECK(channel == int64_t(zero_points.size()),
"length of zero_points must equal to channel");
const auto* qd = qtensor.data_ptr<T>();
float* rd = rtensor.data_ptr<float>();
for (auto b = 0; b < batches; ++b) {
for (auto c = 0; c < channel; ++c) {
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channel * elements_per_channel + c * elements_per_channel + e;
// We need to convert the qint8 value to float to ensure the subtraction
// subexpression returns a float
rd[i] = (static_cast<float>(qd[i].val_) - zero_points[c]) * scales[c];
}
}
}
return rtensor;
}
QuantizerPtr make_per_tensor_affine_quantizer(
double scale,
int64_t zero_point,
ScalarType scalar_type) {
return c10::make_intrusive<PerTensorAffineQuantizer>(scalar_type,
scale, zero_point);
}
QuantizerPtr make_per_channel_affine_quantizer(
const std::vector<double>& scales,
const std::vector<int64_t>& zero_points,
int64_t axis,
ScalarType scalar_type) {
return c10::make_intrusive<PerChannelAffineQuantizer>(scalar_type,
scales, zero_points, axis);
}
QuantizerPtr make_per_channel_affine_quantizer(
const Tensor& scales,
const Tensor& zero_points,
int64_t axis,
ScalarType scalar_type) {
TORCH_CHECK(scales.dim() == 1, "scale tensor must have dimension 1");
TORCH_CHECK(
zero_points.dim() == 1, "zero_points tensor must have dimension 1");
TORCH_CHECK(
scales.numel() == zero_points.numel(),
"number of elements in scales and zero_points must match");
TORCH_CHECK(
isFloatingType(scales.scalar_type()),
"scale tensor must be floating point");
TORCH_CHECK(
isIntegralType(zero_points.scalar_type(), false /*includeBool*/),
"zero_points tensor must have integral type");
Tensor scales_double = scales.to(kDouble).contiguous();
Tensor zero_points_int64 = zero_points.to(kLong).contiguous();
double* scales_data = scales_double.data_ptr<double>();
int64_t* zero_points_data = zero_points_int64.data_ptr<int64_t>();
std::vector<double> scale_vals(scales_data, scales_data + scales.numel());
std::vector<int64_t> zero_point_vals(
zero_points_data, zero_points_data + zero_points.numel());
return make_per_channel_affine_quantizer(
scale_vals, zero_point_vals, axis, scalar_type);
}
QTensorImpl* get_qtensorimpl(const Tensor& self) {
// TODO: remove this when Variable and Tensor are merged
TORCH_INTERNAL_ASSERT(
!self.is_variable(),
"_internal_get_QTensorImpl: should not be a variable");
TORCH_INTERNAL_ASSERT(self.is_quantized(), "get_qtensorimpl: not a quantized tensor");
return static_cast<QTensorImpl*>(self.unsafeGetTensorImpl());
}
inline Tensor new_qtensor_cpu(
IntArrayRef sizes,
const TensorOptions& options,
QuantizerPtr quantizer,
MemoryFormat memory_format=MemoryFormat::Contiguous) {
AT_ASSERT(options.device().is_cpu());
native::check_size_nonnegative(sizes);
auto* allocator = at::getCPUAllocator();
int64_t nelements = at::prod_intlist(sizes);
auto dtype = options.dtype();
TORCH_CHECK(isQIntType(typeMetaToScalarType(dtype)),
"ScalarType is not supported in new_qtensor_cpu.");
auto storage = c10::make_intrusive<StorageImpl>(
dtype,
nelements,
allocator->allocate(nelements * dtype.itemsize()),
allocator,
/*resizable=*/true);
auto tensor = detail::make_tensor<QTensorImpl>(
storage, at::TensorTypeSet(at::TensorTypeId::QuantizedCPUTensorId), quantizer);
get_qtensorimpl(tensor)->set_sizes_contiguous(sizes);
get_qtensorimpl(tensor)->empty_tensor_restride(memory_format);
return tensor;
}
Tensor PerTensorAffineQuantizer::quantize(Tensor rtensor) {
TORCH_CHECK(
rtensor.scalar_type() == kFloat,
"quantize only works on Float Tensor.");
TORCH_CHECK(
rtensor.device() == kCPU,
"quantize only works for CPU backend right now.");
// Here we need a std::intrusive_ptr<Quantizer>.. but actually "this" is the
// quantizer that can be reused, so I'm using intrusive_from_this here
Tensor qtensor = new_qtensor_cpu(
rtensor.sizes(),
rtensor.options().dtype(scalar_type_),
intrusive_from_this());
rtensor = rtensor.contiguous();
AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "quantize_tensor", [&]() {
qtensor = quantize_tensor<scalar_t>(rtensor, qtensor, scale_, zero_point_);
});
return qtensor;
}
Tensor PerTensorAffineQuantizer::dequantize(Tensor qtensor) {
TORCH_CHECK(qtensor.is_quantized(),
"dequantize is only supported in quantized Tensor.");
TORCH_CHECK(
qtensor.device() == kCPU,
"dequantize only works for CPU backend right now.");
Tensor rtensor = at::empty(qtensor.sizes(), qtensor.options().dtype(at::kFloat));
qtensor = qtensor.contiguous();
AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), "dequantize_tensor", [&]() {
rtensor = dequantize_tensor<scalar_t>(qtensor, rtensor, scale_, zero_point_);
});
return rtensor;
}
Tensor PerChannelAffineQuantizer::quantize(Tensor rtensor) {
TORCH_CHECK(
rtensor.scalar_type() == kFloat,
"quantize only works on Float Tensor.");
TORCH_CHECK(
rtensor.device() == kCPU,
"quantize only works for CPU backend right now.");
// Here we need a std::intrusive_ptr<Quantizer>.. but actually "this" is the
// quantizer that can be reused, so I'm using intrusive_from_this here
Tensor qtensor = new_qtensor_cpu(
rtensor.sizes(),
rtensor.options().dtype(scalar_type_),
intrusive_from_this());
rtensor = rtensor.contiguous();
AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(),
"quantize_tensor_per_channel_affine",
[&]() {
qtensor = quantize_tensor_per_channel_affine<scalar_t>(
rtensor, qtensor, scales_, zero_points_, axis_);
});
return qtensor;
}
Tensor PerChannelAffineQuantizer::dequantize(Tensor qtensor) {
TORCH_CHECK(qtensor.is_quantized(),
"dequantize is only supported in quantized Tensor.");
TORCH_CHECK(
qtensor.device() == kCPU,
"dequantize only works for CPU backend right now.");
Tensor rtensor = at::empty(qtensor.sizes(), qtensor.options().dtype(at::kFloat));
qtensor = qtensor.contiguous();
AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(),
"dequantize_tensor_per_channel_affine",
[&]() {
rtensor = dequantize_tensor_per_channel_affine<scalar_t>(
qtensor, rtensor, scales_, zero_points_, axis_);
});
return rtensor;
}
Quantizer::~Quantizer() {}
} // namespace at