blob: e6677556c58617619d217d67df220f4857363a05 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
#define DEBUG_TYPE "tf-hoist-replicate-invariant-resource-writes"
namespace mlir {
namespace TF {
namespace {
struct HoistReplicateInvariantResourceWritesPass
: public TF::HoistReplicateInvariantResourceWritesPassBase<
HoistReplicateInvariantResourceWritesPass> {
void runOnOperation() override;
};
// TODO(prakalps): This is a common utility and other passes use something
// similar. Move to common utils.
bool IsResourceType(Type type) {
return type.isa<TF::ResourceType>() ||
(type.isa<TensorType>() &&
type.cast<TensorType>().getElementType().isa<TF::ResourceType>());
}
SmallVector<Value> GetAccessedResources(Operation& op) {
SmallVector<Value, 4> accessed_resources;
for (auto operand : op.getOperands()) {
if (!IsResourceType(operand.getType())) continue;
accessed_resources.push_back(operand);
}
return std::move(accessed_resources);
}
// Lifts the tail writes outside of tf_device.replicate. The written value is
// added to the values returned by tf_device.replicate op. Modify the assign
// variable ops to use the value from first replica.
void MoveTailWritesAfterReplicate(
tf_device::ReplicateOp replicate_op,
llvm::ArrayRef<TF::AssignVariableOp> tail_assign_variable_ops) {
const auto num_replicas = replicate_op.n();
auto return_op = llvm::dyn_cast<tf_device::ReturnOp>(
replicate_op.getRegion().front().getTerminator());
// Get the new result types.
// TODO(prakalps): Do not add a value to returned values if it is already
// returned.
auto new_result_types = llvm::to_vector<4>(replicate_op->getResultTypes());
for (auto assign : tail_assign_variable_ops) {
return_op->insertOperands(return_op->getNumOperands(), assign.value());
new_result_types.insert(new_result_types.end(), num_replicas,
assign.value().getType());
}
OpBuilder builder(replicate_op);
// Clone this old replicate op but with new result types.
auto new_replicate_op = builder.create<tf_device::ReplicateOp>(
replicate_op->getLoc(), new_result_types, replicate_op->getOperands(),
replicate_op->getAttrs());
// Move region to the new op.
new_replicate_op.getRegion().takeBody(replicate_op.getRegion());
// Replace all old uses with new op results.
int old_num_results = replicate_op->getNumResults();
replicate_op->replaceAllUsesWith(
new_replicate_op->getResults().take_front(old_num_results));
// Move assign ops after replicate and use the output of first replica.
for (auto indexed_assign : llvm::enumerate(tail_assign_variable_ops)) {
auto assign_op = indexed_assign.value();
auto index = indexed_assign.index();
assign_op->moveAfter(new_replicate_op);
assign_op->setOperand(
1, new_replicate_op->getResult(old_num_results + num_replicas * index));
}
replicate_op->erase();
}
// Looks for AssignVariable ops from the end of the tf_device.replicate op. It
// returns all the last writes to replicate invariant resource variables
// (resource handles defined outside the tf_device.replicate op).
SmallVector<TF::AssignVariableOp> GetTailWritesToReplicateInvariantResourceVars(
tf_device::ReplicateOp replicate_op) {
SmallVector<TF::AssignVariableOp, 16> tail_assign_variable_ops;
llvm::SmallDenseSet<Value, 16> visited_resources;
for (auto& op :
llvm::reverse(replicate_op.getRegion().front().getOperations())) {
SmallVector<Value> op_accessed_resources = GetAccessedResources(op);
if (op_accessed_resources.empty()) continue;
if (auto assign = llvm::dyn_cast<TF::AssignVariableOp>(op)) {
Value resource_var = assign.resource();
if (visited_resources.contains(resource_var) ||
!resource_var.getParentRegion()->isProperAncestor(
&replicate_op.getRegion()))
continue;
tail_assign_variable_ops.push_back(assign);
}
for (Value resource : op_accessed_resources)
visited_resources.insert(resource);
}
return std::move(tail_assign_variable_ops);
}
void HoistReplicateInvariantResourceWritesPass::runOnOperation() {
SmallVector<tf_device::ReplicateOp, 2> replicate_ops;
getOperation().walk([&](tf_device::ReplicateOp replicate_op) {
replicate_ops.push_back(replicate_op);
});
for (auto replicate_op : replicate_ops) {
SmallVector<TF::AssignVariableOp> tail_writes =
GetTailWritesToReplicateInvariantResourceVars(replicate_op);
if (tail_writes.empty()) continue;
MoveTailWritesAfterReplicate(replicate_op, tail_writes);
}
}
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
CreateHoistReplicateInvariantResourceWritesPass() {
return std::make_unique<HoistReplicateInvariantResourceWritesPass>();
}
} // namespace TF
} // namespace mlir