[GML] Simplify stride and offset compostion

PiperOrigin-RevId: 457086467
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_subset_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_subset_ops.cc
index 31efa79..3d71c5f 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_subset_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_subset_ops.cc
@@ -29,89 +29,6 @@
 namespace gml_st {
 namespace {
 
-struct OperandOrIntegerRange {
-  OperandOrIntegerRange(ValueRange dynamicValues, ArrayAttr staticValues,
-                        int64_t dynamicIntPlaceholder)
-      : dynamicValues(dynamicValues),
-        staticValues(staticValues),
-        dynamicIntPlaceholder(dynamicIntPlaceholder) {}
-
-  struct Iterator {
-   public:
-    using iterator_category = std::forward_iterator_tag;
-    using difference_type = std::ptrdiff_t;
-    using value_type = OpFoldResult;
-    using pointer = value_type*;
-    using reference = value_type&;
-
-   private:
-    using StaticValuesIteratorTy = const mlir::Attribute*;
-    using DynamicValuesIteratorTy = llvm::detail::indexed_accessor_range_base<
-        ValueRange,
-        llvm::PointerUnion<const Value*, OpOperand*,
-                           mlir::detail::OpResultImpl*>,
-        Value, Value, Value>::iterator;
-
-   public:
-    Iterator(DynamicValuesIteratorTy dynamicValuesIterator,
-             StaticValuesIteratorTy staticValuesIterator,
-             int64_t dynamicIntPlaceholder)
-        : dynamicValuesIterator(dynamicValuesIterator),
-          staticValuesIterator(staticValuesIterator),
-          dynamicIntPlaceholder(dynamicIntPlaceholder) {}
-
-    OpFoldResult operator*() const {
-      if (staticValuesIterator->cast<IntegerAttr>().getInt() ==
-          dynamicIntPlaceholder) {
-        return *dynamicValuesIterator;
-      }
-      return *staticValuesIterator;
-    }
-
-    // Increments.
-    Iterator& operator++() {
-      int64_t integer = staticValuesIterator->cast<IntegerAttr>().getInt();
-      if (integer == dynamicIntPlaceholder) dynamicValuesIterator++;
-      staticValuesIterator++;
-      return *this;
-    }
-    Iterator operator++(int) {
-      Iterator tmp = *this;
-      ++(*this);
-      return tmp;
-    }
-
-    // Equivalence.
-    friend bool operator==(const Iterator& a, const Iterator& b) {
-      return a.staticValuesIterator == b.staticValuesIterator &&
-             a.dynamicValuesIterator == b.dynamicValuesIterator;
-    }
-    friend bool operator!=(const Iterator& a, const Iterator& b) {
-      return !(a == b);
-    }
-
-   private:
-    DynamicValuesIteratorTy dynamicValuesIterator;
-    StaticValuesIteratorTy staticValuesIterator;
-    int64_t dynamicIntPlaceholder;
-  };
-
- public:
-  Iterator begin() {
-    return Iterator(dynamicValues.begin(), staticValues.begin(),
-                    dynamicIntPlaceholder);
-  }
-  Iterator end() {
-    return Iterator(dynamicValues.end(), staticValues.end(),
-                    dynamicIntPlaceholder);
-  }
-
- private:
-  ValueRange dynamicValues;
-  ArrayAttr staticValues;
-  int64_t dynamicIntPlaceholder;
-};
-
 OpFoldResult multiplyOperandsOrIntegers(PatternRewriter& rewriter, Location loc,
                                         OpFoldResult lhs, OpFoldResult rhs) {
   // Both operands are static.
@@ -168,66 +85,48 @@
 }
 
 // Compose offsets with newOffset = argOffset + argStride * offset.
-std::pair<SmallVector<Value>, ArrayAttr> composeOffsets(
-    ValueRange dynamicOffsets, ArrayAttr staticOffsets,
-    ValueRange dynamicStrides, ArrayAttr staticStrides,
-    ValueRange argDynamicOffsets, ArrayAttr argStaticOffsets, Location loc,
+SmallVector<OpFoldResult> composeOffsets(
+    const llvm::SmallVectorImpl<OpFoldResult>& argOffsets,
+    const llvm::SmallVectorImpl<OpFoldResult>& argStrides,
+    const llvm::SmallVectorImpl<OpFoldResult>& offsets, Location loc,
     PatternRewriter& rewriter) {
-  // Create ranges.
-  OperandOrIntegerRange offsets(dynamicOffsets, staticOffsets,
-                                ShapedType::kDynamicStrideOrOffset);
-  OperandOrIntegerRange argStrides(dynamicStrides, staticStrides,
-                                   ShapedType::kDynamicStrideOrOffset);
-  OperandOrIntegerRange argOffsets(argDynamicOffsets, argStaticOffsets,
-                                   ShapedType::kDynamicStrideOrOffset);
-
-  // Compose.
-  SmallVector<Value> composedDynamicOffsets;
-  SmallVector<int64_t> composedStaticOffsets;
+  SmallVector<OpFoldResult> composedOffsets;
   for (auto it : llvm::zip(argOffsets, argStrides, offsets)) {
-    auto composed = addOperandsOrIntegers(
+    composedOffsets.push_back(addOperandsOrIntegers(
         rewriter, loc, std::get<0>(it),
         multiplyOperandsOrIntegers(rewriter, loc, std::get<1>(it),
-                                   std::get<2>(it)));
-    if (composed.is<Attribute>()) {
-      composedStaticOffsets.push_back(
-          composed.get<Attribute>().cast<IntegerAttr>().getInt());
-    } else {
-      composedStaticOffsets.push_back(ShapedType::kDynamicStrideOrOffset);
-      composedDynamicOffsets.push_back(composed.get<Value>());
-    }
+                                   std::get<2>(it))));
   }
-  return {composedDynamicOffsets,
-          rewriter.getI64ArrayAttr(composedStaticOffsets)};
+  return composedOffsets;
 }
 
 // Compose strides with newStride = argStride * stride.
-std::pair<SmallVector<Value>, ArrayAttr> composeStrides(
-    PatternRewriter& rewriter, Location loc, ValueRange argDynamicStrides,
-    ArrayAttr argStaticStrides, ValueRange dynamicStrides,
-    ArrayAttr staticStrides) {
-  // Create ranges.
-  OperandOrIntegerRange argStrides(argDynamicStrides, argStaticStrides,
-                                   ShapedType::kDynamicStrideOrOffset);
-  OperandOrIntegerRange strides(dynamicStrides, staticStrides,
-                                ShapedType::kDynamicStrideOrOffset);
-
-  // Compose.
-  SmallVector<Value> composedDynamicStrides;
-  SmallVector<int64_t> composedStaticStrides;
+SmallVector<OpFoldResult> composeStrides(
+    PatternRewriter& rewriter, Location loc,
+    const llvm::SmallVectorImpl<OpFoldResult>& argStrides,
+    const llvm::SmallVectorImpl<OpFoldResult>& strides) {
+  SmallVector<OpFoldResult> composedStrides;
   for (auto it : llvm::zip(argStrides, strides)) {
-    auto product = multiplyOperandsOrIntegers(rewriter, loc, std::get<0>(it),
-                                              std::get<1>(it));
-    if (product.is<Attribute>()) {
-      composedStaticStrides.push_back(
-          product.get<Attribute>().cast<IntegerAttr>().getInt());
+    composedStrides.push_back(multiplyOperandsOrIntegers(
+        rewriter, loc, std::get<0>(it), std::get<1>(it)));
+  }
+  return composedStrides;
+}
+
+// TODO(frgossen): Move this upstream to the ViewLikeInterface
+std::pair<ArrayAttr, SmallVector<Value>> decomposeMixedStridesOrOffsets(
+    OpBuilder& b, const SmallVectorImpl<OpFoldResult>& mixedValues) {
+  SmallVector<int64_t> staticValues;
+  SmallVector<Value> dynamicValues;
+  for (const auto& it : mixedValues) {
+    if (it.is<Attribute>()) {
+      staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
     } else {
-      composedStaticStrides.push_back(ShapedType::kDynamicStrideOrOffset);
-      composedDynamicStrides.push_back(product.get<Value>());
+      staticValues.push_back(ShapedType::kDynamicStrideOrOffset);
+      dynamicValues.push_back(it.get<Value>());
     }
   }
-  return {composedDynamicStrides,
-          rewriter.getI64ArrayAttr(composedStaticStrides)};
+  return {b.getI64ArrayAttr(staticValues), dynamicValues};
 }
 
 struct ComposeTilesPattern : public OpRewritePattern<TileOp> {
@@ -240,24 +139,21 @@
 
     // Compose offsets with newOffset = argOffset + argStride * offset.
     auto loc = op.getLoc();
-    auto composedOffsets =
-        composeOffsets(op.offsets(), op.static_offsets(), argOp.strides(),
-                       argOp.static_strides(), argOp.offsets(),
-                       argOp.static_offsets(), loc, rewriter);
-
-    // Reuse sizes.
-    std::pair composedSizes = {op.sizes(), op.static_sizes()};
+    auto composedOffsets = decomposeMixedStridesOrOffsets(
+        rewriter,
+        composeOffsets(argOp.getMixedOffsets(), argOp.getMixedStrides(),
+                       op.getMixedOffsets(), loc, rewriter));
 
     // Compose strides with newStride = argStride * stride.
-    auto newStrides =
-        composeStrides(rewriter, loc, argOp.strides(), argOp.static_strides(),
-                       op.strides(), op.static_strides());
+    auto composedStrides = decomposeMixedStridesOrOffsets(
+        rewriter, composeStrides(rewriter, loc, argOp.getMixedStrides(),
+                                 op.getMixedStrides()));
 
     // Build the composed tile op.
     rewriter.replaceOpWithNewOp<TileOp>(
-        op, argOp.subset(), composedOffsets.first, composedSizes.first,
-        newStrides.first, composedOffsets.second, composedSizes.second,
-        newStrides.second);
+        op, argOp.subset(), composedOffsets.second, op.sizes(),
+        composedStrides.second, composedOffsets.first, op.static_sizes(),
+        composedStrides.first);
     return success();
   }
 };