blob: 0dccee582ee3a022629ec1ec473c45c4d9379ab1 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/graph_view.h"
#include <utility>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/graph_view_internal.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
namespace tensorflow {
namespace grappler {
namespace utils {
FaninView::FaninView(NodeView* node_view, int index)
: NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
index) {}
FanoutView::FanoutView(NodeView* node_view, int index)
: NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
index) {}
const NodeDef* NodeView::node() const {
return &graph_view_->graph()->node(node_index_);
}
bool NodeView::HasFanin(const FanoutView& fanin) const {
if (fanin.index() < Graph::kControlSlot || graph_view_ != fanin.graph_view_) {
return false;
}
return fanins_set_.contains(
{&graph_view_->graph_->node(fanin.node_index_), fanin.index()});
}
bool NodeView::HasFanout(const FaninView& fanout) const {
if (fanout.index() < Graph::kControlSlot ||
graph_view_ != fanout.graph_view_) {
return false;
}
NodeView* view = fanout.node_view();
if (view == nullptr) {
return false;
} else if (fanout.index() == Graph::kControlSlot) {
return view->fanins_set_.contains({this->node(), Graph::kControlSlot});
} else if (fanout.index() >= view->regular_fanins_.size()) {
return false;
}
return view->regular_fanins_[fanout.index()].node_index_ == node_index_;
}
inline const FanoutView& NodeView::GetMissingFanin() const {
return graph_view_->missing_fanin_;
}
inline const std::vector<FaninView>& NodeView::GetMissingFanout() const {
return graph_view_->missing_fanout_;
}
namespace {
const char kGraphViewError[] = "GraphView::GraphView error: ";
} // namespace
GraphView::GraphView(const GraphDef* graph, Status* status)
: GraphViewInternal(graph) {
const int num_nodes = graph->node_size();
node_index_by_name_.reserve(num_nodes);
nodes_.reserve(num_nodes);
for (const NodeDef& node : graph->node()) {
if (!AddUniqueNodeInternal(&node)) {
*status = errors::InvalidArgument(
kGraphViewError, "graph has multiple nodes with the name '",
node.name(), "'.");
Reset();
return;
}
}
Status s;
for (NodeView& node_view : nodes_) {
s = CheckAndAddFaninsInternal(&node_view);
if (!s.ok()) {
*status = s;
Reset();
return;
}
}
*status = Status::OK();
}
bool GraphView::AddUniqueNodeInternal(const NodeDef* node) {
const int node_index = node_index_by_name_.size();
auto it = node_index_by_name_.emplace(node->name(), node_index);
if (it.second) {
nodes_.emplace_back(this, node_index);
return true;
}
return false;
}
Status GraphView::CheckAndAddFaninsInternal(NodeView* node_view) {
bool has_observed_control = false;
const NodeDef* node = node_view->node();
const string& node_name = node->name();
const int node_index = node_view->node_index_;
node_view->fanins_set_.reserve(node->input_size());
for (const string& input : node->input()) {
TensorId fanin_id = ParseTensorName(input);
if (fanin_id.node() == node_name) {
return errors::InvalidArgument(kGraphViewError, "node '", node_name,
"' has self cycle fanin '", input, "'.");
}
bool is_control = IsTensorIdControl(fanin_id);
if (!is_control && has_observed_control) {
return errors::InvalidArgument(kGraphViewError, "node '", node_name,
"' has regular fanin '", input,
"' after controlling fanins.");
}
auto it = node_index_by_name_.find(fanin_id.node());
if (it == node_index_by_name_.end()) {
return errors::InvalidArgument(kGraphViewError, "node '", node_name,
"' has missing fanin '", input, "'.");
}
const int fanin_node_index = it->second;
NodeView& fanin_node_view = nodes_[fanin_node_index];
if (is_control) {
fanin_node_view.controlled_fanouts_.emplace_back(this, node_index,
Graph::kControlSlot);
node_view->controlling_fanins_.emplace_back(this, fanin_node_index,
Graph::kControlSlot);
node_view->fanins_set_.emplace(fanin_node_view.node(),
Graph::kControlSlot);
has_observed_control = true;
} else {
if (fanin_node_view.regular_fanouts_by_port_.size() <
fanin_id.index() + 1) {
fanin_node_view.regular_fanouts_by_port_.resize(fanin_id.index() + 1);
}
fanin_node_view.regular_fanouts_by_port_[fanin_id.index()].emplace_back(
this, node_index, node_view->regular_fanins_.size());
++fanin_node_view.num_regular_fanouts_;
node_view->regular_fanins_.emplace_back(this, fanin_node_index,
fanin_id.index());
node_view->fanins_set_.emplace(fanin_node_view.node(), fanin_id.index());
}
}
return Status::OK();
}
MutableFaninView::MutableFaninView(MutableNodeView* node_view, int index)
: NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
index) {}
MutableFanoutView::MutableFanoutView(MutableNodeView* node_view, int index)
: NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_,
index) {}
NodeDef* MutableNodeView::node() const {
return graph_view_->graph()->mutable_node(node_index_);
}
bool MutableNodeView::HasFanin(const MutableFanoutView& fanin) const {
if (fanin.index() < Graph::kControlSlot || graph_view_ != fanin.graph_view_) {
return false;
}
return fanins_count_.contains(
{&graph_view_->graph_->node(fanin.node_index_), fanin.index()});
}
bool MutableNodeView::HasFanout(const MutableFaninView& fanout) const {
if (fanout.index() < Graph::kControlSlot ||
graph_view_ != fanout.graph_view_) {
return false;
}
MutableNodeView* view = fanout.node_view();
if (view == nullptr) {
return false;
} else if (fanout.index() == Graph::kControlSlot) {
return view->fanins_count_.contains({this->node(), Graph::kControlSlot});
} else if (fanout.index() >= view->regular_fanins_.size()) {
return false;
}
return view->regular_fanins_[fanout.index()].node_index_ == node_index_;
}
const MutableFanoutView& MutableNodeView::GetMissingFanin() const {
return graph_view_->missing_fanin_;
}
const std::vector<MutableFaninView>& MutableNodeView::GetMissingFanout() const {
return graph_view_->missing_fanout_;
}
namespace {
const char kMutationAddNodeError[] = "Mutation::AddNode error: ";
bool IsTensorIdRegular(const TensorId& tensor_id) {
return tensor_id.index() >= 0;
}
} // namespace
Mutation::Mutation(MutableGraphView* graph_view) : graph_view_(graph_view) {}
MutationNewNode Mutation::AddNode(NodeDef&& node, Status* status) {
bool has_observed_control = false;
const string& node_name = node.name();
std::vector<SafeTensorId> regular_fanins;
absl::flat_hash_set<string> controlling_fanins;
const int num_fanins = node.input_size();
for (int i = 0; i < num_fanins; ++i) {
const string& input = node.input(i);
TensorId fanin_id = ParseTensorName(input);
if (fanin_id.node() == node_name) {
*status =
errors::InvalidArgument(kMutationAddNodeError, "node '", node_name,
"' has self cycle fanin '", input, "'.");
return MutationNewNode(this, mutation_counter_, internal::kMissingIndex);
}
bool is_control = IsTensorIdControl(fanin_id);
if (is_control) {
has_observed_control = true;
controlling_fanins.emplace(fanin_id.node());
} else if (has_observed_control) {
*status = errors::InvalidArgument(kMutationAddNodeError, "node '",
node_name, "' has regular fanin '",
input, "' after controlling fanins.");
return MutationNewNode(this, mutation_counter_, internal::kMissingIndex);
} else {
regular_fanins.emplace_back(fanin_id);
}
}
node.mutable_input()->Clear();
new_nodes_.emplace_back(graph_view_, std::move(node));
MutationNewNodeHolder& mutation_node = new_nodes_.back();
mutation_node.regular_fanins = std::move(regular_fanins);
mutation_node.num_regular_fanins = mutation_node.regular_fanins.size();
mutation_node.controlling_fanins = std::move(controlling_fanins);
*status = Status::OK();
return MutationNewNode(this, mutation_counter_, new_nodes_.size() - 1);
}
void Mutation::AddMutation(
MutableNodeView* node,
std::function<void(MutableNodeViewDiff*)> mutate_fn) {
DCHECK(node->graph_view_ == graph_view_);
if (node->update_index_ == internal::kMissingIndex) {
node->update_index_ = updated_nodes_.size();
updated_nodes_.emplace_back(graph_view_, node->node_index_);
mutate_fn(&updated_nodes_.back());
} else if (!removed_nodes_[node->node_index_]) {
auto& diff = updated_nodes_[node->update_index_];
mutate_fn(&diff);
}
}
void Mutation::RemoveNode(MutableNodeView* node) {
auto& update_index = node->update_index_;
if (update_index != internal::kMissingIndex) {
if (update_index < updated_nodes_.size() - 1) {
graph_view_->nodes_[updated_nodes_.back().node_index].update_index_ =
update_index;
std::swap(updated_nodes_[update_index], updated_nodes_.back());
}
updated_nodes_.pop_back();
update_index = internal::kMissingIndex;
}
removed_nodes_[node->node_index_] = true;
}
void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) {
AddMutation(node, [name](MutableNodeViewDiff* diff) {
internal::UpdateName(diff, name);
});
}
void Mutation::UpdateNodeName(const MutationNewNode& node,
absl::string_view name) {
DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
internal::UpdateName(&new_nodes_[node.index_], name);
}
void Mutation::UpdateNodeOp(MutableNodeView* node, absl::string_view op) {
AddMutation(
node, [op](MutableNodeViewDiff* diff) { internal::UpdateOp(diff, op); });
}
void Mutation::UpdateNodeOp(const MutationNewNode& node, absl::string_view op) {
DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
internal::UpdateOp(&new_nodes_[node.index_], op);
}
void Mutation::UpdateNodeDevice(MutableNodeView* node,
absl::string_view device) {
AddMutation(node, [device](MutableNodeViewDiff* diff) {
internal::UpdateDevice(diff, device);
});
}
void Mutation::UpdateNodeDevice(const MutationNewNode& node,
absl::string_view device) {
DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
internal::UpdateDevice(&new_nodes_[node.index_], device);
}
void Mutation::AddOrUpdateRegularFanin(MutableNodeView* node, int index,
const TensorId& fanin) {
AddMutation(node, [index, fanin](MutableNodeViewDiff* diff) {
internal::AddOrUpdateRegularFanin(diff, index, fanin);
});
}
void Mutation::AddOrUpdateRegularFanin(const MutationNewNode& node, int index,
const TensorId& fanin) {
DCHECK(node.mutation_ == this &&
node.mutation_counter_ == mutation_counter_ && index >= 0 &&
IsTensorIdRegular(fanin));
internal::AddOrUpdateRegularFanin(&new_nodes_[node.index_], index, fanin);
}
void Mutation::RemoveRegularFanin(MutableNodeView* node, int index) {
AddMutation(node, [index](MutableNodeViewDiff* diff) {
internal::RemoveRegularFanin(diff, index);
});
}
void Mutation::RemoveRegularFanin(const MutationNewNode& node, int index) {
DCHECK(node.mutation_ == this &&
node.mutation_counter_ == mutation_counter_ && index >= 0);
internal::RemoveRegularFanin(&new_nodes_[node.index_], index);
}
void Mutation::AddControllingFanin(MutableNodeView* node,
absl::string_view fanin_node_name) {
AddMutation(node, [node, fanin_node_name](MutableNodeViewDiff* diff) {
auto it = node->controlling_fanins_index_.find(fanin_node_name);
const int control_index = it != node->controlling_fanins_index_.end()
? it->second
: internal::kMissingIndex;
internal::AddControllingFanin(diff, control_index, fanin_node_name);
});
}
void Mutation::AddControllingFanin(const MutationNewNode& node,
absl::string_view fanin_node_name) {
DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
internal::AddControllingFanin(&new_nodes_[node.index_], fanin_node_name);
}
void Mutation::RemoveControllingFanin(MutableNodeView* node,
absl::string_view fanin_node_name) {
AddMutation(node, [node, fanin_node_name](MutableNodeViewDiff* diff) {
auto it = node->controlling_fanins_index_.find(fanin_node_name);
const int control_index = it != node->controlling_fanins_index_.end()
? it->second
: internal::kMissingIndex;
internal::RemoveControllingFanin(diff, control_index, fanin_node_name);
});
}
void Mutation::RemoveControllingFanin(const MutationNewNode& node,
absl::string_view fanin_node_name) {
DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
internal::RemoveControllingFanin(&new_nodes_[node.index_], fanin_node_name);
}
void Mutation::AddOrUpdateNodeAttr(MutableNodeView* node,
absl::string_view attr_name,
const AttrValue& attr_value) {
AddMutation(node, [attr_name, attr_value](MutableNodeViewDiff* diff) {
internal::AddOrUpdateAttribute(diff, attr_name, attr_value);
});
}
void Mutation::AddOrUpdateNodeAttr(const MutationNewNode& node,
absl::string_view attr_name,
const AttrValue& attr_value) {
DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
internal::AddOrUpdateAttribute(&new_nodes_[node.index_], attr_name,
attr_value);
}
void Mutation::RemoveNodeAttr(MutableNodeView* node,
absl::string_view attr_name) {
AddMutation(node, [attr_name](MutableNodeViewDiff* diff) {
internal::RemoveAttribute(diff, attr_name);
});
}
void Mutation::RemoveNodeAttr(const MutationNewNode& node,
absl::string_view attr_name) {
DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_);
internal::RemoveAttribute(&new_nodes_[node.index_], attr_name);
}
void Mutation::ResetInternal() {
std::vector<MutableNodeViewDiff>().swap(updated_nodes_);
std::vector<bool>(graph_view_->NumNodes()).swap(removed_nodes_);
std::vector<MutationNewNodeHolder>().swap(new_nodes_);
}
void Mutation::Reset() {
for (const auto& update : updated_nodes_) {
graph_view_->nodes_[update.node_index].update_index_ =
internal::kMissingIndex;
}
ResetInternal();
}
Status Mutation::Apply() { return graph_view_->ApplyMutationInternal(); }
namespace {
const char kMutableGraphViewError[] =
"MutableGraphView::MutableGraphView error: ";
const char kMutableGraphViewApplyError[] = "Mutation::Apply error: ";
inline void IncrementFaninCount(
absl::flat_hash_map<internal::NodeDefAndPortIndex, int>* fanins_count,
const internal::NodeDefAndPortIndex& fanin) {
++(*fanins_count)[fanin];
}
inline void DecrementFaninCount(
absl::flat_hash_map<internal::NodeDefAndPortIndex, int>* fanins_count,
const internal::NodeDefAndPortIndex& fanin) {
auto it = fanins_count->find(fanin);
if (it != fanins_count->end()) {
if (it->second <= 1) {
fanins_count->erase(it);
} else {
--it->second;
}
}
}
} // namespace
MutableGraphView::MutableGraphView(GraphDef* graph, Status* status)
: GraphViewInternal(graph), mutation_(Mutation(this)) {
const int num_nodes = graph->node_size();
node_index_by_name_.reserve(num_nodes);
nodes_.reserve(num_nodes);
for (NodeDef& node : *graph->mutable_node()) {
if (!AddUniqueNodeInternal(&node)) {
*status = errors::InvalidArgument(
kMutableGraphViewError, "graph has multiple nodes with the name '",
node.name(), "'.");
Reset();
return;
}
}
std::vector<std::vector<TensorId>> fanins;
Status s = CheckFaninsInternal(&fanins);
if (!s.ok()) {
*status = s;
Reset();
return;
}
AddFaninsInternal(&fanins);
mutation_.ResetInternal();
*status = Status::OK();
}
Mutation* MutableGraphView::GetMutationBuilder() { return &mutation_; }
bool MutableGraphView::AddUniqueNodeInternal(NodeDef* node) {
const int node_index = node_index_by_name_.size();
auto it = node_index_by_name_.emplace(node->name(), node_index);
if (it.second) {
nodes_.emplace_back(this, node_index);
return true;
}
return false;
}
Status MutableGraphView::CheckFaninsInternal(
std::vector<std::vector<TensorId>>* fanins) {
const int num_nodes = nodes_.size();
fanins->reserve(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
bool has_observed_control = false;
const NodeDef* node = nodes_[i].node();
const string& node_name = node->name();
std::vector<TensorId> node_fanins;
node_fanins.reserve(node->input_size());
for (const string& input : node->input()) {
TensorId fanin_id = ParseTensorName(input);
if (fanin_id.node() == node_name) {
return errors::InvalidArgument(kMutableGraphViewError, "node '",
node_name, "' has self cycle fanin '",
input, "'.");
}
bool is_control = IsTensorIdControl(fanin_id);
if (!is_control && has_observed_control) {
return errors::InvalidArgument(kMutableGraphViewError, "node '",
node_name, "' has regular fanin '",
input, "' after controlling fanins.");
}
if (!node_index_by_name_.contains(fanin_id.node())) {
return errors::InvalidArgument(kMutableGraphViewError, "node '",
node_name, "' has missing fanin '",
input, "'.");
}
if (is_control) {
has_observed_control = true;
}
node_fanins.push_back(std::move(fanin_id));
}
fanins->push_back(std::move(node_fanins));
}
return Status::OK();
}
void MutableGraphView::AddFaninsInternal(
std::vector<std::vector<TensorId>>* fanins) {
const int num_nodes = nodes_.size();
for (int i = 0; i < num_nodes; ++i) {
MutableNodeView& node_view = nodes_[i];
NodeDef* node = node_view.node();
std::vector<TensorId>& node_fanins = fanins->at(i);
absl::flat_hash_set<absl::string_view> observed_controls;
int pos = 0;
const int last_idx = node_fanins.size() - 1;
int last_pos = last_idx;
node_view.fanins_count_.reserve(node->input_size());
node_view.controlling_fanins_index_.reserve(node->input_size());
while (pos <= last_pos) {
const TensorId& fanin_id = node_fanins[pos];
bool is_control = IsTensorIdControl(fanin_id);
const int fanin_node_index = node_index_by_name_[fanin_id.node()];
MutableNodeView& fanin_node_view = nodes_[fanin_node_index];
if (is_control) {
if (gtl::InsertIfNotPresent(&observed_controls, fanin_id.node())) {
fanin_node_view.controlled_fanouts_.emplace_back(
this, i, Graph::kControlSlot,
node_view.controlling_fanins_.size());
node_view.controlling_fanins_.emplace_back(
this, fanin_node_index, Graph::kControlSlot,
fanin_node_view.controlled_fanouts_.size() - 1);
IncrementFaninCount(
&node_view.fanins_count_,
{&graph_->node(fanin_node_index), Graph::kControlSlot});
node_view.controlling_fanins_index_.emplace(
fanin_id.node(), pos - node_view.NumRegularFanins());
++pos;
} else {
node->mutable_input()->SwapElements(pos, last_pos);
std::swap(node_fanins[pos], node_fanins[last_pos]);
--last_pos;
}
} else {
if (fanin_node_view.regular_fanouts_by_port_.size() <
fanin_id.index() + 1) {
fanin_node_view.regular_fanouts_by_port_.resize(fanin_id.index() + 1);
}
auto& fanin_regular_fanouts =
fanin_node_view.regular_fanouts_by_port_[fanin_id.index()];
fanin_regular_fanouts.emplace_back(this, i,
node_view.regular_fanins_.size(),
node_view.regular_fanins_.size());
++fanin_node_view.num_regular_fanouts_;
node_view.regular_fanins_.emplace_back(
this, fanin_node_index, fanin_id.index(),
fanin_regular_fanouts.size() - 1);
IncrementFaninCount(
&node_view.fanins_count_,
{&graph_->node(fanin_node_index), fanin_id.index()});
++pos;
}
}
if (last_pos < last_idx) {
node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
}
}
}
Status MutableGraphView::GetNodeNamesAndPartitionUpdatedNodes(
absl::flat_hash_map<absl::string_view, int>* node_names,
std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
std::vector<int>* inplace_nodes,
std::vector<int>* empty_diff_node_indices) {
// For all nodes to be removed and renamed, mark their original names as
// missing and put associated node index in graph.
for (const auto& diff : mutation_.updated_nodes_) {
if (diff.update_name) {
const int index = diff.node_index;
const string& node_name = nodes_[index].GetName();
node_names->emplace(node_name, index);
}
}
for (int i = 0; i < mutation_.removed_nodes_.size(); ++i) {
if (mutation_.removed_nodes_[i]) {
const string& node_name = nodes_[i].GetName();
node_names->emplace(node_name, i);
}
}
auto name_conflict = [](const absl::string_view node_name) {
return errors::InvalidArgument(kMutableGraphViewApplyError,
"multiple nodes with the name: '", node_name,
"' exists in Mutation.");
};
// Partition updated nodes by if they will be renamed or not.
const int num_updated_nodes = mutation_.updated_nodes_.size();
renamed_nodes->reserve(num_updated_nodes);
inplace_nodes->reserve(num_updated_nodes);
empty_diff_node_indices->reserve(num_updated_nodes);
for (int i = 0; i < num_updated_nodes; ++i) {
auto& diff = mutation_.updated_nodes_[i];
if (internal::IsEmpty(&diff)) {
empty_diff_node_indices->emplace_back(diff.node_index);
continue;
}
// Get name of updated node after potential mutation.
const string& node_name =
diff.update_name ? diff.name : nodes_[diff.node_index].GetName();
auto it = node_names->insert({node_name, internal::kNodeNamePresent});
if (!it.second) {
if (it.first->second == internal::kNodeNamePresent) {
// Another node in the mutation is already using this name, which will
// result in a conflict.
return name_conflict(node_name);
} else {
// Mark name as present (node was marked missing from either being
// removed or renamed).
it.first->second = internal::kNodeNamePresent;
}
}
if (diff.update_name) {
// Lookup new name of node in current graph. If a node has such name,
// store its index for later lookups as this node will be overwritten.
auto node_name_it = node_index_by_name_.find(node_name);
const int overwritten_node_index =
node_name_it != node_index_by_name_.end() ? node_name_it->second
: internal::kMissingIndex;
renamed_nodes->emplace_back(i, overwritten_node_index);
} else {
inplace_nodes->push_back(i);
}
}
// Get names of new nodes after potential mutation.
for (const auto& new_node : mutation_.new_nodes_) {
const string& node_name = new_node.node.name();
auto it = node_names->insert({node_name, internal::kNodeNamePresent});
if (it.second) {
continue;
}
if (it.first->second == internal::kNodeNamePresent) {
// Another node in the mutation is already using this name, which will
// result in a conflict.
return name_conflict(node_name);
} else {
// Mark name as present (node was marked missing from either being removed
// or renamed).
it.first->second = internal::kNodeNamePresent;
}
}
return Status::OK();
}
Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed(
const absl::flat_hash_map<absl::string_view, int>& node_names,
const std::vector<RenamedOrOverwrittenNode>& renamed_nodes) {
auto bad_fanout = [](absl::string_view fanout_node_name,
absl::string_view node_name) {
return errors::InvalidArgument(
kMutableGraphViewApplyError, "fanout '", fanout_node_name,
"' exist for missing node '", node_name, "'.");
};
// Lookup nodes to be overwritten.
std::vector<bool> overwritten_nodes(NumNodes());
for (auto& renamed_node : renamed_nodes) {
if (renamed_node.overwritten_node_index_ == internal::kMissingIndex) {
continue;
}
overwritten_nodes[renamed_node.overwritten_node_index_] = true;
}
// Check if removed nodes and previous state of renamed nodes have no fanouts.
for (const auto& node_name_state : node_names) {
if (node_name_state.second == internal::kNodeNamePresent) {
continue;
}
const MutableNodeView& node_view = nodes_[node_name_state.second];
for (const auto& regular_fanouts : node_view.GetRegularFanouts()) {
for (const auto& regular_fanout : regular_fanouts) {
// Check all fanouts of a single port.
MutableNodeView* fanout_view = regular_fanout.node_view();
if (fanout_view->update_index_ == internal::kMissingIndex) {
if (mutation_.removed_nodes_[fanout_view->node_index_]) {
// Fanout node will be removed, this can be ignored.
continue;
} else if (!overwritten_nodes[fanout_view->node_index_]) {
// Fanout is not updated or removed/overwritten.
return bad_fanout(fanout_view->GetName(), node_name_state.first);
}
} else {
auto& diff = mutation_.updated_nodes_[fanout_view->update_index_];
const int last_index = fanout_view->NumRegularFanins() -
diff.num_regular_inputs_to_remove - 1;
if (regular_fanout.index() > last_index) {
// Fanin of fanout is removed, this can be ignored.
continue;
}
// Check if fanin is updated.
if (diff.regular_inputs_to_update.find(regular_fanout.index()) ==
diff.regular_inputs_to_update.end()) {
return bad_fanout(fanout_view->GetName(), node_name_state.first);
}
}
}
}
for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
MutableNodeView* fanout_view = controlled_fanout.node_view();
if (fanout_view->update_index_ == internal::kMissingIndex) {
if (mutation_.removed_nodes_[fanout_view->node_index_]) {
// Fanout node will be removed, this can be ignored.
continue;
} else if (!overwritten_nodes[fanout_view->node_index_]) {
// Fanout is not updated or removed/overwritten.
return bad_fanout(fanout_view->GetName(), node_name_state.first);
}
} else {
auto& diff = mutation_.updated_nodes_[fanout_view->update_index_];
// Check if controlling fanin is removed.
if (diff.controlling_inputs_to_remove.find(
controlled_fanout.fanin_index_) ==
diff.controlling_inputs_to_remove.end()) {
return bad_fanout(fanout_view->GetName(), node_name_state.first);
}
}
}
}
return Status::OK();
}
Status MutableGraphView::CheckNodeNamesAndFanins(
const absl::flat_hash_map<absl::string_view, int>& node_names,
const std::vector<RenamedOrOverwrittenNode>& renamed_nodes,
const std::vector<int>& inplace_nodes) {
// Check if removed/missing node fanouts are valid.
TF_RETURN_IF_ERROR(
RemovedOrMissingNodeFanoutsWellFormed(node_names, renamed_nodes));
// Check if updated nodes and their fanins are valid.
for (auto& inplace_node : inplace_nodes) {
auto& diff = mutation_.updated_nodes_[inplace_node];
if (!internal::IsWellFormed(&diff, node_names)) {
return errors::InvalidArgument(
kMutableGraphViewApplyError, "inplace updated node '",
nodes_[diff.node_index].GetName(), "' is ill-formed.");
}
}
for (auto& renamed_node : renamed_nodes) {
auto& diff = mutation_.updated_nodes_[renamed_node.renamed_update_index_];
if (!internal::IsWellFormed(&diff, node_names)) {
return errors::InvalidArgument(
kMutableGraphViewApplyError, "renamed updated node '", diff.name,
"' ('", nodes_[diff.node_index].GetName(), "') is ill-formed.");
}
}
// Check if new nodes and their fanins are valid.
for (auto& new_node : mutation_.new_nodes_) {
if (!internal::IsWellFormed(&new_node, node_names)) {
return errors::InvalidArgument(kMutableGraphViewApplyError, "new node '",
new_node.node.name(), "' is ill-formed.");
}
}
return Status::OK();
}
Status MutableGraphView::CheckKernelRegisteredForNodes() {
Status s;
for (auto& diff : mutation_.updated_nodes_) {
if (internal::IsEmpty(&diff)) {
continue;
}
NodeDef* node = nodes_[diff.node_index].node();
diff.processed_attrs =
AttrValueMap(node->attr().begin(), node->attr().end());
for (const auto& attr_to_remove : diff.attrs_to_remove) {
diff.processed_attrs.erase(attr_to_remove);
}
for (const auto& attr_to_add : diff.attrs_to_add) {
gtl::InsertOrUpdate(&diff.processed_attrs, attr_to_add.first,
attr_to_add.second);
}
const string& device = diff.update_device ? diff.device : node->device();
if (device.empty()) {
continue;
}
s = IsKernelRegisteredForNode(diff.update_name ? diff.name : node->name(),
node->has_experimental_debug_info(),
node->experimental_debug_info(),
diff.update_op ? diff.op : node->op(), device,
AttrSlice(&diff.processed_attrs));
if (!s.ok()) {
return errors::InvalidArgument(kMutableGraphViewApplyError,
s.error_message());
}
}
for (const auto& new_node_holder : mutation_.new_nodes_) {
const auto& new_node_def = new_node_holder.node;
if (new_node_def.device().empty()) {
continue;
}
s = IsKernelRegisteredForNode(new_node_def);
if (!s.ok()) {
return errors::InvalidArgument(kMutableGraphViewApplyError,
s.error_message());
}
}
return Status::OK();
}
template <typename T>
void MutableGraphView::ReplaceNodeFanouts(MutableNodeView* node, T* fanouts) {
node->num_regular_fanouts_ = fanouts->num_regular_fanouts_;
node->regular_fanouts_by_port_ = std::move(fanouts->regular_fanouts_by_port_);
for (int i = 0; i < node->regular_fanouts_by_port_.size(); ++i) {
for (int j = 0; j < node->regular_fanouts_by_port_[i].size(); ++j) {
auto& fanout = node->regular_fanouts_by_port_[i][j];
auto* fanout_node_view = fanout.node_view();
auto& fanout_fanin = fanout_node_view->regular_fanins_[fanout.index()];
auto* fanout_fanins_count = &fanout_node_view->fanins_count_;
DecrementFaninCount(
fanout_fanins_count,
{&graph_->node(fanout_fanin.node_index_), fanout_fanin.index()});
fanout_fanin.node_index_ = node->node_index_;
IncrementFaninCount(
fanout_fanins_count,
{&graph_->node(node->node_index_), fanout_fanin.index()});
}
}
node->controlled_fanouts_ = std::move(fanouts->controlled_fanouts_);
for (int i = 0; i < node->controlled_fanouts_.size(); ++i) {
auto& fanout = node->controlled_fanouts_[i];
auto* fanout_node_view = fanout.node_view();
auto& fanout_fanin =
fanout_node_view->controlling_fanins_[fanout.fanin_index_];
auto* fanout_fanins_count = &fanout_node_view->fanins_count_;
DecrementFaninCount(
fanout_fanins_count,
{&graph_->node(fanout_fanin.node_index_), Graph::kControlSlot});
fanout_fanin.node_index_ = node->node_index_;
fanout_fanin.fanout_index_ = i;
IncrementFaninCount(fanout_fanins_count, {&graph_->node(node->node_index_),
Graph::kControlSlot});
}
}
void MutableGraphView::FixRenamedNodes(
std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
absl::flat_hash_map<string, NodeViewFanouts>* renamed_fanouts,
std::vector<bool>* overwritten_name_removed_nodes) {
// Extract all renamed node fanouts.
renamed_fanouts->reserve(renamed_nodes->size());
for (auto& renamed : *renamed_nodes) {
auto& diff = mutation_.updated_nodes_[renamed.renamed_update_index_];
// Remove node index by name from graph.
node_index_by_name_.erase(nodes_[diff.node_index].GetName());
MutableNodeView& renamed_node = nodes_[diff.node_index];
renamed_fanouts->try_emplace(
renamed_node.GetName(),
std::move(renamed_node.regular_fanouts_by_port_),
renamed_node.num_regular_fanouts_,
std::move(renamed_node.controlled_fanouts_));
}
// Replace renamed node fanouts with fanouts associated with updated name.
for (auto& renamed : *renamed_nodes) {
auto& diff = mutation_.updated_nodes_[renamed.renamed_update_index_];
MutableNodeView& renamed_node = nodes_[diff.node_index];
auto fanouts_it = renamed_fanouts->find(diff.name);
if (fanouts_it != renamed_fanouts->end()) {
// Another renamed node's fanout.
auto& fanouts = fanouts_it->second;
ReplaceNodeFanouts(&renamed_node, &fanouts);
renamed_fanouts->erase(fanouts_it);
// Node to be overwritten is being renamed, so it won't be overwritten.
renamed.overwritten_node_index_ = internal::kMissingIndex;
} else if (renamed.overwritten_node_index_ != internal::kMissingIndex) {
// Existing node in graph.
MutableNodeView& node_to_overwrite =
nodes_[renamed.overwritten_node_index_];
ReplaceNodeFanouts(&renamed_node, &node_to_overwrite);
node_index_by_name_.erase(node_to_overwrite.GetName());
if (mutation_.removed_nodes_[node_to_overwrite.node_index_]) {
(*overwritten_name_removed_nodes)[node_to_overwrite.node_index_] = true;
}
} else {
// No existing fanouts.
renamed_node.num_regular_fanouts_ = 0;
}
// Update node name.
renamed_node.node()->set_name(diff.name);
diff.update_name = false;
diff.name.clear();
// Rehash renamed nodes with updated name.
node_index_by_name_.emplace(renamed_node.GetName(), diff.node_index);
}
}
void MutableGraphView::AddNewNodes(
absl::flat_hash_map<string, NodeViewFanouts>* renamed_fanouts,
std::vector<int>* new_node_indices) {
new_node_indices->reserve(mutation_.new_nodes_.size());
for (auto& new_node : mutation_.new_nodes_) {
int node_index;
auto graph_it = node_index_by_name_.find(new_node.node.name());
if (graph_it != node_index_by_name_.end()) {
// Overwrite existing node.
node_index = graph_it->second;
MutableNodeView& node_view = nodes_[node_index];
RemoveAllFaninFanoutInternal(&node_view);
auto* node_def = graph_->mutable_node(node_index);
node_def->mutable_op()->swap(*new_node.node.mutable_op());
node_def->mutable_device()->swap(*new_node.node.mutable_device());
node_def->mutable_input()->Clear();
node_def->mutable_attr()->swap(*new_node.node.mutable_attr());
mutation_.removed_nodes_[node_index] = false;
} else {
// New node.
auto* new_node_def = graph_->add_node();
*new_node_def = std::move(new_node.node);
node_index = nodes_.size();
nodes_.emplace_back(this, node_index);
MutableNodeView& new_node_view = nodes_.back();
auto it = renamed_fanouts->find(new_node_view.GetName());
if (it != renamed_fanouts->end()) {
// Reuse fanouts of renamed node.
NodeViewFanouts& fanouts = it->second;
ReplaceNodeFanouts(&new_node_view, &fanouts);
renamed_fanouts->erase(it);
}
node_index_by_name_.emplace(new_node_view.GetName(), node_index);
}
new_node_indices->emplace_back(node_index);
}
}
void MutableGraphView::FixRenamedFanouts(
const absl::flat_hash_map<string, NodeViewFanouts>& renamed_fanouts) {
// Leftover fanouts in renamed_fanouts are due to nodes not existing anymore
// or a node being renamed without another node taking its place. For these
// leftover fanouts, mark their respective fanin fanout_index_ to
// internal::kMissingIndex as an indicator so when it comes to updating or
// removing fanins inplace, nodes with the same index don't get affected and
// other fanouts are accidently removed.
for (auto& renamed_fanout : renamed_fanouts) {
for (auto& regular_fanouts :
renamed_fanout.second.regular_fanouts_by_port_) {
for (auto& fanout : regular_fanouts) {
auto* fanout_node_view = fanout.node_view();
auto& fanin = fanout_node_view->regular_fanins_[fanout.index()];
fanout_node_view->fanins_count_.erase(
{fanin.node_view()->node(), fanin.index()});
fanin.fanout_index_ = internal::kMissingIndex;
}
}
for (auto& fanout : renamed_fanout.second.controlled_fanouts_) {
auto* fanout_node_view = fanout.node_view();
auto& fanin = fanout_node_view->controlling_fanins_[fanout.fanin_index_];
fanout_node_view->fanins_count_.erase(
{fanin.node_view()->node(), Graph::kControlSlot});
fanout_node_view->controlling_fanins_index_.erase(renamed_fanout.first);
fanin.fanout_index_ = internal::kMissingIndex;
}
}
}
inline void MutableGraphView::RemoveRegularFaninFanoutInternal(
MutableNodeView* node_view, int i) {
MutableFanoutView& fanin = node_view->regular_fanins_[i];
// Fanin was marked as removed via FixRenamedFanouts.
if (fanin.fanout_index_ == internal::kMissingIndex) {
return;
}
DecrementFaninCount(&node_view->fanins_count_,
{&graph_->node(fanin.node_index_), fanin.index()});
auto* fanin_node_view = fanin.node_view();
auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin.index()];
if (fanin.fanout_index_ < fanouts.size() - 1) {
// Swap fanout with last fanout in vector, and update it's associated fanin
// index.
MutableFaninView& last_fanout = fanouts.back();
last_fanout.node_view()
->regular_fanins_[last_fanout.index()]
.fanout_index_ = fanin.fanout_index_;
std::swap(last_fanout, fanouts[fanin.fanout_index_]);
}
// Remove fanout.
fanouts.pop_back();
--fanin.node_view()->num_regular_fanouts_;
// Resize fanouts. Fanouts may not be removed sequentially in relation to
// output port, so trailing empty output ports may be left behind. It is
// necessary to loop through all of the output ports to determine the maximum
// output port before resizing.
int last_fanout_index = fanin_node_view->regular_fanouts_by_port_.size();
for (int i = fanin_node_view->regular_fanouts_by_port_.size() - 1; i >= 0;
--i) {
if (fanin_node_view->regular_fanouts_by_port_[i].empty()) {
last_fanout_index = i;
} else {
break;
}
}
if (last_fanout_index < fanin_node_view->regular_fanouts_by_port_.size()) {
fanin_node_view->regular_fanouts_by_port_.resize(last_fanout_index);
}
}
inline void MutableGraphView::AddRegularFaninInternal(
MutableNodeView* node_view, const SafeTensorId& fanin_id) {
MutableNodeView* fanin_node_view = GetNode(fanin_id.node());
// Resize fanouts to include new output port index.
if (fanin_node_view->regular_fanouts_by_port_.size() < fanin_id.index() + 1) {
fanin_node_view->regular_fanouts_by_port_.resize(fanin_id.index() + 1);
}
// Add node as fanout to fanin.
auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin_id.index()];
fanouts.emplace_back(this, node_view->node_index(),
node_view->regular_fanins_.size(),
node_view->regular_fanins_.size());
++fanin_node_view->num_regular_fanouts_;
// Add fanin to node.
node_view->regular_fanins_.emplace_back(this, fanin_node_view->node_index(),
fanin_id.index(), fanouts.size() - 1);
IncrementFaninCount(
&node_view->fanins_count_,
{&graph_->node(fanin_node_view->node_index()), fanin_id.index()});
}
inline void MutableGraphView::UpdateRegularFaninInternal(
MutableNodeView* node_view, const int i, const SafeTensorId& fanin_id) {
// Remove fanin.
RemoveRegularFaninFanoutInternal(node_view, i);
MutableNodeView* fanin_node_view = GetNode(fanin_id.node());
// Resize fanouts to include new output port index.
if (fanin_node_view->regular_fanouts_by_port_.size() < fanin_id.index() + 1) {
fanin_node_view->regular_fanouts_by_port_.resize(fanin_id.index() + 1);
}
// Add node as fanout to fanin.
auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin_id.index()];
fanouts.emplace_back(this, node_view->node_index(), i, i);
++fanin_node_view->num_regular_fanouts_;
// Replace fanin in node.
node_view->regular_fanins_[i] =
MutableFanoutView(this, fanin_node_view->node_index(), fanin_id.index(),
fanouts.size() - 1);
IncrementFaninCount(
&node_view->fanins_count_,
{&graph_->node(fanin_node_view->node_index()), fanin_id.index()});
}
inline void MutableGraphView::RemoveControllingFaninFanoutInternal(
MutableNodeView* node_view, int i) {
auto& control_to_remove = node_view->controlling_fanins_[i];
if (control_to_remove.fanout_index_ != internal::kMissingIndex) {
// Update internal state associated with node.
node_view->fanins_count_.erase(
{control_to_remove.node_view()->node(), Graph::kControlSlot});
node_view->controlling_fanins_index_.erase(
control_to_remove.node_view()->GetName());
// Remove controlled fanout from controlling fanin, via swapping last
// controlled fanout in controlling fanin with controlled fanout to be
// removed.
auto* control_to_remove_view = control_to_remove.node_view();
if (control_to_remove.fanout_index_ <
control_to_remove_view->controlled_fanouts_.size() - 1) {
auto& control_to_remove_view_last_control =
control_to_remove_view->controlled_fanouts_.back();
control_to_remove_view_last_control.node_view()
->controlling_fanins_[control_to_remove_view_last_control
.fanin_index_]
.fanout_index_ = control_to_remove.fanout_index_;
std::swap(control_to_remove_view_last_control,
control_to_remove_view
->controlled_fanouts_[control_to_remove.fanout_index_]);
}
control_to_remove_view->controlled_fanouts_.pop_back();
}
}
inline void MutableGraphView::RemoveControllingFaninInternal(
MutableNodeView* node_view, const std::set<int>& indices_to_remove) {
const int num_regular_fanins = node_view->NumRegularFanins();
auto* mutable_input = node_view->node()->mutable_input();
// Iterate in descending order so indices stay consistent.
for (auto rit = indices_to_remove.rbegin(); rit != indices_to_remove.rend();
++rit) {
const int control_index = *rit;
RemoveControllingFaninFanoutInternal(node_view, control_index);
// Swap last controlling fanin in node with controlling fanin to be removed.
if (control_index < node_view->controlling_fanins_.size() - 1) {
auto& last_control = node_view->controlling_fanins_.back();
auto* last_control_view = last_control.node_view();
last_control_view->controlled_fanouts_[last_control.fanout_index_]
.fanin_index_ = control_index;
node_view->controlling_fanins_index_.find(last_control_view->GetName())
->second = control_index;
mutable_input->SwapElements(
num_regular_fanins + control_index,
num_regular_fanins + node_view->NumControllingFanins() - 1);
std::swap(last_control, node_view->controlling_fanins_[control_index]);
}
mutable_input->RemoveLast();
node_view->controlling_fanins_.pop_back();
}
}
inline void MutableGraphView::AddControllingFaninInternal(
MutableNodeView* node_view, absl::string_view fanin_node_name) {
NodeDef* node = node_view->node();
// Add controlling fanin to NodeDef.
node->add_input(AsControlDependency(string(fanin_node_name)));
MutableNodeView* fanin_node_view = GetNode(fanin_node_name);
const int index = node_view->controlling_fanins_.size();
fanin_node_view->controlled_fanouts_.emplace_back(
this, node_view->node_index(), Graph::kControlSlot, index);
node_view->controlling_fanins_.emplace_back(
this, fanin_node_view->node_index(), Graph::kControlSlot,
fanin_node_view->controlled_fanouts_.size() - 1);
IncrementFaninCount(
&node_view->fanins_count_,
{&graph_->node(fanin_node_view->node_index()), Graph::kControlSlot});
// Parse new fanin string for node name.
TensorId tensor_id = ParseTensorName(node->input(node->input_size() - 1));
node_view->controlling_fanins_index_.emplace(tensor_id.node(), index);
}
void MutableGraphView::ApplyNodeUpdates() {
for (auto& diff : mutation_.updated_nodes_) {
if (internal::IsEmpty(&diff)) {
continue;
}
MutableNodeView& node_view = nodes_[diff.node_index];
diff.node_index = internal::kMissingIndex;
// Clean up node view.
node_view.update_index_ = internal::kMissingIndex;
NodeDef* node_def = node_view.node();
// Set updated fields and attributes of node.
if (diff.update_op) {
node_def->set_op(diff.op);
}
if (diff.update_device) {
node_def->set_device(diff.device);
}
node_def->mutable_attr()->swap(diff.processed_attrs);
// Updated fanins. Only one of `regular_inputs_to_remove_` or
// `regular_inputs_to_add_` can be set.
if (diff.num_regular_inputs_to_remove > 0) {
// Truncate trailing regular fanins.
const int first_index =
node_view.NumRegularFanins() - diff.num_regular_inputs_to_remove;
for (int i = first_index; i < node_view.NumRegularFanins(); ++i) {
RemoveRegularFaninFanoutInternal(&node_view, i);
}
node_view.regular_fanins_.resize(first_index);
node_def->mutable_input()->DeleteSubrange(
node_view.NumRegularFanins(), diff.num_regular_inputs_to_remove);
} else if (diff.num_regular_inputs_to_add > 0) {
// Append regular fanins.
node_def->mutable_input()->Reserve(node_def->mutable_input()->size() +
diff.num_regular_inputs_to_add);
int curr_index = node_view.NumRegularFanins();
int curr_control_start = curr_index;
for (const SafeTensorId& fanin : diff.regular_inputs_to_add) {
AddRegularFaninInternal(&node_view, fanin);
node_def->add_input(SafeTensorIdToString(fanin));
node_def->mutable_input()->SwapElements(curr_index,
node_def->input_size() - 1);
if (curr_control_start == curr_index) {
curr_control_start = node_def->input_size() - 1;
}
++curr_index;
}
// Rotate shifted controlling fanins to match up with
// `node_view.controlling_fanins_` as `num_regular_inputs_to_add_` may not
// be a multiple of `num_regular_inputs_to_add_`. This is to prevent
// rehashing controlling fanins in `node_view.controlling_fanins_index_`.
if (node_view.NumControllingFanins() > 1 &&
curr_control_start != node_view.NumRegularFanins()) {
std::rotate(
node_def->mutable_input()->begin() + node_view.NumRegularFanins(),
node_def->mutable_input()->begin() + curr_control_start,
node_def->mutable_input()->end());
}
}
for (const auto& update_fanin : diff.regular_inputs_to_update) {
UpdateRegularFaninInternal(&node_view, update_fanin.first,
update_fanin.second);
node_def->set_input(update_fanin.first,
SafeTensorIdToString(update_fanin.second));
}
RemoveControllingFaninInternal(&node_view,
diff.controlling_inputs_to_remove);
node_def->mutable_input()->Reserve(node_def->mutable_input()->size() +
diff.controlling_inputs_to_add.size());
for (const auto& control_to_add : diff.controlling_inputs_to_add) {
AddControllingFaninInternal(&node_view, control_to_add);
}
}
}
void MutableGraphView::SetNewNodesFanins(
const std::vector<int>& new_node_indices) {
auto new_node = mutation_.new_nodes_.begin();
for (const int new_node_index : new_node_indices) {
MutableNodeView& new_node_view = nodes_[new_node_index];
NodeDef* new_node_def = new_node_view.node();
new_node_def->mutable_input()->Reserve(new_node->num_regular_fanins +
new_node->controlling_fanins.size());
for (const SafeTensorId& fanin : new_node->regular_fanins) {
AddRegularFaninInternal(&new_node_view, fanin);
new_node_def->add_input(SafeTensorIdToString(fanin));
}
for (const string& control_to_add : new_node->controlling_fanins) {
AddControllingFaninInternal(&new_node_view, control_to_add);
}
++new_node;
}
}
inline void MutableGraphView::RemoveAllFaninFanoutInternal(
MutableNodeView* node_view) {
const int num_regular_fanins = node_view->NumRegularFanins();
for (int i = 0; i < num_regular_fanins; ++i) {
RemoveRegularFaninFanoutInternal(node_view, i);
}
std::vector<MutableFanoutView>().swap(node_view->regular_fanins_);
const int num_controlling_fanins = node_view->NumControllingFanins();
for (int i = 0; i < num_controlling_fanins; ++i) {
RemoveControllingFaninFanoutInternal(node_view, i);
}
std::vector<MutableFanoutView>().swap(node_view->controlling_fanins_);
}
void MutableGraphView::RemoveNodesInternal(
const std::vector<RenamedOrOverwrittenNode>& renamed_nodes,
const std::vector<bool>& overwritten_name_removed_nodes) {
// Get all nodes overwritten by renamed nodes and remove their fanins.
std::vector<int> overwritten_nodes;
overwritten_nodes.reserve(renamed_nodes.size());
for (const auto& renamed : renamed_nodes) {
if (renamed.overwritten_node_index_ != internal::kMissingIndex) {
auto& node = nodes_[renamed.overwritten_node_index_];
RemoveAllFaninFanoutInternal(&node);
overwritten_nodes.emplace_back(renamed.overwritten_node_index_);
}
}
// Get all nodes explicitly marked for removal and remove their fanins.
std::vector<int> node_indices_to_remove;
node_indices_to_remove.reserve(mutation_.updated_nodes_.size() +
overwritten_nodes.size());
for (int i = 0; i < mutation_.removed_nodes_.size(); ++i) {
if (mutation_.removed_nodes_[i]) {
auto& node = nodes_[i];
RemoveAllFaninFanoutInternal(&node);
node_indices_to_remove.push_back(i);
if (!overwritten_name_removed_nodes[i]) {
node_index_by_name_.erase(node.GetName());
}
}
}
node_indices_to_remove.insert(node_indices_to_remove.end(),
overwritten_nodes.begin(),
overwritten_nodes.end());
std::set<int> sorted_node_indices_to_remove(node_indices_to_remove.begin(),
node_indices_to_remove.end());
// Iterate in descending order so indices stay consistent.
for (auto rit = sorted_node_indices_to_remove.rbegin();
rit != sorted_node_indices_to_remove.rend(); ++rit) {
const int removed_node_index = *rit;
MutableNodeView& last_node = nodes_.back();
if (last_node.node_index_ > removed_node_index) {
last_node.node_index_ = removed_node_index;
for (auto& regular_fanin : last_node.regular_fanins_) {
// Update fanouts of regular fanins with new index.
regular_fanin.node_view()
->regular_fanouts_by_port_[regular_fanin.index()]
[regular_fanin.fanout_index_]
.node_index_ = removed_node_index;
}
for (auto& controlling_fanin : last_node.controlling_fanins_) {
// Update fanouts of controlling fanins with new index.
controlling_fanin.node_view()
->controlled_fanouts_[controlling_fanin.fanout_index_]
.node_index_ = removed_node_index;
}
for (auto& regular_fanouts : last_node.regular_fanouts_by_port_) {
for (auto& regular_fanout : regular_fanouts) {
// Update fanins of regular fanouts.
MutableNodeView* fanout_node_view = regular_fanout.node_view();
fanout_node_view->regular_fanins_[regular_fanout.fanin_index_]
.node_index_ = removed_node_index;
}
}
for (auto& controlled_fanout : last_node.controlled_fanouts_) {
// Update fanins of controlled fanouts.
MutableNodeView* fanout_node_view = controlled_fanout.node_view();
fanout_node_view->controlling_fanins_[controlled_fanout.fanin_index_]
.node_index_ = removed_node_index;
}
const int last_node_index = nodes_.size() - 1;
std::swap(nodes_[last_node_index], nodes_[removed_node_index]);
graph()->mutable_node()->SwapElements(last_node_index,
removed_node_index);
node_index_by_name_.find(nodes_[removed_node_index].GetName())->second =
removed_node_index;
}
nodes_.pop_back();
graph()->mutable_node()->RemoveLast();
}
}
namespace {
constexpr int kTopologicalSortDone = -1;
const char kMutableGraphViewSortTopologicallyError[] =
"MutableGraphView::SortTopologically error: ";
// TraversalState is an enum representing the state of a node when it is being
// traversed via DFS.
enum TraversalState : uint8_t { PENDING, PROCESSING, PROCESSED };
// RecursionStackState is an enum representing the recursion stack state
// when using DFS iteratively. `ENTER` is the state representing entering into
// a recursive call, while `EXIT` is the state representing exiting a
// recursive call.
enum RecursionStackState : bool { ENTER, EXIT };
// RecursionStackEntry is a helper struct representing an instance of a
// recursive call in the iterative DFS simulating a recursive ordering.
struct RecursionStackEntry {
RecursionStackEntry(int node_index, RecursionStackState recursion_state)
: node_index(node_index), recursion_state(recursion_state) {}
const int node_index;
const RecursionStackState recursion_state;
};
// Edge is a helper struct representing an edge in the graph.
struct Edge {
Edge(int from, int to) : from(from), to(to) {}
const int from;
const int to;
};
} // namespace
Status MutableGraphView::SortTopologically(
bool ignore_cycles,
absl::Span<const TopologicalDependency> extra_dependencies) {
if (!mutation_.updated_nodes_.empty() || !mutation_.new_nodes_.empty()) {
// Cannot sort when there is an active mutation due to indices possibly
// being changed or invalidated.
return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
"active mutation exists.");
}
const int num_nodes = nodes_.size();
// Group extra dependencies by `from` node.
absl::flat_hash_map<int, std::vector<int>> extra_dependencies_by_parent;
for (const auto& extra_dependency : extra_dependencies) {
if (extra_dependency.graph_view_ != this ||
extra_dependency.from_ == extra_dependency.to_ ||
extra_dependency.from_ < 0 || extra_dependency.from_ >= num_nodes ||
extra_dependency.to_ < 0 || extra_dependency.to_ >= num_nodes) {
return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
"invalid extra dependencies.");
}
extra_dependencies_by_parent[extra_dependency.from_].push_back(
extra_dependency.to_);
}
// Reversed colored post-order DFS traversal. This does not fail on cycles,
// but there are no guarantees on ordering within a cycle.
std::vector<TraversalState> traversal_state(num_nodes, PENDING);
int curr_pos = num_nodes - 1;
std::vector<int> order(num_nodes);
std::vector<Edge> edges_in_cycle;
auto push_onto_stack = [this](
const int curr_index, const int fanout_index,
std::vector<RecursionStackEntry>* recursion_stack,
std::vector<TraversalState>* traversal_state,
std::vector<Edge>* edges_in_cycle) {
// Ignore NextIteration -> Merge connections to break control flow cycles.
if (IsNextIteration(graph_->node(curr_index)) &&
IsMerge(graph_->node(fanout_index))) {
return;
}
auto& fanout_traversal_state = (*traversal_state)[fanout_index];
if (fanout_traversal_state == PROCESSING) {
// Cycle detected.
edges_in_cycle->push_back({curr_index, fanout_index});
} else if (fanout_traversal_state == PENDING) {
// Unvisited node, simply add to stack for future traversal.
recursion_stack->push_back({fanout_index, ENTER});
}
};
auto process_fanouts = [this, &extra_dependencies_by_parent,
&push_onto_stack](
const int curr_index,
std::vector<RecursionStackEntry>* recursion_stack,
std::vector<TraversalState>* traversal_state,
std::vector<Edge>* edges_in_cycle) {
const auto& node_view = nodes_[curr_index];
// Regular fanouts.
for (const auto& regular_fanouts_port_i : node_view.GetRegularFanouts()) {
for (const auto& regular_fanout : regular_fanouts_port_i) {
push_onto_stack(curr_index, regular_fanout.node_index_, recursion_stack,
traversal_state, edges_in_cycle);
}
}
// Controlled fanouts.
for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
push_onto_stack(curr_index, controlled_fanout.node_index_,
recursion_stack, traversal_state, edges_in_cycle);
}
// Extra dependencies.
auto it = extra_dependencies_by_parent.find(curr_index);
if (it != extra_dependencies_by_parent.end()) {
for (const auto& extra_fanout : it->second) {
push_onto_stack(curr_index, extra_fanout, recursion_stack,
traversal_state, edges_in_cycle);
}
}
};
auto reversed_postorder_dfs =
[&process_fanouts](const MutableNodeView& root_node_view,
std::vector<int>* order,
std::vector<TraversalState>* traversal_state,
int* curr_pos, std::vector<Edge>* edges_in_cycle) {
std::vector<RecursionStackEntry> recursion_stack;
// Add the root to stack to start the traversal.
const int root_index = root_node_view.node_index_;
auto& root_traversal_state = (*traversal_state)[root_index];
if (root_traversal_state == PENDING) {
recursion_stack.push_back({root_index, ENTER});
}
while (!recursion_stack.empty()) {
auto curr_entry = recursion_stack.back();
recursion_stack.pop_back();
const int curr_index = curr_entry.node_index;
auto& curr_traversal_state = (*traversal_state)[curr_index];
if (curr_traversal_state == PROCESSED) {
// Node already processed which can be ignored.
continue;
} else if (curr_entry.recursion_state == EXIT) {
// Node from recursion stack where all fanouts were visited.
// Instead of adding node index to a vector, simply set what its
// index would be, so there will not be a need for inversion later
// on. The value set is in decending order so the reversed
// post-order is returned.
(*order)[curr_index] = *curr_pos;
curr_traversal_state = PROCESSED;
--(*curr_pos);
} else {
// Process current node and fanouts.
curr_traversal_state = PROCESSING;
recursion_stack.push_back({curr_index, EXIT});
process_fanouts(curr_index, &recursion_stack, traversal_state,
edges_in_cycle);
}
}
};
// Determine sources to start DFS (nodes with no inputs) and unique fanout
// nodes.
for (int i = num_nodes - 1; i >= 0; --i) {
auto& node = nodes_[i];
if (node.NumRegularFanins() + node.NumControllingFanins() == 0) {
reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos,
&edges_in_cycle);
}
}
if (!ignore_cycles && !edges_in_cycle.empty()) {
std::vector<string> edges_formatted;
edges_formatted.reserve(edges_in_cycle.size());
for (const auto& edge : edges_in_cycle) {
edges_formatted.push_back(
absl::StrCat("'", graph_->node(edge.from).name(), "' -> '",
graph_->node(edge.to).name(), "'"));
}
const string edges_str =
absl::StrCat("{", absl::StrJoin(edges_formatted, ", "), "}");
return errors::InvalidArgument(kMutableGraphViewSortTopologicallyError,
"detected edge(s) creating cycle(s) ",
edges_str, ".");
}
if (curr_pos != kTopologicalSortDone) {
// Not all nodes were processed.
if (!ignore_cycles) {
return errors::InvalidArgument(
kMutableGraphViewSortTopologicallyError,
"was not able to sort all nodes topologically.");
}
// Otherwise process all nodes regardless of cycles.
for (const auto& node : nodes_) {
reversed_postorder_dfs(node, &order, &traversal_state, &curr_pos,
&edges_in_cycle);
}
}
// Permute nodes by reversed post-order DFS.
std::vector<MutableNodeView> permuted_nodes(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
permuted_nodes[order[i]] = std::move(nodes_[i]);
}
nodes_.swap(permuted_nodes);
// Fix up indices of MutableNodeViews.
for (MutableNodeView& node_view : nodes_) {
const int prev_node_index = node_view.node_index_;
if (prev_node_index != order[prev_node_index]) {
const string& node_name = graph_->node(prev_node_index).name();
node_view.node_index_ = order[prev_node_index];
node_index_by_name_.find(node_name)->second = node_view.node_index_;
}
for (MutableFanoutView& regular_fanin : node_view.regular_fanins_) {
regular_fanin.node_index_ = order[regular_fanin.node_index_];
}
for (MutableFanoutView& controlling_fanin : node_view.controlling_fanins_) {
controlling_fanin.node_index_ = order[controlling_fanin.node_index_];
}
for (std::vector<MutableFaninView>& regular_fanouts_port_i :
node_view.regular_fanouts_by_port_) {
for (MutableFaninView& regular_fanout : regular_fanouts_port_i) {
regular_fanout.node_index_ = order[regular_fanout.node_index_];
}
}
for (MutableFaninView& controlled_fanout : node_view.controlled_fanouts_) {
controlled_fanout.node_index_ = order[controlled_fanout.node_index_];
}
}
// Permute graph NodeDefs.
PermuteNodesInPlace(graph_, &order, /*invert_permutation=*/false);
return Status::OK();
}
inline Status MutableGraphView::ValidateInternal(
absl::flat_hash_map<absl::string_view, int>* node_names,
std::vector<RenamedOrOverwrittenNode>* renamed_nodes,
std::vector<int>* inplace_nodes,
std::vector<int>* empty_diff_node_indices) {
// Get node names and partition updated_nodes_ by if they are renamed or not,
// skipping empty MutableNodeViewDiff.
TF_RETURN_IF_ERROR(GetNodeNamesAndPartitionUpdatedNodes(
node_names, renamed_nodes, inplace_nodes, empty_diff_node_indices));
// Check existence of fanins and validity (i.e. no self loops).
TF_RETURN_IF_ERROR(
CheckNodeNamesAndFanins(*node_names, *renamed_nodes, *inplace_nodes));
// Check if nodes after mutation have kernels registered.
TF_RETURN_IF_ERROR(CheckKernelRegisteredForNodes());
return Status::OK();
}
Status MutableGraphView::ApplyMutationInternal() {
// Node name -> node index mapping. If a node index is -1, the associated node
// with key node name exists. Otherwise the node index is the node's index in
// the graph.
absl::flat_hash_map<absl::string_view, int> node_names;
// Indices of MutableNodeViewDiff in Mutation::updated_nodes_ where nodes are
// renamed (and possibly have other fields mutated).
std::vector<RenamedOrOverwrittenNode> renamed_nodes;
// Indices of MutableNodeViewDiff in Mutation::updated_nodes_ where nodes are
// not renamed but have fields mutated.
std::vector<int> inplace_nodes;
// Indices of nodes in graph where MutableNodeViewDiff are empty.
// `update_index_` of nodes associated to empty MutableNodeViewDiff should be
// cleared after validation success.
std::vector<int> empty_diff_node_indices;
// Check if this mutation is valid before applying, and partition
// updated_nodes_ into inplace mutated nodes and renamed nodes.
TF_RETURN_IF_ERROR(ValidateInternal(
&node_names, &renamed_nodes, &inplace_nodes, &empty_diff_node_indices));
// Clear `update_index_` of MutableNodeView with empty associated
// MutableNodeViewDiff.
for (const int empty_diff_node_index : empty_diff_node_indices) {
nodes_[empty_diff_node_index].update_index_ = internal::kMissingIndex;
}
// Node name and associated fanouts.
absl::flat_hash_map<string, NodeViewFanouts> renamed_fanouts;
// Removed nodes where name was overwritten by a renamed node.
std::vector<bool> overwritten_name_removed_nodes(nodes_.size());
// Fix renaming of existing nodes by swapping fanouts and rehashing names.
// This will also overwrite removed or unmodified nodes.
FixRenamedNodes(&renamed_nodes, &renamed_fanouts,
&overwritten_name_removed_nodes);
// Indices of nodes in graph where new nodes were inserted/appended. These
// will be corresponding to `new_nodes_` in order.
std::vector<int> new_node_indices;
// Add new nodes, overwriting removed or unmodified nodes.
AddNewNodes(&renamed_fanouts, &new_node_indices);
// For abandoned fanouts, mark their respective fanins so the original node
// associated will not have their fanouts removed and be left in an
// inconsistent state.
FixRenamedFanouts(renamed_fanouts);
// Apply mutations to updated nodes (renamed nodes are treated as inplace
// nodes as they have already been renamed). Removed nodes are ignored.
ApplyNodeUpdates();
// Set fanins of new nodes.
SetNewNodesFanins(new_node_indices);
// Remove overwritten nodes and updated nodes set to be removed.
RemoveNodesInternal(renamed_nodes, overwritten_name_removed_nodes);
mutation_.ResetInternal();
mutation_.mutation_counter_++;
return Status::OK();
}
} // namespace utils
} // namespace grappler
} // namespace tensorflow