[GmlSt] Use tiling interface to tile Linalg to GmlSt ForOp.

PiperOrigin-RevId: 467224197
diff --git a/tensorflow/compiler/xla/mlir_hlo/BUILD b/tensorflow/compiler/xla/mlir_hlo/BUILD
index 42ac83f..4100bc9 100644
--- a/tensorflow/compiler/xla/mlir_hlo/BUILD
+++ b/tensorflow/compiler/xla/mlir_hlo/BUILD
@@ -1647,6 +1647,7 @@
         ":gml_st_fusion",
         ":gml_st_passes_inc_gen",
         ":gml_st_tiling",
+        ":gml_st_tiling_using_interface",
         ":gml_st_to_scf",
         ":group_reduction_dimensions",
         ":hlo_legalize_shape_ops_to_standard",
@@ -2579,6 +2580,31 @@
     ],
 )
 
+cc_library(
+    name = "tiling_interface_impl",
+    srcs = [
+        "lib/Dialect/gml_st/transforms/tiling_interface_impl.cc",
+    ],
+    hdrs = [
+        "include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":gml_st",
+        ":tiling_interface",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:AffineDialect",
+        "@llvm-project//mlir:ArithmeticDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgDialect",
+        "@llvm-project//mlir:LinalgTransforms",
+        "@llvm-project//mlir:LinalgUtils",
+        "@llvm-project//mlir:Support",
+        "@llvm-project//mlir:TensorDialect",
+        "@llvm-project//mlir:TensorUtils",
+    ],
+)
+
 gentbl_cc_library(
     name = "fusion_interface_inc_gen",
     compatible_with = get_compatible_with_cloud(),
@@ -2749,6 +2775,9 @@
     deps = [
         ":gml_st",
         ":gml_st_passes_inc_gen",
+        ":gml_st_tiling_using_interface",
+        ":gml_st_transforms",
+        ":tiling_interface_impl",
         "@llvm-project//llvm:Support",
         "@llvm-project//mlir:ArithmeticDialect",
         "@llvm-project//mlir:FuncDialect",
@@ -2761,6 +2790,31 @@
 )
 
 cc_library(
+    name = "gml_st_tiling_using_interface",
+    srcs = [
+        "include/mlir-hlo/Dialect/gml_st/transforms/pass_detail.h",
+        "lib/Dialect/gml_st/transforms/tiling_using_interface.cc",
+    ],
+    hdrs = [
+        "include/mlir-hlo/Dialect/gml_st/transforms/passes.h",
+        "include/mlir-hlo/Dialect/gml_st/transforms/tiling_using_interface.h",
+    ],
+    includes = ["include"],
+    deps = [
+        ":gml_st",
+        ":gml_st_passes_inc_gen",
+        ":gml_st_transforms",
+        ":tiling_interface",
+        "@llvm-project//mlir:AffineDialect",
+        "@llvm-project//mlir:ArithmeticDialect",
+        "@llvm-project//mlir:ArithmeticUtils",
+        "@llvm-project//mlir:FuncDialect",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Pass",
+    ],
+)
+
+cc_library(
     name = "gml_st_compose_set_ops",
     srcs = [
         "include/mlir-hlo/Dialect/gml_st/transforms/pass_detail.h",
diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h
index 42fa9d4..f9732cc 100644
--- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h
+++ b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h
@@ -31,6 +31,10 @@
 std::unique_ptr<OperationPass<func::FuncOp>> createTilingPass(
     ArrayRef<int64_t> tileSizes = {});
 
+/// Pass to tile ops using TilingInterface and gml_st::ForOp.
+std::unique_ptr<OperationPass<func::FuncOp>> createTileToForPass(
+    ArrayRef<int64_t> tileSizes = {});
+
 /// Pass to compose set operations.
 std::unique_ptr<OperationPass<func::FuncOp>> createComposeSetOpsPass();
 
diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td
index cfacaf6..6e9eb41 100644
--- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td
+++ b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.td
@@ -29,6 +29,15 @@
   ];
 }
 
+def TileToForPass : Pass<"gml-tile-to-for", "mlir::func::FuncOp"> {
+  let summary = "Tile operations using TilingInterface to produce gml_st.for";
+  let constructor = "::mlir::gml_st::createTileToForPass()";
+  let options = [
+    ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes",
+               "llvm::cl::ZeroOrMore">,
+  ];
+}
+
 def ComposeSetOpsPass : Pass<"gml-compose-set-ops", "mlir::func::FuncOp"> {
   let summary = "Compose set operations.";
   let constructor = "::mlir::gml_st::createComposeSetOpsPass()";
diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td
index 4c4faa2..d664e24 100644
--- a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td
+++ b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface.td
@@ -89,7 +89,7 @@
           The method returns the operation that is the tiled
           implementation.
         }],
