blob: 82ce409d9017951b011cb601f7a145fd15122398 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include <functional>
#include <queue>
#include <random>
#include <set>
#include <unordered_map>
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace {
Status ValidateTensorId(const tf2xla::TensorId& id) {
if (id.node_name().empty()) {
return errors::InvalidArgument("TensorId node_name must be non-empty");
}
if (id.output_index() < 0) {
return errors::InvalidArgument("TensorId output_index must be positive");
}
return Status::OK();
}
Status CheckNameDuplicates(const string& kind, const string& name,
std::set<string>* names) {
if (!name.empty()) {
if (!names->insert(name).second) {
return errors::InvalidArgument("duplicate ", kind, " name: ", name);
}
}
return Status::OK();
}
Status CheckFeedFetchNameConflicts(const string& kind,
const std::set<string>& names) {
// We don't allow the feeds or fetches to contain both "foo" and "foo_data",
// since that will cause a collision in codegen symbols.
for (const string& name : names) {
const string name_data(name + "_data");
if (names.find(name_data) != names.end()) {
return errors::InvalidArgument("conflicting ", kind, " name: ", name,
" and ", name_data);
}
}
return Status::OK();
}
// For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
// `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
Status CopyAssociatedFunctions(Graph* g,
const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld) {
for (Node* n : g->op_nodes()) {
for (const auto& associated_function :
GetAssociatedFunctions(*n, lookup_fld)) {
switch (associated_function.type()) {
case AssociatedFunctionInfo::kFunctionCallNode: {
const FunctionDef* fdef =
lookup_fld->Find(associated_function.func_name());
if (!fdef) {
return errors::Internal(
"Cannot find function ", associated_function.func_name(),
" for function call node ", n->DebugString());
}
TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
break;
}
case AssociatedFunctionInfo::kSymbolicGradient:
case AssociatedFunctionInfo::kFunctionAttr:
break;
}
}
}
return Status::OK();
}
// For graph `g`, replaces _Arg nodes whose "index" attribute is in
// `const_input_index_to_node` with Const nodes.
Status ReplaceArgUsageWithConstNode(
Graph* g,
const std::unordered_map<int, const Node*>& const_input_index_to_node) {
// Collect all _Arg nodes.
std::unordered_map<int, Node*> arg_nodes;
for (Node* n : g->op_nodes()) {
if (n->IsArg()) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
arg_nodes[index] = n;
}
}
for (const auto& iter : const_input_index_to_node) {
int arg_index = iter.first;
NodeDef const_def = iter.second->def();
const_def.set_name(g->NewName(const_def.name()));
Status s;
Node* const_node = g->AddNode(const_def, &s);
TF_RETURN_IF_ERROR(s);
Node* arg_node = arg_nodes[arg_index];
// Collect all usages of the _Arg node.
struct OutEdgeInfo {
int dst_node_id, dst_input;
};
std::vector<OutEdgeInfo> usages;
for (const Edge* e : arg_node->out_edges()) {
if (e->IsControlEdge()) {
continue;
}
usages.push_back({e->dst()->id(), e->dst_input()});
}
for (int i = 0, end = usages.size(); i < end; i++) {
// Make a copy of `usage_node`, and change its input to const node.
Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
NodeDef replace_def = usage_node->def();
*replace_def.mutable_input(usages[i].dst_input) = const_node->name();
TF_ASSIGN_OR_RETURN(Node * replace_node,
ReplaceNode(g, usage_node, replace_def));
const Edge* usage_edge;
TF_RETURN_IF_ERROR(
replace_node->input_edge(usages[i].dst_input, &usage_edge));
g->RemoveEdge(usage_edge);
g->AddEdge(const_node, 0, replace_node, usages[i].dst_input);
// Later entries in `usages` might have `usage_node` as dst node, but
// `usage_node` is removed. Replace such entries with `replace_node`.
for (int j = i + 1, end = usages.size(); j < end; j++) {
if (usages[j].dst_node_id == usages[i].dst_node_id) {
usages[j].dst_node_id = replace_node->id();
}
}
}
}
return Status::OK();
}
// For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
// the function to replace _Arg nodes in `const_input_index_to_node` with Const
// inputs.
Status PropagateConstIntoFuncAttr(
Node* n, const string& attr_name,
const std::unordered_map<int, const Node*>& const_input_index_to_node,
const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld) {
// Instantiate the function.
NameAttrList func_attr;
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
if (!fdef) {
return errors::Internal("Cannot find function ", func_attr.name(),
" for node ", n->name());
}
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
*fdef, AttrSlice(&func_attr.attr()), lookup_fld, &fbody));
// Rewrite _Arg usages with Const node.
Graph* func_graph = fbody->graph;
TF_RETURN_IF_ERROR(
ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
// Save rewritten function.
FunctionDef replace_fdef;
string new_func_name =
fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(
replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
// Change the node to use rewritten function.
func_attr.set_name(new_func_name);
n->ClearAttr(attr_name);
n->AddAttr(attr_name, func_attr);
TF_RETURN_IF_ERROR(fld->AddFunctionDef(
replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
// Copy associated functions.
TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
return Status::OK();
}
// For an "If" node in graph `g`, if it has Const node inputs, rewrite its
// then/else branch function to replace _Arg nodes with those Const inputs.
Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld) {
// Notice that first input for If node is predicate; other inputs are function
// inputs.
std::unordered_map<int, const Node*> const_input_index_to_node;
for (int i = 1; i < if_node->num_inputs(); i++) {
const Node* input_node;
TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
if (input_node->type_string() == "Const") {
const_input_index_to_node[i - 1] = input_node;
}
}
if (const_input_index_to_node.empty()) {
return Status::OK();
}
// Rewrite "then_branch" and "else_branch" function, replace usage of those
// _Arg nodes with corresponding const node.
for (const auto& attr_name :
std::vector<string>{"then_branch", "else_branch"}) {
TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
}
return Status::OK();
}
// For a "While" node in graph `g`, if it has Const node inputs, rewrite its
// cond/body function to replace _Arg nodes with those Const inputs.
Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld) {
// For "While" node, we should only replace _Arg nodes which are loop
// invariants. For such _Arg nodes, the return value's input will come
// directly from the corresponding arg.
std::unordered_map<int, const Node*> const_input_index_to_node;
NameAttrList body_attr;
TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
const FunctionDef* body_func = lookup_fld->Find(body_attr.name());
if (!body_func) {
return errors::Internal("Cannot find body function ", body_attr.name(),
" for While node ", while_node->name());
}
for (int i = 0; i < while_node->num_inputs(); i++) {
const Node* input_node;
TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node));
if (input_node->type_string() != "Const") {
continue;
}
// Check if i-th retval's input comes from i-th arg directly.
// For resource variable input of While nodes, TF2XLA convention is to place
// them at the end of all inputs (after all data inputs), and *not* return
// them. So number of While node inputs might be larger than number of its
// outputs.
if (i >= body_func->signature().output_arg_size()) {
continue;
}
const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
auto output_arg_input = body_func->ret().find(output_arg.name());
if (output_arg_input == body_func->ret().end()) {
return errors::Internal("Cannot find input for output arg ",
output_arg.name(), " in function ",
body_attr.name());
}
const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i);
if (output_arg_input->second != input_arg.name()) {
continue;
}
const_input_index_to_node[i] = input_node;
}
if (const_input_index_to_node.empty()) {
return Status::OK();
}
// Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
// corresponding const node.
for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
while_node, attr_name, const_input_index_to_node, lookup_fld, fld));
}
return Status::OK();
}
} // namespace
const char kTpuReplicateAttrName[] = "_tpu_replicate";
const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
Status ValidateConfig(const tf2xla::Config& config) {
std::set<string> names;
for (const tf2xla::Feed& feed : config.feed()) {
TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
}
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
names.clear();
for (const tf2xla::Fetch& fetch : config.fetch()) {
TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
}
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
if (config.fetch().empty()) {
return errors::InvalidArgument("fetches must be specified");
}
return Status::OK();
}
Status AddPlaceholdersForFeeds(
const tf2xla::Config& config, const OpRegistryInterface* op_registry,
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
struct PlaceholderInfo {
const tf2xla::Feed* feed = nullptr; // point to Feed in <config>.
string placeholder_name;
DataType data_type = DT_INVALID;
};
// Put each fed tensor into a map by name:port. A map is used for determinism
// when creating placeholders (genrules want deterministic output).
std::map<string, PlaceholderInfo> placeholder_info;
for (int i = 0; i < config.feed_size(); ++i) {
const tf2xla::Feed* feed = &config.feed(i);
const string name_port = TensorIdToString(feed->id());
PlaceholderInfo& info = placeholder_info[name_port];
info.feed = feed;
info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
"/", feed->id().node_name());
(*feed_remapping)[name_port] = info.placeholder_name;
}
// Verify node exists and determine data type.
std::unordered_map<string, const NodeDef*> name_to_node;
for (int i = 0; i < graph_def->node_size(); ++i) {
name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
}
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
PlaceholderInfo& info = it->second;
const tf2xla::TensorId& feed_id = info.feed->id();
// Find the existing node and determine data type.
auto node_it = name_to_node.find(feed_id.node_name());
if (node_it == name_to_node.end()) {
return errors::NotFound("Can't find feed node: ",
TensorIdToString(feed_id));
}
const NodeDef* existing = node_it->second;
if (info.feed->type() != DT_INVALID) {
info.data_type = info.feed->type();
} else {
// Build the node in order to infer its type.
// Must first add default attrs as well, so do this in a copied GraphDef.
GraphDef gd;
*gd.mutable_versions() = graph_def->versions();
*gd.add_node() = *existing;
MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
TF_RETURN_IF_ERROR(
AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
// Now build the node from the copied node def.
Graph g(op_registry);
g.set_versions(graph_def->versions());
Status status;
Node* feed_node = g.AddNode(gd.node(0), &status);
TF_RETURN_IF_ERROR(status);
if (info.feed->id().output_index() < feed_node->num_outputs()) {
info.data_type =
BaseType(feed_node->output_type(info.feed->id().output_index()));
} else {
return errors::InvalidArgument(
"Invalid output_index ", info.feed->id().output_index(),
" for feed node ", info.feed->id().node_name());
}
}
}
// Create placeholders. Note that we could avoid creating a placeholder for
// feeds which are already placeholders, but we omit that to avoid more cases
// in this code.
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
const PlaceholderInfo& info = it->second;
// TODO(shikharagarwal): Add original node information.
NodeDef* d = graph_def->add_node();
d->set_name(info.placeholder_name);
d->set_op("Placeholder");
auto& attr_map = *d->mutable_attr();
attr_map["dtype"].set_type(info.data_type);
*attr_map["shape"].mutable_shape() = info.feed->shape();
}
// Rewrite references to the fed tensors to refer to the placeholder.
for (int i = 0; i < graph_def->node_size(); ++i) {
NodeDef* node_def = graph_def->mutable_node(i);
for (int j = 0; j < node_def->input_size(); ++j) {
auto id = ParseTensorName(node_def->input(j));
auto it = placeholder_info.find(id.ToString());
if (it != placeholder_info.end()) {
node_def->set_input(j, it->second.placeholder_name);
}
}
}
return Status::OK();
}
Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
GraphDef* out) {
*out = in;
out->clear_node();
// Tensors needed for feeding.
std::set<std::pair<string, int>> feed_tensors;
for (const tf2xla::Feed& feed : config.feed()) {
feed_tensors.insert(
std::make_pair(feed.id().node_name(), feed.id().output_index()));
}
// Maps node name to reachability.
std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
for (const NodeDef& node : in.node()) {
node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
}
// Traverse.
std::queue<string> name_queue;
for (int i = 0; i < config.fetch_size(); ++i) {
name_queue.push(config.fetch(i).id().node_name());
}
while (!name_queue.empty()) {
const string name = name_queue.front();
name_queue.pop();
auto find_it = node_by_name.find(name);
if (find_it == node_by_name.end()) {
return errors::InvalidArgument("While pruning graph, node ", name,
" needed but not found in the graph.");
}
auto& map_entry = find_it->second;
if (map_entry.first) {
continue;
}
map_entry.first = true;
// Push input nodes of the currently visited node to name_queue.
for (const string& in_edge : map_entry.second->input()) {
auto id = ParseTensorName(in_edge);
const string node_name = string(id.first);
if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
feed_tensors.end()) {
name_queue.push(node_name);
} else {
// The input tensor is from an edge that is being fed. Therefore,
// we skip recursing down that edge, to avoid requiring nodes that
// may not be needed (note that the input node may still be added
// to name_queue later if one of its output edges is not being fed).
}
}
}
// Copy over, preserving order of original and only nodes that are reachable
// from the fetches.
out->mutable_node()->Reserve(in.node_size());
for (const NodeDef& node : in.node()) {
if (node_by_name[node.name()].first) {
*out->add_node() = node;
}
}
return Status::OK();
}
string TensorIdToString(const tf2xla::TensorId& id) {
return absl::StrCat(id.node_name(), ":", id.output_index());
}
Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
int core = -1;
const Node* matching_node = nullptr;
for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
if (edge->IsControlEdge()) continue;
const Node* possible_match = out_edges ? edge->dst() : edge->src();
TF_ASSIGN_OR_RETURN(
absl::optional<xla::OpSharding> sharding,
ParseShardingFromDevice(
*possible_match,
/*num_cores_per_replica=*/std::numeric_limits<int32>::max(),
/*add_metadata=*/false));
if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
const int core_annotation = sharding.value().tile_assignment_devices(0);
if (core == -1 || core > core_annotation) {
core = core_annotation;
matching_node = possible_match;
}
}
}
if (matching_node != nullptr) {
n->set_assigned_device_name(matching_node->assigned_device_name());
n->set_requested_device(matching_node->requested_device());
}
return Status::OK();
}
void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
KernelDef* kdef) {
for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
if (constraint.name() == name) {
constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
}
}
}
namespace {
uint32 InitialRandomSeed() {
// Support plumbing the TF seed through to XLA is being worked on.
// If a user wants deterministic behavior, their best option
// is to start with a known checkpoint. This also handles issues when
// multiple random calls can be invoked in any order by TF executor.
// Another option is to use stateless random ops. They have much cleaner
// semantics.
// If a user really wants to set a deterministic seed for XLA-based
// devices, this is the place to do it.
std::random_device rd;
// Make the starting value odd.
return rd() | 1;
}
} // namespace
uint32 GetXLARandomSeed() {
// We initialize counter with an odd number and increment it by two
// everytime. This ensures that it will never be zero, even
// after an overflow. When seeded with zero, some XLA backends
// can return all zeros instead of random numbers.
static std::atomic<uint32> counter(InitialRandomSeed());
uint32 seed = counter.fetch_add(2);
std::srand(seed);
return std::rand() | 1;
}
// TODO(b/77601805): add tests for associated function related stuff.
bool HasAssociatedFunction(const NodeDef& node_def,
const FunctionLibraryDefinition* fld) {
if (fld->Contains(node_def.op())) {
return true;
}
if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
// Gradient op has "f" attr, which is set to the function we are getting
// gradient for. We need to functionalize the gradient function.
return true;
}
if (node_def.op() == "XlaHostCompute") {
// XlaHostCompute has "shape_inference_graph" func attr, but that's not
// related to graph execution.
return false;
}
for (const auto& iter : node_def.attr()) {
if (iter.second.has_func()) {
return true;
}
}
return false;
}
std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
const Node& node, const FunctionLibraryDefinition* fld) {
std::vector<AssociatedFunctionInfo> results;
const string& op = node.type_string();
if (fld->Contains(op)) {
// This is a function call node.
AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
} else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
// This is a SymbolicGradient op.
AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
} else if (node.type_string() == "XlaHostCompute") {
// XlaHostCompute has "shape_inference_graph" func attr, but that's not
// related to graph execution.
} else {
// Collect all function attrs for the node.
for (auto& iter : node.attrs()) {
if (iter.second.has_func()) {
VLOG(2) << "Found function attr for node " << node.name() << ": "
<< iter.first << " = " << iter.second.func().name();
results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
iter.second.func().name(), iter.second.func().attr(), iter.first));
}
}
}
return results;
}
Status RewriteAssociatedFunction(
Graph* graph, Node* node, FunctionLibraryDefinition* fld,
const AssociatedFunctionInfo& associated_function,
const string& rewritten_function_name) {
switch (associated_function.type()) {
case AssociatedFunctionInfo::kFunctionCallNode: {
// Change this node to call the new function.
NodeDebugInfo debug_info(*node);
NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
&debug_info);
for (const auto& attr : node->attrs()) {
builder.Attr(attr.first, attr.second);
}
for (int i = 0; i < node->num_inputs(); i++) {
Node* input_node;
TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
builder.Input(input_node->name(), i, node->input_type(i));
}
builder.Device(node->assigned_device_name().empty()
? node->requested_device()
: node->assigned_device_name());
NodeDef node_def;
TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
Status s;
Node* new_node = graph->AddNode(node_def, &s);
TF_RETURN_IF_ERROR(s);
for (auto edge : node->in_edges()) {
graph->AddEdge(edge->src(), edge->src_output(), new_node,
edge->dst_input());
}
for (auto edge : node->out_edges()) {
graph->AddEdge(new_node, edge->src_output(), edge->dst(),
edge->dst_input());
}
graph->RemoveNode(node);
break;
}
case AssociatedFunctionInfo::kSymbolicGradient: {
NameAttrList func;
TF_RETURN_IF_ERROR(GetNodeAttr(
node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
GradientDef gradient_def;
gradient_def.set_function_name(func.name());
gradient_def.set_gradient_func(rewritten_function_name);
string original_grad_func = fld->FindGradient(func.name());
if (original_grad_func.empty()) {
TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
} else if (original_grad_func != rewritten_function_name) {
TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
}
break;
}
case AssociatedFunctionInfo::kFunctionAttr: {
// Change function attr to rewritten functions.
NameAttrList func;
TF_RETURN_IF_ERROR(
GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
node->ClearAttr(associated_function.attr_name());
func.set_name(rewritten_function_name);
node->AddAttr(associated_function.attr_name(), func);
break;
}
}
return Status::OK();
}
Status CachedFunctionHandles::GetOrInstantiate(
const string& func_name, AttrSlice attrs,
FunctionLibraryRuntime::Handle* handle) {
string canonicalized_name = Canonicalize(func_name, attrs);
auto iter = handles_.find(canonicalized_name);
if (iter != handles_.end()) {
*handle = iter->second;
return Status::OK();
}
TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
handles_[canonicalized_name] = *handle;
return Status::OK();
}
Status CachedFunctionHandles::ReleaseAllHandles() {
Status result;
for (const auto& iter : handles_) {
result.Update(flr_->ReleaseHandle(iter.second));
}
handles_.clear();
return result;
}
xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
// Create the replacement node.
Status s;
Node* new_node = g->AddNode(node_def, &s);
if (!s.ok()) {
return s;
}
// Record original node's output edges and remove them first. This is to avoid
// multiple producers for dst nodes' input.
std::vector<OutEdgeInfo> out_edge_info;
std::vector<const Edge*> out_edges;
for (const Edge* edge : n->out_edges()) {
out_edges.push_back(edge);
out_edge_info.push_back(
{edge->dst(), edge->src_output(), edge->dst_input()});
}
for (const Edge* edge : out_edges) {
g->RemoveEdge(edge);
}
// Add original node's input and output edges to the replacement node.
for (const Edge* in_edge : n->in_edges()) {
g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
in_edge->dst_input());
}
for (const OutEdgeInfo& out_edge : out_edge_info) {
g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
}
// Remove the original node.
g->RemoveNode(n);
return new_node;
}
xla::StatusOr<Node*> BuildIdentityNode(
Graph* graph, const string& node_name, DataType dtype, const Node* input,
absl::optional<string> requested_device) {
// Create identity node.
NodeDef ndef;
ndef.set_name(node_name);
ndef.set_op("Identity");
if (input) {
ndef.add_input(input->name());
}
if (requested_device) {
ndef.set_device(*requested_device);
}
AddNodeAttr("T", dtype, &ndef);
Status s;
Node* id_node = graph->AddNode(ndef, &s);
TF_RETURN_IF_ERROR(s);
return id_node;
}
Status PropagateConstIntoFunctionalNodes(
Graph* g, const FunctionLibraryDefinition* lookup_fld,
FunctionLibraryDefinition* fld) {
for (Node* n : g->op_nodes()) {
if (n->IsIfNode()) {
TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
} else if (n->IsWhileNode()) {
TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
}
}
return Status::OK();
}
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
FunctionLibraryDefinition* fld) {
GraphDef graph_def;
g.ToGraphDef(&graph_def);
FunctionLibraryDefinition reachable_functions =
fld->ReachableDefinitions(graph_def);
for (const string& func_name : fld->ListFunctionNames()) {
if (!reachable_functions.Find(func_name)) {
TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
}
}
return Status::OK();
}
Status RewriteTensorListWithConstElement(Graph* g,
FunctionLibraryDefinition* fld) {
for (Node* n : g->nodes()) {
if (n->type_string() != "EmptyTensorList") {
continue;
}
// Find the forward While op.
std::vector<const Edge*> fwd_while_edges;
for (const Edge* e : n->out_edges()) {
if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
fwd_while_edges.push_back(e);
}
}
if (fwd_while_edges.size() != 1) {
// No forward While op found, or multiple forward While ops.
continue;
}
// Find the backward While op.
Node* fwd_while = fwd_while_edges[0]->dst();
int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
std::vector<const Edge*> bwd_while_edges;
for (const Edge* e : fwd_while->out_edges()) {
if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
bwd_while_edges.push_back(e);
}
}
if (bwd_while_edges.size() != 1) {
// No backward While op found, or multiple backward While ops.
continue;
}
Node* bwd_while = bwd_while_edges[0]->dst();
int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
// Look into forward While body function and check if TensorListPushBack op
// has a Const input.
NameAttrList fwd_body_attr;
TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
if (!fwd_body) {
return errors::InvalidArgument("Cannot find function ",
fwd_body_attr.name(), " for While node ",
fwd_while->DebugString());
}
std::unique_ptr<FunctionBody> fwd_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(
*fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
// Find the TensorListPushBack node; it's one of fwd_arg's successors.
Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
std::vector<Node*> tl_push_nodes;
for (const Edge* out_edge : fwd_arg->out_edges()) {
if (out_edge->dst()->type_string() == "TensorListPushBack") {
tl_push_nodes.push_back(out_edge->dst());
}
}
if (tl_push_nodes.size() != 1) {
// No TensorListPushBack found, or multiple TensorListPushBack.
continue;
}
// Get input for the TensorListPushBack node.
Node* input_node;
TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
if (input_node->type_string() != "Const") {
// Input for the TensorList is not Const node.
continue;
}
NodeDef const_input_nodedef = input_node->def();
// Rewrite backward While body function, replace usages of
// TensorListPopBack with a Const node.
NameAttrList bwd_body_attr;
TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
if (!bwd_body) {
return errors::InvalidArgument("Cannot find function ",
bwd_body_attr.name(), " for While node ",
bwd_while->DebugString());
}
std::unique_ptr<FunctionBody> bwd_fbody;
TF_CHECK_OK(FunctionDefToBodyHelper(
*bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
// Find the TensorListPopBack node; it's one of bwd_arg's successors.
Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
std::vector<Node*> tl_pop_nodes;
for (const Edge* out_edge : bwd_arg->out_edges()) {
if (out_edge->dst()->type_string() == "TensorListPopBack") {
tl_pop_nodes.push_back(out_edge->dst());
}
}
if (tl_pop_nodes.size() != 1) {
// No TensorListPopBack found, or multiple TensorListPopBack.
continue;
}
// Replace TensorListPopBack usages with Const node.
std::vector<const Edge*> edges_to_replace;
for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
if (e->src_output() == 1) {
edges_to_replace.push_back(e);
}
}
if (edges_to_replace.empty()) {
continue;
}
Status s;
const_input_nodedef.set_name(
bwd_fbody->graph->NewName(const_input_nodedef.name()));
Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s);
TF_RETURN_IF_ERROR(s);
for (const Edge* e : edges_to_replace) {
Node* dst = e->dst();
int dst_input = e->dst_input();
bwd_fbody->graph->RemoveEdge(e);
bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
}
// Add rewritten backward While body function.
FunctionDef new_fdef;
string new_name = fld->UniqueFunctionName(
absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
// Change backward While op to use the new body function.
bwd_body_attr.set_name(new_name);
bwd_while->ClearAttr("body");
bwd_while->AddAttr("body", bwd_body_attr);
}
return Status::OK();
}
} // namespace tensorflow