blob: 775ee3c0c9a4078cdf8295e4b1f8bd9465a9f907 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering LHLO GPU dialect to TFRT CUDA
// dialect.
#include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_gpu.h"
#include <memory>
#include <utility>
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/xla/service/gpu/xlir_ops.h"
#include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime
#include "tfrt/gpu/passes/passes.h" // from @tf_runtime
#include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
namespace tensorflow {
void populateCclConversionPattern(RewritePatternSet&, TypeConverter&);
void populateCholeskyConversionPattern(RewritePatternSet&, TypeConverter&);
void populateCustomCallConversionPattern(RewritePatternSet&, TypeConverter&);
void populateGemmConversionPattern(RewritePatternSet&, TypeConverter&);
void populateTriangularSolveConversionPattern(RewritePatternSet&,
TypeConverter&);
namespace {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/gpu_passes.h.inc"
struct ConvertLmhloToGpuPass
: public ConvertLmhloToGpuPassBase<ConvertLmhloToGpuPass> {
private:
void runOnFunction() override;
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<mlir::gpu::GPUDialect, tfrt::compiler::TFRTDialect,
tfrt::gpu::GpuDialect,
tfrt::gpu::conversion::GpuConversionDialect,
xla::gpu::XlirDialect>();
}
};
} // namespace
void ConvertLmhloToGpuPass::runOnFunction() {
auto* context = &getContext();
TypeConverter converter = tfrt::gpu::createMemrefToTfrtGpuConverter();
RewritePatternSet patterns(context);
populateCclConversionPattern(patterns, converter);
populateCholeskyConversionPattern(patterns, converter);
populateCustomCallConversionPattern(patterns, converter);
populateGemmConversionPattern(patterns, converter);
populateTriangularSolveConversionPattern(patterns, converter);
populateFuncOpTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
// Set of ops that need to be wrapped in tfrt_gpu_conversion.async.execute
// before lowering directly to tfrt_gpu ops (and therefore require some chain
// and stream, which the wrapper op provides as block arguments). On the other
// hand, ops which lower to the gpu dialect do not need to be wrapped.
ConversionTarget wrap_target(*context);
wrap_target.addLegalDialect<lmhlo_gpu::LmhloGpuDialect>();
wrap_target.addLegalOp<lmhlo::AllGatherOp, lmhlo::AllReduceOp,
lmhlo::ReduceScatterOp, lmhlo::AllToAllOp,
lmhlo::CollectivePermuteOp, lmhlo::CustomCallOp,
lmhlo::TriangularSolveOp>();
tfrt::gpu::populateGpuAsyncConversionPatterns(patterns, converter,
wrap_target);
ConversionTarget target(*context);
target.addIllegalOp<memref::ReinterpretCastOp, memref::ViewOp>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return converter.isSignatureLegal(op.getType()) &&
converter.isLegal(&op.body());
});
target.addDynamicallyLegalOp<tfrt::gpu::conversion::AsyncExecuteOp>(
[&](tfrt::gpu::conversion::AsyncExecuteOp op) {
return converter.isLegal(&op.body());
});
target.markUnknownOpDynamicallyLegal([&](Operation* op) {
if (op->hasTrait<OpTrait::ReturnLike>()) return converter.isLegal(op);
return !wrap_target.isLegal(op); // Wrapped ops are immediately lowered.
});
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
return signalPassFailure();
}
std::unique_ptr<FunctionPass> createConvertLmhloToGpuPass() {
return std::make_unique<ConvertLmhloToGpuPass>();
}
void registerConvertLmhloToGpuPass() {
::mlir::registerPass([] { return createConvertLmhloToGpuPass(); });
}
} // namespace tensorflow