-        /*retType=*/"SmallVector<Operation *>",
+        /*retType=*/"TilingInterface",
         /*methodName=*/"getTiledImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -104,28 +104,6 @@
       >,
       InterfaceMethod<
         /*desc=*/[{
-          Method to return the position of the result tile computed by the tiled operation.
-
-          Specifies what tile of the result of the original tensor is computed
-          by the tiled implementation. Expects the same `offsets` and `sizes` as
-          used to obtain the tiled implementation of the operation.
-        }],
-        /*retType=*/"LogicalResult",
-        /*methodName=*/"getResultTilePosition",
-        /*args=*/(ins
-          "OpBuilder &":$b,
-          "unsigned":$resultNumber,
-          "ArrayRef<OpFoldResult> ":$offsets,
-          "ArrayRef<OpFoldResult> ":$sizes,
-          "SmallVector<OpFoldResult> &":$resultOffsets,
-          "SmallVector<OpFoldResult> &":$resultSizes),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
-          return failure();
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
           Method to generate the code that produces a tile of the result.
 
           Generates the IR that computes the tile of a result of the
@@ -167,29 +145,7 @@
         /*defaultImplementation=*/[{
           return failure();
         }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
-          Generates the scalar implementation of the operation. 
-
-          Given the list `ivs` that represent points in the iteration space
-          (as specified by `getIterationDomain()`) returns the scalar operations
-          that represent the computation at that point in the iteration space.
-          This method is typically used as the "exit path", i.e. once all
-          transformations are done, this method can be used to lower to scalar 
-          code that can then be lowered to LLVM or SPIR-V dialects.
-        }],
-        /*retType=*/"LogicalResult",
-        /*methodName=*/"generateScalarImplementation",
-        /*args=*/(ins
-            "OpBuilder &":$b,
-            "Location ":$loc,
-            "ValueRange ":$ivs),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
-          return failure();
-        }]
       >
-  ];  
+  ];
 }
 #endif // GML_ST_TILING_INTERFACE
diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h
new file mode 100644
index 0000000..f99cbf2
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h
@@ -0,0 +1,30 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_INTERFACE_IMPL_H
+#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_INTERFACE_IMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace gml_st {
+
+void registerGmlStTilingInterfaceExternalModels(DialectRegistry &registry);
+
+}  // namespace gml_st
+}  // namespace mlir
+
+#endif  // MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_INTERFACE_IMPL_H
diff --git a/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_using_interface.h b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_using_interface.h
new file mode 100644
index 0000000..c7270f7
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/tiling_using_interface.h
@@ -0,0 +1,71 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_USING_INTERFACE_H
+#define MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_USING_INTERFACE_H
+
+#include <functional>
+
+#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
+#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h"
+#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+
+namespace mlir {
+namespace gml_st {
+
+/// Options to use to control tiling.
+struct GmlStTilingOptions {
+  using TileSizeComputationFunction =
+      std::function<SmallVector<Value>(OpBuilder &, Operation *)>;
+
+  /// Computation function that returns the tile sizes for each operation.
+  TileSizeComputationFunction tileSizeComputationFunction = nullptr;
+};
+
+struct GmlStTilingResult {
+  TilingInterface tiledOp;
+  gml_st::ForOp loop;
+};
+
+/// Pattern to tile an op that implements the `TilingInterface` using
+/// `gml_st.for` for iterating over the tiles.
+struct TileToGmlStLoops : public OpInterfaceRewritePattern<TilingInterface> {
+  TileToGmlStLoops(MLIRContext *context, GmlStTilingOptions options,
+                   PatternBenefit benefit = 1);
+
+  FailureOr<GmlStTilingResult> returningMatchAndRewrite(
+      TilingInterface op, PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(TilingInterface op,
+                                PatternRewriter &rewriter) const override {
+    if (hasTransformationAttr(op)) return failure();
+
+    auto tilingResult = returningMatchAndRewrite(op, rewriter);
+    if (failed(tilingResult)) return failure();
+
+    setTransformationAttr(rewriter, tilingResult->tiledOp);
+    return success();
+  }
+
+ private:
+  GmlStTilingOptions options;
+};
+
+}  // namespace gml_st
+}  // namespace mlir
+
+#endif  // MLIR_HLO_DIALECT_GML_ST_TRANSFORMS_TILING_USING_INTERFACE_H
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt
index b754d4f..92af8c8 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/CMakeLists.txt
@@ -43,6 +43,22 @@
   MLIRGmlStFusionInterfaceIncGen
 )
 
+add_mlir_library(GmlStTilingInterfaceImpl
+  tiling_interface_impl.cc
+
+  LINK_LIBS PUBLIC
+  GmlStTilingInterface
+  MLIRArithmeticDialect
+  MLIRAffineDialect
+  MLIRLinalgDialect
+  MLIRLinalgTransforms
+  MLIRLinalgUtils
+  MLIRTensorDialect
+  MLIRTensorUtils
+  MLIRIR
+  MLIRSupport
+)
+
 add_mlir_library(GmlStFusionInterfaceImpl
   fusion_interface_impl.cc
 
@@ -79,6 +95,7 @@
   fusion.cc
   gml_st_to_scf.cc
   tiling.cc
+  tiling_using_interface.cc
   vectorization.cc
 
   DEPENDS
@@ -91,6 +108,8 @@
   GmlStComposeSetInterface
   GmlStFusionInterface
   GmlStFusionInterfaceImpl
+  GmlStTilingInterface
+  GmlStTilingInterfaceImpl
   MLIRFuncDialect
   MLIRIR
   MLIRLinalgDialect
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling.cc
index 3ccf515..7ba03b5 100644
--- a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling.cc
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling.cc
@@ -16,13 +16,12 @@
 #include <memory>
 #include <utility>
 
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Casting.h"
 #include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
 #include "mlir-hlo/Dialect/gml_st/transforms/pass_detail.h"
 #include "mlir-hlo/Dialect/gml_st/transforms/passes.h"
+#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h"
+#include "mlir-hlo/Dialect/gml_st/transforms/tiling_using_interface.h"
+#include "mlir-hlo/Dialect/gml_st/transforms/transforms.h"
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -212,6 +211,43 @@
   }
 };
 
