blob: 494a8d1613b7bb9ac7fb02eb33de276f81d605c6 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include <functional>
#include <memory>
#include <numeric>
#include <string>
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/dump_graph.h"
namespace tensorflow {
const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
const char* const kXlaHostTransferSequencerAttr =
"_xla_host_transfer_sequencer";
const char* const kXlaHasReferenceVarsAttr = "_XlaHasReferenceVars";
void SortControlInputs(GraphDef* gdef) {
int64 num_nodes = gdef->node_size();
for (int64 i = 0; i < num_nodes; ++i) {
NodeDef* node = gdef->mutable_node(i);
// Stable sort control inputs and leave the order of data inputs unchanged.
std::stable_sort(node->mutable_input()->begin(),
node->mutable_input()->end(),
[](const string& a, const string& b) {
bool a_is_control = absl::StartsWith(a, "^");
bool b_is_control = absl::StartsWith(b, "^");
return (!a_is_control && b_is_control) ||
(a_is_control && b_is_control && a < b);
});
}
}
namespace {
bool AreAllParentsGuaranteedConst(
const Node& n,
const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
if (n.type_string() == "GuaranteeConst") {
// If the current node is itself a cast-to-const, no need
// to look at the incoming edges.
return true;
}
bool all_parents_const = true;
bool atleast_one_non_control_edge = false;
for (const Edge* in : n.in_edges()) {
atleast_one_non_control_edge =
atleast_one_non_control_edge || !in->IsControlEdge();
if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) {
all_parents_const = false;
break;
}
}
return all_parents_const && atleast_one_non_control_edge;
}
void MarkGuaranteedConstants(
const Graph& graph,
const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
absl::flat_hash_set<const Node*> guaranteed_const_nodes;
std::vector<const Node*> srcs;
srcs.reserve(src_arg_pairs.size());
for (const auto& src_arg : src_arg_pairs) {
srcs.push_back(src_arg.first);
}
ReverseDFSFrom(
graph, srcs, /*enter=*/nullptr,
/*leave=*/[&guaranteed_const_nodes](const Node* n) {
// TODO(vinuraja): Doesn't work in the presence of loops.
if (AreAllParentsGuaranteedConst(*n, guaranteed_const_nodes)) {
guaranteed_const_nodes.insert(n);
}
});
for (auto& src_arg : src_arg_pairs) {
if (guaranteed_const_nodes.count(src_arg.first) != 0) {
VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString();
src_arg.second->AddAttr("_is_guaranteed_constant", true);
}
}
}
struct OutputInputTensorPairHasher {
uint64 operator()(std::pair<OutputTensor, InputTensor> const& s) const {
return Hash64Combine(OutputTensor::Hash()(s.first),
InputTensor::Hash()(s.second));
}
};
// TODO(phawkins) add a canonical copy of these operator names and refactor
// everything to use it.
static const char* const kArgOp = "_Arg";
static const char* const kRetValOp = "_Retval";
static const char* const kHostComputeOp = "XlaHostCompute";
static const char* const kSendFromHostOp = "_XlaSendFromHost";
static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
class Encapsulator {
public:
Encapsulator(string group_attribute, Graph const* graph_in)
: group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {}
// Find subgraphs marked with 'group_attribute', and build a new
// subgraph, one for each value of 'group_attribute'.
Status SplitIntoSubgraphs(FunctionLibraryDefinition* library);
// Build a FunctionDef for each subgraph, and add it 'library'. The values of
// the 'group_attribute' annotations become the function names.
// If 'reuse_existing_functions' is set, use an existing function with the
// same name, if any.
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
// function conversion.
Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
bool reuse_existing_functions,
FunctionLibraryDefinition* library);
// Write a copy of the input graph to 'graph_out', where the subgraphs are
// replaced with calls to the new functions.
Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library);
private:
// A subgraph of the input, all marked with a common 'group_attribute'
// value.
//
// In the following simple example, A, B, ..., E are nodes in the original
// graph. The group attributes g are each shown as either 0 or empty.
//
// A --> B --> C --> D --> E
// g: g:0 g:0 g:0 g:
//
// The example is rewritten to two graphs; one on the host and one to be
// compiled. The host graph is as follows.
//
// A --> Call --> E
//
// The compiled cluster is as follows.
//
// Arg --> B --> C --> D --> Retval
class Subgraph {
public:
// Creates a graph to build the subgraph in, if it doesn't already exist,
// using the same op registry and versions as graph_in.
Node* MakeNodeImage(const Graph* graph_in, Node* node);
// Returns the graph the subgraph is being built in.
Graph* GetGraph() const;
// Builds a FunctionDef, and adds it to 'library'. The value of the
// 'group_attribute' annotations becomes the function name. If
// 'reuse_existing_functions' is set, use an existing function with the same
// name, if any. If 'rewrite_subgraph_fn' is set, it is applied to the
// subgraph before function conversion.
Status BuildFunctionDef(const string& name_in,
const RewriteSubgraphFn& rewrite_subgraph_fn,
bool reuse_existing_functions,
FunctionLibraryDefinition* library);
// Adds the function call node to graph_out.
Status AddFunctionCallNode(
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out);
// Returns the Node that the inputs and outputs of the function should be
// wired up to.
Node* GetCallNode() const;
// Returns the index of the arg that the dst of edge should connect to.
int GetArgIndexForEdge(const Edge* edge) const;
// Returns the index of the result that the src of edge should connect to.
int GetResultIndexForEdge(const Edge* edge) const;
// Creates an _Arg node for the src node of edge, and add its index to
// args_by_src_, if none exists yet. Also adds its index to args_by_dst_,
// and adds the edge within the subgraph from the _Arg node to the image of
// the dst node.
Status RecordArg(const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images,
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
// Records the src of the given edge as a control result of the graph.
// Used during graph to function conversion to tie control results to
// the function signature.
Status RecordControlResult(
const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images);
// Creates a _Retval node for the src node of edge, and add it to results_,
// if none exists yet. If a new _Retval node is created, also adds the edge
// within the subgraph from the src to the _Retval node.
Status RecordResult(
const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images);
// Creates the sequencer node if it doesn't exist, adding it to graph_out.
Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out);
// If there is a sequencer node, adds a control edge from the sequencer to
// the call node.
void ConnectSequencerToCallNode(Graph* graph_out);
Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
private:
// The subgraph extracted from the input graph, suitable for being turned
// into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
// returned by _Retval nodes.
std::unique_ptr<Graph> graph_;
// Which device are these nodes on? Used to assign a device to the call
// node.
string device_;
// NodeDef for the function call node.
NodeDef call_node_def_;
// Name that is used for the call node. This may not be
// call_node_def_.name() if the client supplies a rewrite lambda.
string function_def_name_;
// Placeholder node simulating the host compute key in the output graph.
// Not owned.
Node* host_compute_key_placeholder_ = nullptr;
// Function call node in the output graph. Not owned.
Node* call_node_;
// Maps from source (producer node/slot) and destination
// (consumer node/slot) tensors in the input graph to _Arg numbers in
// the subgraph. The source map is one-to-one, whereas the dest map may be
// many-to-one.
std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
// The arguments to the subgraph, in order.
std::vector<Node*> args_;
// Map from source tensor in the input graph to result #.
std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
// Set of node names that are the source of a control output of the
// subgraph. We store strings here so that we can tolerate nodes being
// removed from the graph.
absl::flat_hash_set<string> control_output_nodes_;
// NoOp node in the output graph that is sequenced after the call node.
Node* sequencer_ = nullptr;
};
// Returns the key attribute associated with a node in attr. Sets either
// result to the empty string if the respective attribute is not found.
Status GetFunctionNameAttr(Node const* node, string* attr) const;
// Copies edges local to a subgraph. Adds _Arg and _Retval nodes to
// subgraphs for data edges that cross subgraph boundaries.
Status CopySubgraphEdges(
const std::unordered_map<const Node*, Node*>& node_images,
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs);
// Copies all marked nodes to a subgraph. Does nothing for unmarked nodes.
Status CopySubgraphNodes(std::unordered_map<const Node*, Node*>* node_images);
// Copies all nodes that aren't in a compiled subgraph to the output graph.
Status CopyNodesToOutputGraph(
Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images);
// Adds function call nodes for each compiled subgraph.
Status AddFunctionCallNodes(
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out);
// Finds the image of an edge source in the output graph. If the edge crosses
// a subgraph boundary it is the output of a call node, otherwise it is a node
// in the output graph.
Status FindOutputImageOfEdgeSrc(
const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images,
const Node* original_src_node, Node** src_image);
// Finds an edge source slot in the output graph. If the edge crosses a
// subgraph boundary it is a slot on the output of a call node, otherwise it
// is a slot on a node in the output graph.
int FindOutputSlotOfEdgeSrc(const string& src_func_id,
const string& dst_func_id,
const Edge* edge);
// Finds the image of an edge destination in the output graph. If the edge
// crosses a subgraph boundary it is the input of a call node, otherwise it is
// a node in the output graph.
Status FindOutputImageOfEdgeDst(
const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images,
const Node* original_dst_node, Node** dst_image);
// Finds an edge destination slot in the output graph. If the edge crosses a
// subgraph boundary it is a slot on the input of a call node, otherwise it is
// a slot on a node in the output graph.
int FindOutputSlotOfEdgeDst(const string& src_func_id,
const string& dst_func_id,
const Edge* edge);
// Copies a single edge to the output graph. The edge is either entirely
// within the output graph, or crosses into or out of a compiled subgraph.
Status CopyEdgeToOutputGraph(
const Edge* edge, const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out,
std::unordered_set<std::pair<OutputTensor, InputTensor>,
OutputInputTensorPairHasher>* edges_added);
// Adds all edges to the output graph.
Status AddEdgesToOutputGraph(
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out);
// Makes a copy of graph containing only nodes that are ancestors of at least
// one node in send_from_host_nodes and store it in pruned_graph. On exit
// nodes_images contains a mapping from nodes in graph to nodes in
// pruned_graph. All functions in the copied graph are inlined.
Status MakePrunedGraphCopyAndInline(
const Graph& graph, const std::vector<Node*>& sink_nodes,
std::unique_ptr<Graph>* pruned_graph,
std::unordered_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library);
const string group_attribute_;
const Graph* graph_in_;
std::unordered_map<string, Subgraph> subgraphs_;
TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
};
namespace {
// Return in 'sorted' a topological sort of clusters according to the
// dependencies encoded in ancestors. clusters is the list of all clusters
// including clusters that are not present in the ancestors map. has_successors
// is the set of clusters that are ancestors of some other cluster.
void TopologicalClusterSort(
const std::unordered_set<string>& clusters,
const std::unordered_set<string>& has_successors,
const std::unordered_map<string, std::unordered_set<string>>& ancestors,
std::vector<string>* sorted) {
// The nodes are placed in 'sorted' in topological order.
sorted->clear();
// We don't use the standard DFS because we are not operating on Node*
// objects.
struct Work {
string cluster;
bool leave;
};
std::set<string> visited;
std::vector<Work> stack;
// Seed the processing list with clusters that have no successors.
for (const auto& cluster : clusters) {
if (has_successors.find(cluster) == has_successors.end()) {
stack.push_back({cluster, false});
}
}
while (!stack.empty()) {
const Work item = stack.back();
stack.pop_back();
if (item.leave) {
sorted->push_back(item.cluster);
continue;
}
if (visited.find(item.cluster) != visited.end()) continue;
visited.insert(item.cluster);
stack.push_back({item.cluster, true});
const auto& iter = ancestors.find(item.cluster);
if (iter != ancestors.end()) {
for (const auto& ancestor : iter->second) {
stack.push_back({ancestor, false});
}
}
}
CHECK(sorted->size() == clusters.size());
}
} // namespace
Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; }
int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input()));
}
int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
return results_.at(OutputTensor(edge->src(), edge->src_output()));
}
Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
if (!graph_) {
graph_.reset(new Graph(graph_in->op_registry()));
graph_->set_versions(graph_in->versions());
}
// TODO(b/116981129): Enhance how the device for the encapsulated subgraph is
// determined. In case of hard placement, ensure all the encapsulated nodes
// have the same requested device, which in turn will be the requested device
// for the entire encapsulated subgraph. In case of soft placement, use a
// deterministic approach to fill in the requested device. Handle co-location
// constraints similarly if they exist.
if (device_.empty()) {
device_ = node->assigned_device_name().empty()
? node->requested_device()
: node->assigned_device_name();
}
return graph_->CopyNode(node);
}
Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); }
Status Encapsulator::Subgraph::RecordArg(
const Edge* edge, const std::unordered_map<const Node*, Node*>& node_images,
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
Node* src_node = edge->src();
int src_slot = edge->src_output();
std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
bool inserted;
std::tie(iter, inserted) = args_by_src_.emplace(
OutputTensor(src_node, src_slot), args_by_src_.size());
int arg_index = iter->second;
if (inserted) {
NodeDef arg_def;
NodeDefBuilder builder(
absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp,
NodeDebugInfo(src_node->def()));
DataType dtype = edge->dst()->input_type(edge->dst_input());
builder.Attr("T", dtype);
builder.Attr("index", arg_index);
Status s = builder.Finalize(&arg_def);
if (!s.ok()) return s;
Node* arg = graph_->AddNode(arg_def, &s);
if (!s.ok()) return s;
src_arg_pairs->push_back({src_node, arg});
args_.push_back(arg);
}
Node* dst_node = edge->dst();
Node* dst_image = node_images.at(dst_node);
int dst_slot = edge->dst_input();
args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index;
graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
return Status::OK();
}
Status Encapsulator::Subgraph::RecordControlResult(
const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images) {
Node* src_node = edge->src();
Node* src_image = node_images.at(src_node);
control_output_nodes_.insert(src_image->name());
return Status::OK();
}
Status Encapsulator::Subgraph::RecordResult(
const Edge* edge,
const std::unordered_map<const Node*, Node*>& node_images) {
Node* src_node = edge->src();
Node* src_image = node_images.at(src_node);
int src_slot = edge->src_output();
std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
bool inserted;
std::tie(iter, inserted) =
results_.emplace(OutputTensor(src_node, src_slot), results_.size());
int ret_index = iter->second;
if (inserted) {
NodeDef ret_def;
NodeDefBuilder builder(
absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp,
NodeDebugInfo(src_node->def()));
DataType dtype = src_node->output_type(src_slot);
builder.Attr("T", dtype);
builder.Attr("index", ret_index);
builder.Input(src_image->name(), src_slot, dtype);
Status s = builder.Finalize(&ret_def);
if (!s.ok()) return s;
Node* ret = graph_->AddNode(ret_def, &s);
if (!s.ok()) return s;
graph_->AddEdge(src_image, src_slot, ret, 0);
}
return Status::OK();
}
Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
Graph* graph_out) {
if (sequencer_ == nullptr) {
NodeDef seq_def;
// TODO(shikharagarwal): What source node should we use for errors?
NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp");
builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name);
builder.Device(device_);
Status s = builder.Finalize(&seq_def);
if (!s.ok()) return s;
sequencer_ = graph_out->AddNode(seq_def, &s);
if (!s.ok()) return s;
}
return Status::OK();
}
void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) {
if (sequencer_ != nullptr) {
VLOG(2) << "ConnectSequencerToCallNode";
graph_out->AddControlEdge(sequencer_, call_node_,
/* allow_duplicates= */ true);
}
}
Status Encapsulator::Subgraph::BuildFunctionDef(
const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
bool reuse_existing_functions, FunctionLibraryDefinition* library) {
// name_in is copied here because name may be modified below if
// rewrite_subgraph_fn is true.
string name = name_in;
call_node_def_.set_op(name);
call_node_def_.set_name(name);
call_node_def_.set_device(device_);
if (rewrite_subgraph_fn) {
std::vector<OutputTensor> arg_source_tensors(args_by_src_.size());
for (const auto& arg : args_by_src_) {
arg_source_tensors.at(arg.second) = arg.first;
}
// Initialize the input and output permutations to the identity.
std::vector<int> input_permutation(args_by_src_.size());
std::iota(input_permutation.begin(), input_permutation.end(), 0);
std::vector<int> output_permutation(results_.size());
std::iota(output_permutation.begin(), output_permutation.end(), 0);
TF_RETURN_IF_ERROR(
rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation,
&output_permutation, &call_node_def_));
// Apply the input/output permutations to the 'args_by_...' and 'results_'
// mappings, so when we build edges in BuildOutputGraph() we
// connect them to the right input/output positions.
if (input_permutation.size() != args_by_src_.size()) {
return errors::InvalidArgument("Input permutation has incorrect size.");
}
if (output_permutation.size() != results_.size()) {
return errors::InvalidArgument("Output permutation has incorrect size.");
}
for (auto& arg : args_by_src_) {
arg.second = input_permutation[arg.second];
}
for (auto& arg : args_by_dst_) {
arg.second = input_permutation[arg.second];
}
for (auto& result : results_) {
result.second = output_permutation[result.second];
}
name = call_node_def_.op();
}
function_def_name_ = name;
FunctionDef fdef;
auto lookup = [this](const Node* node) -> absl::optional<string> {
if (control_output_nodes_.contains(node->name())) {
return absl::make_optional(node->name());
}
return absl::nullopt;
};
// Verify that the graph has well-formed control flow structure.
std::vector<ControlFlowInfo> dummy;
TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &dummy));
TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, lookup, &fdef));
if (VLOG_IS_ON(1)) {
VLOG(2) << "Build function def " << name;
DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), *graph_,
library);
DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), fdef);
}
const FunctionDef* original_fdef = library->Find(name);
if (!reuse_existing_functions || original_fdef == nullptr) {
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
} else if (!FunctionDefsEqual(*original_fdef, fdef)) {
TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
}
return Status::OK();
}
Status Encapsulator::Subgraph::ReplaceFunctionDef(
FunctionLibraryDefinition* library) {
const string& name = function_def_name_;
FunctionDef fdef;
TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
if (VLOG_IS_ON(1)) {
VLOG(2) << "Replace function def " << name;
DumpGraphToFile(absl::StrCat("replace_encapsulate_fdef_graph_", name),
*graph_, library);
DumpFunctionDefToFile(absl::StrCat("replace_encapsulate_fdef_", name),
fdef);
}
TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
return Status::OK();
}
Status Encapsulator::Subgraph::AddFunctionCallNode(
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out) {
Status s;
call_node_ = graph_out->AddNode(call_node_def_, &s);
if (!s.ok()) return s;
// Copy the assigned device and the key_annotation over.
call_node_->set_assigned_device_name(device_);
return Status::OK();
}
Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const {
AttrSlice attrs = node->attrs();
attr->clear();
bool found_group_attribute = false;
for (const auto& node_attr : attrs) {
if (node_attr.first == group_attribute_) {
TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
*attr = node_attr.second.s();
found_group_attribute = true;
break;
}
}
return Status::OK();
}
bool IsInSubgraph(const string& func_id) { return !func_id.empty(); }
Status Encapsulator::CopySubgraphNodes(
std::unordered_map<const Node*, Node*>* node_images) {
for (Node* node : graph_in_->op_nodes()) {
string func_id;
TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
if (!IsInSubgraph(func_id)) continue;
Subgraph& subgraph = subgraphs_[func_id];
Node* image = subgraph.MakeNodeImage(graph_in_, node);
image->ClearAttr(group_attribute_);
(*node_images)[node] = image;
}
return Status::OK();
}
Status Encapsulator::CopySubgraphEdges(
const std::unordered_map<const Node*, Node*>& node_images,
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
for (const Edge* edge : graph_in_->edges()) {
string src_func_id;
TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
string dst_func_id;
TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
// Copy edges that are local to a subgraph.
if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
src_func_id == dst_func_id) {
Graph* g = subgraphs_[src_func_id].GetGraph();
if (edge->IsControlEdge()) {
g->AddControlEdge(src_image, dst_image,
/* allow_duplicates= */ true);
} else {
g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
}
continue;
}
// Record 'src' as an output of its subgraph, if applicable.
if (IsInSubgraph(src_func_id)) {
if (!edge->IsControlEdge()) {
DataType dtype = edge->src()->output_type(edge->src_output());
if (IsRefType(dtype)) {
return errors::InvalidArgument(
"Ref Tensors (e.g., Variables) are not supported as results: "
"tensor ",
edge->src()->name(), ":", edge->src_output());
}
}
Subgraph& src_subgraph = subgraphs_[src_func_id];
if (edge->IsControlEdge()) {
TF_RETURN_IF_ERROR(src_subgraph.RecordControlResult(edge, node_images));
} else {
TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images));
}
}
// Record 'dst' as an input of its subgraph, if applicable.
if (IsInSubgraph(dst_func_id)) {
// Look at the type of the destination not the source, since Ref output
// Tensors can be automatically cast to non-Ref Tensors at the
// destination.
if (!edge->IsControlEdge()) {
DataType dtype = edge->dst()->input_type(edge->dst_input());
if (IsRefType(dtype)) {
return errors::InvalidArgument(
"Ref Tensors (e.g., Variables) are not supported as args: "
"tensor ",
edge->src()->name(), ":", edge->src_output());
}
}
Subgraph& dst_subgraph = subgraphs_[dst_func_id];
// Ignore control edges entering the subgraph. We will lift them onto
// the enclosing call operators in BuildOutputGraph().
if (!edge->IsControlEdge()) {
TF_RETURN_IF_ERROR(
dst_subgraph.RecordArg(edge, node_images, src_arg_pairs));
}
}
}
return Status::OK();
}
Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
Status s;
// Map from input graph nodes to subgraph nodes.
std::unordered_map<const Node*, Node*> node_images;
// Each entry of src_arg_pairs is a pair whose first element is a node in the
// original graph that has an output edge in the subgraph, and whose second
// element is the arg node in the subgraph that it sends to. The vector will
// be filled in below in AddArgs.
std::vector<std::pair<const Node*, Node*>> src_arg_pairs;
TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images));
TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs));
MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
for (auto& entry : subgraphs_) {
Subgraph& subgraph = entry.second;
FixupSourceAndSinkEdges(subgraph.GetGraph());
}
if (VLOG_IS_ON(1)) {
// Dump subgraphs.
for (auto& entry : subgraphs_) {
DumpGraphToFile(
absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
*entry.second.GetGraph(), library);
}
}
return s;
}
Status Encapsulator::BuildFunctionDefs(
const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
FunctionLibraryDefinition* library) {
for (auto& subgraph_entry : subgraphs_) {
string name = subgraph_entry.first;
Subgraph& subgraph = subgraph_entry.second;
TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef(
name, rewrite_subgraph_fn, reuse_existing_functions, library));
}
return Status::OK();
}
Status Encapsulator::CopyNodesToOutputGraph(
Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images) {
for (Node* node : graph_in_->op_nodes()) {
string func_id;
TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &func_id));
// Don't copy nodes that are going to be encapsulated.
if (IsInSubgraph(func_id)) continue;
Node* image = graph_out->CopyNode(node);
(*node_images)[node] = image;
}
(*node_images)[graph_in_->source_node()] = graph_out->source_node();
(*node_images)[graph_in_->sink_node()] = graph_out->sink_node();
return Status::OK();
}
Status Encapsulator::AddFunctionCallNodes(
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out) {
for (auto& subgraph_entry : subgraphs_) {
TF_RETURN_IF_ERROR(
subgraph_entry.second.AddFunctionCallNode(node_images, graph_out));
}
return Status::OK();
}
Status Encapsulator::FindOutputImageOfEdgeSrc(
const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images,
const Node* original_src_node, Node** src_image) {
if (IsInSubgraph(src_func_id)) {
// The edge is from a subgraph to a regular node in the output graph so
// use the subgraph's call node output.
*src_image = subgraphs_.at(src_func_id).GetCallNode();
} else {
// The source of the edge is in the output graph so use the node image in
// the output graph.
*src_image = node_images.at(original_src_node);
}
return Status::OK();
}
int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id,
const string& dst_func_id,
const Edge* edge) {
if (IsInSubgraph(src_func_id)) {
const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
// 'src' is in a subgraph and 'dst' is a regular node in the output
// graph. Use the corresponding call output instead.
return src_subgraph.GetResultIndexForEdge(edge);
} else {
// The source of the edge is in the output graph so use the regular edge
// slot.
return edge->src_output();
}
}
Status Encapsulator::FindOutputImageOfEdgeDst(
const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images,
const Node* original_dst_node, Node** dst_image) {
if (IsInSubgraph(dst_func_id)) {
// The edge is to a subgraph from a regular node in the output graph so
// use the subgraph's call node input.
*dst_image = subgraphs_.at(dst_func_id).GetCallNode();
} else {
// The destination of the edge is in the output graph so use the node image
// in the output graph.
*dst_image = node_images.at(original_dst_node);
}
return Status::OK();
}
int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id,
const string& dst_func_id,
const Edge* edge) {
if (IsInSubgraph(dst_func_id)) {
const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
// 'dst' is in a subgraph and 'src' is a regular node in the output
// graph. Use the corresponding call input instead.
return dst_subgraph.GetArgIndexForEdge(edge);
} else {
// The destination of the edge is in the output graph so use the regular
// edge slot.
return edge->dst_input();
}
}
Status Encapsulator::CopyEdgeToOutputGraph(
const Edge* edge, const string& src_func_id, const string& dst_func_id,
const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
std::unordered_set<std::pair<OutputTensor, InputTensor>,
OutputInputTensorPairHasher>* edges_added) {
Node* src_image;
TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
src_func_id, dst_func_id, node_images, edge->src(), &src_image));
Node* dst_image;
TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst(
src_func_id, dst_func_id, node_images, edge->dst(), &dst_image));
// If this is a control edge then copy it and return. Lift control edges onto
// the enclosing call operator.
if (edge->IsControlEdge()) {
// Add the control edge, if we have not already added it, using the images
// determined above (potentially call operators or RecvAtHost/SendFromHost).
if (edges_added
->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1))
.second) {
graph_out->AddControlEdge(src_image, dst_image,
/* allow_duplicates= */ true);
}
return Status::OK();
}
int src_output = FindOutputSlotOfEdgeSrc(src_func_id, dst_func_id, edge);
int dst_input = FindOutputSlotOfEdgeDst(src_func_id, dst_func_id, edge);
// Add the edge, if we have not already added it.
if (edges_added
->emplace(OutputTensor(src_image, src_output),
InputTensor(dst_image, dst_input))
.second) {
graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
}
return Status::OK();
}
Status Encapsulator::AddEdgesToOutputGraph(
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out) {
// Set of edges already added to the output graph, represented as (src, dst)
// pairs. We use the set to deduplicate edges; multiple edges in the input
// graph may map to one edge in the output graph.
std::unordered_set<std::pair<OutputTensor, InputTensor>,
OutputInputTensorPairHasher>
edges_added;
for (const Edge* edge : graph_in_->edges()) {
string src_func_id;
TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id));
string dst_func_id;
TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id));
// Ignore edges that are strictly contained within one subgraph, unless
// we are constructing parallel check graphs.
if (IsInSubgraph(src_func_id) && IsInSubgraph(dst_func_id) &&
src_func_id == dst_func_id) {
continue;
}
// We have an edge that crosses a cluster boundary or is entirely within the
// unclustered graph.
TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(
edge, src_func_id, dst_func_id, node_images, graph_out, &edges_added));
}
for (auto& subgraph_entry : subgraphs_) {
Subgraph& subgraph = subgraph_entry.second;
subgraph.ConnectSequencerToCallNode(graph_out);
}
return Status::OK();
}
namespace {
// Adds a dummy Const node to graph_out. The "constant" has the type of
// data_type and the shape indicated in 'shape'. The dummy node is not a valid
// Const node because it does not have any value defined, but this doesn't
// matter because it will only be used subsequently for shape inference. (It
// would be possible to add a switch statement over data_type to create a value
// for the constant, but that would entail maintaining the logic as new types
// are added, and is not necessary.) If the node being replaced was within a
// control flow frame, adds appropriate Enter nodes so that the use of the Const
// is well-formed.
Node* AddDummyShapedNode(const Node* src_node, int src_port,
const std::vector<ControlFlowInfo>& control_flow_info,
const TensorShapeProto& shape, Graph* graph_out) {
DataType data_type = src_node->output_type(src_port);
TensorProto dummy_proto;
dummy_proto.set_dtype(data_type);
*dummy_proto.mutable_tensor_shape() = shape;
// Don't set any value field in the proto, since it is only going to be used
// for shape inference.
GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
options.op_registry());
node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
Node* node = options.FinalizeBuilder(&node_builder);
// Add any Enter nodes required to bring the constant to the correct control
// flow frame.
while (!control_flow_info[src_node->id()].frame_name.empty()) {
NodeDebugInfo debug_info(*src_node);
NodeBuilder enter_builder(options.GetNameForOp("Enter"), "Enter",
options.op_registry(), &debug_info);
enter_builder.Attr("frame_name",
control_flow_info[src_node->id()].frame_name);
enter_builder.Attr("is_constant", true);
enter_builder.Input(node, 0);
Node* enter_node = options.FinalizeBuilder(&enter_builder);
// Adopt the new Enter node as the value in the current frame.
node = enter_node;
// Recurse to the parent frame to see if more Enter nodes need to be added.
src_node = control_flow_info[src_node->id()].parent_frame;
}
return node;
}
} // namespace
Status Encapsulator::MakePrunedGraphCopyAndInline(
const Graph& graph, const std::vector<Node*>& sink_nodes,
std::unique_ptr<Graph>* pruned_graph,
std::unordered_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library) {
// First copy all ancestor nodes of sink_nodes into a new graph.
pruned_graph->reset(new Graph(library));
(*pruned_graph)->set_versions(graph.versions());
ReverseDFSFrom(graph, sink_nodes,
/*enter=*/nullptr,
/*leave=*/[&](Node* n) {
if (!n->IsSource()) {
Node* copied = (*pruned_graph)->CopyNode(n);
node_images->emplace(n, copied);
}
});
// Add all the edges between copied nodes.
for (auto entry : *node_images) {
const Node* orig = entry.first;
Node* image = entry.second;
for (const Edge* out_edge : orig->out_edges()) {
auto iter = node_images->find(out_edge->dst());
if (iter != node_images->end()) {
// The source and destination are both in the copied graph.
(*pruned_graph)
->AddEdge(image, out_edge->src_output(), iter->second,
out_edge->dst_input());
}
}
}
// Find all the function call nodes, and inline them.
std::vector<Node*> function_nodes;
for (auto node : (*pruned_graph)->nodes()) {
const OpRegistrationData* op_reg_data;
TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
if (op_reg_data->is_function_op) {
function_nodes.push_back(node);
}
}
for (auto node : function_nodes) {
VLOG(2) << "Inlining function " << node->name();
const FunctionDef* fdef = library->Find(node->type_string());
if (fdef == nullptr) {
return errors::Internal("Failed to find function ", node->type_string(),
" in function library.");
}
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(*fdef, node->attrs(), library, &fbody));
InlineFunctionBodyOptions inline_opts;
TF_RETURN_IF_ERROR(InlineFunctionBody(*library, pruned_graph->get(), node,
fbody.get(), inline_opts));
}
return Status::OK();
}
Status Encapsulator::BuildOutputGraph(Graph* graph_out,
FunctionLibraryDefinition* library) {
// Map from nodes in the input graph to nodes in the output graph.
std::unordered_map<const Node*, Node*> node_images;
TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images));
TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out));
TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out));
return Status::OK();
}
} // anonymous namespace
Status EncapsulateSubgraphsInFunctions(
string group_attribute, const Graph& graph_in,
const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions,
std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
Encapsulator encapsulator(std::move(group_attribute),
&graph_in);
TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library));
TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
rewrite_subgraph_fn, reuse_existing_functions, library));
std::unique_ptr<Graph> out(new Graph(library));
out->set_versions(graph_in.versions());
TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library));
*graph_out = std::move(out);
return Status::OK();
}
// Finds the types of the _Arg nodes, indexed by position.
static Status GetArgTypes(const Graph& graph, DataTypeVector* types) {
for (Node* n : graph.op_nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
if (index < 0 || index >= types->size()) {
return errors::InvalidArgument("Invalid argument number");
}
(*types)[index] = n->output_type(0);
}
}
return Status::OK();
}
// Renumber the indices of _Arg nodes in a graph, according to
// 'permutation' that maps old indices to new indices.
static Status RenumberArguments(Graph* graph,
const std::vector<int>& permutation) {
for (Node* n : graph->op_nodes()) {
if (n->type_string() == kArgOp) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
if (index < 0 || index >= permutation.size()) {
return errors::InvalidArgument("Invalid argument number");
}
n->AddAttr("index", permutation[index]);
}
}
return Status::OK();
}
Status EncapsulateSubgraphsPass::Run(
const GraphOptimizationPassOptions& options) {
VLOG(1) << "EncapsulateSubgraphsPass::Run";
if (VLOG_IS_ON(1)) {
DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
options.flib_def);
}
std::unique_ptr<Graph> graph_out;
FunctionLibraryDefinition* const library = options.flib_def;
// Constant folding below might need to run part of the function to compute
// constants. Create an FunctionLibraryRuntime with a single CPU device
// that can run the part of the function.
// NOTE: If this turns out to be slow, we can cache the FLRs keyed by
// `options`.
SessionOptions session_options;
auto* device_count = session_options.config.mutable_device_count();
device_count->insert({"CPU", 1});
std::vector<std::unique_ptr<Device>> devices;
DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
if (!cpu_factory) {
return errors::NotFound(
"CPU Factory not registered. Can't run EncapsulateSubgraphsPass");
}
TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
if (devices.empty()) {
return errors::NotFound(
"Failed to create a CPU device for EncapsulateSubgraphsPass");
}
std::unique_ptr<DeviceMgr> device_mgr =
absl::make_unique<StaticDeviceMgr>(std::move(devices));
const auto* config = &options.session_options->config;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(
device_mgr.get(), options.session_options->env,
/*config=*/config, TF_GRAPH_DEF_VERSION, library,
config->graph_options().optimizer_options()));
FunctionLibraryRuntime* flr =
pflr->GetFLR("/job:localhost/replica:0/task:0/device:CPU:0");
if (flr == nullptr) {
return errors::Internal(
"Failed to create and retrieve function library runtime to run "
"constant folding");
}
auto rewrite_subgraph =
[flr](const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* subgraph,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation, NodeDef* node) {
// Optimize the subgraph.
// Do not constant fold nodes that output DT_VARIANT type tensors.
// XLA does not support Const nodes of Variant type since it needs
// to know the original ops to be able to compile them to the relevant
// XLA form.
// TODO(srbs): This filter is a little conservative. E.g. a subgraph of
// the form:
// Const
// |
// EmptyTensorList -> TensorListPushBack -> TensorListPopBack -> Op
// |
// (Discard popped list)
//
// Would have been reduced to "Const -> Op" without this filter.
// However since we are only allowed to specify the filter at the "Node"
// level there is no good way to allow the above behavior. So we
// disallow any sort of constant folding on Variant nodes for now.
bool disable_constant_folding =
GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding;
auto cf_consider_fn = [disable_constant_folding](const Node* n) {
if (disable_constant_folding) return false;
for (const auto& output_arg : n->op_def().output_arg()) {
if (output_arg.type() == DT_VARIANT) {
return false;
}
}
return true;
};
GraphOptimizer::Options graph_optimizer_options;
graph_optimizer_options.cf_consider_fn = cf_consider_fn;
OptimizeGraph(flr, subgraph, graph_optimizer_options);
const int num_args = input_permutation->size();
std::vector<bool> const_args(num_args);
TF_RETURN_IF_ERROR(
BackwardsConstAnalysis(**subgraph, &const_args,
/*compile_time_const_nodes=*/nullptr, flr));
DataTypeVector arg_types(num_args);
TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
// Compute a permutation of the arguments such that the constant
// arguments are first.
const int num_consts =
std::count(const_args.begin(), const_args.end(), true);
const int num_resources =
std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
const int num_nonconsts = num_args - num_resources - num_consts;
if (num_nonconsts < 0) {
return errors::Internal("num_nonconsts should be >= 0, was ",
num_nonconsts);
}
int const_pos = 0;
int arg_pos = num_consts;
int resource_pos = num_consts + num_nonconsts;
for (int i = 0; i < num_args; ++i) {
if (const_args[i]) {
if (arg_types[i] == DT_RESOURCE) {
return errors::Internal(
"Resource arguments cannot be constant (argument ", i, ")");
}
(*input_permutation)[i] = const_pos;
++const_pos;
} else if (arg_types[i] == DT_RESOURCE) {
(*input_permutation)[i] = resource_pos;
++resource_pos;
} else {
(*input_permutation)[i] = arg_pos;
++arg_pos;
}
}
// Renumber argument nodes in the graph.
TF_RETURN_IF_ERROR(
RenumberArguments(subgraph->get(), *input_permutation));
// TODO(phawkins): add a forward is-constant analysis, similarly split
// outputs into host-memory constants and device-memory non-constants.
AddNodeAttr(kXlaCompiledKernelAttr, true, node);
AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
return Status::OK();
};
TF_RETURN_WITH_CONTEXT_IF_ERROR(
EncapsulateSubgraphsInFunctions(
kXlaClusterAttr, **options.graph, rewrite_subgraph,
/*reuse_existing_functions=*/false, &graph_out, library),
"EncapsulateSubgraphsPass failed");
if (VLOG_IS_ON(1)) {
DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
options.flib_def);
}
*options.graph = std::move(graph_out);
TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> ref_related_nodes,
GetNodesRelatedToRefVariables(**options.graph, flr));
for (Node* node : (*options.graph)->nodes()) {
bool has_ref_vars = ref_related_nodes.contains(node);
node->AddAttr(kXlaHasReferenceVarsAttr, has_ref_vars);
VLOG(3) << "Has ref vars = " << has_ref_vars
<< ", node: " << node->def().SerializeAsString();
}
return Status::OK();
}
bool IsXlaCompiledKernel(const Node& node) {
bool is_compiled = false;
bool has_compilation_attr =
TryGetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled) &&
is_compiled;
return has_compilation_attr ? is_compiled : false;
}
} // namespace tensorflow