blob: 64272d0f59a378ef48b28082e80d57fd9c82731e [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <queue>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/SymbolTable.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h"
namespace mlir {
namespace tf_executor {
namespace {
using TF::ResourceId;
static constexpr ResourceId kUnknownResourceId =
TF::detail::ResourceAliasAnalysisInfo::kUnknownResourceId;
static constexpr ResourceId kInvalidResourceId =
TF::detail::ResourceAliasAnalysisInfo::kInvalidResourceId;
using OperationSetTy = SmallPtrSet<Operation*, 4>;
using ResourceToOpsMapTy = DenseMap<ResourceId, OperationSetTy>;
class ConvertControlToDataOutputsPass
: public TF::ExecutorConvertControlToDataOutputsPassBase<
ConvertControlToDataOutputsPass> {
public:
void runOnOperation() override;
};
// Returns a vector of all tf.WhileOp(s) which use func as while body. If any of
// the uses is as a while condition, an empty vector is returned.
SmallVector<TF::WhileOp> GetWhileCallers(func::FuncOp func,
SymbolUserMap& symbol_map) {
SmallVector<TF::WhileOp> while_callers;
for (auto user : symbol_map.getUsers(func)) {
if (auto while_caller = dyn_cast<TF::WhileOp>(user)) {
// If used as while conditional anywhere, then skip optimizing this
// function. Return empty vector.
if (while_caller.cond_function() == func) return {};
assert(while_caller.body_function() == func);
while_callers.push_back(while_caller);
}
}
return while_callers;
}
// Populates `chain_resource_to_ops_map`, the map from all resources that need
// to be chained to the set of operations that access the resource, and
// `resource_equivalence_classes`. Resources are equivalent if they are accessed
// by a common op, and equivalent resources will be assigned to the same chain.
void CollectChainResources(
func::FuncOp func, ResourceToOpsMapTy& chain_resource_to_ops_map,
llvm::EquivalenceClasses<ResourceId>& resource_equivalence_classes,
const TF::SideEffectAnalysis::Info& side_effect_analysis) {
auto graph_op = cast<GraphOp>(func.front().front());
// For each op in the graph, get the resources it uses and update the access
// information for them.
graph_op.walk([&](IslandOp island) {
// This pass assumes that all functions are suitable for export i.e., each
// function has a single tf_executor.graph op and all islands wrap the
// internal op perfectly. Hence this assertion should never fail.
assert(island.WrapsSingleOp());
Operation& op = island.GetBody().front();
ResourceId prev_resource_id = kInvalidResourceId;
for (auto resource_id_read_only_pair :
side_effect_analysis.GetResourceIds(&op)) {
ResourceId resource_id = resource_id_read_only_pair.first;
// If the resource was allocated by an op with `UniqueResourceAllocation`
// trait, then we don't need to chain resource ops accessing this resource
// between iterations: Every iteration will create a new independent
// resource. This enables more parallelism across iterations.
if (!side_effect_analysis.IsUniqueResourceAllocationId(resource_id)) {
chain_resource_to_ops_map[resource_id].insert(&op);
if (prev_resource_id != kInvalidResourceId) {
// Merge class of current ID with class of previous ID since both
// resources are accessed by `op`.
resource_equivalence_classes.unionSets(prev_resource_id, resource_id);
} else {
resource_equivalence_classes.insert(resource_id);
}
prev_resource_id = resource_id;
}
}
});
}
// tf.NoOp islands are used to combine multiple control dependencies into one.
// These islands have a single tf.NoOp inside them and consume multiple control
// outputs to generate a single control output.
//
// For example,
// ```
// %merged_control = "tf_executor.island"(%control_a, %control_b) ({
// "tf.NoOp"() : () -> ()
// "tf_executor.yield"() : () -> ()
// }) : (!tf_executor.control, !tf_executor.control) -> (!tf_executor.control)
// ```
//
// `%merged_control` is a NoOp control barrier in this case.
//
// Checks if the value `control` is a NoOp control barrier.
bool IsNoOpControlBarrier(Value control) {
if (!control.getType().isa<ControlType>()) return false;
auto control_island = dyn_cast_or_null<IslandOp>(control.getDefiningOp());
if (!control_island) return false;
// All islands perfectly wrap a single op is an invariant of this pass and
// is checked at the very beginning of the pass.
assert(control_island.WrapsSingleOp());
return control_island.outputs().empty() &&
isa<TF::NoOp>(control_island.GetBody().front());
}
// Remove all control outputs of the function. Traverses NoOp control barrier
// chains from FetchOp to all NoOp control barriers. Returns true
// iff at least one control output is deleted.
bool RemoveAllControlOutputs(func::FuncOp func) {
auto graph_op = cast<GraphOp>(func.front().front());
FetchOp fetch = graph_op.GetFetch();
// Return early if no control outputs exist.
if (fetch.getNumOperands() == graph_op->getNumResults()) return false;
std::queue<Value> control_barrier_worklist;
for (Value control_output :
fetch.fetches().drop_front(graph_op->getNumResults())) {
if (IsNoOpControlBarrier(control_output))
control_barrier_worklist.push(control_output);
}
// Erase all control outputs at the end from fetch.
fetch.fetchesMutable().erase(
graph_op.getNumResults(),
fetch.getNumOperands() - graph_op.getNumResults());
// Iterate the worklist to remove all NoOp control barriers at the end of the
// function body that are used to merge two or more control dependencies.
while (!control_barrier_worklist.empty()) {
Value control_barrier = control_barrier_worklist.front();
control_barrier_worklist.pop();
// We can only erase control barriers whose uses have been erased as well.
if (!control_barrier.use_empty()) continue;
// Only values defined by IslandOp were inserted in the worklist.
IslandOp current_island = cast<IslandOp>(control_barrier.getDefiningOp());
for (auto control_input : current_island.controlInputs()) {
if (IsNoOpControlBarrier(control_input))
control_barrier_worklist.push(control_input);
}
current_island.erase();
}
return true;
}
// Appends function arguments with `num_resources` number of arguments of
// requested type.
void AppendFunctionArguments(func::FuncOp func, int num_resources,
ShapedType chaining_data_type) {
for (int i = 0; i < num_resources; ++i) {
func.getRegion().addArgument(chaining_data_type, func.getLoc());
}
FunctionType ftype =
FunctionType::get(func.getContext(), func.getBody().getArgumentTypes(),
func.getFunctionType().getResults());
func.setType(ftype);
}
// Appends function results with `num_resources` number of results of requested
// type.
void AppendFunctionResults(func::FuncOp func, int num_resources,
ShapedType chaining_data_type) {
Block& block = func.front();
auto graph_op = cast<GraphOp>(block.front());
// Note that func result types are same as the result types of
// GraphOp in the function `func`.
assert(std::equal(func->getResultTypes().begin(),
func->getResultTypes().end(),
graph_op->getResultTypes().begin()));
auto new_result_types =
llvm::to_vector<4>(func.getFunctionType().getResults());
for (int i = 0; i < num_resources; ++i) {
new_result_types.push_back(chaining_data_type);
}
FunctionType ftype = FunctionType::get(
func.getContext(), func.getArgumentTypes(), new_result_types);
func.setType(ftype);
// Rewrite GraphOp to have same number of results as the
// function.
OpBuilder builder(graph_op);
auto new_graph_op =
builder.create<GraphOp>(graph_op.getLoc(), new_result_types);
new_graph_op.getRegion().takeBody(graph_op.getRegion());
graph_op->replaceAllUsesWith(
new_graph_op->getResults().drop_back(num_resources));
graph_op.erase();
func::ReturnOp return_op = cast<func::ReturnOp>(block.getTerminator());
int num_old_arguments = return_op.getNumOperands();
for (int i = 0; i < num_resources; ++i) {
return_op.operandsMutable().append(
new_graph_op.getResult(num_old_arguments + i));
}
}
// Creates a wrapper island enclosing the `sub_op` dependent on
// `control_inputs`.
IslandOp CreateIsland(Operation* sub_op, ValueRange control_inputs,
OpBuilder builder) {
assert(sub_op);
auto control_type = ControlType::get(builder.getContext());
auto island = builder.create<IslandOp>(
sub_op->getLoc(), sub_op->getResultTypes(), control_type, control_inputs);
island.body().push_back(new Block);
Block* block = &island.body().back();
builder.setInsertionPointToEnd(block);
sub_op->replaceAllUsesWith(island.outputs());
sub_op->moveBefore(block, block->begin());
builder.create<YieldOp>(sub_op->getLoc(), sub_op->getResults());
return island;
}
// Adds control dependencies from/to chain arguments/results. It adds two
// identity ops, chain_src and chain_sink, per resource equivalence class.
// Using the resource to operations map, it adds (1) a control dependency
// from chain_src to all the operations that read/write to a resource of the
// equivalence class, and (2) a control dependency from all the operations that
// read/write to a resource of the class to the chain_sink operation.
void ChainResourceOps(
func::FuncOp func, ResourceToOpsMapTy& chain_resource_to_ops_map,
llvm::EquivalenceClasses<ResourceId>& resource_equivalence_classes,
int num_old_outputs) {
assert(num_old_outputs + resource_equivalence_classes.getNumClasses() ==
func.getNumArguments());
auto graph_op = cast<GraphOp>(func.front().front());
auto fetch = graph_op.GetFetch();
OpBuilder builder_chain_src(fetch);
builder_chain_src.setInsertionPointToStart(fetch->getBlock());
OpBuilder builder_chain_sink(fetch);
int chain_index = num_old_outputs;
// Iterate over all equivalence classes.
for (auto class_iter = resource_equivalence_classes.begin();
class_iter != resource_equivalence_classes.end(); ++class_iter) {
// Only visit one element per class, the leader.
if (!class_iter->isLeader()) continue;
// Create chain source and sink identity islands for current equivalence
// class.
auto chain_arg = func.getArgument(chain_index++);
auto src_identity = builder_chain_src.create<TF::IdentityOp>(
chain_arg.getLoc(), chain_arg.getType(), chain_arg);
auto chain_src_island = CreateIsland(src_identity, {}, builder_chain_src);
auto sink_identity = builder_chain_sink.create<TF::IdentityOp>(
chain_arg.getLoc(), chain_arg.getType(), chain_arg);
auto chain_sink_island =
CreateIsland(sink_identity, {}, builder_chain_sink);
// Add the chain sink data output to fetch.
fetch.fetchesMutable().append(chain_sink_island.outputs().front());
// Iterate over all members of the current equivalence class (represented
// by `class_iter`). Keep track of ops that have already been processed.
llvm::SmallDenseSet<Operation*> processed_ops;
for (auto member_iter =
resource_equivalence_classes.member_begin(class_iter);
member_iter != resource_equivalence_classes.member_end();
++member_iter) {
ResourceId resource_id = *member_iter;
auto map_iter = chain_resource_to_ops_map.find(resource_id);
if (map_iter == chain_resource_to_ops_map.end()) continue;
OperationSetTy& resource_ops = map_iter->getSecond();
// Add dependencies between all ops that access current resource and chain
// source and sink.
for (Operation* op : resource_ops) {
if (processed_ops.contains(op)) continue;
IslandOp wrapper = op->getParentOfType<IslandOp>();
assert(wrapper);
wrapper.controlInputsMutable().append(chain_src_island.control());
chain_sink_island.controlInputsMutable().append(wrapper.control());
processed_ops.insert(op);
}
}
}
VLOG(2) << "Added " << resource_equivalence_classes.getNumClasses()
<< " chains for " << chain_resource_to_ops_map.size() << " resources";
}
// Generate a dummy constant island of requested type.
IslandOp GetDummyConstant(OpBuilder builder, ShapedType const_type,
Location loc) {
DenseIntElementsAttr val = DenseIntElementsAttr::get(const_type, 1);
auto const_op = builder.create<TF::ConstOp>(loc, val);
auto const_island = CreateIsland(const_op, {}, builder);
return const_island;
}
// Rewrites the while op with extra chaining operands and results. Uses a
// dummy constant of requested type as argument to all the new chaining
// operands.
TF::WhileOp RewriteWhileOp(TF::WhileOp while_op, int num_resource_inputs,
ShapedType const_type) {
IslandOp while_wrapper = while_op->getParentOfType<IslandOp>();
assert(while_wrapper && "While op is expected to be wrapped in a IslandOp");
// Get the dummy constant.
OpBuilder builder(while_wrapper);
auto loc = NameLoc::get(
builder.getStringAttr("chain_control_outputs@" + while_op.body()));
IslandOp const_wrapper = GetDummyConstant(builder, const_type, loc);
// Get new operand and result types.
auto new_operands = llvm::to_vector<4>(while_op->getOperands());
auto new_result_types = llvm::to_vector<4>(while_op->getResultTypes());
Value const_output = const_wrapper.outputs()[0];
for (int i = 0; i < num_resource_inputs; ++i) {
new_operands.push_back(const_output);
new_result_types.push_back(const_output.getType());
}
// Replace old while op with new while op.
auto new_while_op = builder.create<TF::WhileOp>(
while_op.getLoc(), new_result_types, new_operands, while_op->getAttrs());
auto new_while_wrapper =
CreateIsland(new_while_op, while_wrapper.controlInputs(), builder);
for (auto result : while_wrapper.outputs()) {
result.replaceAllUsesWith(
new_while_wrapper.outputs()[result.getResultNumber()]);
}
while_wrapper.control().replaceAllUsesWith(new_while_wrapper.control());
while_wrapper.erase();
return new_while_op;
}
// Converts the control outputs of the while body to data outputs, thus
// removing control barrier at the end of while loop body.
void ConvertControlToDataOutputs(
func::FuncOp while_body, SmallVectorImpl<TF::WhileOp>& while_callers,
OperationSetTy& recompute_analysis_for_funcs,
const TF::SideEffectAnalysis::Info& side_effect_analysis) {
if (while_callers.empty()) return;
// Collect access information for each resource in the while body that needs
// to be chained, along with equivalence classes (resources in one class will
// use the same chain).
ResourceToOpsMapTy chain_resource_to_ops_map;
llvm::EquivalenceClasses<ResourceId> resource_equivalence_classes;
CollectChainResources(while_body, chain_resource_to_ops_map,
resource_equivalence_classes, side_effect_analysis);
// Check for presence of unknown side-effecting ops within the while loop
// body. These ops act as barriers and the optimization would not yield much
// inter iteration parallelism for this while loop body. So return with
// warning.
if (chain_resource_to_ops_map.count(kUnknownResourceId) > 0) {
std::set<std::string> blocking_ops;
for (Operation* op : chain_resource_to_ops_map[kUnknownResourceId]) {
std::string op_name = op->getName().getStringRef().str();
if (blocking_ops.insert(op_name).second) {
LOG(INFO) << "[`tf-executor-convert-control-to-data-outputs` disabled] "
"Op type '"
<< op_name
<< "' has unknown side effects and blocks inter iteration "
"parallelism for the while loop. Consider modeling side "
"effects of this op.";
}
}
return;
}
// First remove all control outputs of while loop body.
bool changed = RemoveAllControlOutputs(while_body);
// If there was no control output to be removed, return early.
if (!changed) return;
int num_chains = resource_equivalence_classes.getNumClasses();
RankedTensorType chaining_data_type =
RankedTensorType::get({}, OpBuilder(while_body).getI32Type());
// Create new while body
int num_old_outputs = while_body.getNumResults();
AppendFunctionArguments(while_body, num_chains, chaining_data_type);
AppendFunctionResults(while_body, num_chains, chaining_data_type);
// Insert identity ops with control dep
ChainResourceOps(while_body, chain_resource_to_ops_map,
resource_equivalence_classes, num_old_outputs);
// Modify all the while ops referencing the body function and the
// corresponding while condition functions. Note that each while condition
// needs to be modified only once.
OperationSetTy visited;
for (TF::WhileOp while_op : while_callers) {
// If the while callers are modified as part of the optimization, then the
// side effect analysis of their parent functions are invalidated. They
// need to be recomputed.
recompute_analysis_for_funcs.insert(
while_op->getParentOfType<func::FuncOp>());
func::FuncOp while_cond = while_op.cond_function();
// Rewrite while op with extra chaining arguments and results.
while_op = RewriteWhileOp(while_op, num_chains, chaining_data_type);
bool first_visit = visited.insert(while_cond).second;
if (!first_visit) continue;
// Modify while condition function with extra chaining arguments.
AppendFunctionArguments(while_cond, num_chains, chaining_data_type);
}
}
void ConvertControlToDataOutputsPass::runOnOperation() {
ModuleOp module = getOperation();
// This pass assumes that all functions are suitable for export i.e., each
// function has a single tf_executor.graph op and all islands wrap the
// internal op perfectly. Verify that in the beginning once.
if (failed(tensorflow::VerifyExportSuitable(module))) {
signalPassFailure();
return;
}
TF::SideEffectAnalysis side_effect_analysis(module);
SymbolTableCollection table;
SymbolUserMap symbol_map(table, module);
llvm::SmallDenseMap<func::FuncOp, SmallVector<TF::WhileOp>>
while_body_func_to_while_ops;
// Get all the while body functions and the corresponding while ops first
// because the symbol user map is invalidated once we start deleting while
// ops.
for (auto func : module.getOps<func::FuncOp>()) {
if (func.isExternal()) continue;
SmallVector<TF::WhileOp> while_callers = GetWhileCallers(func, symbol_map);
if (while_callers.empty()) continue;
while_body_func_to_while_ops[func] = while_callers;
}
// Keep track of functions whose side effect analysis is invalidated because
// of modifications to that function.
OperationSetTy recompute_analysis_for_funcs;
for (auto& entry : while_body_func_to_while_ops) {
func::FuncOp while_body = entry.getFirst();
SmallVector<TF::WhileOp>& while_callers = entry.getSecond();
if (recompute_analysis_for_funcs.contains(while_body)) {
// TODO(b/202540801): Recomputing side effect analysis for the entire
// module is wasteful. It would be better to just recompute analysis for
// specific functions but the current side effect analysis interface
// does not allow that.
side_effect_analysis = TF::SideEffectAnalysis(module);
}
ConvertControlToDataOutputs(
while_body, while_callers, recompute_analysis_for_funcs,
side_effect_analysis.GetAnalysisForFunc(while_body));
}
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
CreateTFExecutorConvertControlToDataOutputsPass() {
return std::make_unique<ConvertControlToDataOutputsPass>();
}
} // namespace tf_executor
} // namespace mlir