+struct TileToForPass : public TileToForPassBase<TileToForPass> {
+  TileToForPass() = default;
+  explicit TileToForPass(llvm::ArrayRef<int64_t> sizes) { tileSizes = sizes; }
+
+  void getDependentDialects(DialectRegistry &registry) const final {
+    registry.insert<GmlStDialect>();
+    registerGmlStTilingInterfaceExternalModels(registry);
+  }
+
+  void runOnOperation() override {
+    func::FuncOp f = getOperation();
+    MLIRContext *ctx = &getContext();
+
+    GmlStTilingOptions opts;
+    SmallVector<int64_t> ts(tileSizes.begin(), tileSizes.end());
+    opts.tileSizeComputationFunction = [ts](OpBuilder &b, Operation *op) {
+      OpBuilder::InsertionGuard guard(b);
+      b.setInsertionPointToStart(
+          &op->getParentOfType<func::FuncOp>().getBody().front());
+      return llvm::to_vector<4>(llvm::map_range(ts, [&](int64_t s) {
+        Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
+        return v;
+      }));
+    };
+
+    RewritePatternSet patterns(ctx);
+    patterns.add<TileToGmlStLoops>(ctx, opts);
+
+    if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) {
+      return signalPassFailure();
+    }
+
+    // Clean up by removing temporary attributes.
+    f.walk([](Operation *op) { removeTransformationAttr(op); });
+  }
+};
+
 }  // namespace
 
 std::unique_ptr<OperationPass<func::FuncOp>> createTilingPass(
@@ -219,5 +255,10 @@
   return std::make_unique<TilingPass>(tileSizes);
 }
 
