[mhlo:linalg] Lower base dilation on mhlo.reduce_window to interior padding
Also simplify code a bit by lowering through mhlo.pad
PiperOrigin-RevId: 450655046
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index 1935e3f..7cb4c95 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -2266,10 +2266,8 @@
}
llvm::SmallVector<int64_t> base_dilations;
- if (op.window_dilations()) {
+ if (op.base_dilations()) {
base_dilations = Extract1DVector(*op.base_dilations());
- if (llvm::any_of(base_dilations, [](int64_t& x) { return x != 1; }))
- return failure();
}
llvm::SmallVector<int64_t> window_strides(window_dimensions.size(), 1);
@@ -2282,14 +2280,14 @@
window_dilations = Extract1DVector(*op.window_dilations());
}
- auto rank = window_dimensions.size();
+ auto rank = static_cast<int64_t>(window_dimensions.size());
SmallVector<AffineExpr, 2> src_exprs;
SmallVector<AffineExpr, 2> window_exprs;
SmallVector<AffineExpr, 2> dst_exprs;
SmallVector<int64_t> filtered_window_dims;
int window_dim = 0;
- for (int i = 0; i < rank; i++) {
+ for (int64_t i = 0; i < rank; i++) {
AffineExpr src_expr = mlir::getAffineDimExpr(i, ctx);
if (window_strides[i] != 1) src_expr = src_expr * window_strides[i];
@@ -2334,37 +2332,30 @@
llvm::SmallVector<Value> inputs = llvm::to_vector(adaptor.operands());
// Pad as necessary.
- if (llvm::any_of(padding, [](int32_t v) { return v != 0; })) {
+ if (llvm::any_of(padding, [](int64_t v) { return v != 0; }) ||
+ llvm::any_of(base_dilations, [](int64_t v) { return v != 1; })) {
llvm::SmallVector<int64_t> static_lows;
llvm::SmallVector<int64_t> static_highs;
for (int i = 0; i < padding.size(); i += 2) {
static_lows.push_back(padding[i]);
static_highs.push_back(padding[i + 1]);
}
+ // Translate base dilation into interior padding.
+ auto static_interiors = llvm::to_vector(llvm::map_range(
+ base_dilations, [](int64_t dilation) { return dilation - 1; }));
+
+ auto pad_attr_type =
+ RankedTensorType::get({rank}, rewriter.getIndexType());
+ auto pad_lows = DenseIntElementsAttr::get(pad_attr_type, static_lows);
+ auto pad_highs = DenseIntElementsAttr::get(pad_attr_type, static_highs);
+ auto pad_interiors =
+ DenseIntElementsAttr::get(pad_attr_type, static_interiors);
+
for (auto values : llvm::zip(inputs, init_values)) {
auto& input = std::get<0>(values);
auto& init_value = std::get<1>(values);
-
- // Extract the single element from init value. This mimic the lowering
- // behavior of mhlo.pad.
- Value padding_value =
- rewriter.createOrFold<tensor::ExtractOp>(loc, init_value);
-
- auto pad_op = rewriter.create<tensor::PadOp>(
- loc, input, static_lows, static_highs, ValueRange{}, ValueRange{});
-
- SmallVector<Type, 4> block_arg_types;
- block_arg_types.assign(input.getType().cast<ShapedType>().getRank(),
- rewriter.getIndexType());
- auto& region = pad_op.region();
-
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.createBlock(
- ®ion, region.end(), block_arg_types,
- SmallVector<Location>(block_arg_types.size(), loc));
- rewriter.create<tensor::YieldOp>(loc, padding_value);
-
- input = pad_op.getResult();
+ input = rewriter.create<mhlo::PadOp>(loc, input, init_value, pad_lows,
+ pad_highs, pad_interiors);
}
}
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
index ca997bb..fd8ea3e 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
@@ -3608,6 +3608,26 @@
// -----
+// CHECK-LABEL: func @reduce_window_generic_base_dilation
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
+
+// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [8, 9] : tensor<8x9xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PADVAL]] : f32) outs(%[[INIT]] : tensor<8x9xf32>) -> tensor<8x9xf32>
+// CHECK: %[[PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 1] [3, 6] [2, 1] : tensor<3x6xf32> into tensor<8x9xf32>
+
+func.func @reduce_window_generic_base_dilation(%arg0: tensor<3x6xf32>, %arg1: tensor<f32>) -> tensor<4x7xf32> {
+ %0 = "mhlo.reduce_window"(%arg0, %arg1) ({
+ ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+ %1 = mhlo.add %arg2, %arg3 : tensor<f32>
+ "mhlo.return"(%1) : (tensor<f32>) -> ()
+ }) {base_dilations = dense<[2, 1]> : tensor<2xi64>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<4x7xf32>
+ func.return %0 : tensor<4x7xf32>
+}
+
+// -----
+
func.func @gather(%operand : tensor<1x4x8xi32>, %start_indices : tensor<1x8x2xi32>) -> tensor<1x8x8xi32> {
%res = "mhlo.gather"(%operand, %start_indices) {
dimension_numbers = #mhlo.gather<