Specalize f32->i8/u8 Quanitization with C++ native arithmetic to optimize performance.

The CL adds a rounding mode flag to the class and changes the default to rmNearestTiesToAway from rmNearestTiesToEven because 1) Tensorflow QuantizeV2 ops uses rmNearestTiesToAway; 2) the specialization only implements rmNearestTiesToAway.

PiperOrigin-RevId: 270600739
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h b/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
index 0ce76b1..f1fb329 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
@@ -68,37 +68,58 @@
 class UniformQuantizedValueConverter {
 public:
   explicit UniformQuantizedValueConverter(UniformQuantizedType uniformType)
-      : scale(uniformType.getScale()),
-        zeroPoint(static_cast<double>(uniformType.getZeroPoint())),
-        clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
-        clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
-        storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
-        isSigned(uniformType.isSigned()) {
+      : UniformQuantizedValueConverter(
+            uniformType.getScale(),
+            static_cast<double>(uniformType.getZeroPoint()),
+            static_cast<double>(uniformType.getStorageTypeMin()),
+            static_cast<double>(uniformType.getStorageTypeMax()),
+            uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) {
     assert(uniformType.getExpressedType().isa<FloatType>());
     assert(uniformType.getStorageType().isa<IntegerType>());
   }
 
   UniformQuantizedValueConverter(double scale, double zeroPoint,
+                                 double clampMin, double clampMax,
+                                 uint32_t storageBitWidth, bool isSigned)
+      : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
+        clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
+        clampMinDouble(clampMin), clampMaxDouble(clampMax),
+        storageBitWidth(storageBitWidth), isSigned(isSigned),
+        roundMode(llvm::APFloat::rmNearestTiesToAway) {}
+
+  UniformQuantizedValueConverter(double scale, double zeroPoint,
                                  APFloat clampMin, APFloat clampMax,
                                  uint32_t storageBitWidth, bool isSigned)
       : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
-        clampMax(clampMax), storageBitWidth(storageBitWidth),
-        isSigned(isSigned) {}
+        clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
+        clampMinDouble(clampMin.convertToDouble()),
+        clampMaxDouble(clampMax.convertToDouble()),
+        storageBitWidth(storageBitWidth), isSigned(isSigned),
+        roundMode(llvm::APFloat::rmNearestTiesToAway) {}
 
   virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
+    // This function is a performance critical code path in quantization
+    // since it runs for each single float parameter value.
+
+    // Specalize f32->u8/i8 case to optimize performance.
+    if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() &&
+        storageBitWidth == 8 &&
+        roundMode == llvm::APFloatBase::rmNearestTiesToAway) {
+      return quantizeF32ToInt8(expressedValue);
+    }
+
     bool lossy;
-    expressedValue.convert(scale.getSemantics(), APFloat::rmNearestTiesToEven,
-                           &lossy);
+    expressedValue.convert(scale.getSemantics(), roundMode, &lossy);
     // fixedpoint = clamp(clampMin, clampMax, (
     //   roundHalfToEven(expressed / scale) + zeroPoint))
     APFloat scaled = (expressedValue / scale);
-    scaled.roundToIntegral(APFloat::rmNearestTiesToEven);
-    scaled.add(zeroPoint, APFloat::rmNearestTiesToEven);
+    scaled.roundToIntegral(roundMode);
+    scaled.add(zeroPoint, roundMode);
     APFloat fixedpoint = llvm::minimum(scaled, clampMax);
     fixedpoint = llvm::maximum(fixedpoint, clampMin);
 
     llvm::APSInt result(storageBitWidth, !isSigned);
-    fixedpoint.convertToInteger(result, APFloat::rmNearestTiesToEven, &lossy);
+    fixedpoint.convertToInteger(result, roundMode, &lossy);
 
     return std::move(result);
   }
@@ -111,12 +132,48 @@
   virtual ~UniformQuantizedValueConverter() {}
 
 private:
+  // An optimized implementation to quantize f32 to i8/u8 with C++ native
+  // arithmetic.
+  virtual APInt quantizeF32ToInt8(APFloat expressedValue) const {
+    assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle());
+    assert(storageBitWidth == 8);
+    assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway);
+
+    const float realValue = expressedValue.convertToFloat();
+
+    const double scaled = realValue / scaleDouble + zeroPointDouble;
+    // Round to nearest integer with halfway cases rounded away from zero.
+    const double scaledRounded = std::round(scaled);
+    const double clamped =
+        std::clamp(scaledRounded, clampMinDouble, clampMaxDouble);
+
+    uint64_t signlessResult;
+    if (isSigned) {
+      int64_t clampedInt = static_cast<int8_t>(clamped);
+      memcpy(&signlessResult, &clampedInt, sizeof(clampedInt));
+    } else {
+      signlessResult = static_cast<uint8_t>(clamped);
+    }
+    llvm::APInt result(storageBitWidth, signlessResult);
+    return result;
+  }
+
+  // Keep both APFloat and double versions of the quantization parameters
+  // around since they will be used in generic and specialized arithmetic,
+  // respectively.
   const APFloat scale;
   const APFloat zeroPoint;
   const APFloat clampMin;
   const APFloat clampMax;
+
+  const double scaleDouble;
+  const double zeroPointDouble;
+  const double clampMinDouble;
+  const double clampMaxDouble;
+
   const uint32_t storageBitWidth;
   const bool isSigned;
+  const llvm::APFloat::roundingMode roundMode;
 };
 
 /// An utility class to quantize an attribute by the per-axis quantization