blob: 708bcfc673e6a132740ae5560675d98bd8dcced4 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This transformation pass transforms region bases control flow operations in
// the TensorFlow dialect to their functional counterparts, i.e.,
// tf.IfRegion -> tf.If and tf.WhileRegion -> tf.While
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Verifier.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
#define DEBUG_TYPE "tf-region-cf-to-functional"
namespace mlir {
namespace TF {
namespace {
constexpr char kElseFuncNameAttr[] = "_else_func_name";
constexpr char kThenFuncNameAttr[] = "_then_func_name";
constexpr char kXlaPropagateCompileTimeConsts[] =
"_xla_propagate_compile_time_consts";
struct RegionControlFlowToFunctional
: public TF::RegionControlFlowToFunctionalPassBase<
RegionControlFlowToFunctional> {
void runOnOperation() override;
private:
LogicalResult ConvertIfOp(IfRegionOp if_region);
LogicalResult ConvertWhileOp(WhileRegionOp while_region);
// Get unique name by using the loc to name mapping.
std::string GetName(Operation* op, StringRef suffix);
tensorflow::OpOrArgLocNameMapper mapper;
llvm::SmallVector<func::FuncOp, 4> worklist;
};
std::string RegionControlFlowToFunctional::GetName(Operation* op,
StringRef suffix) {
return (mapper.GetUniqueName(op) + suffix).str();
}
// Returns all the external values referenced from the given regions. If the
// external value is a constant, sink it into the region instead (and do not
// add it to the returned vector).
llvm::SmallVector<Value, 4> CollectExternValues(Region& first, Region& second) {
llvm::SetVector<Value> extern_values;
for (Region* region : {&first, &second}) {
llvm::SetVector<Value> region_extern_values;
getUsedValuesDefinedAbove(*region, region_extern_values);
// Sink down constants into the functions.
for (auto extern_value : region_extern_values) {
if (!matchPattern(extern_value, m_Constant())) {
extern_values.insert(extern_value);
continue;
}
// Add constant at start of region.
auto const_builder = OpBuilder::atBlockBegin(&region->front());
auto const_value = const_builder.clone(*extern_value.getDefiningOp());
replaceAllUsesInRegionWith(extern_value, const_value->getResult(0),
*region);
}
}
return llvm::to_vector<4>(extern_values);
}
// Copies over optional attributes from source region op `src` to the given
// functional op `dst` and appropriately overrides any necessary attributes.
void CopyAndOverrideAttributes(Operation* src, Operation* dst,
OpBuilder* builder) {
CopyDeviceAndUnderscoredAttributes(src, dst);
// Explicitly override attribute to propagate constants to the functions
// before compiling to XLA. This is necessary along with conversion to
// functional format because inlined regions may have moved loop invariant ops
// outside of the region which may cause some new legalization failures.
// TODO(b/126739593): Enable this attribute in TensorFlow by default. Also,
// see b/185542519 for the context.
dst->setAttr(kXlaPropagateCompileTimeConsts, builder->getBoolAttr(true));
}
// Extracts the contents of a region with a single block into a new function.
// `extern_values` is the set of external values that the region refers to.
//
// Inputs to the terminator of the region are converted to return values of
// the function. If `extern_values_passthrough` is true, all the extern values
// are also added as return values from the function
void ExtractSingleBlockRegion(Region& region, StringRef name,
llvm::SmallVectorImpl<Value>& extern_values,
llvm::SmallVectorImpl<func::FuncOp>& worklist,
bool extern_values_passthrough) {
ModuleOp module = region.getParentOfType<ModuleOp>();
auto builder = OpBuilder::atBlockBegin(module.getBody());
auto loc = region.getParentOp()->getLoc();
Block& entry = region.front();
int num_region_arguments = entry.getNumArguments();
Operation* terminator = entry.getTerminator();
// Build the function type. Region arguments and extern values together
// become the function arguments, with region arguments going first.
auto input_types = llvm::to_vector<4>(entry.getArgumentTypes());
for (auto input : extern_values) input_types.push_back(input.getType());
// Terminator operands and pass through extern values (if enabled) together
// become the function return values.
auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
if (extern_values_passthrough)
for (auto input : extern_values) return_types.push_back(input.getType());
auto type = FunctionType::get(region.getContext(), input_types, return_types);
// Create new function and extract region body into the function.
auto outlined_func = builder.create<func::FuncOp>(loc, name, type);
Region& func_region = outlined_func.getBody();
func_region.takeBody(region);
Block& first_block = func_region.front();
// Replace all external uses with function arguments.
for (auto it : llvm::enumerate(extern_values)) {
Value arg = first_block.addArgument(it.value().getType(), loc);
replaceAllUsesInRegionWith(it.value(), arg, func_region);
}
// Function return values are all the terminator operands + pass through
// extern values (if enabled).
auto return_values = llvm::to_vector<4>(terminator->getOperands());
if (extern_values_passthrough)
return_values.insert(return_values.end(),
first_block.args_begin() + num_region_arguments,
first_block.args_end());
// Replace the existing terminator with a return.
terminator = first_block.getTerminator();
builder.setInsertionPoint(terminator);
builder.create<func::ReturnOp>(terminator->getLoc(), return_values);
terminator->erase();
outlined_func.setPrivate();
// Add the outlined function to the worklist in case its body has
// IfRegion or WhileRegion ops that need to converted.
worklist.push_back(outlined_func);
}
// Returns call for region with single call whose result feeds into the
// terminator of the region. if `allow_to_bool` is true, also allows a single
// ToBoolOp between the region yield and the call. Returns none if the region
// does not conform to this pattern.
llvm::Optional<func::CallOp> IsSingleCallRegion(Region& region,
bool allow_to_bool = false) {
if (!llvm::hasSingleElement(region)) return llvm::None;
Block& block = region.front();
auto it = block.rbegin();
YieldOp yield = dyn_cast<YieldOp>(*it++);
if (it == block.rend()) return llvm::None;
// Operation which is expected to consume all the call results.
Operation* call_consumer = yield;
// Allow a single ToBoolOp between the call and the yield (valid only
// when the yield has a single operand)
if (allow_to_bool && yield.getNumOperands() == 1 && isa<ToBoolOp>(*it)) {
if (it->getResult(0) != yield.getOperand(0)) return llvm::None;
call_consumer = cast<ToBoolOp>(*it);
it++;
if (it == block.rend()) return llvm::None;
}
// Check if there is a Call before the Yield.
func::CallOp call = dyn_cast<func::CallOp>(*it++);
if (!call) return llvm::None;
// All call results should feed into expected consumer
// All results of the call should feed into the yield.
if (call.getNumResults() != call_consumer->getNumOperands())
return llvm::None;
for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands()))
if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None;
// There can only be non-truncating cast op's prior to the call.
for (; it != block.rend(); ++it) {
CastOp cast = dyn_cast<CastOp>(*it);
if (!cast || cast.Truncate()) return llvm::None;
}
return call;
}
using ArgMatcherFn = function_ref<bool(Value, Region&, Value, Region&)>;
// Returns whether the arguments of the given 2 calls are match (after looking
// through cast ops). `matcher` is the predicate used to check if two arguments
// match.
bool MatchCallArgs(func::CallOp first, func::CallOp second,
ArgMatcherFn matcher) {
if (first.getNumOperands() != second.getNumOperands()) return false;
Region& first_region = *first->getParentRegion();
Region& second_region = *second->getParentRegion();
for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) {
// Get the defining Op, skipping over casts.
auto get_defining_op = [](Value value) {
while (auto cast_op =
llvm::dyn_cast_or_null<CastOp>(value.getDefiningOp())) {
// Consider cast compatibility in case
// %cast = "tf.Cast"(%0) : (tensor<2xi64>) -> tensor<2xf32>
// is skipped.
if (cast_op.SrcT() != cast_op.DstT()) {
break;
}
value = cast_op.getOperand();
}
return value;
};
Value first_arg = get_defining_op(std::get<0>(it));
Value second_arg = get_defining_op(std::get<1>(it));
if (!matcher(first_arg, first_region, second_arg, second_region))
return false;
}
return true;
}
// Summary information for trivially transforming region based op's to
// functional ops. A trivial transformation can be done when the regions are
// just calls to functions, in which case no outlining is needed.
struct TrivialTransformInfo {
// Can the op be transformed trivially?
bool can_transform = false;
// List of callee names (one for each region).
llvm::SmallVector<StringRef, 2> callee_names;
// Analyzes the given calls (from regions attached to the same parent op) to
// check if the parent op be transformed to functional form trivially (i.e.,
// reusing existing functions and without outlining). This is possible when
// all the regions are single call regions (checked using matchers outside
// this class) and the all the calls match using the given argument matcher.
//
// If such a trivial transformation is possible, stash the relevant
// information needed for the transformation, else indicate that a trivial
// transformation is not possible by setting `can_transform` to false.
TrivialTransformInfo(llvm::Optional<func::CallOp> first_call,
llvm::Optional<func::CallOp> second_call,
ArgMatcherFn arg_matcher) {
if (!first_call || !second_call) return;
if (!MatchCallArgs(first_call.getValue(), second_call.getValue(),
arg_matcher))
return;
can_transform = true;
callee_names = {first_call.getValue().getCallee(),
second_call.getValue().getCallee()};
}
};
// Transform IfRegionOp to IfOp.
LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) {
llvm::SmallVector<Value, 4> extern_values;
// For IfOp, arguments of calls in the then and else regions match if they
// are the same value.
auto if_arg_matcher = [&](Value first, Region&, Value second, Region&) {
if (first != second) return false;
// collect the call arguments post lookup through cast Op's
extern_values.push_back(first);
return true;
};
const TrivialTransformInfo tti(IsSingleCallRegion(if_region.then_branch()),
IsSingleCallRegion(if_region.else_branch()),
if_arg_matcher);
std::string then_name, else_name;
if (tti.can_transform) {
// We can transform to functional form trivially without outlining.
then_name = tti.callee_names[0].str();
else_name = tti.callee_names[1].str();
} else {
// Collect external values that are used within the else and then bodies.
extern_values =
CollectExternValues(if_region.then_branch(), if_region.else_branch());
// These external values need to be added as inputs to the generated If. The
// order is determined by the order of these values the `extern_vales`.
// Create 2 new functions with the input signature matching this order,
// and outline the `then` and `else` regions by moving the bodies of these
// regions into these functions. Replace tf.yield with a regular return.
if (if_region->hasAttrOfType<StringAttr>(kThenFuncNameAttr) &&
!if_region._then_func_nameAttr().getValue().empty()) {
then_name =
mapper.GetUniqueName(if_region._then_func_nameAttr().getValue())
.str();
} else {
then_name = GetName(if_region, "_then");
}
ExtractSingleBlockRegion(if_region.then_branch(), then_name, extern_values,
worklist, /*extern_values_passthrough=*/false);
if (if_region->hasAttrOfType<StringAttr>(kElseFuncNameAttr) &&
!if_region._else_func_nameAttr().getValue().empty()) {
else_name =
mapper.GetUniqueName(if_region._else_func_nameAttr().getValue())
.str();
} else {
else_name = GetName(if_region, "_else");
}
ExtractSingleBlockRegion(if_region.else_branch(), else_name, extern_values,
worklist, /*extern_values_passthrough=*/false);
}
// Look through ToBool operations for the condition.
Value cond = if_region.cond();
auto to_bool = dyn_cast_or_null<ToBoolOp>(cond.getDefiningOp());
if (to_bool) cond = to_bool.getOperand();
// Once we have the `then` and `else` functions ready (either outlined or
// existing ones), replace the region based op with a functional control flow
// op.
OpBuilder builder(if_region);
auto if_op = builder.create<IfOp>(
if_region.getLoc(), if_region.getResultTypes(), cond, extern_values,
then_name, else_name, if_region.is_stateless());
CopyAndOverrideAttributes(if_region, if_op, &builder);
if_region.replaceAllUsesWith(if_op.getResults());
if_region.erase();
if (to_bool && to_bool.use_empty()) to_bool.erase();
return success();
}
// Transform WhileRegion to WhileOp.
LogicalResult RegionControlFlowToFunctional::ConvertWhileOp(
WhileRegionOp while_region) {
// For While, the arguments of the calls in the body and cond regions match
// if they are region arguments with the same region argument numbers. If the
// 2 calls have the same value (an extern value) used as an argument, we
// cannot do a trivial transformation because post transform, we will need to
// pass this extern value as an argument to the function, so we cannot use the
// existing function as is.
auto while_arg_matcher = [](Value first, Region& first_region, Value second,
Region& second_region) {
if (!first.isa<BlockArgument>() || !second.isa<BlockArgument>())
return false;
BlockArgument first_block_arg = first.cast<BlockArgument>();
BlockArgument second_block_arg = second.cast<BlockArgument>();
// 2 block arguments will match if they are the same argument number, and
// are block arguments of the corresponding containing regions.
return first_block_arg.getArgNumber() == second_block_arg.getArgNumber() &&
first_block_arg.getParentBlock() == &first_region.front() &&
second_block_arg.getParentBlock() == &second_region.front();
};
const TrivialTransformInfo tti(
IsSingleCallRegion(while_region.cond(), /*allow_to_bool=*/true),
IsSingleCallRegion(while_region.body()), while_arg_matcher);
// All existing inputs to while region are inputs to the functional while.
auto new_inputs = llvm::to_vector<4>(while_region.getOperands());
// All existing results will also be generated by the functional while.
auto new_result_types = llvm::to_vector<4>(while_region.getResultTypes());
std::string cond_name, body_name;
if (tti.can_transform) {
// We can transform to functional form trivially without outlining.
cond_name = tti.callee_names[0].str();
body_name = tti.callee_names[1].str();
} else {
// The WhileRegion regions can refer to either arguments of the region, or
// external values implicitly captured by the region. When converting to
// functional form, all such external values need to become function
// arguments of the outlined functions, and become pass through values in
// the outlined body function. So when outlining the while body, in addition
// to the region arguments, all these external references need to be added
// as function arguments.
llvm::SmallVector<Value, 4> extern_values =
CollectExternValues(while_region.cond(), while_region.body());
// Outline the `cond` and `body` regions by moving the bodies of these
// regions into new functions. Replace tf.yield with a regular return.
cond_name = GetName(while_region, "_cond");
ExtractSingleBlockRegion(while_region.cond(), cond_name, extern_values,
worklist, /*extern_values_passthrough=*/false);
body_name = GetName(while_region, "_body");
ExtractSingleBlockRegion(while_region.body(), body_name, extern_values,
worklist, /*extern_values_passthrough=*/true);
// All extern values become additional inputs and additional output types
// for the functional while.
new_inputs.append(extern_values.begin(), extern_values.end());
for (auto ext : extern_values) new_result_types.push_back(ext.getType());
}
// Once we have the `cond` and `body` functions ready (either outlined or
// existing ones), replace the region based op with a functional op.
OpBuilder builder(while_region);
auto while_op = builder.create<WhileOp>(
while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name,
while_region.parallel_iterations(), while_region.is_stateless(),
while_region.shape_invariant());
CopyAndOverrideAttributes(while_region, while_op, &builder);
// Redirect old results to new results.
for (auto it : llvm::zip(
while_region.getResults(),
while_op.getResults().take_front(while_region.getNumResults())))
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
while_region.erase();
return success();
}
void RegionControlFlowToFunctional::runOnOperation() {
ModuleOp module = getOperation();
// Seed worklist with all functions in the module.
worklist = llvm::to_vector<4>(module.getOps<func::FuncOp>());
while (!worklist.empty()) {
func::FuncOp function = worklist.pop_back_val();
auto result = function.walk([&](Operation* op) {
if (auto if_region = llvm::dyn_cast<IfRegionOp>(op)) {
if (failed(ConvertIfOp(if_region))) {
op->emitOpError() << "failed to convert to functional form";
return WalkResult::interrupt();
}
} else if (auto while_region = llvm::dyn_cast<WhileRegionOp>(op)) {
if (failed(ConvertWhileOp(while_region))) {
op->emitOpError() << "failed to convert to functional form";
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
if (result.wasInterrupted()) return signalPassFailure();
}
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
CreateTFRegionControlFlowToFunctional() {
return std::make_unique<RegionControlFlowToFunctional>();
}
} // namespace TF
} // namespace mlir