Specalize f32->i8/u8 Quanitization with C++ native arithmetic to optimize performance.
The CL adds a rounding mode flag to the class and changes the default to rmNearestTiesToAway from rmNearestTiesToEven because 1) Tensorflow QuantizeV2 ops uses rmNearestTiesToAway; 2) the specialization only implements rmNearestTiesToAway.
PiperOrigin-RevId: 270600739
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h b/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
index 0ce76b1..f1fb329 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
@@ -68,37 +68,58 @@
class UniformQuantizedValueConverter {
public:
explicit UniformQuantizedValueConverter(UniformQuantizedType uniformType)
- : scale(uniformType.getScale()),
- zeroPoint(static_cast<double>(uniformType.getZeroPoint())),
- clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
- clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
- storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
- isSigned(uniformType.isSigned()) {
+ : UniformQuantizedValueConverter(
+ uniformType.getScale(),
+ static_cast<double>(uniformType.getZeroPoint()),
+ static_cast<double>(uniformType.getStorageTypeMin()),
+ static_cast<double>(uniformType.getStorageTypeMax()),
+ uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) {
assert(uniformType.getExpressedType().isa<FloatType>());
assert(uniformType.getStorageType().isa<IntegerType>());
}
UniformQuantizedValueConverter(double scale, double zeroPoint,
+ double clampMin, double clampMax,
+ uint32_t storageBitWidth, bool isSigned)
+ : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
+ clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
+ clampMinDouble(clampMin), clampMaxDouble(clampMax),
+ storageBitWidth(storageBitWidth), isSigned(isSigned),
+ roundMode(llvm::APFloat::rmNearestTiesToAway) {}
+
+ UniformQuantizedValueConverter(double scale, double zeroPoint,
APFloat clampMin, APFloat clampMax,
uint32_t storageBitWidth, bool isSigned)
: scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
- clampMax(clampMax), storageBitWidth(storageBitWidth),
- isSigned(isSigned) {}
+ clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
+ clampMinDouble(clampMin.convertToDouble()),
+ clampMaxDouble(clampMax.convertToDouble()),
+ storageBitWidth(storageBitWidth), isSigned(isSigned),
+ roundMode(llvm::APFloat::rmNearestTiesToAway) {}
virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
+ // This function is a performance critical code path in quantization
+ // since it runs for each single float parameter value.
+
+ // Specalize f32->u8/i8 case to optimize performance.
+ if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() &&
+ storageBitWidth == 8 &&
+ roundMode == llvm::APFloatBase::rmNearestTiesToAway) {
+ return quantizeF32ToInt8(expressedValue);
+ }
+
bool lossy;
- expressedValue.convert(scale.getSemantics(), APFloat::rmNearestTiesToEven,
- &lossy);
+ expressedValue.convert(scale.getSemantics(), roundMode, &lossy);
// fixedpoint = clamp(clampMin, clampMax, (
// roundHalfToEven(expressed / scale) + zeroPoint))
APFloat scaled = (expressedValue / scale);
- scaled.roundToIntegral(APFloat::rmNearestTiesToEven);
- scaled.add(zeroPoint, APFloat::rmNearestTiesToEven);
+ scaled.roundToIntegral(roundMode);
+ scaled.add(zeroPoint, roundMode);
APFloat fixedpoint = llvm::minimum(scaled, clampMax);
fixedpoint = llvm::maximum(fixedpoint, clampMin);
llvm::APSInt result(storageBitWidth, !isSigned);
- fixedpoint.convertToInteger(result, APFloat::rmNearestTiesToEven, &lossy);
+ fixedpoint.convertToInteger(result, roundMode, &lossy);
return std::move(result);
}
@@ -111,12 +132,48 @@
virtual ~UniformQuantizedValueConverter() {}
private:
+ // An optimized implementation to quantize f32 to i8/u8 with C++ native
+ // arithmetic.
+ virtual APInt quantizeF32ToInt8(APFloat expressedValue) const {
+ assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle());
+ assert(storageBitWidth == 8);
+ assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway);
+
+ const float realValue = expressedValue.convertToFloat();
+
+ const double scaled = realValue / scaleDouble + zeroPointDouble;
+ // Round to nearest integer with halfway cases rounded away from zero.
+ const double scaledRounded = std::round(scaled);
+ const double clamped =
+ std::clamp(scaledRounded, clampMinDouble, clampMaxDouble);
+
+ uint64_t signlessResult;
+ if (isSigned) {
+ int64_t clampedInt = static_cast<int8_t>(clamped);
+ memcpy(&signlessResult, &clampedInt, sizeof(clampedInt));
+ } else {
+ signlessResult = static_cast<uint8_t>(clamped);
+ }
+ llvm::APInt result(storageBitWidth, signlessResult);
+ return result;
+ }
+
+ // Keep both APFloat and double versions of the quantization parameters
+ // around since they will be used in generic and specialized arithmetic,
+ // respectively.
const APFloat scale;
const APFloat zeroPoint;
const APFloat clampMin;
const APFloat clampMax;
+
+ const double scaleDouble;
+ const double zeroPointDouble;
+ const double clampMinDouble;
+ const double clampMaxDouble;
+
const uint32_t storageBitWidth;
const bool isSigned;
+ const llvm::APFloat::roundingMode roundMode;
};
/// An utility class to quantize an attribute by the per-axis quantization