blob: f96ee9e2cea9b1be4d4cb4d2afefa3bd9b8753b4 [file] [log] [blame]
//===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
#define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
#include "mlir/IR/Operation.h"
#include "mlir/Quantization/QuantOps.h"
#include "mlir/Quantization/QuantTypes.h"
#include "mlir/Quantization/UniformSupport.h"
#include <cmath>
namespace mlir {
namespace fxpmath {
namespace detail {
inline quant::UniformQuantizedType getUniformElementType(Type t) {
return quant::QuantizedType::getQuantizedElementType(t)
.dyn_cast_or_null<quant::UniformQuantizedType>();
}
inline bool hasStorageBitWidth(quant::QuantizedType t,
llvm::ArrayRef<unsigned> checkWidths) {
unsigned w = t.getStorageType().getIntOrFloatBitWidth();
for (unsigned checkWidth : checkWidths) {
if (w == checkWidth)
return true;
}
return false;
}
/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
/// be considered an exact integral value.
template <typename F> bool integralLog2(F x, int &log2Result) {
const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
const F xLog2Rounded = std::round(xLog2);
const F xLog2Frac = xLog2 - xLog2Rounded;
log2Result = static_cast<int>(xLog2Rounded);
// Allow small comparison slop below the level that would make a difference
// for 2^16 levels.
return std::abs(xLog2Frac) < 1e-6;
}
/// Helper class for operating on binary operations where all operands
/// and the result are a UniformQuantizedType.
struct UniformBinaryOpInfo {
UniformBinaryOpInfo(Operation *op, Value *lhs, Value *rhs,
Optional<APFloat> clampMin, Optional<APFloat> clampMax)
: op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
lhsType(getUniformElementType(lhs->getType())),
rhsType(getUniformElementType(rhs->getType())),
resultType(getUniformElementType(*op->result_type_begin())),
lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())),
rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())),
resultStorageType(
quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
}
/// Returns whether this info is valid (all types defined, etc).
bool isValid() const {
return lhsType && rhsType && resultType && lhsStorageType &&
rhsStorageType && resultStorageType;
}
/// Gets the final quantized result type of the result.
Type getQuantizedResultType() const { return *op->result_type_begin(); }
/// Returns whether the storage type of all operands is identical.
bool isSameStorageType() const {
return lhsType.getStorageType() == rhsType.getStorageType() &&
lhsType.getStorageType() == resultType.getStorageType();
}
/// Returns whether all operands and result are considered fixedpoint power
/// of two, setting the lhs, rhs, and result log2 scale references.
bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
int &resultLog2Scale) const {
if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
!resultType.isFixedPoint()) {
return false;
}
if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
!integralLog2(rhsType.getScale(), rhsLog2Scale) ||
!integralLog2(resultType.getScale(), resultLog2Scale)) {
return false;
}
return true;
}
/// Gets the result integer clamp range given the result quantized type
// and any explicit clamp provided as attributes.
std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const {
int64_t typeMin = resultType.getStorageTypeMin();
int64_t typeMax = resultType.getStorageTypeMax();
if (clampMin || clampMax) {
quant::UniformQuantizedValueConverter conv(resultType);
if (clampMin) {
typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
}
if (clampMax) {
typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
}
}
// The quantized, integral ops expect clamps as 32bit ints.
return {
IntegerAttr::get(ty, typeMin),
IntegerAttr::get(ty, typeMax),
};
}
Operation *op;
Value *lhs;
Value *rhs;
Optional<APFloat> clampMin;
Optional<APFloat> clampMax;
// Element UniformQuantizedType for operands/result.
quant::UniformQuantizedType lhsType;
quant::UniformQuantizedType rhsType;
quant::UniformQuantizedType resultType;
// Full storage-based types.
Type lhsStorageType;
Type rhsStorageType;
Type resultStorageType;
};
/// Derives a quantized multiplier and shift from a real valued multiplier
/// less than 1.
struct QuantizedMultiplierSmallerThanOneExp {
QuantizedMultiplierSmallerThanOneExp(double realMultiplier) {
assert(realMultiplier < 1.0);
assert(realMultiplier > 0.0);
const double q = std::frexp(realMultiplier, &exponent);
auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
assert(qFixed <= (1ll << 31));
if (qFixed == (1ll << 31)) {
qFixed /= 2;
++exponent;
}
assert(qFixed <= std::numeric_limits<int32_t>::max());
multiplier = static_cast<int32_t>(qFixed);
}
int32_t multiplier;
int exponent;
};
/// Casts an integer or floating point based type to a new element type.
inline Type castElementType(Type t, Type newElementType) {
if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
switch (vt.getKind()) {
case StandardTypes::Kind::Vector:
return VectorType::get(vt.getShape(), newElementType);
case StandardTypes::Kind::RankedTensor:
return RankedTensorType::get(vt.getShape(), newElementType);
case StandardTypes::Kind::UnrankedTensor:
return UnrankedTensorType::get(newElementType);
}
}
assert(t.isIntOrFloat());
return newElementType;
}
/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
/// be a primitive/vector/tensor).
inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
assert(vt.getElementType().isa<IntegerType>());
return SplatElementsAttr::get(vt,
IntegerAttr::get(vt.getElementType(), value));
}
auto integerType = t.cast<IntegerType>();
assert(t.isa<IntegerType>() && "integer broadcast must be of integer type");
return IntegerAttr::get(integerType, value);
}
/// Given an APFloat, converts it to the float semantics that matches the
/// given FloatType, silently ignoring inexact conversions.
inline APFloat convertFloatToType(FloatType ft, APFloat value) {
bool losesInfo;
auto status = value.convert(ft.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
(void)status; // unused in opt mode
assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 &&
"could not convert to float const");
return value;
}
/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
/// be a primitive/vector/tensor).
inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
FloatType floatElementType = vt.getElementType().dyn_cast<FloatType>();
assert(floatElementType &&
"float broadcast element type must be float like");
APFloat apValue = convertFloatToType(floatElementType, value);
return SplatElementsAttr::get(vt,
FloatAttr::get(vt.getElementType(), apValue));
} else {
auto floatType = t.dyn_cast<FloatType>();
assert(floatType && "float broadcast must be of float type");
APFloat apValue = convertFloatToType(floatType, value);
return FloatAttr::get(floatType, apValue);
}
}
} // namespace detail
} // namespace fxpmath
} // namespace mlir
#endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_