blob: df323846cca0a65b80f64c9b5b9fbe0876ca773d [file] [log] [blame]
#include "fused_8bit_rowwise_conversion.h"
#include <c10/util/Half.h>
#include <algorithm>
#include <cmath>
#include "common.h"
namespace caffe2 {
void FloatToFused8BitRowwiseQuantized__base(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
constexpr float kEpsilon = 1e-8f;
int output_columns = input_columns + 2 * sizeof(float);
for (std::size_t row = 0; row < input_rows; ++row) {
const float* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
float* output_row_scale_bias =
reinterpret_cast<float*>(output_row + input_columns);
float minimum_element =
*std::min_element(input_row, input_row + input_columns);
float maximum_element =
*std::max_element(input_row, input_row + input_columns);
float range = maximum_element - minimum_element;
output_row_scale_bias[0] = range / 255.0f;
output_row_scale_bias[1] = minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon);
for (std::size_t col = 0; col < input_columns; ++col) {
output_row[col] =
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
}
}
}
void Fused8BitRowwiseQuantizedToFloat__base(
const std::uint8_t* input,
int input_rows,
int input_columns,
float* output) {
int output_columns = input_columns - 2 * sizeof(float);
for (std::size_t row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const float* input_row_scale_bias =
reinterpret_cast<const float*>(input_row + output_columns);
float* output_row = output + row * output_columns;
for (std::size_t col = 0; col < output_columns; ++col) {
output_row[col] =
input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
}
}
}
decltype(FloatToFused8BitRowwiseQuantized__base)
FloatToFused8BitRowwiseQuantized__avx2_fma;
void FloatToFused8BitRowwiseQuantized(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
AVX2_FMA_DO(
FloatToFused8BitRowwiseQuantized,
input,
input_rows,
input_columns,
output);
BASE_DO(
FloatToFused8BitRowwiseQuantized,
input,
input_rows,
input_columns,
output);
}
decltype(Fused8BitRowwiseQuantizedToFloat__base)
Fused8BitRowwiseQuantizedToFloat__avx2_fma;
void Fused8BitRowwiseQuantizedToFloat(
const std::uint8_t* input,
int input_rows,
int input_columns,
float* output) {
AVX2_FMA_DO(
Fused8BitRowwiseQuantizedToFloat,
input,
input_rows,
input_columns,
output);
BASE_DO(
Fused8BitRowwiseQuantizedToFloat,
input,
input_rows,
input_columns,
output);
}
void FloatToFused8BitRowwiseQuantizedSBHalf__base(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
int output_columns = input_columns + 2 * sizeof(at::Half);
for (std::size_t row = 0; row < input_rows; ++row) {
const float* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
at::Half* output_row_scale_bias =
reinterpret_cast<at::Half*>(output_row + input_columns);
float Xmin = *std::min_element(input_row, input_row + input_columns);
float Xmax = *std::max_element(input_row, input_row + input_columns);
Xmin = static_cast<at::Half>(Xmin);
const float range = Xmax - Xmin;
at::Half scale = range == 0 ? 1.0f : range / 255.0;
if (scale == 0) {
// Corner case handling when Xmax == Xmin
// Any scale would work because X - Xmin will be 0 for all X
scale = 1.0f;
}
output_row_scale_bias[0] = scale;
output_row_scale_bias[1] = Xmin;
for (std::size_t col = 0; col < input_columns; ++col) {
float X = input_row[col];
std::uint8_t quantized =
std::max(0, std::min<int>(std::lrintf((X - Xmin) / scale), 255));
output_row[col] = quantized;
}
}
}
void Fused8BitRowwiseQuantizedSBHalfToFloat__base(
const std::uint8_t* input,
int input_rows,
int input_columns,
float* output) {
int output_columns = input_columns - 2 * sizeof(at::Half);
for (std::size_t row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const at::Half* input_row_scale_bias =
reinterpret_cast<const at::Half*>(input_row + output_columns);
float* output_row = output + row * output_columns;
for (std::size_t col = 0; col < output_columns; ++col) {
output_row[col] = (int)input_row[col] * input_row_scale_bias[0] +
input_row_scale_bias[1];
}
}
}
decltype(FloatToFused8BitRowwiseQuantizedSBHalf__base)
FloatToFused8BitRowwiseQuantizedSBHalf__avx2_fma;
void FloatToFused8BitRowwiseQuantizedSBHalf(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
// TODO: Add AVX2 version
BASE_DO(
FloatToFused8BitRowwiseQuantizedSBHalf,
input,
input_rows,
input_columns,
output);
}
decltype(Fused8BitRowwiseQuantizedSBHalfToFloat__base)
Fused8BitRowwiseQuantizedSBHalfToFloat__avx2_fma;
void Fused8BitRowwiseQuantizedSBHalfToFloat(
const std::uint8_t* input,
int input_rows,
int input_columns,
float* output) {
// TODO: Add AVX2 version
BASE_DO(
Fused8BitRowwiseQuantizedSBHalfToFloat,
input,
input_rows,
input_columns,
output);
}
} // namespace caffe2