Add materialization to lmhlo-to-gpu pass.
Plus some cleanup changes and renames.
This prepares the pass pipeline to allow memref ops to persist past the initial lmhlo-to-gpu pass. Specifically, we want to defer the conversion of gpu ops (gpu.memcpy and gpu.memset) to the ConvertGpuToTfrtGpuPass.
PiperOrigin-RevId: 399863409
Change-Id: Ie27150e7a935bc8bebaf2b0518357438c10890a3
diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD
index 9a71c36..572f8ba 100644
--- a/tensorflow/compiler/mlir/tfrt/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/BUILD
@@ -481,7 +481,7 @@
"//platforms/xla/tests/gpu:__pkg__",
],
deps = [
- "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:pass",
+ "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:passes",
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core/platform:errors",
diff --git a/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc b/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc
index 8f7eab0..b04ab36 100644
--- a/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc
+++ b/tensorflow/compiler/mlir/tfrt/lhlo-tfrt-opt.cc
@@ -36,7 +36,7 @@
mlir::mhlo::MhloDialect, tfrt::gpu::GpuDialect,
tfrt::gpu::conversion::GpuConversionDialect>();
tfrt::RegisterTFRTDialects(registry);
- tensorflow::registerTFRTGPUPasses();
+ tensorflow::registerXlirPasses();
tfrt::gpu::registerPasses();
return failed(mlir::MlirOptMain(argc, argv, "MHLO TFRT pass driver\n",
registry,
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/BUILD
index 5a4c25a..aa47f40 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/BUILD
@@ -3,20 +3,8 @@
load("//tensorflow:tensorflow.bzl", "if_google")
# TF to TFRT kernels conversion.
-package(licenses = ["notice"])
-
-package_group(
- name = "friends",
- packages = if_google([
- "//learning/brain/experimental/mlir/tfrt_compiler/...",
- "//learning/brain/experimental/tfrt/...",
- "//platforms/xla/tests/gpu/...",
- "//third_party/tf_runtime_google/...",
- ]) + [
- "//tensorflow/compiler/...",
- "//tensorflow/core/runtime_fallback/...",
- "//tensorflow/core/tfrt/saved_model/...",
- ],
+package(
+ licenses = ["notice"],
)
gentbl_cc_library(
@@ -26,7 +14,7 @@
(
[
"-gen-pass-decls",
- "-name=TFRTGPU",
+ "-name=Xlir",
],
"gpu_passes.h.inc",
),
@@ -39,43 +27,31 @@
)
cc_library(
- name = "pass_details",
- hdrs = [
- "PassDetail.h",
- ],
- deps = [
- ":TfrtGpuPassIncGen",
- "//tensorflow/compiler/xla/service/gpu:xlir_opdefs",
- "@llvm-project//mlir:GPUDialect",
- "@llvm-project//mlir:Pass",
- "@tf_runtime//:basic_kernels_opdefs",
- "@tf_runtime//backends/gpu:gpu_opdefs",
- ],
-)
-
-cc_library(
- name = "pass",
+ name = "passes",
srcs = [
"ccl_pattern.cc",
- "ccl_pattern.h",
"custom_call_pattern.cc",
- "custom_call_pattern.h",
"gemm_pattern.cc",
- "gemm_pattern.h",
"gpu_passes.cc",
"memcpy_pattern.cc",
- "memcpy_pattern.h",
"memset_pattern.cc",
- "memset_pattern.h",
],
- hdrs = ["gpu_passes.h"],
+ hdrs = [
+ "gpu_passes.h",
+ "register_passes.h",
+ ],
tags = [
"gpu",
"no_oss",
],
- visibility = [":friends"],
+ visibility = if_google([
+ "//platforms/xla/tests/gpu:__pkg__",
+ ]) + [
+ "//tensorflow/compiler/mlir/tfrt:__pkg__",
+ "//tensorflow/compiler/xla/service/gpu:__pkg__",
+ ],
deps = [
- ":pass_details",
+ ":TfrtGpuPassIncGen",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
"//tensorflow/compiler/mlir/tensorflow",
@@ -92,6 +68,7 @@
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Rewrite",
"@llvm-project//mlir:StandardOps",
+ "@llvm-project//mlir:StandardOpsTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@tf_runtime//:basic_kernels_opdefs",
@@ -101,18 +78,3 @@
],
alwayslink = 1,
)
-
-cc_library(
- name = "passes",
- hdrs = ["register_passes.h"],
- tags = [
- "gpu",
- "no_oss",
- ],
- visibility = [":friends"],
- deps = [
- ":TfrtGpuPassIncGen",
- ":pass",
- "@llvm-project//mlir:Pass",
- ],
-)
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/PassDetail.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/PassDetail.h
deleted file mode 100644
index 4e08993..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/PassDetail.h
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright 2020 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_PASSDETAIL_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_PASSDETAIL_H_
-
-#include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Pass/Pass.h"
-#include "tensorflow/compiler/xla/service/gpu/xlir_ops.h"
-#include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime
-#include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime
-
-namespace tensorflow {
-
-#define GEN_PASS_CLASSES
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h.inc"
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_PASSDETAIL_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.cc
index 5a2433b..3717d5e 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.cc
@@ -18,8 +18,6 @@
// Pattern to lower lmhlo collective ops to tfrt_gpu/xlir dialect.
//
//===----------------------------------------------------------------------===//
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h"
-
#include <functional>
#include <string>
@@ -311,13 +309,14 @@
} // namespace
-void populateCclConversionPattern(RewritePatternSet& patterns) {
+void populateCclConversionPattern(RewritePatternSet& patterns,
+ TypeConverter& converter) {
patterns.add<CclRewritePattern<lmhlo::AllGatherOp>,
CclRewritePattern<lmhlo::AllReduceOp>,
CclRewritePattern<lmhlo::ReduceScatterOp>,
CclRewritePattern<lmhlo::AllToAllOp>,
CclRewritePattern<lmhlo::CollectivePermuteOp>>(
- patterns.getContext());
+ converter, patterns.getContext());
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h
deleted file mode 100644
index 1185511..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2021 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_CUDA_CCL_PATTERN_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_CUDA_CCL_PATTERN_H_
-
-#include "mlir/IR/PatternMatch.h"
-
-namespace tensorflow {
-
-void populateCclConversionPattern(mlir::RewritePatternSet& patterns);
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_CUDA_CCL_PATTERN_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.cc
index e1c2a3d..e9e01c9 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.cc
@@ -18,8 +18,6 @@
// Pattern to lower lmhlo.custom_call op to tfrt_gpu/xlir dialect.
//
//===----------------------------------------------------------------------===//
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.h"
-
#include <functional>
#include <string>
@@ -85,8 +83,9 @@
} // namespace
-void populateCustomCallConversionPattern(RewritePatternSet& patterns) {
- patterns.add<CustomCallRewritePattern>(patterns.getContext());
+void populateCustomCallConversionPattern(RewritePatternSet& patterns,
+ TypeConverter& converter) {
+ patterns.add<CustomCallRewritePattern>(converter, patterns.getContext());
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.h
deleted file mode 100644
index 735fd24..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.h
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2021 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_CUSTOM_CALL_PATTERN_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_CUSTOM_CALL_PATTERN_H_
-
-#include "mlir/IR/PatternMatch.h"
-
-namespace tensorflow {
-
-void populateCustomCallConversionPattern(mlir::RewritePatternSet& patterns);
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_CUSTOM_CALL_PATTERN_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.cc
index 19c64f8..ef34270 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.cc
@@ -18,8 +18,6 @@
// Pattern to lower lhlogpu_gemm Ops to tfrt cuda dialect.
//
//===----------------------------------------------------------------------===//
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h"
-
#include <assert.h>
#include <stdint.h>
@@ -345,10 +343,11 @@
} // namespace
-void populateGemmConversionPattern(RewritePatternSet& patterns) {
+void populateGemmConversionPattern(RewritePatternSet& patterns,
+ TypeConverter& converter) {
patterns.add<GemmRewritePattern<lmhlo_gpu::GEMMOp>,
GemmRewritePattern<lmhlo_gpu::GEMM_BiasOp>>(
- patterns.getContext());
+ converter, patterns.getContext());
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h
deleted file mode 100644
index 07451af..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2020 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_GEMM_PATTERN_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_GEMM_PATTERN_H_
-
-#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
-#include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/StringRef.h"
-
-namespace tensorflow {
-
-void populateGemmConversionPattern(mlir::RewritePatternSet& patterns);
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_GEMM_PATTERN_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.cc
index 22498f3..02e551f 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.cc
@@ -21,85 +21,105 @@
#include <memory>
#include <utility>
+#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/PassDetail.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/ccl_pattern.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/custom_call_pattern.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gemm_pattern.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.h"
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/xla/service/gpu/xlir_ops.h"
#include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime
#include "tfrt/gpu/passes/passes.h" // from @tf_runtime
#include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
namespace tensorflow {
+
+void populateCclConversionPattern(RewritePatternSet&, TypeConverter&);
+void populateCustomCallConversionPattern(RewritePatternSet&, TypeConverter&);
+void populateGemmConversionPattern(RewritePatternSet&, TypeConverter&);
+void populateMemcpyConversionPattern(RewritePatternSet&, TypeConverter&);
+void populateMemsetConversionPattern(RewritePatternSet&, TypeConverter&);
+
+#define GEN_PASS_CLASSES
+#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h.inc"
+
namespace {
-struct LmhloGpuAsyncConversionPass
- : public LmhloGpuAsyncConversionPassBase<LmhloGpuAsyncConversionPass> {
+struct ConvertLmhloToGpuPass
+ : public ConvertLmhloToGpuPassBase<ConvertLmhloToGpuPass> {
private:
- void runOnFunction() override {
- auto* context = &getContext();
-
- TypeConverter converter;
- converter.addConversion([](Type type) { return type; });
- auto buffer_type = tfrt::gpu::BufferType::get(context);
- converter.addConversion([&](BaseMemRefType) { return buffer_type; });
-
- ConversionTarget target(*context);
- target
- .addIllegalDialect<lmhlo_gpu::LmhloGpuDialect, mlir::gpu::GPUDialect>();
- target.addLegalDialect<tfrt::compiler::TFRTDialect, tfrt::gpu::GpuDialect,
- xla::gpu::XlirDialect>();
- target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
- return converter.isSignatureLegal(op.getType()) &&
- converter.isLegal(&op.body());
- });
- target.addDynamicallyLegalOp<tfrt::gpu::conversion::AsyncExecuteOp>(
- [&](tfrt::gpu::conversion::AsyncExecuteOp op) {
- return converter.isLegal(&op.body());
- });
-
- RewritePatternSet patterns(context);
- populateCclConversionPattern(patterns);
- populateCustomCallConversionPattern(patterns);
- populateGemmConversionPattern(patterns);
- populateMemcpyConversionPattern(patterns);
- populateMemsetConversionPattern(patterns);
- populateFuncOpTypeConversionPattern(patterns, converter);
-
- ConversionTarget wrap_target(*context);
- wrap_target
- .addLegalDialect<lmhlo_gpu::LmhloGpuDialect, mlir::gpu::GPUDialect>();
- wrap_target.addLegalOp<lmhlo::AllGatherOp, lmhlo::AllReduceOp,
- lmhlo::ReduceScatterOp, lmhlo::AllToAllOp,
- lmhlo::CollectivePermuteOp, lmhlo::CustomCallOp>();
- tfrt::gpu::populateGpuAsyncConversionPatterns(patterns, converter,
- wrap_target);
-
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
- return signalPassFailure();
- }
+ void runOnFunction() override;
};
} // namespace
-std::unique_ptr<FunctionPass> createLmhloGpuAsyncConversionPass() {
- return std::make_unique<LmhloGpuAsyncConversionPass>();
+static Value MaterializeCast(OpBuilder& builder, Type type, ValueRange values,
+ Location loc) {
+ auto cast_op = builder.create<UnrealizedConversionCastOp>(loc, type, values);
+ return cast_op.getResult(0);
+}
+
+void ConvertLmhloToGpuPass::runOnFunction() {
+ auto* context = &getContext();
+
+ TypeConverter converter;
+ converter.addConversion([](Type type) { return type; });
+ auto buffer_type = tfrt::gpu::BufferType::get(context);
+ converter.addConversion([&](BaseMemRefType) { return buffer_type; });
+ converter.addArgumentMaterialization(MaterializeCast);
+ converter.addSourceMaterialization(MaterializeCast);
+ converter.addTargetMaterialization(MaterializeCast);
+
+ RewritePatternSet patterns(context);
+ populateCclConversionPattern(patterns, converter);
+ populateCustomCallConversionPattern(patterns, converter);
+ populateGemmConversionPattern(patterns, converter);
+ populateMemcpyConversionPattern(patterns, converter);
+ populateMemsetConversionPattern(patterns, converter);
+ populateFuncOpTypeConversionPattern(patterns, converter);
+ populateReturnOpTypeConversionPattern(patterns, converter);
+
+ // Set of ops that need to be wrapped in tfrt_gpu_conversion.async.execute
+ // before lowering directly to tfrt_gpu ops (and therefore require some chain
+ // and stream, which the wrapper op provides as block arguments). On the other
+ // hand, ops which lower to the gpu dialect do not need to be wrapped.
+ ConversionTarget wrap_target(*context);
+ wrap_target
+ .addLegalDialect<lmhlo_gpu::LmhloGpuDialect, mlir::gpu::GPUDialect>();
+ wrap_target.addLegalOp<lmhlo::AllGatherOp, lmhlo::AllReduceOp,
+ lmhlo::ReduceScatterOp, lmhlo::AllToAllOp,
+ lmhlo::CollectivePermuteOp, lmhlo::CustomCallOp>();
+ tfrt::gpu::populateGpuAsyncConversionPatterns(patterns, converter,
+ wrap_target);
+
+ ConversionTarget target(*context);
+ target.addIllegalOp<memref::ReinterpretCastOp, memref::ViewOp>();
+ target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
+ return converter.isSignatureLegal(op.getType()) &&
+ converter.isLegal(&op.body());
+ });
+ target.addDynamicallyLegalOp<tfrt::gpu::conversion::AsyncExecuteOp>(
+ [&](tfrt::gpu::conversion::AsyncExecuteOp op) {
+ return converter.isLegal(&op.body());
+ });
+ target.markUnknownOpDynamicallyLegal([&](Operation* op) {
+ if (op->hasTrait<OpTrait::ReturnLike>()) return converter.isLegal(op);
+ return !wrap_target.isLegal(op); // Wrapped ops are immediately lowered.
+ });
+
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(patterns))))
+ return signalPassFailure();
+}
+
+std::unique_ptr<FunctionPass> createConvertLmhloToGpuPass() {
+ return std::make_unique<ConvertLmhloToGpuPass>();
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h
index 1994a34..29f6e67 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h
@@ -24,7 +24,7 @@
// Creates a pass that lowers lmhlo_gpu ops to tfrt_gpu. Prepares the function
// to be consumed by MLIR's gpu-async-region pass.
-std::unique_ptr<mlir::FunctionPass> createLmhloGpuAsyncConversionPass();
+std::unique_ptr<mlir::FunctionPass> createConvertLmhloToGpuPass();
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.td b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.td
index 7cf1b7b..a84c662 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.td
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.td
@@ -15,13 +15,13 @@
include "mlir/Pass/PassBase.td"
-def LmhloGpuAsyncConversionPass : FunctionPass<"lmhlo-to-gpu"> {
+def ConvertLmhloToGpuPass : FunctionPass<"lmhlo-to-gpu"> {
let summary = "Convert lmhlo_gpu ops to tfrt_gpu dialect.";
let description = [{
Move lmhlo_gpu ops inside tfrt_gpu_conversion.execute ops and convert them
to tfrt_gpu dialect.
}];
- let constructor = "createLmhloGpuAsyncConversionPass()";
+ let constructor = "createConvertLmhloToGpuPass()";
let dependentDialects = [
"::mlir::gpu::GPUDialect",
"::tfrt::compiler::TFRTDialect",
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.cc
index a1b55d7..10d53a3 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.cc
@@ -13,8 +13,6 @@
// limitations under the License.
// Pattern to lower gpu.memcpy ops to tfrt_gpu.mem.copy.
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.h"
-
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime
#include "tfrt/gpu/passes/passes.h" // from @tf_runtime
@@ -44,8 +42,9 @@
} // namespace
-void populateMemcpyConversionPattern(RewritePatternSet& patterns) {
- patterns.add<MemcpyRewritePattern>(patterns.getContext());
+void populateMemcpyConversionPattern(RewritePatternSet& patterns,
+ TypeConverter& converter) {
+ patterns.add<MemcpyRewritePattern>(converter, patterns.getContext());
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.h
deleted file mode 100644
index 27a9364..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memcpy_pattern.h
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2021 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMCPY_PATTERN_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMCPY_PATTERN_H_
-
-#include "mlir/IR/PatternMatch.h" // from @llvm-project
-
-namespace tensorflow {
-
-// Add a pattern to the given pattern list to convert from mlir::gpu::MemcpyOp
-// to tfrt::gpu::MemCopyOp.
-void populateMemcpyConversionPattern(mlir::RewritePatternSet& patterns);
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMCPY_PATTERN_H_
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.cc b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.cc
index 117c086..e3b5965 100644
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.cc
+++ b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.cc
@@ -13,8 +13,6 @@
// limitations under the License.
// Pattern to lower mlir::gpu::memset Ops to tfrt cuda dialect.
-#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.h"
-
#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project
#include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime
#include "tfrt/gpu/passes/passes.h" // from @tf_runtime
@@ -45,8 +43,9 @@
} // namespace
-void populateMemsetConversionPattern(RewritePatternSet& patterns) {
- patterns.add<MemsetRewritePattern>(patterns.getContext());
+void populateMemsetConversionPattern(RewritePatternSet& patterns,
+ TypeConverter& converter) {
+ patterns.add<MemsetRewritePattern>(converter, patterns.getContext());
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.h b/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.h
deleted file mode 100644
index f4e354b..0000000
--- a/tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/memset_pattern.h
+++ /dev/null
@@ -1,28 +0,0 @@
-// Copyright 2020 The TensorFlow Runtime Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMSET_PATTERN_H_
-#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMSET_PATTERN_H_
-
-#include "mlir/IR/PatternMatch.h" // from @llvm-project
-
-namespace tensorflow {
-
-// Add a pattern to the given pattern list to convert from mlir::gpu::MemsetOp
-// to tfrt::gpu::MemSetOp.
-void populateMemsetConversionPattern(mlir::RewritePatternSet& patterns);
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_LHLO_GPU_TO_TFRT_GPU_MEMSET_PATTERN_H_
diff --git a/tensorflow/compiler/mlir/tfrt/translate/convert_xla_gpu.cc b/tensorflow/compiler/mlir/tfrt/translate/convert_xla_gpu.cc
index 376700b..4fee874 100644
--- a/tensorflow/compiler/mlir/tfrt/translate/convert_xla_gpu.cc
+++ b/tensorflow/compiler/mlir/tfrt/translate/convert_xla_gpu.cc
@@ -14,12 +14,13 @@
==============================================================================*/
#include "tensorflow/compiler/mlir/tfrt/translate/convert_xla_gpu.h"
+#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Dialect/GPU/Passes.h" // from @llvm-project
-#include "mlir/InitAllDialects.h" // from @llvm-project
+#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu/gpu_passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
#include "tensorflow/core/platform/errors.h"
@@ -55,20 +56,15 @@
// LHLO -> TFRT Dialect (gpu kernels)
mlir::PassManager pm(&context, mlir::PassManager::Nesting::Implicit);
- pm.addPass(tensorflow::createLmhloGpuAsyncConversionPass());
+ pm.addPass(tensorflow::createConvertLmhloToGpuPass());
pm.addPass(mlir::createGpuAsyncRegionPass());
tfrt::gpu::populateGpuToTfrtGpuPasses(pm);
+ pm.addPass(mlir::createCanonicalizerPass());
if (pm.run(*module).failed()) {
return errors::Internal(
"Failed to lower LHLO to TFRT Dialect with gpu kernels.");
}
- // Perform DCE with empty pattern set.
- if (failed(mlir::applyPatternsAndFoldGreedily(*module,
- RewritePatternSet(&context)))) {
- return errors::Internal("Failed to remove dead ops.");
- }
-
// TFRT Dialect -> BEF
std::string bef;
llvm::raw_string_ostream bef_ostream(bef);
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 4afdae4..7fcb6ad 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -846,7 +846,7 @@
"@llvm-project//mlir:GPUTransforms",
"@llvm-project//mlir:Pass",
"//tensorflow/compiler/mlir/hlo:lhlo",
- "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:pass",
+ "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:passes",
"//tensorflow/compiler/mlir/xla:attribute_exporter",
"//tensorflow/compiler/xla/service:collective_ops_utils",
"//tensorflow/core/tfrt/runtime:work_queue_interface",
@@ -1719,7 +1719,7 @@
"//tensorflow/stream_executor:stream_executor_headers",
] + select({
":is_bef_executable_enabled": [
- "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:pass",
+ "//tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu:passes",
"@tf_runtime//:mlirtobef_translate",
"@tf_runtime//:support",
"@tf_runtime//:bef",
diff --git a/tensorflow/compiler/xla/service/gpu/bef_thunk.cc b/tensorflow/compiler/xla/service/gpu/bef_thunk.cc
index 42b9cb8..541625a 100644
--- a/tensorflow/compiler/xla/service/gpu/bef_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/bef_thunk.cc
@@ -189,7 +189,7 @@
static Status RunLmhloGpuToTfrtConversionPipeline(mlir::ModuleOp module) {
mlir::PassManager pass_manager(module->getContext(),
mlir::PassManager::Nesting::Implicit);
- pass_manager.addPass(tensorflow::createLmhloGpuAsyncConversionPass());
+ pass_manager.addPass(tensorflow::createConvertLmhloToGpuPass());
pass_manager.addPass(mlir::createGpuAsyncRegionPass());
tfrt::gpu::populateGpuToTfrtGpuPasses(pass_manager);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 3c8cd8d..88b37f6 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -40,6 +40,7 @@
#include "mlir/InitAllDialects.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
+#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/utils/name_utils.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
@@ -749,20 +750,15 @@
// LHLO -> TFRT Dialect (gpu kernels)
mlir::PassManager pm(mlir_module.getContext(),
mlir::PassManager::Nesting::Implicit);
- pm.addPass(tensorflow::createLmhloGpuAsyncConversionPass());
+ pm.addPass(tensorflow::createConvertLmhloToGpuPass());
pm.addPass(mlir::createGpuAsyncRegionPass());
tfrt::gpu::populateGpuToTfrtGpuPasses(pm);
+ pm.addPass(mlir::createCanonicalizerPass());
if (pm.run(mlir_module).failed()) {
return InternalError(
"Failed to lower LHLO to TFRT Dialect with gpu kernels.");
}
- // Perform DCE with empty pattern set.
- if (failed(mlir::applyPatternsAndFoldGreedily(
- mlir_module, mlir::RewritePatternSet(mlir_module.getContext())))) {
- return InternalError("Failed to remove dead ops.");
- }
-
// TFRT Dialect -> BEF
std::string bef;
llvm::raw_string_ostream bef_ostream(bef);