+std::unique_ptr<OperationPass<func::FuncOp>> createTileToForPass(
+    ArrayRef<int64_t> tileSizes) {
+  return std::make_unique<TileToForPass>(tileSizes);
+}
+
 }  // namespace gml_st
 }  // namespace mlir
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_interface_impl.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_interface_impl.cc
new file mode 100644
index 0000000..6c8736e
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_interface_impl.cc
@@ -0,0 +1,187 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface_impl.h"
+
+#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
+#include "mlir-hlo/Dialect/gml_st/transforms/tiling_interface.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+
+namespace mlir {
+namespace gml_st {
+namespace {
+
+using linalg::LinalgOp;
+using linalg::SliceParameters;
+
+///////////////////////////////////////////////////////////////////////////////
+/// Linalg Tiling Interface.
+///////////////////////////////////////////////////////////////////////////////
+
+SmallVector<OpFoldResult> getMixedValues(Location loc, OpBuilder &b,
+                                         Value tensor) {
+  SmallVector<OpFoldResult> tensorDims;
+
+  auto tensorType = tensor.getType().cast<RankedTensorType>();
+  int64_t rank = tensorType.getRank();
+  for (auto i = 0; i < rank; ++i) {
+    tensorDims.push_back(
+        tensorType.isDynamicDim(i)
+            ? OpFoldResult{b.createOrFold<tensor::DimOp>(loc, tensor, i)}
+            : OpFoldResult{b.getI64IntegerAttr(tensorType.getDimSize(i))});
+  }
+  return tensorDims;
+}
+
+template <typename LinalgOpTy>
+struct LinalgOpTilingInterface
+    : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
+                                            LinalgOpTy> {
+  /// Return the destination operands.
+  SmallVector<Value> getDestinationOperands(Operation *op,
+                                            OpBuilder & /*b*/) const {
+    return cast<LinalgOp>(op).getOutputOperands();
+  }
+
+  /// Return the loop iterator type.
+  SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+    LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
+    return llvm::to_vector(
+        llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
+          return strAttr.cast<StringAttr>().getValue();
+        }));
+  }
+
+  /// Return the iteration domain range.
+  SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
+    OpBuilder::InsertionGuard g(b);
+    b.setInsertionPoint(op);
+    Location loc = op->getLoc();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+    SmallVector<OpFoldResult> allShapesSizes =
+        linalgOp.createFlatListOfOperandDims(b, loc);
+    AffineMap map = linalgOp.getShapesToLoopsMap();
+
+    IRRewriter rewriter(b);
+    return llvm::to_vector(
+        llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) {
+          OpFoldResult ofr = makeComposedFoldedAffineApply(
+              rewriter, loc, loopExpr, allShapesSizes);
+          return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)};
+        }));
+  }
+
+  // Instantiate the tiled implementation of the operation.
+  TilingInterface getTiledImplementation(Operation *op, OpBuilder &b,
+                                         ValueRange /*dest*/,
+                                         ArrayRef<OpFoldResult> offsets,
+                                         ArrayRef<OpFoldResult> sizes,
+                                         bool /*tileDestOperands*/) const {
+    Location loc = op->getLoc();
+    LinalgOp linalgOp = cast<LinalgOp>(op);
+    SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+    SmallVector<Optional<SliceParameters>> allSliceParams =
+        linalg::computeAllSliceParameters(b, loc, linalgOp, valuesToTile,
+                                          offsets, sizes, {}, true);
+
+    SmallVector<Value> tiledOperands;
+    for (auto item : llvm::zip(valuesToTile, allSliceParams)) {
+      Value valueToTile = std::get<0>(item);
+      const Optional<linalg::SliceParameters> &sliceParams = std::get<1>(item);
+
+      SmallVector<OpFoldResult> tensorDims =
+          getMixedValues(loc, b, valueToTile);
+      Value set = b.create<SpaceOp>(loc, tensorDims);
+      if (sliceParams.has_value()) {
+        set = b.create<TileOp>(loc, set, sliceParams->offsets,
+                               sliceParams->sizes, sliceParams->strides);
+      }
+      Value materializedTile = b.create<MaterializeOp>(loc, valueToTile, set);
+      tiledOperands.push_back(materializedTile);
+    }
+
+    SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
+        linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
+          return tiledOperands[opOperand->getOperandNumber()].getType();
+        }));
+
+    Operation *tiledOp =
+        linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
+    offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
+
+    return {tiledOp};
+  }
+
+  FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
+                                           unsigned resultNumber,
+                                           ValueRange dest,
+                                           ArrayRef<OpFoldResult> offsets,
+                                           ArrayRef<OpFoldResult> sizes,
+                                           bool tileDestOperands) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the output is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the result to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return op->emitOpError(
+          "unhandled tiled implementation generation when result is not "
+          "accessed using a permuted projection");
+    }
+
+    auto numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
+        iterationTileSizes(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &range : llvm::enumerate(iterationDomain)) {
+        iterationTileOffsets[range.index()] = range.value().offset;
+        iterationTileSizes[range.index()] = range.value().size;
+      }
+    }
+    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition =
+          resultExpr.value().cast<AffineDimExpr>().getPosition();
+      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
+      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
+    }
+
+    TilingInterface tiledOp = tilingInterfaceOp.getTiledImplementation(
+        b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands);
+
+    return tiledOp->getResult(resultNumber);
+  }
+};
+
+}  // namespace
+
+void registerGmlStTilingInterfaceExternalModels(DialectRegistry &registry) {
+  registry.addExtension(
+      +[](MLIRContext *ctx, linalg::LinalgDialect * /*dialect*/) {
+        linalg::GenericOp::attachInterface<
+            LinalgOpTilingInterface<linalg::GenericOp>>(*ctx);
+      });
+}
+
+}  // namespace gml_st
+}  // namespace mlir
diff --git a/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_using_interface.cc b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_using_interface.cc
new file mode 100644
index 0000000..ca10573
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir_hlo/lib/Dialect/gml_st/transforms/tiling_using_interface.cc
@@ -0,0 +1,159 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir-hlo/Dialect/gml_st/transforms/tiling_using_interface.h"
+
+#include <memory>
+#include <tuple>
+#include <utility>
+
+#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
+#include "mlir/IR/MLIRContext.h"
+
+namespace mlir {
+namespace gml_st {
+namespace {
+
+/// Generate an empty loop nest that represents the tiled loop nest shell.
+/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
+/// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
+/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
+/// the tile processed within the inner most loop.
+gml_st::ForOp generateTileLoopNest(OpBuilder &builder, Location loc,
+                                   ArrayRef<Range> loopRanges,
+                                   ArrayRef<Value> tileSizeVals,
+                                   ArrayRef<Value> dstOperands,
+                                   SmallVector<OpFoldResult> &offsets,
+                                   SmallVector<OpFoldResult> &sizes) {
+  assert(!loopRanges.empty() && "expected at least one loop range");
+  assert(loopRanges.size() == tileSizeVals.size() &&
+         "expected as many tile sizes as loop ranges");
+  OpBuilder::InsertionGuard guard(builder);
+
+  // The tile size to use (to avoid out of bounds access) is  minimum of
+  // `tileSize` and `ub - iv`, where `iv` is the induction variable
+  // of the tiled loop.
+  AffineExpr s0, s1, d0;
+  bindDims(builder.getContext(), d0);
+  bindSymbols(builder.getContext(), s0, s1);
+  AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, builder.getContext());
+
+  SmallVector<Value> lbs, ubs, steps;
+  SmallVector<unsigned> nonemptyRangeIndices;
+  for (auto &loopRange : llvm::enumerate(loopRanges)) {
+    Value offset =
+        getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
+    Value size =
+        getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
+    // No loops if tile size is zero. Set offset and size to the loop
+    // offset and size.
+    offsets.push_back(offset);
+    sizes.push_back(size);
+    if (matchPattern(tileSizeVals[loopRange.index()], m_Zero())) continue;
+    lbs.push_back(offset);
+    ubs.push_back(size);
+    steps.push_back(tileSizeVals[loopRange.index()]);
+    nonemptyRangeIndices.push_back(loopRange.index());
+  }
+
+  auto loop = builder.create<gml_st::ForOp>(
+      loc, TypeRange{ValueRange{dstOperands}}, lbs, ubs, steps, dstOperands,
+      [&](OpBuilder &nestedBuilder, Location bodyLoc, ValueRange ivs,
+          ValueRange /*inits*/) {
+        for (const auto &en : llvm::enumerate(ivs)) {
+          Value iv = en.value();
+          size_t index = en.index();
+          Value boundedTileSize = nestedBuilder.create<AffineMinOp>(
+              bodyLoc, minMap, ValueRange{iv, steps[index], ubs[index]});
+          sizes[nonemptyRangeIndices[index]] = boundedTileSize;
+          offsets[nonemptyRangeIndices[index]] = iv;
+        }
+      });
+  return loop;
+}
+
+}  // namespace
+
+TileToGmlStLoops::TileToGmlStLoops(MLIRContext *context,
+                                   GmlStTilingOptions options,
+                                   PatternBenefit benefit)
+    : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+      options(std::move(options)) {}
+
+FailureOr<GmlStTilingResult> TileToGmlStLoops::returningMatchAndRewrite(
+    TilingInterface op, PatternRewriter &rewriter) const {
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPointAfter(op);
+
+  if (!options.tileSizeComputationFunction) {
+    return rewriter.notifyMatchFailure(
+        op, "missing tile size computation function");
+  }
+
+  // 1. Get the range of the loops that are represented by the operation.
+  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
+  size_t numLoops = iterationDomain.size();
+  if (numLoops == 0)
+    return rewriter.notifyMatchFailure(op, "missing iteration domain");
+
+  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
+  // skips tiling a particular dimension. This convention is significantly
+  // simpler to handle instead of adjusting affine maps to account for missing
+  // dimensions.
+  SmallVector<Value> tileSizeVector =
+      options.tileSizeComputationFunction(rewriter, op);
+  if (tileSizeVector.size() < iterationDomain.size()) {
+    auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
+    tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
+  }
+
+  // 3. Materialize an empty loop nest that iterates over the tiles.
+  auto dstOperands = op.getDestinationOperands(rewriter);
+  SmallVector<OpFoldResult> offsets, sizes;
+  GmlStTilingResult tilingResult;
+  tilingResult.loop =
+      generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
+                           tileSizeVector, dstOperands, offsets, sizes);
+  Block *loopBody = tilingResult.loop.getBody();
+  Operation *terminator = loopBody->getTerminator();
+  rewriter.setInsertionPoint(terminator);
+
+  // 4. Insert the tiled implementation within the loop.
+  tilingResult.tiledOp =
+      op.getTiledImplementation(rewriter, dstOperands, offsets, sizes, true);
+
+  // 5. Add `gml_st.set_yield` terminator.
+  SmallVector<Value> dstSubsets;
+  for (Value dst : tilingResult.tiledOp.getDestinationOperands(rewriter))
+    dstSubsets.push_back(dst.getDefiningOp<MaterializeOp>().set());
+  rewriter.replaceOpWithNewOp<SetYieldOp>(
+      terminator, tilingResult.tiledOp->getResults(), dstOperands, dstSubsets);
+
+  // 6. Replace the uses of `outputs` with the output block arguments.
+  for (auto [dst, regionArg] :
+       llvm::zip(dstOperands, tilingResult.loop.getRegionOutputArgs())) {
+    dst.replaceUsesWithIf(regionArg, [&](OpOperand &operand) {
+      return operand.getOwner()->getBlock() == loopBody;
+    });
+  }
+  rewriter.replaceOp(op, tilingResult.loop.getResults());
+  return tilingResult;
+}
+
+}  // namespace gml_st
+}  // namespace mlir
diff --git a/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tile_linalg.mlir b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tile_linalg.mlir
new file mode 100644
index 0000000..93f4c1c
--- /dev/null
+++ b/tensorflow/compiler/xla/mlir_hlo/tests/Dialect/gml_st/tile_linalg.mlir
@@ -0,0 +1,122 @@
+// RUN: mlir-hlo-opt %s --split-input-file \
+// RUN: --gml-tile-to-for="tile-sizes=256,512" | FileCheck %s
+
+#id_map = affine_map<(d0, d1) -> (d0, d1)>
+
+func.func @add(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>)
+    -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %lhs, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %lhs, %c1 : tensor<?x?xf32>
+  %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+  %add = linalg.generic {
+      indexing_maps = [#id_map, #id_map, #id_map],
+      iterator_types = ["parallel", "parallel"]}
+      ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>)
+      outs(%init : tensor<?x?xf32>) {
+  ^bb0(%lhs_scalar: f32, %rhs_scalar: f32, %_: f32):
+    %add_scalar = arith.addf %lhs_scalar, %rhs_scalar : f32
+    linalg.yield %add_scalar : f32
+  } -> tensor<?x?xf32>
+  func.return %add : tensor<?x?xf32>
+}
+// CHECK-LABEL: func.func @add
+// CHECK-SAME:    (%[[LHS:.*]]: tensor<?x?xf32>, %[[RHS:.*]]: tensor<?x?xf32>)
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
+// CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
+
+// CHECK:     %[[INIT:.*]] = linalg.init_tensor
+// CHECK:     %[[DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
+// CHECK:     %[[DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
+
+// CHECK:     gml_st.for (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+// CHECK-SAME:    to (%[[DIM_0]], %[[DIM_1]])
+// CHECK-SAME:    step (%[[C256]], %[[C512]])
+// CHECK-SAME:    outs (%[[INIT_:.*]] = %[[INIT]]: tensor<?x?xf32>) {
+
+// CHECK:       %[[SIZE_0:.*]] = affine.min #map0(%[[I]])[%[[C256]], %[[DIM_0]]]
+// CHECK:       %[[SIZE_1:.*]] = affine.min #map1(%[[J]])[%[[C512]], %[[DIM_1]]]
+
+// CHECK:       %[[LHS_T:.*]] = gml_st.tile
+// CHECK-SAME:    [%[[I]], %[[J]]] [%[[SIZE_0]], %[[SIZE_1]]] [1, 1]
+// CHECK:       %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[LHS_T]]]
+
+// CHECK:       %[[RHS_T:.*]] = gml_st.tile
+// CHECK-SAME:    [%[[I]], %[[J]]] [%[[SIZE_0]], %[[SIZE_1]]] [1, 1]
+// CHECK:       %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[RHS_T]]]
+
+// CHECK:       %[[INIT_T:.*]] = gml_st.tile
+// CHECK-SAME:    [%[[I]], %[[J]]] [%[[SIZE_0]], %[[SIZE_1]]] [1, 1]
+// CHECK:       %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT_]][%[[INIT_T]]]
+
+// CHECK:       %[[SUM:.*]] = linalg.generic
+// CHECK-SAME:    ins(%[[LHS_SUB]], %[[RHS_SUB]]
+// CHECK-SAME:    outs(%[[INIT_SUB]] : tensor<?x?xf32>)
+// CHECK:       gml_st.set_yield %[[SUM:.*]] into %[[INIT_]][%[[INIT_T]]
+
+// -----
+
+func.func @reduce_row(%lhs: tensor<?x?xf32>,
+                      %rhs: tensor<?x?xf32>) -> tensor<?xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %c0 = arith.constant 0 : index
+  %0 = tensor.dim %lhs, %c0 : tensor<?x?xf32>
+
+  %init = linalg.init_tensor [%0] : tensor<?xf32>
+  %fill = linalg.fill ins(%cst : f32)
+                      outs(%init : tensor<?xf32>) -> tensor<?xf32>
+  %sum_of_prod = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0)>],
+    iterator_types = ["parallel", "reduction"]}
+    ins(%lhs, %rhs : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%fill : tensor<?xf32>) {
+  ^bb0(%l: f32, %r: f32, %o: f32):
+    %prod = arith.mulf %l, %r : f32
+    %add = arith.addf %prod, %o : f32
+    linalg.yield %add : f32
+  } -> tensor<?xf32>
+  func.return %sum_of_prod : tensor<?xf32>
+}
+// CHECK:   func.func @reduce_row(%[[LHS:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:                    %[[RHS:.*]]: tensor<?x?xf32>)
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
+// CHECK-DAG: %[[C512:.*]] = arith.constant 512 : index
+// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+
+// CHECK:     %[[FILL:.*]] = linalg.fill ins(%[[C0_F32]] : f32)
+// CHECK:     %[[DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]]
+// CHECK:     %[[DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]]
+
+// CHECK:     gml_st.for (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
+// CHECK-SAME:    to (%[[DIM_0]], %[[DIM_1]])
+// CHECK-SAME:    step (%[[C256]], %[[C512]])
+// CHECK-SAME:    outs (%[[INIT_:.*]] = %[[FILL]]: tensor<?xf32>) {
+
+// CHECK:      %[[SIZE_0:.*]] = affine.min {{.*}}(%[[I]])[%[[C256]], %[[DIM_0]]]
+// CHECK:      %[[SIZE_1:.*]] = affine.min {{.*}}(%[[J]])[%[[C512]], %[[DIM_1]]]
+
+// CHECK:      %[[LHS_T:.*]] = gml_st.tile
+// CHECK-SAME:   [%[[I]], %[[J]]] [%[[SIZE_0]], %[[SIZE_1]]] [1, 1]
+// CHECK:      %[[LHS_SUB:.*]] = gml_st.materialize %[[LHS]][%[[LHS_T]]]
+
+// CHECK:      %[[RHS_T:.*]] = gml_st.tile
+// CHECK-SAME:   [%[[I]], %[[J]]] [%[[SIZE_0]], %[[SIZE_1]]] [1, 1]
+// CHECK:      %[[RHS_SUB:.*]] = gml_st.materialize %[[RHS]][%[[RHS_T]]]
+
+// CHECK:      %[[INIT_T:.*]] = gml_st.tile %{{.*}}[%[[I]]] [%[[SIZE_0]]] [1]
+// CHECK:      %[[INIT_SUB:.*]] = gml_st.materialize %[[INIT_]][%[[INIT_T]]]
+
+// CHECK:      %[[REDUCE:.*]] = linalg.generic
+// CHECK-SAME:   ins(%[[LHS_SUB]], %[[RHS_SUB]]
+// CHECK-SAME:   outs(%[[INIT_SUB]] : tensor<?xf32>)
+// CHECK:      gml_st.set_yield %[[REDUCE:.*]] into %[[INIT_]][%[[INIT_T]]]
+