blob: ddde4c0d045a77df30977e3b3761e196acd7fdfe [file] [log] [blame]
// Copyright 2022 The TensorFlow Runtime Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.h"
#include <cstdint>
#include <utility>
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/GPU/Passes.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/TypeRange.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_gpu_binary.h"
#include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime
#include "tfrt/gpu/passes/passes.h" // from @tf_runtime
namespace tensorflow {
namespace {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/jitrt_passes.h.inc"
using lmhlo_gpu::GEMM_BiasOp;
using mlir::DialectRegistry;
using mlir::FunctionType;
using mlir::MLIRContext;
using mlir::ModuleOp;
using mlir::NamedAttribute;
using mlir::OperationPass;
using mlir::success;
using mlir::SymbolTable;
using mlir::Type;
using mlir::TypeRange;
using mlir::arith::IndexCastOp;
using mlir::detail::PassOptions;
using mlir::func::CallOp;
using mlir::func::FuncOp;
using mlir::func::ReturnOp;
using mlir::gpu::GPUModuleOp;
using mlir::gpu::LaunchFuncOp;
using mlir::lmhlo::TerminatorOp;
using mlir::lmhlo_gpu::GEMMOp;
class ConvertGpuBinaryToJitRtPass
: public ConvertGpuBinaryToJitRtPassBase<ConvertGpuBinaryToJitRtPass> {
void runOnOperation() override;
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<mlir::func::FuncDialect, mlir::arith::ArithmeticDialect>();
}
};
class ConvertLmhloGpuToJitRtPass
: public ConvertLmhloGpuToJitRtPassBase<ConvertLmhloGpuToJitRtPass> {
void runOnOperation() override;
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<mlir::func::FuncDialect, mlir::arith::ArithmeticDialect>();
}
};
} // namespace
// -------------------------------------------------------------------------- //
class GpuModuleOpLowering : public OpRewritePattern<GPUModuleOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GPUModuleOp op,
PatternRewriter& rewriter) const override {
rewriter.eraseOp(op);
return success();
}
};
// -------------------------------------------------------------------------- //
class TerminatorOpLowering : public OpRewritePattern<TerminatorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TerminatorOp op,
PatternRewriter& rewriter) const override {
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return mlir::success();
}
};
// -------------------------------------------------------------------------- //
class LaunchFuncOpLowering : public OpRewritePattern<LaunchFuncOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(LaunchFuncOp op,
PatternRewriter& rewriter) const override {
MLIRContext* ctx = getContext();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
ModuleOp module = op->getParentOfType<ModuleOp>();
// Cast grid and block dimensions to i32 before passing to the custom call.
auto cast = [&](mlir::Value value) {
return b.create<IndexCastOp>(b.getI32Type(), value);
};
// Prepare arguments for the custom call.
llvm::SmallVector<Value> args = {
cast(op.gridSizeX()), cast(op.gridSizeY()), cast(op.gridSizeZ()),
cast(op.blockSizeX()), cast(op.blockSizeY()), cast(op.blockSizeZ())};
// Add kernel arguments.
llvm::copy(op.operands(), std::back_inserter(args));
// Types of the custom call arguments.
llvm::SmallVector<Type> args_types = TypeRange(ValueRange(args));
// Custom call target.
NamedAttribute target(b.getStringAttr("rt.direct_custom_call"),
b.getStringAttr("xla.gpu.func.launch"));
// Create a custom call function declaration.
auto custom_call_type = FunctionType::get(ctx, args_types, TypeRange());
auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
auto custom_call = FuncOp::create(op.getLoc(), "launch_func",
custom_call_type, custom_call_attrs);
custom_call.setPrivate();
SymbolTable sym_table(module);
auto inserted = sym_table.insert(custom_call);
rewriter.notifyOperationInserted(custom_call);
// Get the compiled gpu function.
auto* kernel = SymbolTable::lookupNearestSymbolFrom(op, op.kernel());
assert(kernel && "kernel not found");
// Get the compiled GPU binary from the device kernel module.
auto gpu_module = kernel->getParentOfType<mlir::gpu::GPUModuleOp>();
auto gpu_binary = gpu_module->getAttrOfType<mlir::StringAttr>("binary");
// Create a function launch call operation.
auto call =
rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(), args);
call->setAttr(b.getStringAttr("ptx"), gpu_binary);
call->setAttr(b.getStringAttr("kernel"), op.getKernelName());
// Erase the original gpu launch operation.
rewriter.eraseOp(op);
return success();
}
};
// -------------------------------------------------------------------------- //
// Every Gemm operation in the module gets assigned a unique id, that is passed
// to the custom call handler. This id is used for caching resources between the
// different invocations of the same gemm operation.
class GemmUidGenerator {
public:
int64_t uid() { return cnt_.fetch_add(1); }
private:
std::atomic<int64_t> cnt_ = 0;
};
template <typename Gemm>
class GemmLowering : public OpRewritePattern<Gemm> {
private:
static StringRef CustomCallTarget(GEMMOp) { return "xla.gpu.gemm"; }
static StringRef CustomCallTarget(GEMM_BiasOp) { return "xla.gpu.gemm.bias"; }
static void SetOptionalAttrs(ImplicitLocOpBuilder& b, GEMMOp op,
CallOp call) {}
static void SetOptionalAttrs(ImplicitLocOpBuilder& b, GEMM_BiasOp op,
CallOp call) {
call->setAttr(b.getStringAttr("beta"), op.betaAttr());
}
public:
GemmLowering(MLIRContext* ctx, GemmUidGenerator& uid)
: OpRewritePattern<Gemm>(ctx), uid_(uid) {}
LogicalResult matchAndRewrite(Gemm op,
PatternRewriter& rewriter) const override {
MLIRContext* ctx = this->getContext();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
ModuleOp module = op->template getParentOfType<ModuleOp>();
// Custom call target.
NamedAttribute target(b.getStringAttr("rt.direct_custom_call"),
b.getStringAttr(CustomCallTarget(op)));
// Create a custom call function declaration.
auto custom_call_type =
FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
auto custom_call = FuncOp::create(op.getLoc(), "gemm", custom_call_type,
custom_call_attrs);
custom_call.setPrivate();
SymbolTable sym_table(module);
auto inserted = sym_table.insert(custom_call);
rewriter.notifyOperationInserted(custom_call);
// Convert Gemm to a function call.
auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
op.getOperands());
// Assign a unique id to this instance of a gemm operation.
call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid()));
// Copy backend specific attributes.
call->setAttr(b.getStringAttr("algorithm"), op.algorithmAttr());
call->setAttr(b.getStringAttr("alpha_imag"), op.alpha_imagAttr());
call->setAttr(b.getStringAttr("alpha_real"), op.alpha_realAttr());
// Set optional arguments that are defined only for some Gemm ops.
SetOptionalAttrs(b, op, call);
// TODO(ezhulenev): Once cutom calls support passing structured attributes
// we should be able to pass `mhlo.dot` attribute directly.
auto dot = op.dot_dimension_numbers();
auto lhs_batch = b.getI64TensorAttr(dot.getLhsBatchingDimensions());
auto lhs_contract = b.getI64TensorAttr(dot.getLhsContractingDimensions());
auto rhs_batch = b.getI64TensorAttr(dot.getRhsBatchingDimensions());
auto rhs_contract = b.getI64TensorAttr(dot.getRhsContractingDimensions());
call->setAttr(b.getStringAttr("lhs_batching_dimensions"), lhs_batch);
call->setAttr(b.getStringAttr("lhs_contracting_dimensions"), lhs_contract);
call->setAttr(b.getStringAttr("rhs_batching_dimensions"), rhs_batch);
call->setAttr(b.getStringAttr("rhs_contracting_dimensions"), rhs_contract);
// Erase the original gemm operation.
rewriter.eraseOp(op);
return success();
}
private:
GemmUidGenerator& uid_;
};
class GemmOpLowering : public GemmLowering<GEMMOp> {
public:
using GemmLowering::GemmLowering;
};
class GemmBiasOpLowering : public GemmLowering<GEMM_BiasOp> {
public:
using GemmLowering::GemmLowering;
};
// -------------------------------------------------------------------------- //
void ConvertGpuBinaryToJitRtPass::runOnOperation() {
ModuleOp module = getOperation();
MLIRContext* ctx = module.getContext();
// Convert gpu operations to JitRt gpu runtime custom calls.
RewritePatternSet patterns(ctx);
patterns.insert<GpuModuleOpLowering, LaunchFuncOpLowering>(ctx);
// Set up conversion target to rewrite gpu operations.
ConversionTarget target(*ctx);
target.addIllegalOp<GPUModuleOp, LaunchFuncOp>();
target.addLegalOp<IndexCastOp, FuncOp, CallOp, ReturnOp>();
if (failed(applyPartialConversion(module, target, std::move(patterns))))
signalPassFailure();
}
void ConvertLmhloGpuToJitRtPass::runOnOperation() {
ModuleOp module = getOperation();
MLIRContext* ctx = module.getContext();
GemmUidGenerator uid;
// Convert lmhlo_gpu operations to JitRt gpu runtime custom calls.
RewritePatternSet patterns(ctx);
patterns.insert<GemmOpLowering, GemmBiasOpLowering>(ctx, uid);
patterns.insert<TerminatorOpLowering>(ctx);
if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns))))
return signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>> createConvertGpuBinaryToJitRtPass() {
return std::make_unique<ConvertGpuBinaryToJitRtPass>();
}
std::unique_ptr<OperationPass<ModuleOp>> createConvertLmhloGpuToJitRtPass() {
return std::make_unique<ConvertLmhloGpuToJitRtPass>();
}
void populateLmhloToJitRtPasses(mlir::OpPassManager& pm) {
pm.addPass(createConvertLmhloToGpuBinaryPass());
pm.addPass(createConvertGpuBinaryToJitRtPass());
pm.addPass(createConvertLmhloGpuToJitRtPass());
}
void registerLmhloToJitRtPasses() {
mlir::registerPass([] { return createConvertGpuBinaryToJitRtPass(); });
mlir::registerPass([] { return createConvertLmhloGpuToJitRtPass(); });
mlir::registerPassPipeline(
"lmhlo-to-jitrt", "Lower LMHLO to JitRt IR",
[](OpPassManager& pm, StringRef options,
function_ref<LogicalResult(const Twine&)> errorHandler) {
populateLmhloToJitRtPasses(pm);
return success();
},
/*optHandler=*/[](function_ref<void(const PassOptions&)>) {});
}
} // namespace tensorflow