Convert per channel fake quant attributes to type

For per channel fake quant attributes, the returned type should be
UniformQuantizedPerAxisType. Currently, this method isn't under test because we
haven't added the quant_ConstFakeQuantPerAxis op and the convert method.

PiperOrigin-RevId: 268084017
diff --git a/third_party/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h b/third_party/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h
index 560b632..23e2967 100644
--- a/third_party/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h
+++ b/third_party/mlir/include/mlir/Dialect/QuantOps/FakeQuantSupport.h
@@ -62,6 +62,14 @@
                                           bool narrowRange, Type expressedType,
                                           bool isSigned = false);
 
+/// Converts per-channel FakeQuant attributes to the corresponding type.
+/// In the event that the parameters cannot be converted, returns a nullptr
+/// convertible Type and issues an appropriate error.
+UniformQuantizedPerAxisType
+fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
+                     ArrayRef<double> rmins, ArrayRef<double> rmax,
+                     bool narrowRange, Type expressedType,
+                     bool isSigned = false);
 } // namespace quant
 } // namespace mlir
 
diff --git a/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
index 637f6a0..02f803a 100644
--- a/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
+++ b/third_party/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
@@ -18,71 +18,48 @@
 #include "mlir/Dialect/QuantOps/FakeQuantSupport.h"
 #include "mlir/Dialect/QuantOps/QuantTypes.h"
 
-using namespace mlir;
-using namespace mlir::quant;
-
-UniformQuantizedType
-mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
-                                  double rmax, bool narrowRange,
-                                  Type expressedType, bool isSigned) {
-  MLIRContext *ctx = expressedType.getContext();
-  Type storageType;
-  unsigned flags;
-  int64_t qmin;
-  int64_t qmax;
-
+namespace mlir {
+namespace quant {
+namespace {
+bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned,
+                             MLIRContext *ctx, Type &storageType, int64_t &qmin,
+                             int64_t &qmax) {
   // Hard-coded type mapping from TFLite.
   if (numBits <= 8) {
     storageType = IntegerType::get(8, ctx);
     if (isSigned) {
-      flags = QuantizationFlags::Signed;
       qmin = -128;
       qmax = 127;
     } else {
-      flags = 0;
       qmin = 0;
       qmax = 255;
     }
   } else if (numBits <= 16) {
     storageType = IntegerType::get(16, ctx);
     if (isSigned) {
-      flags = QuantizationFlags::Signed;
       qmin = -32768;
       qmax = 32767;
     } else {
-      flags = 0;
       qmin = 0;
       qmax = 65535;
     }
   } else {
-    emitError(loc, "unsupported FakeQuant number of bits: ") << numBits;
-    return nullptr;
+    return true;
   }
 
   // Handle narrowRange.
   if (narrowRange) {
     qmin += 1;
   }
+  return false;
+}
 
-  // Range must straddle zero.
-  if (rmin > 0.0 || rmax < 0.0) {
-    return (emitError(loc, "FakeQuant range must straddle zero: [")
-                << rmin << "," << rmax << "]",
-            nullptr);
-  }
-
-  // Special case where min/max is close enough. The tensor contents are all
-  // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
-  // points and dequantized to 0.0.
-  if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
-    return UniformQuantizedType::getChecked(flags, storageType, expressedType,
-                                            1.0, qmin, qmin, qmax, loc);
-  }
-
+void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax,
+                          double &scale, int64_t &nudgedZeroPoint) {
   // Determine the scale.
   const double qminDouble = qmin;
   const double qmaxDouble = qmax;
-  const double scale = (rmax - rmin) / (qmaxDouble - qminDouble);
+  scale = (rmax - rmin) / (qmaxDouble - qminDouble);
 
   // Zero point computation.
   // In float, solve the affine equation for any known pair
@@ -103,7 +80,7 @@
                                      : zeroPointFromMax;
 
   // Now nudge the zero point to be an integer.
-  int64_t nudgedZeroPoint = 0;
+  nudgedZeroPoint = 0;
   if (zeroPointDouble < qminDouble) {
     nudgedZeroPoint = qmin;
   } else if (zeroPointDouble > qmaxDouble) {
@@ -115,8 +92,97 @@
   // By construction, the nudged zero point should always be in range.
   assert(nudgedZeroPoint >= qmin);
   assert(nudgedZeroPoint <= qmax);
+}
+
+} // end namespace
+
+UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
+                                          double rmin, double rmax,
+                                          bool narrowRange, Type expressedType,
+                                          bool isSigned) {
+  // Range must straddle zero.
+  // TODO(b/140641593): remove this constraint.
+  if (rmin > 0.0 || rmax < 0.0) {
+    return (emitError(loc, "FakeQuant range must straddle zero: [")
+                << rmin << "," << rmax << "]",
+            nullptr);
+  }
+
+  MLIRContext *ctx = expressedType.getContext();
+  unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
+  Type storageType;
+  int64_t qmin;
+  int64_t qmax;
+  if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
+                              qmin, qmax)) {
+    return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
+            nullptr);
+  }
+
+  // Special case where min/max is close enough. The tensor contents are all
+  // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
+  // points and dequantized to 0.0.
+  if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
+    return UniformQuantizedType::getChecked(flags, storageType, expressedType,
+                                            1.0, qmin, qmin, qmax, loc);
+  }
+
+  double scale;
+  int64_t nudgedZeroPoint;
+  getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
 
   return UniformQuantizedType::getChecked(flags, storageType, expressedType,
                                           scale, nudgedZeroPoint, qmin, qmax,
                                           loc);
 }
+
+// TODO(fengliuai): test this method once the quantizeAttr method is fixed.
+UniformQuantizedPerAxisType
+fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
+                     ArrayRef<double> rmins, ArrayRef<double> rmaxs,
+                     bool narrowRange, Type expressedType, bool isSigned) {
+  size_t axis_size = rmins.size();
+  if (axis_size != rmaxs.size()) {
+    return (emitError(loc, "mismatched per-axis min and max size: ")
+                << axis_size << " vs. " << rmaxs.size(),
+            nullptr);
+  }
+
+  MLIRContext *ctx = expressedType.getContext();
+  Type storageType;
+  int64_t qmin;
+  int64_t qmax;
+  if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
+                              qmin, qmax)) {
+    return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
+            nullptr);
+  }
+
+  SmallVector<double, 4> scales;
+  SmallVector<int64_t, 4> zeroPoints;
+  scales.reserve(axis_size);
+  zeroPoints.reserve(axis_size);
+  for (size_t axis = 0; axis != axis_size; ++axis) {
+    double rmin = rmins[axis];
+    double rmax = rmaxs[axis];
+    if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
+      scales.push_back(1.0);
+      zeroPoints.push_back(qmin);
+      continue;
+    }
+
+    double scale;
+    int64_t nudgedZeroPoint;
+    getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
+    scales.push_back(scale);
+    zeroPoints.push_back(nudgedZeroPoint);
+  }
+
+  unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
+  return UniformQuantizedPerAxisType::getChecked(
+      flags, storageType, expressedType, scales, zeroPoints, qmin, qmax,
+      quantizedDimension, loc);
+}
+
+} // namespace quant
+} // namespace mlir