blob: cf7ae8e6b7cadd2bf04134b0e63a2c2008609060 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass lifts resource variable operations outside of device computation.
#include <cstdint>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Block.h" // TF:llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Diagnostics.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Module.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/IR/TypeUtilities.h" // TF:llvm-project
#include "mlir/IR/Types.h" // TF:llvm-project
#include "mlir/IR/Value.h" // TF:llvm-project
#include "mlir/IR/Visitors.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Support/LLVM.h" // TF:llvm-project
#include "mlir/Support/LogicalResult.h" // TF:llvm-project
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
namespace mlir {
namespace TFDevice {
namespace {
constexpr char kDTypeAttr[] = "dtype";
// This pass lifts resource variable operations outside of device computation.
// This is useful because a lot of accelerator devices can not interact with
// resource variables directly..
//
// Here is a simple example in TensorFlow where a device doubles the value of a
// TensorFlow resource variable and returns new value:
//
// %resource_handle = "tf.VarHandleOp"()
// %1 = "tf_device.launch"() ( {
// %init_value = "tf.ReadVariableOp"(%resource_handle)
// "tf.AssignAddVariableOp"(%resource_handle, %init_value)
// %new_value = "tf.ReadVariableOp"(%resource_handle)
// tf_device.return %new_value
// })
//
// After this pass, the computation would become:
//
// %resource_handle = "tf.VarHandleOp"()
// %init_value = "tf.ReadVariableOp"(%resource_handle)
// %1:2 = "tf_device.launch"() ( {
// %new_value = "tf.AddV2"(%init_value, %init_value)
// tf_device.return %new_value, %new_value
// })
// "tf.AssignVariableOp"(%resource_handle, %1#1)
//
// You can see that there are a few main changes applied:
// 1) All the resource variable reads and writes are now outside of
// tf_device.launch op.
// 2) Instead of taking resource handles as input, this device computation now
// takes snapshotted values of that device.
// 3) Some resource load operations are eliminated with store-load forwarding.
// 4) Updated values to resource are appended to `tf_device.return` and used by
// external resource store operations so that resources are still updated
// after the computation.
//
// If the launch body contains functional control flow, the pass first lifts the
// loads/stores in the body/cond/branch functions to the launch body, then
// performs the above lifting. E.g.,
//
// func @launch_with_loop() -> () {
// %0 = "tf.VarHandleOp"() ...
// "tf_device.launch"() ( {
// %1 = "tf.While"(%0) {body = @while_body, cond = @while_cond}
// tf_device.return
// })
// return
// }
// func @while_body(%arg0: tensor<*x!tf.resource<tensor<f32>>>) {
// %constant = "tf.Const"() ...
// "tf.AssignVariableOp"(%arg0, %constant)
// return %arg0
// }
// func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) {
// %read = "tf.ReadVariableOp"(%arg0)
// return %read
// }
//
// will be be transformed to:
//
// func @launch_with_loop() {
// %0 = "tf.VarHandleOp"() ...
// %1 = "tf.ReadVariableOp"(%0)
// %2 = "tf_device.launch"() ( {
// %3 = "tf.While"(%1) {body = @while_body, cond = @while_cond}
// tf_device.return %3 : tensor<f32>
// }) : () -> tensor<f32>
// "tf.AssignVariableOp"(%0, %2)
// return
// }
// func @while_body(%arg0: tensor<f32>) {
// %0 = "tf.Const"() ...
// return %0 : tensor<f32>
// }
// func @while_cond(%arg0: tensor<f32>) {
// return %arg0
// }
//
struct ResourceOpLiftingPass : public ModulePass<ResourceOpLiftingPass> {
void runOnModule() override;
};
// Removes identity nodes in the block. The device computation does not need
// such nodes to carry information.
void RemoveIdentity(Block* block) {
for (auto& op : llvm::make_early_inc_range(*block)) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
op.replaceAllUsesWith(op.getOperands());
op.erase();
}
}
}
// Performs store-load forwarding. This effectively removes
// 1) Any resource loads after a store to that same resource is done
// 2) Any resource stores except the last one.
// TODO(ycao): Store-load forwarding implemented here is only correct when
// computation is purely sequential (no concurrency). Need to support concurrent
// computation as well.
void ForwardStoreToLoad(Block* block) {
// resource_handle_to_last_store_op keeps track of the most recent (last)
// store to each resource. Non-existent entry indicates that a resource has
// not been stored to yet.
llvm::SmallDenseMap<Value, TF::AssignVariableOp>
resource_handle_to_last_store_op;
// Only iterate through ops directly in the block as we can't handle ops
// nested deeper in regions.
for (Operation& op : llvm::make_early_inc_range(*block)) {
if (auto read_variable_op = dyn_cast<TF::ReadVariableOp>(&op)) {
Value resource = read_variable_op.resource();
auto last_store = resource_handle_to_last_store_op[resource];
if (!last_store) continue;
// Use stored value in last_store to replace all uses of current resource
// load's result, then erase this resource load.
read_variable_op.value().replaceAllUsesWith(last_store.value());
read_variable_op.erase();
continue;
}
if (auto assign_variable_op = dyn_cast<TF::AssignVariableOp>(&op)) {
Value resource = assign_variable_op.resource();
auto last_store = resource_handle_to_last_store_op[resource];
// Previous store ops to same resource can be erased.
if (last_store) last_store.erase();
resource_handle_to_last_store_op[resource] = assign_variable_op;
}
}
}
// Moves resource load operations with the provided `move_load` function. This
// assumes load-store forwarding has been performed on this launch_op such that
// all loads of same resource are on its initial values. A `skip_load` functions
// is used to indicate whether a load should be skipped. If there are multiple
// loads on the same resource, only the first one will be moved, and the later
// ones will be removed and replaced with the first one.
void HoistResourceLoads(
Block* block, llvm::function_ref<bool(TF::ReadVariableOp)> skip_load,
llvm::function_ref<void(TF::ReadVariableOp)> move_load) {
llvm::SmallDenseMap<Value, TF::ReadVariableOp> resource_to_read_ops;
// Only iterate through ops directly in launch_op's body as we can't handle
// ops nested deeper in regions.
for (Operation& op : llvm::make_early_inc_range(*block)) {
auto read_variable_op = dyn_cast<TF::ReadVariableOp>(&op);
if (!read_variable_op) continue;
if (skip_load(read_variable_op)) continue;
Value resource = read_variable_op.resource();
auto p = resource_to_read_ops.insert({resource, read_variable_op});
if (p.second) {
move_load(read_variable_op);
continue;
}
// Getting here means a load operation of this resource has been hoisted out
// before. Use hoisted load result to replace all uses of current op result
// and erase op.
op.replaceAllUsesWith(p.first->second);
op.erase();
}
}
// If there are any stores to resource defined outside of launch_op's body
// region, the stored values must be returned by launch_op and its return op so
// that new values can be used by sunk resource stores.
// Returns true if any resource variable stored values are appended, otherwise
// false.
bool AppendResourceStoreValueToReturn(tf_device::LaunchOp launch_op) {
bool has_resource_store = false;
Block* body = &launch_op.GetBody();
auto old_return = body->getTerminator();
llvm::SmallVector<Value, 4> new_return_operands(old_return->getOperands());
// Only iterate through ops directly in launch_op's body as we can't handle
// ops nested deeper in regions.
for (Operation& op : launch_op.GetBody()) {
auto assign_variable_op = dyn_cast<TF::AssignVariableOp>(&op);
if (!assign_variable_op) continue;
Value resource = assign_variable_op.resource();
if (!resource) continue;
// Skip resources created inside of launch_op.
if (resource.getParentRegion() == &launch_op.body()) continue;
// TODO(ycao): Prevent same value from being returned multiple times.
// TODO(ycao): Do not return resource store value if it is defined outside
// of launch_op.
new_return_operands.push_back(assign_variable_op.value());
has_resource_store = true;
}
// If no resource stores are found, no need to update return op.
if (!has_resource_store) return false;
OpBuilder builder(old_return);
builder.create<tf_device::ReturnOp>(old_return->getLoc(),
new_return_operands);
old_return->erase();
return true;
}
// Moves resource store operations to after launch_op. This assumes load-store
// forwarding has been performed on this launch_op such that there is at most
// one resource store operation carrying its final value.
tf_device::LaunchOp SinkResourceStores(tf_device::LaunchOp launch_op,
OpBuilder* builder) {
// Update ReturnOp inside launch_op's body to output final values of updated
// external resources.
bool has_resource_store = AppendResourceStoreValueToReturn(launch_op);
if (!has_resource_store) return launch_op;
auto new_return_op = launch_op.GetBody().getTerminator();
llvm::SmallVector<Type, 4> new_launch_return_types(
new_return_op->getOperandTypes());
builder->setInsertionPoint(launch_op);
auto new_launch_op = builder->create<tf_device::LaunchOp>(
launch_op.getLoc(), new_launch_return_types,
/*operands=*/llvm::SmallVector<Value, 4>(), launch_op.getAttrs());
new_launch_op.body().takeBody(launch_op.body());
// Replace uses of old launch_op results with those of new_launch_op.
for (auto p : llvm::zip(launch_op.getResults(), new_launch_op.getResults())) {
std::get<0>(p).replaceAllUsesWith(std::get<1>(p));
}
// Create a mapping from operands of new_return_op operands to new_launch_op
// results.
BlockAndValueMapping mapper;
for (auto p :
llvm::zip(new_return_op->getOperands(), new_launch_op.getResults())) {
mapper.map(std::get<0>(p), std::get<1>(p));
}
// Clone all resource store ops and map their operands to values returned from
// new_launch_op.
for (Operation& op : llvm::make_early_inc_range(new_launch_op.GetBody())) {
if (dyn_cast<TF::AssignVariableOp>(&op)) {
builder->clone(op, mapper);
op.erase();
}
}
launch_op.erase();
return new_launch_op;
}
// Hoists resource variable loads and sinks stores from launch_op.
LogicalResult HoistResourceOpsFromLaunchOp(tf_device::LaunchOp launch_op) {
ModuleOp m = launch_op.getParentOfType<ModuleOp>();
OpBuilder builder(m);
// Remove identity nodes to avoid aliasing.
RemoveIdentity(&launch_op.GetBody());
// Perform store-load forwarding. So that each resource is only loaded with
// its initial value and is only stored with its final value.
ForwardStoreToLoad(&launch_op.GetBody());
// Move loads of external resources, if any, to before launch_op.
// (Skipping resources created inside of launch_op.)
HoistResourceLoads(
&launch_op.GetBody(),
/*skip_load=*/
[&](TF::ReadVariableOp read) {
return read.resource().getParentRegion() == &launch_op.body();
},
/*move_load=*/
[&](TF::ReadVariableOp read) {
read.getOperation()->moveBefore(launch_op);
});
// Move stores of external resources, if any, to after launch_op.
auto new_launch_op = SinkResourceStores(launch_op, &builder);
llvm::SetVector<Value> captured_values;
getUsedValuesDefinedAbove(new_launch_op.body(), new_launch_op.body(),
captured_values);
for (Value v : captured_values) {
auto tensor_type = v.getType().dyn_cast<TensorType>();
if (!tensor_type) continue;
if (!tensor_type.getElementType().isa<TF::ResourceType>()) continue;
return new_launch_op.emitOpError()
<< "has remaining resource inputs that can not be lifted";
}
return success();
}
// Holds information about a function's use of a resource argument.
struct ResourceArgUseInfo {
bool used;
Type data_type;
};
// Finds the ResourceArgUseInfo for each resource argument. Forwarding to the
// output (i.e., the argument is an operand of the return op) is not considered
// as a use. This doesn't support nesting of ops, so before calling this, nested
// ops/functions need to be already resource-lifted.
LogicalResult FindResourceArgUseInfo(
FuncOp func_op, llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>* result) {
auto return_op = func_op.front().getTerminator();
for (auto arg : func_op.getArguments()) {
if (!getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) continue;
auto& info = (*result)[arg.getArgNumber()];
info.used = false;
for (auto user : arg.getUsers()) {
if (user == return_op) continue;
if (auto read = llvm::dyn_cast<TF::ReadVariableOp>(user)) {
info.used = true;
info.data_type = read.getType();
continue;
}
if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(user)) {
info.used = true;
info.data_type = assign.value().getType();
continue;
}
user->emitError("Found unsupported operations on resource.");
return failure();
}
}
return success();
}
// Merges two sets of resource arg use infos. An argument is considered used in
// the merged result as long as either set marks it as used. This is used to
// merge results from functions that have aliasing inputs, e.g., a while loop's
// body and condition. The sets of keys of the two maps must be the same.
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> MergeArgResourceUseInfo(
const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos0,
const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos1) {
auto result = infos0;
for (auto& entry : result) {
if (entry.getSecond().used) continue;
auto& info1_entry = *infos1.find(entry.getFirst());
if (info1_entry.getSecond().used) {
entry.getSecond().used = true;
entry.getSecond().data_type = info1_entry.getSecond().data_type;
}
}
return result;
}
// Removes the unused resource arguments, and the return values that forward the
// removed arguments. If old_to_new_arg_indices is provided, it will store the
// new argument index that corresponds to each original index (-1 means it is
// removed). If remaining_resource_data_types is provided, it will store the
// data types of the remaining resource arguments, where the indices are after
// removing unused ones.
void RemoveUnusedResourceArgumentsAndForwardedRetvals(
const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& infos,
FuncOp func_op,
llvm::SmallVector<int64_t, 4>* old_to_new_arg_indices = nullptr,
llvm::SmallDenseMap<int64_t, Type>* remaining_resource_data_types =
nullptr) {
// Remove return values forwarded from unused arguments.
auto return_op = func_op.front().getTerminator();
auto old_return_vals = llvm::to_vector<8>(return_op->getOperands());
int64_t skipped_retvals = 0;
for (auto entry : llvm::enumerate(old_return_vals)) {
auto return_val = entry.value();
if (auto arg = return_val.dyn_cast<BlockArgument>()) {
auto it = infos.find(arg.getArgNumber());
if (it != infos.end() && !it->getSecond().used) {
return_op->eraseOperand(entry.index() - skipped_retvals++);
}
}
}
llvm::SmallVector<unsigned int, 4> indices_to_erase;
llvm::SmallVector<Type, 4> new_types;
int64_t skipped_args = 0;
for (auto arg : func_op.getArguments()) {
auto it = infos.find(arg.getArgNumber());
if (it != infos.end() && !it->getSecond().used) {
indices_to_erase.push_back(arg.getArgNumber());
skipped_args++;
if (old_to_new_arg_indices != nullptr) {
old_to_new_arg_indices->push_back(-1);
}
} else {
new_types.push_back(arg.getType());
if (old_to_new_arg_indices != nullptr) {
old_to_new_arg_indices->push_back(arg.getArgNumber() - skipped_args);
}
if (it != infos.end() && remaining_resource_data_types != nullptr) {
(*remaining_resource_data_types)[arg.getArgNumber() - skipped_args] =
it->second.data_type;
}
}
}
func_op.eraseArguments(indices_to_erase);
func_op.setType(FunctionType::get(
new_types, llvm::to_vector<4>(return_op->getOperandTypes()),
func_op.getContext()));
}
// Lifts reads/writes of resource arguments from func_op and changes its
// signature. resource_data_types is the (index, data type) pair for each
// resource argument. handle_updated_arg_value is a caller-provided function
// that handles the updated value for an resource argument.
void LiftArgRetResourcesForFunction(
FuncOp func_op,
const llvm::SmallDenseMap<int64_t, Type>& resource_data_types,
llvm::function_ref<void(int64_t, Value)> handle_updated_arg_value) {
ForwardStoreToLoad(&func_op.front());
// Maps a resource argument to the first read.
llvm::SmallDenseMap<Value, TF::ReadVariableOp, 4> resource_arg_read;
// Maps a resource argument to the last write.
llvm::SmallDenseMap<Value, TF::AssignVariableOp, 4> resource_arg_write;
// Use HoistResourceLoads to CSE loads and the `move_load` function only
// records the remaining load to resource_arg_read.
HoistResourceLoads(
&func_op.front(),
/*skip_load=*/
[&](TF::ReadVariableOp read) {
return !read.resource().isa<BlockArgument>();
},
/*move_load=*/
[&](TF::ReadVariableOp read) {
resource_arg_read[read.resource()] = read;
});
// Record the stores in resource_arg_read.
for (auto& op : llvm::make_early_inc_range(func_op.front())) {
if (auto write = llvm::dyn_cast<TF::AssignVariableOp>(&op)) {
auto arg = write.resource().dyn_cast<BlockArgument>();
if (!arg) continue;
// After ForwardStoreToLoad(), there should be just one store for each
// resource.
resource_arg_write[arg] = write;
}
}
// Now change the input types to non-resource and remove the internal loads.
auto new_types = llvm::to_vector<8>(func_op.getType().getInputs());
for (auto& entry : resource_data_types) {
auto arg = func_op.getArgument(entry.getFirst());
auto read_it = resource_arg_read.find(arg);
auto write_it = resource_arg_write.find(arg);
arg.setType(entry.getSecond());
new_types[arg.getArgNumber()] = entry.getSecond();
if (read_it != resource_arg_read.end()) {
read_it->getSecond().replaceAllUsesWith(arg);
read_it->getSecond().erase();
}
if (write_it != resource_arg_write.end()) {
handle_updated_arg_value(arg.getArgNumber(),
write_it->getSecond().value());
write_it->getSecond().erase();
}
}
func_op.setType(FunctionType::get(
new_types,
llvm::to_vector<4>(func_op.front().getTerminator()->getOperandTypes()),
func_op.getContext()));
}
// Returns a vector filtered from range where the unused elements (specified by
// resource_arg_uses) are removed.
template <typename T, typename Range>
llvm::SmallVector<T, 4> FilterRange(
Range range,
const llvm::SmallDenseMap<int64_t, ResourceArgUseInfo>& resource_arg_uses) {
llvm::SmallVector<T, 4> filtered;
for (auto entry : llvm::enumerate(range)) {
auto it = resource_arg_uses.find(entry.index());
if (it != resource_arg_uses.end() && !it->getSecond().used) continue;
filtered.push_back(entry.value());
}
return filtered;
}
// Changes the types of the control flow op (e.g., while, if) and adds loads and
// stores around it. arg_data_type_and_updated_output_index maps an operand (to
// be changed) index to its data type and the updated value index in the output
// (-1 means not updated.)
void AddLoadsStoresOutsideControlFlowOp(
Operation* caller,
const llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>&
arg_data_type_and_updated_output_index) {
OpBuilder builder(caller);
auto new_operands = llvm::to_vector<8>(caller->getOperands());
llvm::SmallVector<int64_t, 8> changed_indices;
// Find the operands to change, and create the loads.
for (auto& entry : arg_data_type_and_updated_output_index) {
int64_t index = entry.getFirst();
Type new_type = entry.getSecond().first;
int64_t updated_index = entry.getSecond().second;
auto operand = caller->getOperand(index);
builder.setInsertionPoint(caller);
new_operands[index] = builder.create<TF::ReadVariableOp>(
caller->getLoc(), ArrayRef<Type>{new_type}, ArrayRef<Value>{operand},
ArrayRef<NamedAttribute>{});
caller->setOperand(index, new_operands[index]);
if (updated_index < 0) continue;
builder.setInsertionPointAfter(caller);
builder.create<TF::AssignVariableOp>(
caller->getLoc(), ArrayRef<Type>{},
ArrayRef<Value>{operand, caller->getResult(updated_index)},
ArrayRef<NamedAttribute>{});
}
}
// Lifts loads/stores from while loop's body and cond functions.
LogicalResult HanldeWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
// Remove identity nodes to avoid aliasing.
RemoveIdentity(&body.front());
RemoveIdentity(&cond.front());
auto return_op = body.front().getTerminator();
// Sanity check: body resource input/output should alias each other.
for (auto arg : body.getArguments()) {
if (!getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) continue;
if (return_op->getOperand(arg.getArgNumber()) != arg) {
return_op->emitError(
"Resource used in while loop is only supported when the resource "
"input and output alias each other in the loop body.");
return failure();
}
}
// FindResourceArgUseInfo will check supported resource ops (read and assign),
// but loop condition has additional requirement that it cannot write
// resources.
if (cond.walk([&](TF::AssignVariableOp assign) {
assign.emitError("Found resource write in loop condition.");
return WalkResult::interrupt();
})
.wasInterrupted()) {
return failure();
}
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> body_use_info;
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> cond_use_info;
if (failed(FindResourceArgUseInfo(body, &body_use_info)) ||
failed(FindResourceArgUseInfo(cond, &cond_use_info))) {
return failure();
}
// A resource is considered used as long as it is used in either body or cond.
auto resource_arg_uses =
MergeArgResourceUseInfo(body_use_info, cond_use_info);
if (resource_arg_uses.empty()) return success();
for (const auto& entry : resource_arg_uses) {
// Replace output resource uses with the input, so that we can later freely
// change the output type.
while_op.getResult(entry.getFirst())
.replaceAllUsesWith(while_op.getOperand(entry.getFirst()));
}
// Remove unused resources in functions.
llvm::SmallVector<int64_t, 4> old_to_new_indices;
llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
RemoveUnusedResourceArgumentsAndForwardedRetvals(
resource_arg_uses, body, &old_to_new_indices,
&remaining_resource_data_types);
RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, cond);
LiftArgRetResourcesForFunction(
body, remaining_resource_data_types,
[&](int64_t index, Value value) { return_op->setOperand(index, value); });
LiftArgRetResourcesForFunction(cond, remaining_resource_data_types,
[&](int64_t index, Value value) {
// We already checked that cond should not
// have variable writes.
assert(false && "Should not happen");
});
// Recreate the while op.
OpBuilder builder(while_op);
auto new_output_shapes = FilterRange<Attribute, ArrayRef<Attribute>>(
while_op.output_shapes().getValue(), resource_arg_uses);
// Now use the filtered original operands, which will be replaced by
// AddLoadsStoresOutsideControlFlowOp().
auto new_while = builder.create<TF::WhileOp>(
while_op.getLoc(), body.getType().getResults(),
FilterRange<Value, OperandRange>(while_op.getOperands(),
resource_arg_uses),
while_op.getAttrs());
// Prepare for AddLoadsStoresOutsideControlFlowOp() and update
// new_output_shapes.
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
arg_data_type_and_updated_output_index;
for (const auto& entry : remaining_resource_data_types) {
int64_t update_index = return_op->getOperand(entry.getFirst()) ==
body.getArgument(entry.getFirst())
? -1
: entry.getFirst();
arg_data_type_and_updated_output_index[entry.getFirst()] = {
entry.getSecond(), update_index};
if (!new_output_shapes.empty()) {
tensorflow::TensorShapeProto shape_proto;
tensorflow::ConvertTypeToTensorShape(entry.getSecond())
.AsProto(&shape_proto);
new_output_shapes[entry.getFirst()] = builder.getStringAttr(
tensorflow::mangling_util::MangleShape(shape_proto));
}
}
AddLoadsStoresOutsideControlFlowOp(new_while,
arg_data_type_and_updated_output_index);
new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes));
// Replace uses.
for (int64_t i = 0; i < old_to_new_indices.size(); ++i) {
if (old_to_new_indices[i] >= 0) {
while_op.getResult(i).replaceAllUsesWith(
new_while.getResult(old_to_new_indices[i]));
}
}
while_op.erase();
return success();
}
// Hoists resource loads/stores from control flow ops in `block` outside the
// body/cond/branch functions.
LogicalResult HoistForFunctionalControlFlow(Block* block, ModuleOp module) {
for (Operation& op : llvm::make_early_inc_range(*block)) {
if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
auto body = llvm::cast<FuncOp>(module.lookupSymbol(while_op.body()));
auto cond = llvm::cast<FuncOp>(module.lookupSymbol(while_op.cond()));
// Recursively handle the nested control flow.
HoistForFunctionalControlFlow(&body.front(), module);
HoistForFunctionalControlFlow(&cond.front(), module);
if (failed(HanldeWhileLoop(while_op, body, cond))) return failure();
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
// TODO(yuanzx): Add support for IfOp.
}
}
return success();
}
} // namespace
// Lifts resource operation from tf_device.launch_func ops nested in `op`
// outside. Returns failure if there are remaining resource-type values that can
// not be lifted.
void ResourceOpLiftingPass::runOnModule() {
auto result = getModule().walk([&](FuncOp func_op) {
return func_op.walk([&](tf_device::LaunchOp launch_op) {
if (failed(HoistForFunctionalControlFlow(&launch_op.GetBody(),
getModule())) ||
failed(HoistResourceOpsFromLaunchOp(launch_op))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
});
if (result.wasInterrupted()) {
signalPassFailure();
}
}
std::unique_ptr<OpPassBase<ModuleOp>> CreateResourceOpLiftingPass() {
return std::make_unique<ResourceOpLiftingPass>();
}
static PassRegistration<ResourceOpLiftingPass> pass(
"tf-resource-op-lifting",
"Lifting resource operations out of device computation");
} // namespace TFDevice
} // namespace mlir