[GmlSt] Use arith.minsi for tiled loop upper bound.
This allows deriving a static upper bound in upstream `convert-parallel-loops-to-gpu` pass, see https://reviews.llvm.org/D132354.
PiperOrigin-RevId: 469133395
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling.cc
index cfb2e3e..d3574a4 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling.cc
@@ -29,7 +29,6 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace gml_st {
@@ -60,12 +59,9 @@
}
// Otherwise, compute the tile size dynamically.
- auto ivNext = b.create<arith::AddIOp>(loc, ivs[i], steps[i]);
- auto isPartialTileInDim = b.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sgt, ivNext, upperBounds[i]);
auto remainderInDim = b.create<arith::SubIOp>(loc, upperBounds[i], ivs[i]);
- auto tileSizeInDim = b.create<arith::SelectOp>(loc, isPartialTileInDim,
- remainderInDim, steps[i]);
+ auto tileSizeInDim =
+ b.create<arith::MinSIOp>(loc, steps[i], remainderInDim);
staticSizes.push_back(ShapedType::kDynamicSize);
dynamicSizes.push_back(tileSizeInDim);
}
diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling.mlir
index 94b37a9..3b47d5ff 100644
--- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling.mlir
+++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/nested_tiling.mlir
@@ -50,14 +50,10 @@
// CHECK-IMPERFECT-SAME: (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
// CHECK-IMPERFECT-SAME: to (%[[C64]], %[[C32]])
// CHECK-IMPERFECT-SAME: step (%[[C17]], %[[C9]])
-// CHECK-IMPERFECT: %[[ADDI:.*]] = arith.addi %[[I]], %[[C17]]
-// CHECK-IMPERFECT: %[[CMPI:.*]] = arith.cmpi sgt, %[[ADDI]], %[[C64]]
// CHECK-IMPERFECT: %[[SUBI:.*]] = arith.subi %[[C64]], %[[I]]
-// CHECK-IMPERFECT: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[SUBI]], %[[C17]]
-// CHECK-IMPERFECT: %[[ADDI_0:.*]] = arith.addi %[[J]], %[[C9]]
-// CHECK-IMPERFECT: %[[CMPI_0:.*]] = arith.cmpi sgt, %[[ADDI_0]], %[[C32]]
+// CHECK-IMPERFECT: %[[SELECT:.*]] = arith.minsi %[[C17]], %[[SUBI]]
// CHECK-IMPERFECT: %[[SUBI_0:.*]] = arith.subi %[[C32]], %[[J]]
-// CHECK-IMPERFECT: %[[SELECT_0:.*]] = arith.select %[[CMPI_0]], %[[SUBI_0]], %[[C9]]
+// CHECK-IMPERFECT: %[[SELECT_0:.*]] = arith.minsi %[[C9]], %[[SUBI_0]]
// CHECK-IMPERFECT: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [%[[I]], %[[J]]] [%[[SELECT]], %[[SELECT_0]]] [1, 1]
// CHECK-IMPERFECT: %[[MED_ARG:.*]] = gml_st.materialize %[[ARG]][%[[TILE]]]
// CHECK-IMPERFECT: %[[MED_INIT:.*]] = gml_st.materialize %[[INIT]][%[[TILE]]]
@@ -70,14 +66,10 @@
// CHECK-IMPERFECT-SAME: (%[[ARG3:.*]], %[[ARG4:.*]]) = (%[[C0]], %[[C0]])
// CHECK-IMPERFECT-SAME: to (%[[DIM]], %[[DIM_0]])
// CHECK-IMPERFECT-SAME: step (%[[C3]], %[[C3]])
-// CHECK-IMPERFECT: %[[ADDI_1:.*]] = arith.addi %[[ARG3]], %[[C3]]
-// CHECK-IMPERFECT: %[[CMPI_1:.*]] = arith.cmpi sgt, %[[ADDI_1]], %[[DIM]]
// CHECK-IMPERFECT: %[[SUBI_1:.*]] = arith.subi %[[DIM]], %[[ARG3]]
-// CHECK-IMPERFECT: %[[SELECT_1:.*]] = arith.select %[[CMPI_1]], %[[SUBI_1]], %[[C3]]
-// CHECK-IMPERFECT: %[[ADDI_2:.*]] = arith.addi %[[ARG4]], %[[C3]]
-// CHECK-IMPERFECT: %[[CMPI_2:.*]] = arith.cmpi sgt, %[[ADDI_2]], %[[DIM_0]]
+// CHECK-IMPERFECT: %[[SELECT_1:.*]] = arith.minsi %[[C3]], %[[SUBI_1]]
// CHECK-IMPERFECT: %[[SUBI_2:.*]] = arith.subi %[[DIM_0]], %[[ARG4]]
-// CHECK-IMPERFECT: %[[SELECT_2:.*]] = arith.select %[[CMPI_2]], %[[SUBI_2]], %[[C3]]
+// CHECK-IMPERFECT: %[[SELECT_2:.*]] = arith.minsi %[[C3]], %[[SUBI_2]]
// CHECK-IMPERFECT: %[[INNER_TILE:.*]] = gml_st.tile %[[INNER_SPACE]] [%[[ARG3]], %[[ARG4]]] [%[[SELECT_1]], %[[SELECT_2]]] [1, 1]
// CHECK-IMPERFECT: %[[INNER_MED_ARG:.*]] = gml_st.materialize %[[MED_ARG]][%[[INNER_TILE]]]
// CHECK-IMPERFECT: gml_st.set_yield %[[INNER_MED_ARG]] into %[[MED_INIT]][%[[INNER_TILE]]]
diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling.mlir
index 38755f6..8fcc0a2 100644
--- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling.mlir
+++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling.mlir
@@ -57,14 +57,10 @@
// CHECK-TILE-SAME: (%[[I:.*]], %[[J:.*]]) = (%[[LB_C0]], %[[LB_C0]])
// CHECK-TILE-SAME: to (%[[GENERIC_D0]], %[[GENERIC_D1]])
// CHECK-TILE-SAME: step (%[[C256]], %[[C512]])
-// CHECK-TILE: %[[I_PLUS_256:.*]] = arith.addi %[[I]], %[[C256]]
-// CHECK-TILE: %[[IS_PARTIAL_D0:.*]] = arith.cmpi sgt, %[[I_PLUS_256]], %[[GENERIC_D0]]
// CHECK-TILE: %[[REMAINDER_D0:.*]] = arith.subi %[[GENERIC_D0]], %[[I]]
-// CHECK-TILE: %[[TILE_SIZE_D0:.*]] = arith.select %[[IS_PARTIAL_D0]], %[[REMAINDER_D0]], %[[C256]]
-// CHECK-TILE: %[[I_PLUS_256_0:.*]] = arith.addi %[[J]], %[[C512]]
-// CHECK-TILE: %[[IS_PARTIAL_D1:.*]] = arith.cmpi sgt, %[[I_PLUS_256_0]], %[[GENERIC_D1]]
+// CHECK-TILE: %[[TILE_SIZE_D0:.*]] = arith.minsi %[[C256]], %[[REMAINDER_D0]]
// CHECK-TILE: %[[REMAINDER_D1:.*]] = arith.subi %[[GENERIC_D1]], %[[J]]
-// CHECK-TILE: %[[TILE_SIZE_D1:.*]] = arith.select %[[IS_PARTIAL_D1]], %[[REMAINDER_D1]], %[[C512]]
+// CHECK-TILE: %[[TILE_SIZE_D1:.*]] = arith.minsi %[[C512]], %[[REMAINDER_D1]]
// CHECK-TILE: %[[TILE:.*]] = gml_st.tile %[[SPACE]] [%[[I]], %[[J]]] [%[[TILE_SIZE_D0]], %[[TILE_SIZE_D1]]] [1, 1]
// CHECK-TILE: %[[INNER_GENERIC:.*]] = gml_st.materialize %[[GENERIC]][%[[TILE]]]
// CHECK-TILE: gml_st.set_yield %[[INNER_GENERIC]] into %[[INIT]][%[[TILE]]]
diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_and_fusion.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_and_fusion.mlir
index 776fe31..c9ab1f9 100644
--- a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_and_fusion.mlir
+++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tiling_and_fusion.mlir
@@ -79,7 +79,7 @@
// CHECK-TILE-SAME: %[[ARG0:.*]]:
// CHECK-TILE: %[[INIT:.*]] = linalg.init_tensor
// CHECK-TILE: gml_st.parallel (%[[I:.*]], %[[J:.*]]) =
-// CHECK-TILE: %[[D0_SIZE:.*]] = arith.select {{.*}}, %c256
+// CHECK-TILE: %[[D0_SIZE:.*]] = arith.minsi {{.*}}, %c256
// CHECK-TILE: %[[OUTPUT_TILE:.*]] = gml_st.tile %{{.*}} [%[[I]], %[[J]]]
// CHECK-TILE: %[[INPUT_SPACE:.*]] = gml_st.space [%[[D0_SIZE]]]
// CHECK-TILE: %[[INPUT_TILE:.*]] = gml_st.tile %[[INPUT_SPACE]] [%[[I]]]