blob: 0ebdf7a18ee9d8f2b4102f9786b91dc3e475e4c5 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "tensorflow/compiler/xla/mlir/ir/runtime/rt_ops.h"
#include "tensorflow/compiler/xla/mlir/transforms/runtime/rt_passes.h"
namespace xla {
namespace runtime {
using namespace mlir; // NOLINT
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/xla/mlir/transforms/runtime/rt_gen_passes.h.inc"
class ConvertToEntrypointPass
: public ConvertToEntrypointBase<ConvertToEntrypointPass> {
void runOnOperation() override;
};
static void ConvertCustomCallOperations(func::FuncOp func, Value exec_ctx) {
MLIRContext* ctx = func->getContext();
SymbolTable sym_table(func->getParentOfType<ModuleOp>());
struct CustomCall {
func::CallOp call;
func::FuncOp callee;
llvm::StringRef target;
bool direct;
};
// Collect function calls that have to be converted to custom calls.
llvm::SmallVector<CustomCall> custom_calls;
func.walk([&](func::CallOp op) {
auto callee = dyn_cast<func::FuncOp>(sym_table.lookup(op.getCallee()));
if (!callee) return;
// Check if the call is an indirect custom call ...
StringAttr target = callee->getAttrOfType<StringAttr>("rt.custom_call");
if (target) custom_calls.push_back({op, callee, target.strref(), false});
// ... or a direct custom call.
target = callee->getAttrOfType<StringAttr>("rt.direct_custom_call");
if (target) custom_calls.push_back({op, callee, target.strref(), true});
});
// After converting to custom call we need to clean up all declarations.
llvm::DenseSet<func::FuncOp> erase_declarations;
// Rewrite function calls to `rt.custom_call` operations.
for (CustomCall custom_call : custom_calls) {
ImplicitLocOpBuilder b(custom_call.call.getLoc(), custom_call.call);
// Custom call intrinsic always returns the status flag.
llvm::SmallVector<Type> results = {StatusType::get(ctx)};
results.append(custom_call.call->getResultTypes().begin(),
custom_call.call->getResultTypes().end());
// Rewrite function call with a custom call, and check the return status.
auto call = b.create<CustomCallOp>(results, exec_ctx, custom_call.target,
custom_call.direct,
custom_call.call.getOperands());
// Copy optional attributes from the custom call function declaration.
llvm::ArrayRef<llvm::StringRef> callee_attrs =
custom_call.callee.getAttributeNames();
for (auto& attr : custom_call.callee->getAttrs()) {
if (isa_and_nonnull<RuntimeDialect>(attr.getNameDialect())) continue;
if (llvm::find(callee_attrs, attr.getName()) == callee_attrs.end())
call->setAttr(attr.getName(), attr.getValue());
}
// Copy optional attributes from the call operation to the custom call.
llvm::ArrayRef<llvm::StringRef> orig_attrs =
custom_call.call.getAttributeNames();
for (auto& attr : custom_call.call->getAttrs()) {
if (llvm::find(orig_attrs, attr.getName()) == orig_attrs.end())
call->setAttr(attr.getName(), attr.getValue());
}
b.create<cf::AssertOp>(
b.create<IsOkOp>(TypeRange(b.getI1Type()), call.status()),
b.getStringAttr("custom call '" + custom_call.target + "' failed"));
// Forward users of the original results to custom call results.
auto rets = llvm::zip(custom_call.call.getResults(),
llvm::drop_begin(call.getResults()));
llvm::for_each(rets, [](auto ret) {
std::get<0>(ret).replaceAllUsesWith(std::get<1>(ret));
});
// Keep track of custom call declaration to erase.
erase_declarations.insert(custom_call.callee);
// Erase the original function call operation.
custom_call.call.erase();
}
// Erase all converted custom calls declarations.
for (auto func : erase_declarations) sym_table.erase(func);
}
static void ConvertReturnOperations(func::FuncOp func, Value exec_ctx) {
// Convert all returns to the Runtime API calls.
func.walk([&](func::ReturnOp ret) {
ImplicitLocOpBuilder b(ret.getLoc(), ret);
// Return all outputs via the `rt.set_output` operation.
for (auto& pair : llvm::enumerate(ret.getOperands())) {
b.create<SetOutputOp>(exec_ctx, pair.index(), pair.value());
}
// Replace original return with an empty one.
b.create<func::ReturnOp>();
ret.erase();
});
// Update function type to the function with empty results.
auto type = FunctionType::get(func.getContext(), func.getArgumentTypes(), {});
func.setType(type);
}
static void ConvertAssertOperations(func::FuncOp func, Value exec_ctx) {
// Collect all assert operations in the function body.
llvm::SmallVector<cf::AssertOp> asserts;
func.walk([&](cf::AssertOp op) { asserts.push_back(op); });
// Rewrite all asserts to the Runtime API calls.
for (cf::AssertOp assert : asserts) {
ImplicitLocOpBuilder b(assert.getLoc(), assert);
// Split the block at the assert operation.
Block* block = assert->getBlock();
Block* ok = block->splitBlock(assert);
// Set up block for returning error.
Block* err = func.addBlock();
b.setInsertionPointToStart(err);
b.create<SetErrorOp>(exec_ctx, assert.getMsg());
b.create<func::ReturnOp>();
// Branch into the error block if assertion failed.
b.setInsertionPointToEnd(block);
b.create<cf::CondBranchOp>(assert.getArg(), ok, err);
// Erase the original assert operation.
assert.erase();
}
}
static Value PrependExecutionContextArgument(func::FuncOp func) {
Type new_type = KernelContextType::get(func.getContext());
DictionaryAttr attr = DictionaryAttr::get(func.getContext());
func.insertArguments({0}, {new_type}, {attr}, {func.getLoc()});
return func.getArgument(0);
}
static void ConvertToEntrypoint(func::FuncOp func) {
assert(func->hasAttr(kEntrypointAttrName));
Value exec_ctx = PrependExecutionContextArgument(func);
ConvertCustomCallOperations(func, exec_ctx);
ConvertReturnOperations(func, exec_ctx);
ConvertAssertOperations(func, exec_ctx);
// After conversion !rt.execution_context is a marker of an entrypoint.
func->removeAttr(kEntrypointAttrName);
}
void ConvertToEntrypointPass::runOnOperation() {
llvm::SmallVector<func::FuncOp> entry_points;
// Collect entrypoint functions.
getOperation().walk([&](func::FuncOp op) {
if (op->hasAttr(kEntrypointAttrName)) entry_points.push_back(op);
});
llvm::for_each(entry_points, ConvertToEntrypoint);
}
std::unique_ptr<OperationPass<ModuleOp>> CreateConvertToEntrypoint() {
return std::make_unique<ConvertToEntrypointPass>();
}
} // namespace runtime
} // namespace xla