[GML] Simplify stride and offset compostion
PiperOrigin-RevId: 457086467
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_subset_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_subset_ops.cc
index 31efa79..3d71c5f 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_subset_ops.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/gml_st/transforms/compose_subset_ops.cc
@@ -29,89 +29,6 @@
namespace gml_st {
namespace {
-struct OperandOrIntegerRange {
- OperandOrIntegerRange(ValueRange dynamicValues, ArrayAttr staticValues,
- int64_t dynamicIntPlaceholder)
- : dynamicValues(dynamicValues),
- staticValues(staticValues),
- dynamicIntPlaceholder(dynamicIntPlaceholder) {}
-
- struct Iterator {
- public:
- using iterator_category = std::forward_iterator_tag;
- using difference_type = std::ptrdiff_t;
- using value_type = OpFoldResult;
- using pointer = value_type*;
- using reference = value_type&;
-
- private:
- using StaticValuesIteratorTy = const mlir::Attribute*;
- using DynamicValuesIteratorTy = llvm::detail::indexed_accessor_range_base<
- ValueRange,
- llvm::PointerUnion<const Value*, OpOperand*,
- mlir::detail::OpResultImpl*>,
- Value, Value, Value>::iterator;
-
- public:
- Iterator(DynamicValuesIteratorTy dynamicValuesIterator,
- StaticValuesIteratorTy staticValuesIterator,
- int64_t dynamicIntPlaceholder)
- : dynamicValuesIterator(dynamicValuesIterator),
- staticValuesIterator(staticValuesIterator),
- dynamicIntPlaceholder(dynamicIntPlaceholder) {}
-
- OpFoldResult operator*() const {
- if (staticValuesIterator->cast<IntegerAttr>().getInt() ==
- dynamicIntPlaceholder) {
- return *dynamicValuesIterator;
- }
- return *staticValuesIterator;
- }
-
- // Increments.
- Iterator& operator++() {
- int64_t integer = staticValuesIterator->cast<IntegerAttr>().getInt();
- if (integer == dynamicIntPlaceholder) dynamicValuesIterator++;
- staticValuesIterator++;
- return *this;
- }
- Iterator operator++(int) {
- Iterator tmp = *this;
- ++(*this);
- return tmp;
- }
-
- // Equivalence.
- friend bool operator==(const Iterator& a, const Iterator& b) {
- return a.staticValuesIterator == b.staticValuesIterator &&
- a.dynamicValuesIterator == b.dynamicValuesIterator;
- }
- friend bool operator!=(const Iterator& a, const Iterator& b) {
- return !(a == b);
- }
-
- private:
- DynamicValuesIteratorTy dynamicValuesIterator;
- StaticValuesIteratorTy staticValuesIterator;
- int64_t dynamicIntPlaceholder;
- };
-
- public:
- Iterator begin() {
- return Iterator(dynamicValues.begin(), staticValues.begin(),
- dynamicIntPlaceholder);
- }
- Iterator end() {
- return Iterator(dynamicValues.end(), staticValues.end(),
- dynamicIntPlaceholder);
- }
-
- private:
- ValueRange dynamicValues;
- ArrayAttr staticValues;
- int64_t dynamicIntPlaceholder;
-};
-
OpFoldResult multiplyOperandsOrIntegers(PatternRewriter& rewriter, Location loc,
OpFoldResult lhs, OpFoldResult rhs) {
// Both operands are static.
@@ -168,66 +85,48 @@
}
// Compose offsets with newOffset = argOffset + argStride * offset.
-std::pair<SmallVector<Value>, ArrayAttr> composeOffsets(
- ValueRange dynamicOffsets, ArrayAttr staticOffsets,
- ValueRange dynamicStrides, ArrayAttr staticStrides,
- ValueRange argDynamicOffsets, ArrayAttr argStaticOffsets, Location loc,
+SmallVector<OpFoldResult> composeOffsets(
+ const llvm::SmallVectorImpl<OpFoldResult>& argOffsets,
+ const llvm::SmallVectorImpl<OpFoldResult>& argStrides,
+ const llvm::SmallVectorImpl<OpFoldResult>& offsets, Location loc,
PatternRewriter& rewriter) {
- // Create ranges.
- OperandOrIntegerRange offsets(dynamicOffsets, staticOffsets,
- ShapedType::kDynamicStrideOrOffset);
- OperandOrIntegerRange argStrides(dynamicStrides, staticStrides,
- ShapedType::kDynamicStrideOrOffset);
- OperandOrIntegerRange argOffsets(argDynamicOffsets, argStaticOffsets,
- ShapedType::kDynamicStrideOrOffset);
-
- // Compose.
- SmallVector<Value> composedDynamicOffsets;
- SmallVector<int64_t> composedStaticOffsets;
+ SmallVector<OpFoldResult> composedOffsets;
for (auto it : llvm::zip(argOffsets, argStrides, offsets)) {
- auto composed = addOperandsOrIntegers(
+ composedOffsets.push_back(addOperandsOrIntegers(
rewriter, loc, std::get<0>(it),
multiplyOperandsOrIntegers(rewriter, loc, std::get<1>(it),
- std::get<2>(it)));
- if (composed.is<Attribute>()) {
- composedStaticOffsets.push_back(
- composed.get<Attribute>().cast<IntegerAttr>().getInt());
- } else {
- composedStaticOffsets.push_back(ShapedType::kDynamicStrideOrOffset);
- composedDynamicOffsets.push_back(composed.get<Value>());
- }
+ std::get<2>(it))));
}
- return {composedDynamicOffsets,
- rewriter.getI64ArrayAttr(composedStaticOffsets)};
+ return composedOffsets;
}
// Compose strides with newStride = argStride * stride.
-std::pair<SmallVector<Value>, ArrayAttr> composeStrides(
- PatternRewriter& rewriter, Location loc, ValueRange argDynamicStrides,
- ArrayAttr argStaticStrides, ValueRange dynamicStrides,
- ArrayAttr staticStrides) {
- // Create ranges.
- OperandOrIntegerRange argStrides(argDynamicStrides, argStaticStrides,
- ShapedType::kDynamicStrideOrOffset);
- OperandOrIntegerRange strides(dynamicStrides, staticStrides,
- ShapedType::kDynamicStrideOrOffset);
-
- // Compose.
- SmallVector<Value> composedDynamicStrides;
- SmallVector<int64_t> composedStaticStrides;
+SmallVector<OpFoldResult> composeStrides(
+ PatternRewriter& rewriter, Location loc,
+ const llvm::SmallVectorImpl<OpFoldResult>& argStrides,
+ const llvm::SmallVectorImpl<OpFoldResult>& strides) {
+ SmallVector<OpFoldResult> composedStrides;
for (auto it : llvm::zip(argStrides, strides)) {
- auto product = multiplyOperandsOrIntegers(rewriter, loc, std::get<0>(it),
- std::get<1>(it));
- if (product.is<Attribute>()) {
- composedStaticStrides.push_back(
- product.get<Attribute>().cast<IntegerAttr>().getInt());
+ composedStrides.push_back(multiplyOperandsOrIntegers(
+ rewriter, loc, std::get<0>(it), std::get<1>(it)));
+ }
+ return composedStrides;
+}
+
+// TODO(frgossen): Move this upstream to the ViewLikeInterface
+std::pair<ArrayAttr, SmallVector<Value>> decomposeMixedStridesOrOffsets(
+ OpBuilder& b, const SmallVectorImpl<OpFoldResult>& mixedValues) {
+ SmallVector<int64_t> staticValues;
+ SmallVector<Value> dynamicValues;
+ for (const auto& it : mixedValues) {
+ if (it.is<Attribute>()) {
+ staticValues.push_back(it.get<Attribute>().cast<IntegerAttr>().getInt());
} else {
- composedStaticStrides.push_back(ShapedType::kDynamicStrideOrOffset);
- composedDynamicStrides.push_back(product.get<Value>());
+ staticValues.push_back(ShapedType::kDynamicStrideOrOffset);
+ dynamicValues.push_back(it.get<Value>());
}
}
- return {composedDynamicStrides,
- rewriter.getI64ArrayAttr(composedStaticStrides)};
+ return {b.getI64ArrayAttr(staticValues), dynamicValues};
}
struct ComposeTilesPattern : public OpRewritePattern<TileOp> {
@@ -240,24 +139,21 @@
// Compose offsets with newOffset = argOffset + argStride * offset.
auto loc = op.getLoc();
- auto composedOffsets =
- composeOffsets(op.offsets(), op.static_offsets(), argOp.strides(),
- argOp.static_strides(), argOp.offsets(),
- argOp.static_offsets(), loc, rewriter);
-
- // Reuse sizes.
- std::pair composedSizes = {op.sizes(), op.static_sizes()};
+ auto composedOffsets = decomposeMixedStridesOrOffsets(
+ rewriter,
+ composeOffsets(argOp.getMixedOffsets(), argOp.getMixedStrides(),
+ op.getMixedOffsets(), loc, rewriter));
// Compose strides with newStride = argStride * stride.
- auto newStrides =
- composeStrides(rewriter, loc, argOp.strides(), argOp.static_strides(),
- op.strides(), op.static_strides());
+ auto composedStrides = decomposeMixedStridesOrOffsets(
+ rewriter, composeStrides(rewriter, loc, argOp.getMixedStrides(),
+ op.getMixedStrides()));
// Build the composed tile op.
rewriter.replaceOpWithNewOp<TileOp>(
- op, argOp.subset(), composedOffsets.first, composedSizes.first,
- newStrides.first, composedOffsets.second, composedSizes.second,
- newStrides.second);
+ op, argOp.subset(), composedOffsets.second, op.sizes(),
+ composedStrides.second, composedOffsets.first, op.static_sizes(),
+ composedStrides.first);
return success();
}
};