blob: c50b3075b6905ff9fb7b81eaa1de39dd92a16d54 [file] [log] [blame]
//===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/Dialect/QuantOps/QuantizeUtils.h"
#include "mlir/Dialect/QuantOps/UniformSupport.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
namespace quant {
/// Converts a possible primitive, real expressed value attribute to a
/// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
/// quantizedElementType is the QuantizedType that describes the expressed
/// origValue.
/// Returns a converter Attribute or nullptr if conversion is not possible.
static Attribute convertPrimitiveValueAttr(
Attribute origRealValue, QuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
if (origRealValue.isa<FloatAttr>()) {
FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
outConvertedType = quantizedElementType.getStorageType();
return IntegerAttr::get(quantizedElementType.getStorageType(),
converter.quantizeFloatToInt(floatAttr.getValue()));
}
return nullptr;
}
/// Converts a real expressed DenseFPElementsAttr to a corresponding
/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
/// storage values assuming the given quantizedElementType and converter.
static DenseElementsAttr
convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
QuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter) {
// Convert to corresponding quantized value attributes.
SmallVector<APInt, 8> quantValues;
quantValues.reserve(realFPElementsAttr.size());
for (APFloat realVal : realFPElementsAttr) {
quantValues.push_back(converter.quantizeFloatToInt(realVal));
}
// Cast from an expressed-type-based type to storage-type-based type,
// preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
ShapedType newDenseType =
quantizedElementType
.castExpressedToStorageType(realFPElementsAttr.getType())
.dyn_cast_or_null<ShapedType>();
if (!newDenseType) {
return nullptr;
}
return DenseIntElementsAttr::get(newDenseType, quantValues);
}
/// Converts a real expressed SplatElementsAttr to a corresponding
/// SplatElementsAttr containing quantized storage values assuming the given
/// quantizedElementType and converter.
static SplatElementsAttr
convertSplatElementsAttr(SplatElementsAttr realSplatAttr,
QuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter) {
// Since the splat just references a single primitive value, use the
// function for converting primitives.
// NOTE: When implementing per-channel, we will need to promote the
// splat to a dense and handle channels individually.
Type unusedPrimitiveType;
auto elementAttr =
convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType,
converter, unusedPrimitiveType);
if (!elementAttr) {
return nullptr;
}
// Cast from an expressed-type-based type to storage-type-based type,
// preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>).
ShapedType newSplatType =
quantizedElementType.castExpressedToStorageType(realSplatAttr.getType())
.dyn_cast_or_null<ShapedType>();
if (!newSplatType) {
return nullptr;
}
return SplatElementsAttr::get(newSplatType, elementAttr);
}
/// Converts a real expressed SplatElementsAttr to a corresponding
/// SplatElementsAttr containing quantized storage values assuming the given
/// quantizedElementType and converter.
static SparseElementsAttr
convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
QuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter) {
DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
return nullptr;
}
DenseElementsAttr quantDenseAttr =
convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
quantizedElementType, converter);
if (!quantDenseAttr) {
return nullptr;
}
// Cast from an expressed-type-based type to storage-type-based type,
// preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
ShapedType newSparseType =
quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
.dyn_cast_or_null<ShapedType>();
if (!newSparseType) {
return nullptr;
}
return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
quantDenseAttr);
}
/// Converts a real expressed Attribute to a corresponding Attribute containing
/// quantized storage values assuming the given uniform quantizedElementType and
/// converter.
Attribute quantizeAttrUniform(Attribute realValue,
UniformQuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter,
Type &outConvertedType) {
// Fork to handle different variants of constants supported.
if (realValue.isa<SplatElementsAttr>()) {
// Splatted tensor or vector constant.
auto converted = convertSplatElementsAttr(
realValue.cast<SplatElementsAttr>(), quantizedElementType, converter);
outConvertedType = converted.getType();
return converted;
} else if (realValue.isa<DenseFPElementsAttr>()) {
// Dense tensor or vector constant.
auto converted = convertDenseFPElementsAttr(
realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
outConvertedType = converted.getType();
return converted;
} else if (realValue.isa<SparseElementsAttr>()) {
// Sparse tensor or vector constant.
auto converted = convertSparseElementsAttr(
realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
outConvertedType = converted.getType();
return converted;
} else {
// Nothing else matched: try to convert a primitive.
return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
outConvertedType);
}
}
/// Convert an attribute from a type based on
/// quantizedElementType.getExpressedType() to one based on
/// quantizedElementType.getStorageType().
/// Returns nullptr if the conversion is not supported.
/// On success, stores the converted type in outConvertedType.
Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
Type &outConvertedType) {
// Hard-coded to just support UniformQuantizedType. This will need to
// be generalized when there is more than one.
auto uniformQuantizedType =
quantizedElementType.dyn_cast<UniformQuantizedType>();
if (!uniformQuantizedType) {
return nullptr;
}
UniformQuantizedValueConverter converter(uniformQuantizedType);
return quantizeAttrUniform(realValue, uniformQuantizedType, converter,
outConvertedType);
}
} // namespace quant
} // namespace mlir