blob: 45597bcf4863beb4ce66a0fb84bf063c5b1c4b28 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
namespace mlir {
namespace TFTPU {
namespace {
// Pass that co-locates resource ops that use composite device resources
// (packed tensors) with the underlying physical TPU device.
struct TPUColocateCompositeResourceOps
: public TF::TPUColocateCompositeResourceOpsPassBase<
TPUColocateCompositeResourceOps> {
void runOnOperation() override;
};
// Wraps single op in `tf_device.launch` for explicit device assignment.
void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
llvm::StringRef device) {
builder->setInsertionPoint(op);
auto launch = builder->create<tf_device::LaunchOp>(
loc, builder->getStringAttr(device), op->getResultTypes());
launch.body().push_back(new Block);
op->replaceAllUsesWith(launch);
builder->setInsertionPointToEnd(&launch.GetBody());
builder->create<tf_device::ReturnOp>(loc, op->getResults());
// Move op inside cluster.
op->moveBefore(launch.GetBody().getTerminator());
}
llvm::SmallVector<Operation*, 4> GetResourceOpsUsingCompositeArgsInReplicate(
tf_device::ReplicateOp replicate) {
llvm::SmallVector<Operation*, 4> resource_users;
const auto add_resource_op_to_list = [&resource_users](Operation* op) {
if (!llvm::isa<TF::AssignVariableOp, TF::ReadVariableOp>(op)) return;
resource_users.emplace_back(op);
};
llvm::SmallVector<Operation*, 4> resource_users_to_visit;
for (auto composite_arguments : replicate.GetPackedBlockArguments()) {
for (auto resource_user : composite_arguments.getUsers())
resource_users_to_visit.emplace_back(resource_user);
}
while (!resource_users_to_visit.empty()) {
llvm::SmallVector<Operation*, 4> new_resource_users;
for (auto resource_user : resource_users_to_visit) {
add_resource_op_to_list(resource_user);
// Account for pass-through identity ops.
if (auto pass_through_identity =
llvm::dyn_cast<TF::IdentityOp>(resource_user)) {
for (auto identity_user : pass_through_identity.output().getUsers()) {
new_resource_users.emplace_back(identity_user);
}
}
}
resource_users_to_visit.swap(new_resource_users);
}
return resource_users;
}
void ColocateCompositeResourceOpsInReplicate(
tf_device::ReplicateOp replicate_op, OpBuilder* builder) {
auto devices = replicate_op.devices();
if (!devices) return;
if (!devices.getValue().get(tensorflow::GetDeviceAliasForLogicalCore(0)))
return;
const auto composite_resource_users =
GetResourceOpsUsingCompositeArgsInReplicate(replicate_op);
for (auto resource_user : composite_resource_users) {
WrapOpInLaunch(builder, resource_user->getLoc(), resource_user,
tensorflow::GetDeviceAliasForLogicalCore(0));
}
}
void TPUColocateCompositeResourceOps::runOnOperation() {
// Find all the executes first, since we will mutate the nodes around each
// execute in the same tf_device.replicate op.
llvm::SmallVector<tf_device::LaunchOp, 8> execute_launches;
getOperation().walk([&](tf_device::LaunchOp op) {
if (op.WrapsSingleOp() &&
llvm::isa<TF::TPUExecuteOp, TF::TPUExecuteAndUpdateVariablesOp>(
op.GetBody().front()))
execute_launches.push_back(op);
});
OpBuilder builder(&getContext());
for (auto execute_launch : execute_launches) {
auto replicate = execute_launch->getParentOfType<tf_device::ReplicateOp>();
if (!replicate) continue;
ColocateCompositeResourceOpsInReplicate(replicate, &builder);
}
}
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>>
CreateTPUColocateCompositeResourceOps() {
return std::make_unique<TPUColocateCompositeResourceOps>();
}
} // namespace TFTPU
} // namespace mlir