blob: 86ffd3ab8ff32a75b3a080c7640586302360b826 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass transforms functional control flow operations in the
// TensorFlow dialect to MLIR Control Flow Graph (CFG) form.
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
namespace mlir {
namespace TF {
namespace {
struct FunctionalControlFlowToCFG
: public FunctionalControlFlowToCFGPassBase<FunctionalControlFlowToCFG> {
void runOnOperation() override;
};
// Lowers a general tensor argument that is used as a condition to a functional
// control flow op into an i1 value.
static Value LowerCondition(Location loc, Value value, OpBuilder* builder) {
auto zero_d = builder->create<ToBoolOp>(loc, value);
auto scalar = builder->create<tensor::ExtractOp>(loc, zero_d);
return scalar.getResult();
}
// Calls the function `fn` with arguments provided by the given function and
// return the CallOp. Arguments are cast to the required type before calling
// the function.
//
// Requires the function to provide arguments for each of the `fn` operands
// that is compatible for tensor cast.
static Operation* CallFn(Location loc, const std::function<Value(int)>& get_arg,
func::FuncOp fn, OpBuilder* builder) {
FunctionType fn_type = fn.getFunctionType();
llvm::SmallVector<Value, 4> operands;
int num_operands = fn_type.getNumInputs();
operands.reserve(num_operands);
for (int i = 0; i < num_operands; ++i) {
Value val = get_arg(i);
Type expected = fn_type.getInput(i);
if (val.getType() != expected) {
val =
builder->create<TF::CastOp>(loc, expected, val,
/*Truncate=*/builder->getBoolAttr(false));
}
operands.push_back(val);
}
return builder->create<func::CallOp>(loc, fn, operands).getOperation();
}
// Prepares for jump to the given block by introducing necessary tensor_cast
// operations and returning Values of types required by the block.
//
// Requires the function to provide values for each of the block arguments and
// they should be pair-wise compatible for tensor cast.
static llvm::SmallVector<Value, 4> PrepareValsForJump(
Location loc, const std::function<Value(int)>& get_val, Block* block,
OpBuilder* builder) {
llvm::SmallVector<Value, 4> result;
int num_vals = block->getNumArguments();
result.reserve(num_vals);
for (int i = 0; i < num_vals; ++i) {
Value val = get_val(i);
Type expected = block->getArgument(i).getType();
if (val.getType() != expected) {
val =
builder->create<TF::CastOp>(loc, expected, val,
/*Truncate=*/builder->getBoolAttr(false));
}
result.push_back(val);
}
return result;
}
// Jumps to the given block with arguments provided by the function. Arguments
// are cast to the required type before the jump.
//
// Requires the function to provide values for each of the block arguments and
// they should be pair-wise compatible for tensor cast.
static void JumpToBlock(Location loc, const std::function<Value(int)>& get_arg,
Block* block, OpBuilder* builder) {
auto operands = PrepareValsForJump(loc, get_arg, block, builder);
builder->create<cf::BranchOp>(loc, block, operands);
}
// Replaces all uses of the operation results in this block with block
// arguments.
//
// Requires that the block has same number of arguments as number of results of
// the operation and either they have same types or are more generic types and
// it is possible to cast them to results' types.
static void ReplaceOpResultWithBlockArgs(Location loc, Operation* op,
Block* block, OpBuilder* builder) {
assert(op->getNumResults() == block->getNumArguments());
for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
Value arg = block->getArgument(i);
Value result = op->getResult(i);
if (arg.getType() != result.getType()) {
arg =
builder->create<TF::CastOp>(loc, result.getType(), arg,
/*Truncate=*/builder->getBoolAttr(false));
}
result.replaceAllUsesWith(arg);
}
}
// Given a functional IfOp, transforms the enclosing code to eliminate it
// completely from the IR, breaking it into operations to evaluate the condition
// as a bool, plus some branches.
static LogicalResult LowerIfOp(IfOp op) {
Operation* op_inst = op.getOperation();
Location loc = op_inst->getLoc();
OpBuilder builder(op_inst);
// Lower the condition to a boolean value (i1).
Value cond_i1 = LowerCondition(loc, op.cond(), &builder);
if (!cond_i1) return failure();
// Split the basic block before the 'if'. The new dest will be our merge
// point.
Block* orig_block = op_inst->getBlock();
Block* merge_block = orig_block->splitBlock(op);
// Add the block arguments to the merge point, and replace all uses of the
// original operation results with them.
for (Value value : op_inst->getResults())
merge_block->addArgument(value.getType(), loc);
ReplaceOpResultWithBlockArgs(loc, op_inst, merge_block, &builder);
// Get arguments to the branches after dropping the condition which is the
// first operand.
auto get_operand = [&](int i) { return op_inst->getOperand(i + 1); };
// Set up the 'then' block.
Block* then_block = builder.createBlock(merge_block);
Operation* call_op = CallFn(loc, get_operand, op.then_function(), &builder);
auto get_then_result = [&](int i) { return call_op->getResult(i); };
JumpToBlock(loc, get_then_result, merge_block, &builder);
// Set up the 'else' block.
Block* else_block = builder.createBlock(merge_block);
call_op = CallFn(loc, get_operand, op.else_function(), &builder);
auto get_else_result = [&](int i) { return call_op->getResult(i); };
JumpToBlock(loc, get_else_result, merge_block, &builder);
// Now that we have the then and else blocks, replace the terminator of the
// orig_block with a conditional branch.
builder.setInsertionPointToEnd(orig_block);
builder.create<cf::CondBranchOp>(loc, cond_i1, then_block,
llvm::ArrayRef<Value>(), else_block,
llvm::ArrayRef<Value>());
// Finally, delete the op in question.
op_inst->erase();
return success();
}
// Given a functional WhileOp, transforms the enclosing code to eliminate it
// completely from the IR, breaking it into operations to execute the loop body
// repeatedly while the loop condition is true.
static LogicalResult LowerWhileOp(WhileOp op) {
Operation* op_inst = op.getOperation();
Location loc = op_inst->getLoc();
OpBuilder builder(op_inst);
auto cond_fn = op.cond_function();
auto body_fn = op.body_function();
// Split the block containing the While op into two blocks. One containing
// operations before the While op and other containing the rest. Create two
// new blocks to call condition and body functions.
//
// The final control flow graph would be as follows:
//
// ...
// orig_block_head(...):
// ...
// br cond_block(...)
// cond_block(...):
// %A = call @cond(...)
// cond br %A, body_block(...), orig_block_tail(...)
// body_block(...):
// %B = call @body(...)
// br cond_block(...)
// orig_block_tail(...):
// ...
//
Block* orig_block_head = op_inst->getBlock();
Block* orig_block_tail = orig_block_head->splitBlock(op);
Block* cond_block = builder.createBlock(orig_block_tail);
Block* body_block = builder.createBlock(orig_block_tail);
// Set argument types for the cond_block to be same as the types of the
// condition function and argument types for the other two blocks to be same
// as the input types of the body function. Note that it is always possible
// for body_block and orig_block_tail to have arguments of the same types as
// they have exactly one call-site and they are sharing the operands.
for (Type type : cond_fn.getFunctionType().getInputs()) {
cond_block->addArgument(type, loc);
}
for (Type type : body_fn.getFunctionType().getInputs()) {
body_block->addArgument(type, loc);
orig_block_tail->addArgument(type, loc);
}
auto get_operand = [&](int i) { return op_inst->getOperand(i); };
// Unconditionally branch from the original block to the block containing the
// condition.
builder.setInsertionPointToEnd(orig_block_head);
JumpToBlock(loc, get_operand, cond_block, &builder);
// Call condition function in the condition block and then branch to the body
// block or remainder of the original block depending on condition function
// result.
builder.setInsertionPointToEnd(cond_block);
auto get_cond_arg = [&](int i) { return cond_block->getArgument(i); };
Operation* cond_call_op = CallFn(loc, get_cond_arg, cond_fn, &builder);
assert(cond_call_op->getNumResults() == 1);
Value condition = LowerCondition(loc, cond_call_op->getResult(0), &builder);
auto br_operands =
PrepareValsForJump(loc, get_cond_arg, body_block, &builder);
builder.create<cf::CondBranchOp>(loc, condition, body_block, br_operands,
orig_block_tail, br_operands);
// Call body function in the body block and then unconditionally branch back
// to the condition block.
builder.setInsertionPointToEnd(body_block);
auto get_body_arg = [&](int i) { return body_block->getArgument(i); };
Operation* body_call_op = CallFn(loc, get_body_arg, body_fn, &builder);
auto get_body_result = [&](int i) { return body_call_op->getResult(i); };
JumpToBlock(loc, get_body_result, cond_block, &builder);
// Replace use of the while loop results with block inputs in the remainder of
// the original block and then delete the original While operation.
builder.setInsertionPoint(&orig_block_tail->front());
ReplaceOpResultWithBlockArgs(loc, op_inst, orig_block_tail, &builder);
op_inst->erase();
return success();
}
void FunctionalControlFlowToCFG::runOnOperation() {
// Scan the function looking for these ops.
for (Block& block : getOperation()) {
for (Operation& op : block) {
// If the operation is one of the control flow ops we know, lower it.
// If we lower an operation, then the current basic block will be split,
// and the operation will be removed, so we should continue looking at
// subsequent blocks.
//
// TODO: Use PatternRewriter to eliminate these function control flow ops.
if (IfOp if_op = llvm::dyn_cast<IfOp>(op)) {
if (failed(LowerIfOp(if_op))) {
return signalPassFailure();
}
break;
}
if (WhileOp while_op = llvm::dyn_cast<WhileOp>(op)) {
if (failed(LowerWhileOp(while_op))) {
return signalPassFailure();
}
break;
}
}
}
}
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
CreateTFFunctionalControlFlowToCFG() {
return std::make_unique<FunctionalControlFlowToCFG>();
}
} // namespace TF
} // namespace mlir