blob: b07b4ad6f5a0bf506456065b31213d77de936bfd [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Analysis/CallGraph.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
namespace mlir {
namespace TFTPU {
namespace {
mlir::LogicalResult WalkReachableFromTpuCluster(
bool pass_host_device, ModuleOp module,
std::function<WalkResult(Operation*, tf_device::ClusterOp,
std::optional<std::string>)>
callback) {
mlir::TF::RuntimeDevices devices;
if (failed(tensorflow::GetDevicesFromOp(module, &devices))) return failure();
const CallGraph call_graph(module);
// symbol_table caches callees in the CallGraph.
SymbolTableCollection symbol_table;
// List pending nodes to traverse with their root TPU cluster.
llvm::SmallVector<std::pair<CallGraphNode*, tf_device::ClusterOp>>
pending_call_nodes;
// Cache the host device for each TPU cluster.
std::unordered_map<Operation*, std::optional<std::string>> cluster_to_host;
auto insert_pending_op = [&](Operation* op,
tf_device::ClusterOp tpu_cluster) {
// Add callee nodes to pending_call_nodes.
if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
CallGraphNode* node = call_graph.resolveCallable(call, symbol_table);
pending_call_nodes.emplace_back(node, tpu_cluster);
}
};
// Traverse ops in each TPU cluster.
auto result = module.walk([&](tf_device::ClusterOp tpu_cluster) {
std::optional<std::string> host_device;
if (pass_host_device) {
std::string host_device_value;
if (failed(tensorflow::GetHostDeviceOutsideComputation(
devices, tpu_cluster, &host_device_value)))
return WalkResult::interrupt();
host_device = host_device_value;
}
cluster_to_host[tpu_cluster.getOperation()] = host_device;
return tpu_cluster.walk([&](Operation* op) {
insert_pending_op(op, tpu_cluster);
return callback(op, tpu_cluster,
cluster_to_host[tpu_cluster.getOperation()]);
});
});
if (result.wasInterrupted()) return failure();
// Traverse ops that are reachable from some TPU cluster.
// node_to_cluster is used to avoid traversing the same node twice, and to
// check that no node is reachable from multiple TPU clusters.
std::unordered_map<CallGraphNode*, tf_device::ClusterOp> node_to_cluster;
while (!pending_call_nodes.empty()) {
auto pair = pending_call_nodes.back();
pending_call_nodes.pop_back();
CallGraphNode* node = pair.first;
tf_device::ClusterOp tpu_cluster = pair.second;
if (node_to_cluster.count(node)) {
if (node_to_cluster[node].getOperation() != tpu_cluster) {
node->getCallableRegion()->getParentOp()->emitOpError(
"The same function is reachable from multiple TPU Clusters.");
}
} else {
node_to_cluster[node] = tpu_cluster;
auto result = node->getCallableRegion()->walk([&](Operation* op) {
insert_pending_op(op, tpu_cluster);
return callback(op, tpu_cluster,
cluster_to_host[tpu_cluster.getOperation()]);
});
if (result.wasInterrupted()) return failure();
}
}
return success();
}
} // namespace
mlir::LogicalResult WalkReachableFromTpuCluster(
ModuleOp module, std::function<WalkResult(Operation*, tf_device::ClusterOp,
std::optional<std::string>)>
callback) {
return WalkReachableFromTpuCluster(true, module, callback);
}
mlir::LogicalResult WalkReachableFromTpuCluster(
ModuleOp module,
std::function<WalkResult(Operation*, tf_device::ClusterOp)> callback) {
auto with_host_op = [&](Operation* op, tf_device::ClusterOp tpu_cluster,
std::optional<std::string> host_device) {
return callback(op, tpu_cluster);
};
return WalkReachableFromTpuCluster(false, module, with_host_op);
}
} // namespace TFTPU
} // namespace mlir