[mhlo:linalg] Lower base dilation on mhlo.reduce_window to interior padding

Also simplify code a bit by lowering through mhlo.pad

PiperOrigin-RevId: 450655046
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
index 1935e3f..7cb4c95 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc
@@ -2266,10 +2266,8 @@
     }
 
     llvm::SmallVector<int64_t> base_dilations;
-    if (op.window_dilations()) {
+    if (op.base_dilations()) {
       base_dilations = Extract1DVector(*op.base_dilations());
-      if (llvm::any_of(base_dilations, [](int64_t& x) { return x != 1; }))
-        return failure();
     }
 
     llvm::SmallVector<int64_t> window_strides(window_dimensions.size(), 1);
@@ -2282,14 +2280,14 @@
       window_dilations = Extract1DVector(*op.window_dilations());
     }
 
-    auto rank = window_dimensions.size();
+    auto rank = static_cast<int64_t>(window_dimensions.size());
     SmallVector<AffineExpr, 2> src_exprs;
     SmallVector<AffineExpr, 2> window_exprs;
     SmallVector<AffineExpr, 2> dst_exprs;
     SmallVector<int64_t> filtered_window_dims;
 
     int window_dim = 0;
-    for (int i = 0; i < rank; i++) {
+    for (int64_t i = 0; i < rank; i++) {
       AffineExpr src_expr = mlir::getAffineDimExpr(i, ctx);
 
       if (window_strides[i] != 1) src_expr = src_expr * window_strides[i];
@@ -2334,37 +2332,30 @@
     llvm::SmallVector<Value> inputs = llvm::to_vector(adaptor.operands());
 
     // Pad as necessary.
-    if (llvm::any_of(padding, [](int32_t v) { return v != 0; })) {
+    if (llvm::any_of(padding, [](int64_t v) { return v != 0; }) ||
+        llvm::any_of(base_dilations, [](int64_t v) { return v != 1; })) {
       llvm::SmallVector<int64_t> static_lows;
       llvm::SmallVector<int64_t> static_highs;
       for (int i = 0; i < padding.size(); i += 2) {
         static_lows.push_back(padding[i]);
         static_highs.push_back(padding[i + 1]);
       }
+      // Translate base dilation into interior padding.
+      auto static_interiors = llvm::to_vector(llvm::map_range(
+          base_dilations, [](int64_t dilation) { return dilation - 1; }));
+
+      auto pad_attr_type =
+          RankedTensorType::get({rank}, rewriter.getIndexType());
+      auto pad_lows = DenseIntElementsAttr::get(pad_attr_type, static_lows);
+      auto pad_highs = DenseIntElementsAttr::get(pad_attr_type, static_highs);
+      auto pad_interiors =
+          DenseIntElementsAttr::get(pad_attr_type, static_interiors);
+
       for (auto values : llvm::zip(inputs, init_values)) {
         auto& input = std::get<0>(values);
         auto& init_value = std::get<1>(values);
-
-        // Extract the single element from init value. This mimic the lowering
-        // behavior of mhlo.pad.
-        Value padding_value =
-            rewriter.createOrFold<tensor::ExtractOp>(loc, init_value);
-
-        auto pad_op = rewriter.create<tensor::PadOp>(
-            loc, input, static_lows, static_highs, ValueRange{}, ValueRange{});
-
-        SmallVector<Type, 4> block_arg_types;
-        block_arg_types.assign(input.getType().cast<ShapedType>().getRank(),
-                               rewriter.getIndexType());
-        auto& region = pad_op.region();
-
-        OpBuilder::InsertionGuard guard(rewriter);
-        rewriter.createBlock(
-            &region, region.end(), block_arg_types,
-            SmallVector<Location>(block_arg_types.size(), loc));
-        rewriter.create<tensor::YieldOp>(loc, padding_value);
-
-        input = pad_op.getResult();
+        input = rewriter.create<mhlo::PadOp>(loc, input, init_value, pad_lows,
+                                             pad_highs, pad_interiors);
       }
     }
 
diff --git a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
index ca997bb..fd8ea3e 100644
--- a/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
+++ b/tensorflow/compiler/mlir/hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir
@@ -3608,6 +3608,26 @@
 
 // -----
 
+// CHECK-LABEL: func @reduce_window_generic_base_dilation
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]]
+
+// CHECK: %[[PADVAL:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [8, 9] : tensor<8x9xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[PADVAL]] : f32) outs(%[[INIT]] : tensor<8x9xf32>) -> tensor<8x9xf32>
+// CHECK: %[[PAD:.+]] = tensor.insert_slice %[[ARG0]] into %[[FILL]][0, 1] [3, 6] [2, 1] : tensor<3x6xf32> into tensor<8x9xf32>
+
+func.func @reduce_window_generic_base_dilation(%arg0: tensor<3x6xf32>, %arg1: tensor<f32>) -> tensor<4x7xf32> {
+  %0 = "mhlo.reduce_window"(%arg0, %arg1) ({
+  ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
+    %1 = mhlo.add %arg2, %arg3 : tensor<f32>
+    "mhlo.return"(%1) : (tensor<f32>) -> ()
+  }) {base_dilations = dense<[2, 1]> : tensor<2xi64>, padding = dense<[[0, 3], [1, 2]]> : tensor<2x2xi64>, window_dilations = dense<[1, 2]> : tensor<2xi64>, window_dimensions = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (tensor<3x6xf32>, tensor<f32>) -> tensor<4x7xf32>
+  func.return %0 : tensor<4x7xf32>
+}
+
+// -----
+
 func.func @gather(%operand : tensor<1x4x8xi32>, %start_indices : tensor<1x8x2xi32>) -> tensor<1x8x8xi32> {
   %res = "mhlo.gather"(%operand, %start_indices) {
     dimension_numbers = #mhlo.gather<