blob: fad8e22399c867fec7e6f3aa4c480d1fec1ebeb7 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h"
#include <deque>
#include <map>
#include <unordered_map>
#include "absl/container/flat_hash_set.h"
#include "absl/container/node_hash_set.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h"
namespace tensorflow {
namespace tpu {
namespace {
constexpr char kDefaultShardingValue[] = "";
const Edge* FindEdgeConnecting(const Node* src, const Node* dst) {
for (const auto e : src->out_edges()) {
if (e->dst()->name() == dst->name()) return &(*e);
}
return nullptr;
}
// Contains TPUExecute node and its DT_RESOURCE input nodes that
// correspond to model weights.
struct ExecuteNodeInfo {
Node* execute_node;
std::vector<const Edge*> var_inputs;
};
// Returns whether `node` is in `execute_nodes` or `(identity) -> execute`.
bool IsExecuteNodeOrIdentityToExecuteNode(
const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
const absl::flat_hash_set<Node*>& execute_nodes, Node* node) {
if (execute_nodes.find(node) != execute_nodes.end()) return true;
if (loop_nodes.find(node) == loop_nodes.end()) return false;
if (node->IsNextIteration()) return true;
if (!node->IsIdentity()) return false;
for (const Edge* e : node->out_edges()) {
if (e->IsControlEdge()) continue;
Node* node = e->dst();
if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
node)) {
return false;
}
}
return true;
}
// From input node to the TPUExecute op, finds the corresponding Enter node
// by searching/traversing nodes in below pattern of nodes:
// Enter ----> (identity) ---> While body input
// Returns nullptr if the Enter node is not found.
xla::StatusOr<Node*> FindEnterNodeFromTPUExecuteNodeInput(Node* input_node) {
Node* node = input_node;
while (node->IsIdentity()) {
TF_RETURN_IF_ERROR(node->input_node(0, &node));
}
if (node->IsEnter()) {
return node;
}
return nullptr;
}
xla::StatusOr<bool> ResourceOnlyUsedForTPUExecuteInLoop(
const Graph& graph, const std::unordered_set<Node*>& loop_nodes, // NOLINT
const Node* enter_node, const absl::flat_hash_set<Node*> execute_nodes) {
for (const Edge* output_edge : enter_node->out_edges()) {
Node* output_node = output_edge->dst();
if (output_edge->IsControlEdge() || output_node->IsExit()) continue;
// If output node is not execute node, it must be output node
// to the while loop body.
if (!IsExecuteNodeOrIdentityToExecuteNode(graph, loop_nodes, execute_nodes,
output_node)) {
return false;
}
}
return true;
}
// Given a TPUCompile node, find all TPUExecute nodes that executes the compiled
// program and its model weight variable inputs as well.
// TPUCompileMetadataProto of TPUCompile node must be reset to `new_metadata`
// if new reshard ops are added.
Status ExtractExecuteNodeInfo(const Node* compile_node, const Graph& graph,
const std::unordered_set<Node*>& loop_nodes, // NOLINT
std::vector<ExecuteNodeInfo>* execute_node_info,
TPUCompileMetadataProto* new_metadata) {
string metadata_string;
TF_RETURN_IF_ERROR(
GetNodeAttr(compile_node->attrs(), "metadata", &metadata_string));
new_metadata->ParsePartialFromString(metadata_string);
if (new_metadata->num_cores_per_replica() != 1) {
// We do not support model parallelism yet.
return Status::OK();
}
execute_node_info->clear();
for (Node* node : compile_node->out_nodes()) {
if (node->type_string() == "TPUExecute") {
execute_node_info->push_back({node});
}
}
if (execute_node_info->empty()) {
return Status::OK();
}
TF_RET_CHECK(execute_node_info->size() == new_metadata->num_replicas())
<< "Number of replicas does not equal number of execute nodes: "
<< new_metadata->num_replicas() << " vs " << execute_node_info->size();
DataTypeVector arg_types;
TF_RETURN_IF_ERROR(GetNodeAttr((*execute_node_info)[0].execute_node->attrs(),
"Targs", &arg_types));
for (int64 i = 0; i < arg_types.size(); ++i) {
if (arg_types[i] != DT_RESOURCE) {
continue;
}
const auto sharding_config = new_metadata->args(i).enable_xla_sharding();
if (sharding_config != TPUCompileMetadataProto::Arg::TENTATIVE &&
sharding_config != TPUCompileMetadataProto::Arg::ALLOWED) {
continue;
}
std::vector<const Edge*> edges(execute_node_info->size());
bool is_supported = true;
std::unordered_map<Node*, absl::flat_hash_set<Node*>>
enter_to_execute_nodes;
for (int64 j = 0; j < edges.size(); ++j) {
auto execute = (*execute_node_info)[j].execute_node;
TF_RETURN_IF_ERROR(execute->input_edge(i, &edges[j]));
TF_RET_CHECK(edges[j]->src()->output_type(edges[j]->src_output()) ==
arg_types[i])
<< "Execute op has an unexpected input type.";
// Traverse backwards to find the Enter node from which the input is
// passed.
// This makes sure that we are checking the usages of all potential
// aliases of the input node as well.
TF_ASSIGN_OR_RETURN(auto enter_node, FindEnterNodeFromTPUExecuteNodeInput(
edges[j]->src()));
if (enter_node == nullptr) {
is_supported = false;
enter_to_execute_nodes.clear();
break;
}
enter_to_execute_nodes[enter_node].insert(edges[j]->dst());
}
for (const auto& it : enter_to_execute_nodes) {
// Size of execute nodes should be either 1 (per-replica variables) or
// num_replicas (distributed variables).
if ((it.second.size() != 1) &&
(it.second.size() != new_metadata->num_replicas())) {
is_supported = false;
break;
}
TF_ASSIGN_OR_RETURN(bool no_other_use,
ResourceOnlyUsedForTPUExecuteInLoop(
graph, loop_nodes, it.first, it.second));
if (!no_other_use) {
is_supported = false;
break;
}
}
// Add the variable input edges only when they are supported for all
// executes.
if (is_supported) {
for (int64 j = 0; j < edges.size(); ++j) {
(*execute_node_info)[j].var_inputs.push_back(edges[j]);
}
new_metadata->mutable_args(i)->set_enable_xla_sharding(
TPUCompileMetadataProto::Arg::ALLOWED);
}
}
int64 total = 0;
for (const auto& a : new_metadata->args()) {
if (a.enable_xla_sharding() == TPUCompileMetadataProto::Arg::ALLOWED) {
total++;
}
}
TF_RET_CHECK(total == (*execute_node_info)[0].var_inputs.size())
<< " total " << total << " var_inputs "
<< (*execute_node_info)[0].var_inputs.size();
if (total == 0) {
// We don't need to process anything if no input is added.
execute_node_info->clear();
}
return Status::OK();
}
bool IsTPUCompileOp(const Node& n) { return n.type_string() == "TPUCompile"; }
void FindTPUCompileNodes(
const std::string* current_function_name,
const AttrValueMap* current_function_attr,
const std::unordered_map<string, WhileLoopFrame>& frames,
std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
// Adds frames with no children (i.e., the innermost frames) to a worklist.
std::deque<const WhileLoopFrame*> worklist;
for (auto& frame : frames) {
if (frame.second.num_children == 0) {
worklist.push_back(&frame.second);
}
}
// Check TPUCompile node from the innermost while loop to the outermost
// while loop.
while (!worklist.empty()) {
const WhileLoopFrame* frame = worklist.front();
worklist.pop_front();
for (const auto& n : frame->nodes) {
if (!IsTPUCompileOp(*n)) continue;
HostTrainingLoopInfo host_training_loop_info;
host_training_loop_info.compile_node_name = n->name();
host_training_loop_info.loop_cond_node_name = frame->loop_cond->name();
host_training_loop_info.while_loop_name = frame->name;
for (const auto arg : frame->args) {
LoopArgInfo arg_info;
arg_info.enter_node_name = arg.enter->name();
if (arg.exit) arg_info.exit_node_name = arg.exit->name();
host_training_loop_info.loop_arguments.push_back(std::move(arg_info));
}
host_training_loop_info.loop_nodes = frame->nodes;
if (current_function_name) {
host_training_loop_info.encapsulating_function_name =
*current_function_name;
}
if (current_function_attr) {
host_training_loop_info.encapsulating_function_attrs =
*current_function_attr;
}
host_training_loops_info->emplace_back(
std::move(host_training_loop_info));
}
// If the parent has no remaining children, add it to the worklist.
--frame->parent->num_children;
if (frame->parent->num_children == 0) {
worklist.push_back(frame->parent);
}
}
}
// From while loop cond node, finds all loop exit nodes by searching/traversing
// nodes in below pattern of nodes:
// LoopCond -----> Switch -----> Exit
std::vector<Node*> FindLoopExitNodes(const Node& loop_cond) {
std::vector<Node*> loop_exit_nodes;
for (const auto e_cond : loop_cond.out_edges()) {
if (e_cond->IsControlEdge() || !e_cond->dst()->IsSwitch()) continue;
auto switch_node = e_cond->dst();
for (const auto e_switch : switch_node->out_edges()) {
if (e_switch->IsControlEdge() || !e_switch->dst()->IsExit()) continue;
loop_exit_nodes.push_back(e_switch->dst());
}
}
return loop_exit_nodes;
}
// Find any one of switch nodes in the while loop by traversing the graph
// from while loop condition node.
xla::StatusOr<Node*> GetLoopSwitchNode(const Node& loop_cond_node) {
Node* loop_switch_node;
for (auto n : loop_cond_node.out_nodes()) {
if (n->IsSwitch()) {
loop_switch_node = n;
break;
}
}
TF_RET_CHECK(loop_switch_node->IsSwitch())
<< "Unable to find any switch nodes.";
return loop_switch_node;
}
// Returns or creates a node in that is executed before each loop iteration
// in the while loop.
Status GetOrCreateBeforeEachIterationNode(Graph* graph, Node* loop_switch_node,
Node** node_out) {
// If while loop switch node already has a outgoing data to true brach
// of the switch op, then reuse that node.
for (const auto out_edge : loop_switch_node->out_edges()) {
if (out_edge->src_output() == 1) {
*node_out = out_edge->dst();
return Status::OK();
}
}
// Create Identity node that represents execution at every loop iteration.
NodeDef at_loop_iteration_nodedef;
at_loop_iteration_nodedef.set_op("Identity");
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
AddNodeAttr("T", dtype, &at_loop_iteration_nodedef);
at_loop_iteration_nodedef.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/before_iteration", "/_", internal::GetNodeId())));
Status status;
Node* at_loop_iteration_node =
graph->AddNode(at_loop_iteration_nodedef, &status);
TF_RETURN_IF_ERROR(status);
graph->AddEdge(loop_switch_node, 1, at_loop_iteration_node, 0);
*node_out = at_loop_iteration_node;
return Status::OK();
}
// Injects NoOp node in that is executed after the very last iteration
// of the while loop but before the while loop exit node.
Status AddNoOpAfterLastIteration(Graph* graph, Node* loop_switch_node,
Node** node_out) {
// Find the exit node from loop switch node.
Node* exit_node;
for (const auto out_node : loop_switch_node->out_nodes()) {
if (out_node->IsExit()) {
exit_node = out_node;
break;
}
}
TF_RET_CHECK(exit_node != nullptr)
<< "Cannot find exit node connected to switch node :"
<< loop_switch_node->name();
// Create NoOp that represents execution at the end of while loop
// last iteration.
NodeDef after_last_loop_iteration;
after_last_loop_iteration.set_op("Identity");
DataType dtype;
TF_RETURN_IF_ERROR(GetNodeAttr(loop_switch_node->def(), "T", &dtype));
AddNodeAttr("T", dtype, &after_last_loop_iteration);
after_last_loop_iteration.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
Status status;
Node* after_last_iteration_node =
graph->AddNode(after_last_loop_iteration, &status);
TF_RETURN_IF_ERROR(status);
// Newly created node must be executed once after last iteration of the while
// loop and before while loop exits.
graph->AddEdge(loop_switch_node, 0, after_last_iteration_node, 0);
graph->AddControlEdge(after_last_iteration_node, exit_node);
*node_out = after_last_iteration_node;
return Status::OK();
}
} // namespace
Status DetectHostTrainingLoop(
const std::string* current_function_name,
const AttrValueMap* current_function_attr,
const FunctionLibraryDefinition* library, Graph* graph,
FunctionLibraryRuntime* flr,
std::vector<HostTrainingLoopInfo>* host_training_loops_info) {
std::vector<AssociatedFunctionInfo> associated_function_list;
for (const auto* n : graph->nodes()) {
const auto associated_functions = GetAssociatedFunctions(*n, library);
if (associated_functions.empty()) continue;
associated_function_list.insert(associated_function_list.end(),
associated_functions.begin(),
associated_functions.end());
}
Status ret_status = Status::OK();
for (const auto& function : associated_function_list) {
if (function.type() != AssociatedFunctionInfo::kFunctionAttr) continue;
// Convert the function to Graph.
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(flr->Instantiate(function.func_name(),
AttrSlice(&function.attrs()), &handle));
auto cleanup_handle = gtl::MakeCleanup([&]() {
auto s = flr->ReleaseHandle(handle);
if (!s.ok()) {
ret_status.Update(s);
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
Graph* function_graph = body->graph;
TF_RETURN_IF_ERROR(DetectHostTrainingLoop(
&function.func_name(), &function.attrs(), library, function_graph, flr,
host_training_loops_info));
}
// BuildControlFlowInfo() requires that the graph's source node is connected
// to all source nodes in the graph. Many graphs violate this invariant.
// As so, add edges to source/sink nodes so that this invariant is kept.
FixupSourceAndSinkEdges(graph);
std::vector<ControlFlowInfo> cf_info;
TF_RETURN_IF_ERROR(
BuildControlFlowInfo(graph, &cf_info, /*unreachable_nodes=*/nullptr));
std::unordered_map<string, WhileLoopFrame> frames;
TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames));
FindTPUCompileNodes(current_function_name, current_function_attr, frames,
host_training_loops_info);
return ret_status;
}
Status AddReshardOp(Graph* graph, const HostTrainingLoopInfo& host_loop_info) {
const auto& compile_node_name = host_loop_info.compile_node_name;
const auto node_name_map = graph->BuildNodeNameIndex();
const auto node_it = node_name_map.find(compile_node_name);
TF_RET_CHECK(node_it != node_name_map.end())
<< "Unable to find compile node : " << compile_node_name;
const auto compile_node = node_it->second;
std::vector<ExecuteNodeInfo> execute_nodes_info;
Status status;
TPUCompileMetadataProto metadata;
status =
ExtractExecuteNodeInfo(compile_node, *graph, host_loop_info.loop_nodes,
&execute_nodes_info, &metadata);
if (!status.ok()) {
LOG(ERROR) << "Encountered error when trying to extract execute nodes, "
"skipping host loop optimization. Status: "
<< status.ToString();
return Status::OK();
}
if (execute_nodes_info.empty()) {
return Status::OK();
}
// Update the TPUCompileMetadata such that sharding config of the
// sharded resource variable inputs is set to ALLOWED instead of
// TENTATIVE.
string new_metadata_string;
metadata.SerializeToString(&new_metadata_string);
compile_node->ClearAttr("metadata");
compile_node->AddAttr("metadata", new_metadata_string);
// Unsharding of the model weight variables must happen only at the very
// last loop iteration. As so, add while loop condition predicate as an
// input to the sharding switch node. If loop condition is true, we do not
// unshard.
const auto& cond_node_name = host_loop_info.loop_cond_node_name;
auto loop_cond_node_it = node_name_map.find(cond_node_name);
TF_RET_CHECK(loop_cond_node_it != node_name_map.end())
<< "Cannot find loop condition node : " << cond_node_name;
auto* loop_condition_node = loop_cond_node_it->second;
// In order to make sure that shard/unshard operations are invoked
// at the start of every loop body and at the end of last iteration
// of the loop, respectively, traverse the graph and find a switch node
// of the host training loop.
TF_ASSIGN_OR_RETURN(Node * switch_node,
GetLoopSwitchNode(*loop_condition_node));
Node* after_last_iteration_node;
TF_RETURN_IF_ERROR(AddNoOpAfterLastIteration(graph, switch_node,
&after_last_iteration_node));
Node* before_loop_iteration_node;
TF_RETURN_IF_ERROR(GetOrCreateBeforeEachIterationNode(
graph, switch_node, &before_loop_iteration_node));
// Create const op that represents default sharding value
// (i.e. no-op sharding).
NodeDef default_sharding;
default_sharding.set_op("Const");
default_sharding.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/default_shard_state", "/_", internal::GetNodeId())));
AddNodeAttr("dtype", DT_STRING, &default_sharding);
Tensor t(DT_STRING, {2});
t.vec<tstring>()(0) = kDefaultShardingValue;
t.vec<tstring>()(1) = kDefaultShardingValue;
t.AsProtoTensorContent(
(*default_sharding.mutable_attr())["value"].mutable_tensor());
Node* default_sharding_node = graph->AddNode(default_sharding, &status);
TF_RETURN_IF_ERROR(status);
// Add control edge between loop condition to make sure that
// default_sharding_node node is inside the while loop frame.
graph->AddControlEdge(loop_condition_node, default_sharding_node);
// Build a no-op node used to add control edges after unshard nodes.
NodeDef after_unshard;
after_unshard.set_op("NoOp");
after_unshard.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/last_iteration", "/_", internal::GetNodeId())));
auto after_unshard_node = graph->AddNode(after_unshard, &status);
TF_RETURN_IF_ERROR(status);
for (auto info : execute_nodes_info) {
auto execute_node = info.execute_node;
// Create Reshard op that optionally shards model weight variables
// prior to program execution.
NodeDef reshard_node_def;
reshard_node_def.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/reshard", "/_", internal::GetNodeId())));
reshard_node_def.set_op("TPUReshardVariables");
AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
&reshard_node_def);
Node* reshard_op_node = graph->AddNode(reshard_node_def, &status);
if (!status.ok()) return status;
reshard_op_node->set_assigned_device_name(
execute_node->assigned_device_name());
// Reshard op must execute at every loop iteration prior to
// TPUExecute node.
graph->AddControlEdge(before_loop_iteration_node, reshard_op_node);
graph->AddControlEdge(reshard_op_node, execute_node);
for (int i = 0; i < info.var_inputs.size(); ++i) {
const auto variable_edge = info.var_inputs[i];
graph->AddEdge(variable_edge->src(), variable_edge->src_output(),
reshard_op_node, i);
}
const int new_key_input = info.var_inputs.size();
// Add program input edge from the compiler(i.e. compilation key).
const auto compilation_key_edge =
FindEdgeConnecting(compile_node, execute_node);
graph->AddEdge(compile_node, compilation_key_edge->src_output(),
reshard_op_node, new_key_input);
// Create VarHandleOp to store sharding state. Sharding state holds string
// compilation key that identifies whether the graph is re-compiled and the
// variables need to be sharded again.
NodeDef var_handle_def;
var_handle_def.set_op("VarHandleOp");
var_handle_def.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/reshard_state", "/_", internal::GetNodeId())));
AddNodeAttr("dtype", DT_STRING, &var_handle_def);
AddNodeAttr("shape", TensorShape({}), &var_handle_def);
Node* var_handle_node = graph->AddNode(var_handle_def, &status);
if (!status.ok()) return status;
// Add control edge between `var_handle_def` node and while loop
// loop condition so that `var_handle_def` is inside the same while loop
// frame.
// TODO(hongjunchoi): Consider adding control edge from another node--such
// as input control node.
graph->AddControlEdge(loop_condition_node, var_handle_node);
// Connect data edge between var handle op and reshard op.
const int format_state_input = new_key_input + 1;
graph->AddEdge(var_handle_node, 0, reshard_op_node, format_state_input);
// Create Reshard op that represents unsharding after TPUExecute.
NodeDef unshard_node_def;
unshard_node_def.set_name(graph->NewName(strings::StrCat(
"TPUVariableReshard/unshard", "/_", internal::GetNodeId())));
unshard_node_def.set_op("TPUReshardVariables");
AddNodeAttr("N", static_cast<int>(info.var_inputs.size()),
&unshard_node_def);
Node* unshard_op_node = graph->AddNode(unshard_node_def, &status);
TF_RETURN_IF_ERROR(status);
unshard_op_node->set_assigned_device_name(
execute_node->assigned_device_name());
for (int i = 0; i < info.var_inputs.size(); ++i) {
const auto variable_edge = info.var_inputs[i];
// Connect model weight resource variables to unshard op. Since unshard op
// must be only invoked after the very last loop iteration, for each while
// loop inputs, we traverse backwards to find the switch node of the host
// training loop and connect `output_false` field of the switch node with
// unshard op.
TF_ASSIGN_OR_RETURN(
Node * enter_node,
FindEnterNodeFromTPUExecuteNodeInput(variable_edge->src()));
graph->AddEdge(enter_node, 0, unshard_op_node, i);
}
// Add control dependency before/after unshard node and the control nodes.
graph->AddControlEdge(after_last_iteration_node, unshard_op_node);
graph->AddControlEdge(unshard_op_node, after_unshard_node);
graph->AddEdge(default_sharding_node, 0, unshard_op_node, new_key_input);
// Add data edge from sharding state var handle op to unshard op.
graph->AddEdge(var_handle_node, 0, unshard_op_node, format_state_input);
}
// Add control dependency from after_unshard_node to all exits nodes. This is
// to make sure that the unshard ops will be executed as long as any of the
// exits are used.
for (auto exit : FindLoopExitNodes(*loop_condition_node)) {
graph->AddControlEdge(after_unshard_node, exit);
}
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow