Add materialization to lmhlo-to-gpu pass.

Plus some cleanup changes and renames.

This prepares the pass pipeline to allow memref ops to persist past the initial lmhlo-to-gpu pass. Specifically, we want to defer the conversion of gpu ops (gpu.memcpy and gpu.memset) to the ConvertGpuToTfrtGpuPass.

PiperOrigin-RevId: 399863409
Change-Id: Ie27150e7a935bc8bebaf2b0518357438c10890a3
diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD
index 9a71c36..572f8ba 100644
--- a/tensorflow/compiler/mlir/tfrt/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/BUILD
@@ -481,7 +481,7 @@
         "//platforms/xla/tests/gpu:__pkg__",
     ],
     deps = [
-        "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:pass",
+        "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:passes",
         "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/core/platform:errors",
diff --git a/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc b/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc
index 8f7eab0..b04ab36 100644
--- a/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc
+++ b/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc
@@ -36,7 +36,7 @@
                   mlir::mhlo::MhloDialect, tfrt::gpu::GpuDialect,
                   tfrt::gpu::conversion::GpuConversionDialect>();
   tfrt::RegisterTFRTDialects(registry);
-  tensorflow::registerTFRTGPUPasses();
+  tensorflow::registerXlirPasses();
   tfrt::gpu::registerPasses();
   return failed(mlir::MlirOptMain(argc, argv, "MHLO TFRT pass driver\n",
                                   registry,
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/BUILD
index 5a4c25a..aa47f40 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/BUILD
@@ -3,20 +3,8 @@
 load("//tensorflow:tensorflow.bzl", "if_google")
 
 # TF to TFRT kernels conversion.
-package(licenses = ["notice"])
-
-package_group(
-    name = "friends",
-    packages = if_google([
-        "//learning/brain/experimental/mlir/tfrt_compiler/...",
-        "//learning/brain/experimental/tfrt/...",
-        "//platforms/xla/tests/gpu/...",
-        "//third_party/tf_runtime_google/...",
-    ]) + [
-        "//tensorflow/compiler/...",
-        "//tensorflow/core/runtime_fallback/...",
-        "//tensorflow/core/tfrt/saved_model/...",
-    ],
+package(
+    licenses = ["notice"],
 )
 
 gentbl_cc_library(
@@ -26,7 +14,7 @@
         (
             [
                 "-gen-pass-decls",
-                "-name=TFRTGPU",
+                "-name=Xlir",
             ],
             "gpu_passes.h.inc",
         ),
@@ -39,43 +27,31 @@
 )
 
 cc_library(
-    name = "pass_details",
-    hdrs = [
-        "PassDetail.h",
-    ],
-    deps = [
-        ":TfrtGpuPassIncGen",
-        "//tensorflow/compiler/xla/service/gpu:xlir_opdefs",
-        "@llvm-project//mlir:GPUDialect",
-        "@llvm-project//mlir:Pass",
-        "@tf_runtime//:basic_kernels_opdefs",
-        "@tf_runtime//backends/gpu:gpu_opdefs",
-    ],
-)
-
-cc_library(
-    name = "pass",
+    name = "passes",
     srcs = [
         "ccl_pattern.cc",
-        "ccl_pattern.h",
         "custom_call_pattern.cc",
-        "custom_call_pattern.h",
         "gemm_pattern.cc",
-        "gemm_pattern.h",
         "gpu_passes.cc",
         "memcpy_pattern.cc",
-        "memcpy_pattern.h",
         "memset_pattern.cc",
-        "memset_pattern.h",
     ],
-    hdrs = ["gpu_passes.h"],
+    hdrs = [
+        "gpu_passes.h",
+        "register_passes.h",
+    ],
     tags = [
         "gpu",
         "no_oss",
     ],
-    visibility = [":friends"],
+    visibility = if_google([
+        "//platforms/xla/tests/gpu:__pkg__",
+    ]) + [
+        "//tensorflow/compiler/mlir/tfrt:__pkg__",
+        "//tensorflow/compiler/xla/service/gpu:__pkg__",
+    ],
     deps = [
-        ":pass_details",
+        ":TfrtGpuPassIncGen",
         "//tensorflow/compiler/mlir/hlo:lhlo",
         "//tensorflow/compiler/mlir/hlo:lhlo_gpu",
         "//tensorflow/compiler/mlir/tensorflow",
@@ -92,6 +68,7 @@
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Rewrite",
         "@llvm-project//mlir:StandardOps",
+        "@llvm-project//mlir:StandardOpsTransforms",
         "@llvm-project//mlir:Support",
         "@llvm-project//mlir:Transforms",
         "@tf_runtime//:basic_kernels_opdefs",
@@ -101,18 +78,3 @@
     ],
     alwayslink = 1,
 )
-
-cc_library(
-    name = "passes",
-    hdrs = ["register_passes.h"],
-    tags = [
-        "gpu",
-        "no_oss",
-    ],
-    visibility = [":friends"],
-    deps = [
-        ":TfrtGpuPassIncGen",
-        ":pass",
-        "@llvm-project//mlir:Pass",
-    ],
-)
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/PassDetail.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/PassDetail.h
deleted file mode 100644
index 4e08993..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/PassDetail.h
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright 2020 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_PASSDETAIL_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_PASSDETAIL_H_
-
-#include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Pass/Pass.h"
-#include "tensorflow/compiler/xla/service/gpu/xlir_ops.h"
-#include "tfrt/gpu/kernels/gpu_ops.h"  // from @tf_runtime
-#include "tfrt/basic_kernels/opdefs/tfrt_base.h"  // from @tf_runtime
-
-namespace tensorflow {
-
-#define GEN_PASS_CLASSES
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h.inc"
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_PASSDETAIL_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.cc
index 5a2433b..3717d5e 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.cc
@@ -18,8 +18,6 @@
 // Pattern to lower lmhlo collective ops to tfrt_gpu/xlir dialect.
 //
 //===----------------------------------------------------------------------===//
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h"
-
 #include <functional>
 #include <string>
 
@@ -311,13 +309,14 @@
 
 }  // namespace
 
-void populateCclConversionPattern(RewritePatternSet& patterns) {
+void populateCclConversionPattern(RewritePatternSet& patterns,
+                                  TypeConverter& converter) {
   patterns.add<CclRewritePattern<lmhlo::AllGatherOp>,
                CclRewritePattern<lmhlo::AllReduceOp>,
                CclRewritePattern<lmhlo::ReduceScatterOp>,
                CclRewritePattern<lmhlo::AllToAllOp>,
                CclRewritePattern<lmhlo::CollectivePermuteOp>>(
-      patterns.getContext());
+      converter, patterns.getContext());
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h
deleted file mode 100644
index 1185511..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2021 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_CUDA_CCL_PATTERN_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_CUDA_CCL_PATTERN_H_
-
-#include "mlir/IR/PatternMatch.h"
-
-namespace tensorflow {
-
-void populateCclConversionPattern(mlir::RewritePatternSet& patterns);
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_CUDA_CCL_PATTERN_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.cc
index e1c2a3d..e9e01c9 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.cc
@@ -18,8 +18,6 @@
 // Pattern to lower lmhlo.custom_call op to tfrt_gpu/xlir dialect.
 //
 //===----------------------------------------------------------------------===//
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.h"
-
 #include <functional>
 #include <string>
 
@@ -85,8 +83,9 @@
 
 }  // namespace
 
-void populateCustomCallConversionPattern(RewritePatternSet& patterns) {
-  patterns.add<CustomCallRewritePattern>(patterns.getContext());
+void populateCustomCallConversionPattern(RewritePatternSet& patterns,
+                                         TypeConverter& converter) {
+  patterns.add<CustomCallRewritePattern>(converter, patterns.getContext());
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.h
deleted file mode 100644
index 735fd24..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.h
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2021 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_CUSTOM_CALL_PATTERN_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_CUSTOM_CALL_PATTERN_H_
-
-#include "mlir/IR/PatternMatch.h"
-
-namespace tensorflow {
-
-void populateCustomCallConversionPattern(mlir::RewritePatternSet& patterns);
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_CUSTOM_CALL_PATTERN_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.cc
index 19c64f8..ef34270 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.cc
@@ -18,8 +18,6 @@
 // Pattern to lower lhlogpu_gemm Ops to tfrt cuda dialect.
 //
 //===----------------------------------------------------------------------===//
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h"
-
 #include <assert.h>
 #include <stdint.h>
 
@@ -345,10 +343,11 @@
 
 }  // namespace
 
-void populateGemmConversionPattern(RewritePatternSet& patterns) {
+void populateGemmConversionPattern(RewritePatternSet& patterns,
+                                   TypeConverter& converter) {
   patterns.add<GemmRewritePattern<lmhlo_gpu::GEMMOp>,
                GemmRewritePattern<lmhlo_gpu::GEMM_BiasOp>>(
-      patterns.getContext());
+      converter, patterns.getContext());
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h
deleted file mode 100644
index 07451af..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2020 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_GEMM_PATTERN_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_GEMM_PATTERN_H_
-
-#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/StringRef.h"
-
-namespace tensorflow {
-
-void populateGemmConversionPattern(mlir::RewritePatternSet& patterns);
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_GEMM_PATTERN_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.cc
index 22498f3..02e551f 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.cc
@@ -21,85 +21,105 @@
 #include <memory>
 #include <utility>
 
+#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Rewrite/PatternApplicator.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/PassDetail.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 #include "tensorflow/compiler/xla/service/gpu/xlir_ops.h"
 #include "tfrt/gpu/kernels/gpu_ops.h"  // from @tf_runtime
 #include "tfrt/gpu/passes/passes.h"  // from @tf_runtime
 #include "tfrt/basic_kernels/opdefs/basic_kernels.h"  // from @tf_runtime
 
 namespace tensorflow {
+
+void populateCclConversionPattern(RewritePatternSet&, TypeConverter&);
+void populateCustomCallConversionPattern(RewritePatternSet&, TypeConverter&);
+void populateGemmConversionPattern(RewritePatternSet&, TypeConverter&);
+void populateMemcpyConversionPattern(RewritePatternSet&, TypeConverter&);
+void populateMemsetConversionPattern(RewritePatternSet&, TypeConverter&);
+
+#define GEN_PASS_CLASSES
+#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h.inc"
+
 namespace {
 
-struct LmhloGpuAsyncConversionPass
-    : public LmhloGpuAsyncConversionPassBase<LmhloGpuAsyncConversionPass> {
+struct ConvertLmhloToGpuPass
+    : public ConvertLmhloToGpuPassBase<ConvertLmhloToGpuPass> {
  private:
-  void runOnFunction() override {
-    auto* context = &getContext();
-
-    TypeConverter converter;
-    converter.addConversion([](Type type) { return type; });
-    auto buffer_type = tfrt::gpu::BufferType::get(context);
-    converter.addConversion([&](BaseMemRefType) { return buffer_type; });
-
-    ConversionTarget target(*context);
-    target
-        .addIllegalDialect<lmhlo_gpu::LmhloGpuDialect, mlir::gpu::GPUDialect>();
-    target.addLegalDialect<tfrt::compiler::TFRTDialect, tfrt::gpu::GpuDialect,
-                           xla::gpu::XlirDialect>();
-    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
-      return converter.isSignatureLegal(op.getType()) &&
-             converter.isLegal(&op.body());
-    });
-    target.addDynamicallyLegalOp<tfrt::gpu::conversion::AsyncExecuteOp>(
-        [&](tfrt::gpu::conversion::AsyncExecuteOp op) {
-          return converter.isLegal(&op.body());
-        });
-
-    RewritePatternSet patterns(context);
-    populateCclConversionPattern(patterns);
-    populateCustomCallConversionPattern(patterns);
-    populateGemmConversionPattern(patterns);
-    populateMemcpyConversionPattern(patterns);
-    populateMemsetConversionPattern(patterns);
-    populateFuncOpTypeConversionPattern(patterns, converter);
-
-    ConversionTarget wrap_target(*context);
-    wrap_target
-        .addLegalDialect<lmhlo_gpu::LmhloGpuDialect, mlir::gpu::GPUDialect>();
-    wrap_target.addLegalOp<lmhlo::AllGatherOp, lmhlo::AllReduceOp,
-                           lmhlo::ReduceScatterOp, lmhlo::AllToAllOp,
-                           lmhlo::CollectivePermuteOp, lmhlo::CustomCallOp>();
-    tfrt::gpu::populateGpuAsyncConversionPatterns(patterns, converter,
-                                                  wrap_target);
-
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
-      return signalPassFailure();
-  }
+  void runOnFunction() override;
 };
 
 }  // namespace
 
-std::unique_ptr<FunctionPass> createLmhloGpuAsyncConversionPass() {
-  return std::make_unique<LmhloGpuAsyncConversionPass>();
+static Value MaterializeCast(OpBuilder& builder, Type type, ValueRange values,
+                             Location loc) {
+  auto cast_op = builder.create<UnrealizedConversionCastOp>(loc, type, values);
+  return cast_op.getResult(0);
+}
+
+void ConvertLmhloToGpuPass::runOnFunction() {
+  auto* context = &getContext();
+
+  TypeConverter converter;
+  converter.addConversion([](Type type) { return type; });
+  auto buffer_type = tfrt::gpu::BufferType::get(context);
+  converter.addConversion([&](BaseMemRefType) { return buffer_type; });
+  converter.addArgumentMaterialization(MaterializeCast);
+  converter.addSourceMaterialization(MaterializeCast);
+  converter.addTargetMaterialization(MaterializeCast);
+
+  RewritePatternSet patterns(context);
+  populateCclConversionPattern(patterns, converter);
+  populateCustomCallConversionPattern(patterns, converter);
+  populateGemmConversionPattern(patterns, converter);
+  populateMemcpyConversionPattern(patterns, converter);
+  populateMemsetConversionPattern(patterns, converter);
+  populateFuncOpTypeConversionPattern(patterns, converter);
+  populateReturnOpTypeConversionPattern(patterns, converter);
+
+  // Set of ops that need to be wrapped in tfrt_gpu_conversion.async.execute
+  // before lowering directly to tfrt_gpu ops (and therefore require some chain
+  // and stream, which the wrapper op provides as block arguments). On the other
+  // hand, ops which lower to the gpu dialect do not need to be wrapped.
+  ConversionTarget wrap_target(*context);
+  wrap_target
+      .addLegalDialect<lmhlo_gpu::LmhloGpuDialect, mlir::gpu::GPUDialect>();
+  wrap_target.addLegalOp<lmhlo::AllGatherOp, lmhlo::AllReduceOp,
+                         lmhlo::ReduceScatterOp, lmhlo::AllToAllOp,
+                         lmhlo::CollectivePermuteOp, lmhlo::CustomCallOp>();
+  tfrt::gpu::populateGpuAsyncConversionPatterns(patterns, converter,
+                                                wrap_target);
+
+  ConversionTarget target(*context);
+  target.addIllegalOp<memref::ReinterpretCastOp, memref::ViewOp>();
+  target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+    return converter.isSignatureLegal(op.getType()) &&
+           converter.isLegal(&op.body());
+  });
+  target.addDynamicallyLegalOp<tfrt::gpu::conversion::AsyncExecuteOp>(
+      [&](tfrt::gpu::conversion::AsyncExecuteOp op) {
+        return converter.isLegal(&op.body());
+      });
+  target.markUnknownOpDynamicallyLegal([&](Operation* op) {
+    if (op->hasTrait<OpTrait::ReturnLike>()) return converter.isLegal(op);
+    return !wrap_target.isLegal(op);  // Wrapped ops are immediately lowered.
+  });
+
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    return signalPassFailure();
+}
+
+std::unique_ptr<FunctionPass> createConvertLmhloToGpuPass() {
+  return std::make_unique<ConvertLmhloToGpuPass>();
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h
index 1994a34..29f6e67 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h
@@ -24,7 +24,7 @@
 
 // Creates a pass that lowers lmhlo_gpu ops to tfrt_gpu. Prepares the function
 // to be consumed by MLIR's gpu-async-region pass.
-std::unique_ptr<mlir::FunctionPass> createLmhloGpuAsyncConversionPass();
+std::unique_ptr<mlir::FunctionPass> createConvertLmhloToGpuPass();
 
 }  // namespace tensorflow
 
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.td b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.td
index 7cf1b7b..a84c662 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.td
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.td
@@ -15,13 +15,13 @@
 
 include "mlir/Pass/PassBase.td"
 
-def LmhloGpuAsyncConversionPass : FunctionPass<"lmhlo-to-gpu"> {
+def ConvertLmhloToGpuPass : FunctionPass<"lmhlo-to-gpu"> {
   let summary = "Convert lmhlo_gpu ops to tfrt_gpu dialect.";
   let description = [{
     Move lmhlo_gpu ops inside tfrt_gpu_conversion.execute ops and convert them
     to tfrt_gpu dialect.
   }];
-  let constructor = "createLmhloGpuAsyncConversionPass()";
+  let constructor = "createConvertLmhloToGpuPass()";
   let dependentDialects = [
     "::mlir::gpu::GPUDialect",
     "::tfrt::compiler::TFRTDialect",
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.cc
index a1b55d7..10d53a3 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.cc
@@ -13,8 +13,6 @@
 // limitations under the License.
 
 // Pattern to lower gpu.memcpy ops to tfrt_gpu.mem.copy.
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.h"
-
 #include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
 #include "tfrt/gpu/kernels/gpu_ops.h"  // from @tf_runtime
 #include "tfrt/gpu/passes/passes.h"  // from @tf_runtime
@@ -44,8 +42,9 @@
 
 }  // namespace
 
-void populateMemcpyConversionPattern(RewritePatternSet& patterns) {
-  patterns.add<MemcpyRewritePattern>(patterns.getContext());
+void populateMemcpyConversionPattern(RewritePatternSet& patterns,
+                                     TypeConverter& converter) {
+  patterns.add<MemcpyRewritePattern>(converter, patterns.getContext());
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.h
deleted file mode 100644
index 27a9364..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.h
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2021 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMCPY_PATTERN_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMCPY_PATTERN_H_
-
-#include "mlir/IR/PatternMatch.h"  // from @llvm-project
-
-namespace tensorflow {
-
-// Add a pattern to the given pattern list to convert from mlir::gpu::MemcpyOp
-// to tfrt::gpu::MemCopyOp.
-void populateMemcpyConversionPattern(mlir::RewritePatternSet& patterns);
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMCPY_PATTERN_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.cc
index 117c086..e3b5965 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.cc
@@ -13,8 +13,6 @@
 // limitations under the License.
 
 // Pattern to lower mlir::gpu::memset Ops to tfrt cuda dialect.
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.h"
-
 #include "mlir/Dialect/GPU/GPUDialect.h"  // from @llvm-project
 #include "tfrt/gpu/kernels/gpu_ops.h"  // from @tf_runtime
 #include "tfrt/gpu/passes/passes.h"  // from @tf_runtime
@@ -45,8 +43,9 @@
 
 }  // namespace
 
-void populateMemsetConversionPattern(RewritePatternSet& patterns) {
-  patterns.add<MemsetRewritePattern>(patterns.getContext());
+void populateMemsetConversionPattern(RewritePatternSet& patterns,
+                                     TypeConverter& converter) {
+  patterns.add<MemsetRewritePattern>(converter, patterns.getContext());
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.h
deleted file mode 100644
index f4e354b..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.h
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2020 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//      http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMSET_PATTERN_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMSET_PATTERN_H_
-
-#include "mlir/IR/PatternMatch.h"  // from @llvm-project
-
-namespace tensorflow {
-
-// Add a pattern to the given pattern list to convert from mlir::gpu::MemsetOp
-// to tfrt::gpu::MemSetOp.
-void populateMemsetConversionPattern(mlir::RewritePatternSet& patterns);
-
-}  // namespace tensorflow
-
-#endif  // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMSET_PATTERN_H_
diff --git a/tensorflow/compiler/mlir/tfrt/translate/convert_xla_gpu.cc b/tensorflow/compiler/mlir/tfrt/translate/convert_xla_gpu.cc
index 376700b..4fee874 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/convert_xla_gpu.cc
+++ b/tensorflow/compiler/mlir/tfrt/translate/convert_xla_gpu.cc
@@ -14,12 +14,13 @@
 ==============================================================================*/
 #include "tensorflow/compiler/mlir/tfrt/translate/convert_xla_gpu.h"
 
+#include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Dialect/GPU/Passes.h"  // from @llvm-project
-#include "mlir/InitAllDialects.h"  // from @llvm-project
+#include "mlir/Transforms/Passes.h"
 #include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h"
 #include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
 #include "tensorflow/core/platform/errors.h"
@@ -55,20 +56,15 @@
 
   // LHLO -> TFRT Dialect (gpu kernels)
   mlir::PassManager pm(&context, mlir::PassManager::Nesting::Implicit);
-  pm.addPass(tensorflow::createLmhloGpuAsyncConversionPass());
+  pm.addPass(tensorflow::createConvertLmhloToGpuPass());
   pm.addPass(mlir::createGpuAsyncRegionPass());
   tfrt::gpu::populateGpuToTfrtGpuPasses(pm);
+  pm.addPass(mlir::createCanonicalizerPass());
   if (pm.run(*module).failed()) {
     return errors::Internal(
         "Failed to lower LHLO to TFRT Dialect with gpu kernels.");
   }
 
-  // Perform DCE with empty pattern set.
-  if (failed(mlir::applyPatternsAndFoldGreedily(*module,
-                                                RewritePatternSet(&context)))) {
-    return errors::Internal("Failed to remove dead ops.");
-  }
-
   // TFRT Dialect -> BEF
   std::string bef;
   llvm::raw_string_ostream bef_ostream(bef);
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 4afdae4..7fcb6ad 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -846,7 +846,7 @@
             "@llvm-project//mlir:GPUTransforms",
             "@llvm-project//mlir:Pass",
             "//tensorflow/compiler/mlir/hlo:lhlo",
-            "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:pass",
+            "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:passes",
             "//tensorflow/compiler/mlir/xla:attribute_exporter",
             "//tensorflow/compiler/xla/service:collective_ops_utils",
             "//tensorflow/core/tfrt/runtime:work_queue_interface",
@@ -1719,7 +1719,7 @@
         "//tensorflow/stream_executor:stream_executor_headers",
     ] + select({
         ":is_bef_executable_enabled": [
-            "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:pass",
+            "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:passes",
             "@tf_runtime//:mlirtobef_translate",
             "@tf_runtime//:support",
             "@tf_runtime//:bef",
diff --git a/tensorflow/compiler/xla/service/gpu/bef_thunk.cc b/tensorflow/compiler/xla/service/gpu/bef_thunk.cc
index 42b9cb8..541625a 100644
--- a/tensorflow/compiler/xla/service/gpu/bef_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/bef_thunk.cc
@@ -189,7 +189,7 @@
 static Status RunLmhloGpuToTfrtConversionPipeline(mlir::ModuleOp module) {
   mlir::PassManager pass_manager(module->getContext(),
                                  mlir::PassManager::Nesting::Implicit);
-  pass_manager.addPass(tensorflow::createLmhloGpuAsyncConversionPass());
+  pass_manager.addPass(tensorflow::createConvertLmhloToGpuPass());
   pass_manager.addPass(mlir::createGpuAsyncRegionPass());
   tfrt::gpu::populateGpuToTfrtGpuPasses(pass_manager);
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 3c8cd8d..88b37f6 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -40,6 +40,7 @@
 #include "mlir/InitAllDialects.h"  // from @llvm-project
 #include "mlir/Pass/PassManager.h"  // from @llvm-project
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
+#include "mlir/Transforms/Passes.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/utils/name_utils.h"
 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
@@ -749,20 +750,15 @@
   // LHLO -> TFRT Dialect (gpu kernels)
   mlir::PassManager pm(mlir_module.getContext(),
                        mlir::PassManager::Nesting::Implicit);
-  pm.addPass(tensorflow::createLmhloGpuAsyncConversionPass());
+  pm.addPass(tensorflow::createConvertLmhloToGpuPass());
   pm.addPass(mlir::createGpuAsyncRegionPass());
   tfrt::gpu::populateGpuToTfrtGpuPasses(pm);
+  pm.addPass(mlir::createCanonicalizerPass());
   if (pm.run(mlir_module).failed()) {
     return InternalError(
         "Failed to lower LHLO to TFRT Dialect with gpu kernels.");
   }
 
-  // Perform DCE with empty pattern set.
-  if (failed(mlir::applyPatternsAndFoldGreedily(
-          mlir_module, mlir::RewritePatternSet(mlir_module.getContext())))) {
-    return InternalError("Failed to remove dead ops.");
-  }
-
   // TFRT Dialect -> BEF
   std::string bef;
   llvm::raw_string_ostream bef_ostream(bef);