blob: b90adccaf241746689de6d45e920d5e63e594a2f [file] [log] [blame]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <algorithm>
#include <cinttypes>
#include <cmath>
/**
* For an input tensor, use the scale and zero_point arguments to quantize it.
*/
namespace torch {
namespace executor {
namespace native {
using Tensor = exec_aten::Tensor;
using Scalar = exec_aten::Scalar;
using ScalarType = exec_aten::ScalarType;
namespace {
/**
* Asserts that the parameters are valid.
*/
void check_dequantize_per_tensor_args(
const Tensor& input,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType>& out_dtype,
Tensor& out) {
ET_CHECK_MSG(
input.scalar_type() == ScalarType::Byte ||
input.scalar_type() == ScalarType::Char ||
input.scalar_type() == ScalarType::Short ||
input.scalar_type() == ScalarType::Int,
"input.scalar_type() %" PRId8 " is not supported:",
static_cast<int8_t>(input.scalar_type()));
ET_CHECK_MSG(
input.scalar_type() == dtype,
"input.scalar_type() %" PRId8 " is not matching dtype argumenta:",
static_cast<int8_t>(input.scalar_type()));
if (out_dtype.has_value()) {
ET_CHECK_MSG(
out.scalar_type() == out_dtype.value(),
"output_dtype must match the dtype of the out tensor");
}
ET_CHECK_MSG(
quant_min <= quant_max,
"quant min: %" PRId64 " is greater than quant max: %" PRId64,
quant_min,
quant_max);
}
} // namespace
/**
* Dequantizes the input tensor according to the formula (input - zero_point) *
* scale
*
* NOTE: quant_min and quant_max are not used in computation, but rather
* metadata that is passed around which can be useful for pattern matching. See
* https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
* info.
*/
Tensor& dequantize_per_tensor_out(
const Tensor& input,
double scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
torch::executor::Error err = resize_tensor(out, input.sizes());
ET_CHECK_MSG(
err == torch::executor::Error::Ok,
"Failed to resize out Tensor in dequantize_per_tensor_out");
check_dequantize_per_tensor_args(
input, quant_min, quant_max, dtype, out_dtype, out);
// calculate the dequantized output, cast scale to float to match fbgemm
// behavior
#define DEQUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \
case ScalarType::out_dtype: \
for (size_t i = 0; i < input.numel(); i++) { \
out.data_ptr<OUT_CTYPE>()[i] = static_cast<OUT_CTYPE>( \
(input.data_ptr<IN_CTYPE>()[i] - static_cast<int32_t>(zero_point)) * \
static_cast<float>(scale)); \
} \
break;
#define CALCULATE_INT_TYPE(IN_CTYPE, in_dtype) \
case ScalarType::in_dtype: \
switch (out.scalar_type()) { \
ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, DEQUANTIZE_IMPL); \
default: \
ET_CHECK_MSG( \
false, \
"Unhandled output dtype %" PRId8, \
static_cast<int8_t>(out.scalar_type())); \
} \
break;
switch (input.scalar_type()) {
ET_FORALL_INT_TYPES(CALCULATE_INT_TYPE);
default:
ET_CHECK_MSG(
false,
"Unhandled input dtype %" PRId8,
static_cast<int8_t>(input.scalar_type()));
}
#undef CALCULATE_FLOAT_TYPE
#undef DEQUANTIZE_IMPL
return out;
}
Tensor& dequantize_per_tensor_tensor_args_out(
const Tensor& input,
const Tensor& scale,
const Tensor& zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
ET_CHECK_MSG(
scale.scalar_type() == ScalarType::Double,
"Expected scale to be Double tensor received: %" PRId8,
static_cast<int8_t>(scale.scalar_type()));
ET_CHECK_MSG(
zero_point.scalar_type() == ScalarType::Long,
"Expected scale to be Long tensor received: %" PRId8,
static_cast<int8_t>(zero_point.scalar_type()));
ET_CHECK_MSG(
scale.numel() == 1,
"Exepcted scale to only have one element received: %zd",
ssize_t(scale.numel()));
ET_CHECK_MSG(
zero_point.numel() == 1,
"Exepcted zero_point to only have one element received: %zd",
ssize_t(zero_point.numel()));
dequantize_per_tensor_out(
input,
scale.data_ptr<double>()[0],
zero_point.data_ptr<int64_t>()[0],
quant_min,
quant_max,
dtype,
out_dtype,
out);
return out;
}
Tensor& dequantize_per_channel_out(
const Tensor& input,
const Tensor& scale,
const optional<Tensor>& opt_zero_points,
int64_t axis,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
torch::executor::Error err = resize_tensor(out, input.sizes());
// normalize axis
ET_CHECK_MSG(
tensor_has_dim(input, axis),
"axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd",
ssize_t(axis),
ssize_t(input.dim()));
if (axis < 0) {
axis += nonzero_dim(input);
}
ET_CHECK_MSG(
err == torch::executor::Error::Ok,
"Failed to resize out Tensor in dequantize_per_channel_out");
ET_CHECK_MSG(
scale.scalar_type() == ScalarType::Double,
"scale.scalar_type() %" PRId8 " is not double type",
static_cast<int8_t>(scale.scalar_type()));
ET_CHECK_MSG(
scale.numel() == input.size(axis),
"scale.numel() %zd != input.size(axis) %zd",
ssize_t(scale.numel()),
ssize_t(input.size(axis)));
if (opt_zero_points.has_value()) {
auto zero_point = opt_zero_points.value();
ET_CHECK_MSG(
zero_point.scalar_type() == ScalarType::Long,
"zero_point.scalar_type() %" PRId8 " is not integer type",
static_cast<int8_t>(zero_point.scalar_type()));
ET_CHECK_MSG(
zero_point.numel() == input.size(axis),
"zero_point.numel() %zd != input.size(axis) %zd",
ssize_t(zero_point.numel()),
ssize_t(input.size(axis)));
}
check_dequantize_per_tensor_args(
input, quant_min, quant_max, dtype, out_dtype, out);
// a list contains all dimensions except axis
int64_t dims[input.dim() - 1];
for (int64_t i = 0; i < input.dim() - 1; i++) {
if (i < axis) {
dims[i] = i;
} else {
dims[i] = i - 1;
}
}
const double* scale_data = scale.const_data_ptr<double>();
const int64_t* zero_point_data;
if (opt_zero_points.has_value()) {
zero_point_data = opt_zero_points.value().const_data_ptr<int64_t>();
} else {
zero_point_data = nullptr;
}
exec_aten::optional<exec_aten::ArrayRef<int64_t>> optional_dim_list{
exec_aten::ArrayRef<int64_t>{dims, size_t(input.dim() - 1)}};
// Actual dequantization logic
// input, out are the input and output tensors
// channel_ix is the index along the axis dimension. 0 <= channel_ix <
// input.size(axis).
// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix
// will be 0, 1, 2, ... C-1
// in_ix is the flat index of the element you are dequantizing.
// in other words you are dequantizing in_data[in_ix]
#define DEQUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
case ScalarType::out_dtype: \
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
double _scale = scale_data[channel_ix]; \
int64_t _zero_point = 0; \
if (zero_point_data != nullptr) { \
_zero_point = zero_point_data[channel_ix]; \
} \
apply_over_dim_list( \
[input, out, _scale, _zero_point](size_t in_ix) { \
out.mutable_data_ptr<CTYPE_OUT>()[in_ix] = static_cast<CTYPE_OUT>( \
(input.const_data_ptr<CTYPE_IN>()[in_ix] - _zero_point) * \
_scale); \
}, \
input, \
optional_dim_list, \
channel_ix); \
} \
break;
#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
case ScalarType::in_dtype: \
switch (out.scalar_type()) { \
ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
default: \
ET_CHECK_MSG( \
false, \
"Unhandled output dtype %" PRId8, \
static_cast<int8_t>(out.scalar_type())); \
} \
break;
switch (input.scalar_type()) {
ET_FORALL_INT_TYPES(CALCULATE_FLOAT_TYPE);
default:
ET_CHECK_MSG(
false,
"Unhandled input dtype %" PRId8,
static_cast<int8_t>(input.scalar_type()));
}
#undef CALCULATE_FLOAT_TYPE
#undef QUANTIZE_IMPL
return out;
}
Tensor& dequantize_per_channel_out(
RuntimeContext& context,
const Tensor& input,
const Tensor& scale,
const optional<Tensor>& opt_zero_points,
int64_t axis,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
(void)context;
return dequantize_per_channel_out(
input,
scale,
opt_zero_points,
axis,
quant_min,
quant_max,
dtype,
out_dtype,
out);
}
Tensor& dequantize_per_tensor_out(
RuntimeContext& context,
const Tensor& input,
double scale,
int64_t zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
// TODO(larryliu): Add a context arg to the real op function and remove this
// wrapper
(void)context;
return dequantize_per_tensor_out(
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
}
Tensor& dequantize_per_tensor_tensor_args_out(
RuntimeContext& context,
const Tensor& input,
const Tensor& scale,
const Tensor& zero_point,
int64_t quant_min,
int64_t quant_max,
ScalarType dtype,
exec_aten::optional<ScalarType> out_dtype,
Tensor& out) {
// TODO(larryliu): Add a context arg to the real op function and remove this
// wrapper
(void)context;
return dequantize_per_tensor_tensor_args_out(
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out);
}
} // namespace native
} // namespace executor
} // namespace torch