blob: 9b69771dab29175e7adcaca6bba9373815e473d9 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_phi_graph.h"
#include <queue>
namespace xla {
HloValue::Id PhiGraph::GetOptimizedId(const HloValue& value) {
Node* node = value_id_to_node_[value.id()];
return node->value_id;
}
// Returns true if the input to a hlo value is the same as `inputs`.
bool PhiGraph::InputsEqualTo(const HloValue& value,
absl::Span<const HloValue* const> inputs) {
auto iter = value_id_to_node_.find(value.id());
CHECK(iter != value_id_to_node_.end());
absl::flat_hash_set<HloValue::Id> existing_set;
for (Node* operand : iter->second->operands) {
existing_set.insert(operand->value_id);
}
absl::flat_hash_set<HloValue::Id> new_set;
for (const HloValue* input : inputs) {
new_set.insert(input->id());
}
return existing_set == new_set;
}
HloValue::Id PhiGraph::FindOptimizedValue(const HloValue::Id id) {
auto iter = value_id_to_node_.find(id);
CHECK(iter != value_id_to_node_.end());
return iter->second->value_id;
}
PhiGraph::Node* PhiGraph::CreateOrReuseNode(const HloValue& value) {
auto iter = value_id_to_node_.find(value.id());
if (iter == value_id_to_node_.end()) {
node_storage_.emplace_back(absl::make_unique<Node>());
Node* node = node_storage_.back().get();
node->value_id = value.id();
value_id_to_node_[value.id()] = node;
node_to_value_id_[node].push_back(value.id());
return node;
} else {
// A node is already registered with this value, check the value_id
// is the same as previously registrated.
CHECK_NE(iter->second, nullptr);
CHECK_EQ(iter->second->value_id, value.id());
return iter->second;
}
}
void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) {
// Update users.
CHECK(node->is_phi);
for (Node* user : node->users) {
absl::c_replace(user->operands, node, replace);
}
// Update operand's users
for (Node* operand : node->operands) {
absl::c_replace(operand->users, node, replace);
}
for (HloValue::Id value_id : node_to_value_id_[node]) {
CHECK(value_id_to_node_.contains(value_id));
value_id_to_node_[value_id] = replace;
}
// Update mappings to HloValue::Id.
absl::c_copy(node_to_value_id_[node],
std::back_inserter(node_to_value_id_[replace]));
node_to_value_id_[node].clear();
node->mark_as_dead = true;
}
void PhiGraph::RegisterPhi(const HloValue& value,
absl::Span<const HloValue* const> inputs) {
Node* node = CreateOrReuseNode(value);
CHECK(value.is_phi());
node->is_phi = true;
node->operands.clear();
for (auto input : inputs) {
CHECK(input != nullptr);
Node* input_node = CreateOrReuseNode(*input);
node->operands.push_back(input_node);
}
}
std::string PhiGraph::ToString() {
std::string out = "PhiGraph: \n";
for (auto& node : node_storage_) {
std::string is_phi = node->is_phi ? ", phi" : "";
std::string is_optimized = node->mark_as_dead ? ", dead" : "";
absl::StrAppend(&out, node->value_id);
absl::StrAppend(&out, is_phi);
absl::StrAppend(&out, is_optimized, ":\n");
for (Node* input : node->operands) {
absl::StrAppend(&out, " ", input->value_id);
absl::StrAppend(&out, "\n");
}
}
return out;
}
void PhiGraph::Optimize() {
// Set up users for each node.
for (auto& node : node_storage_) {
for (Node* input : node->operands) {
input->users.push_back(node.get());
}
}
// input_node->users.push_back(node);
bool changed = true;
// Run the optimization to a fixed point.
while (changed) {
changed = false;
absl::flat_hash_set<Node*> checked_for_closure;
for (auto& node : node_storage_) {
// Only optimize phi node.
if (!node->is_phi) {
continue;
}
// Skip dead nodes
if (node->mark_as_dead) {
continue;
}
Node* node_ptr = node.get();
CHECK_GE(node_ptr->operands.size(), 1);
// Remove self-referencing ids from users and operands.
auto it = absl::c_find(node_ptr->operands, node_ptr);
while (it != node_ptr->operands.end()) {
node_ptr->operands.erase(it);
it = absl::c_find(node_ptr->operands, node_ptr);
}
it = absl::c_find(node_ptr->users, node_ptr);
while (it != node_ptr->users.end()) {
node_ptr->users.erase(it);
it = absl::c_find(node_ptr->users, node_ptr);
}
// If all inputs to phi (after self referencing ids are removed) are the
// same value, replace the phi with that value.
//
// phi(A, A, ... A) => A
// phi(A, self) = phi(A) => A
CHECK_GE(node_ptr->operands.size(), 1);
bool all_inputs_are_same = absl::c_all_of(
node_ptr->operands,
[&](Node* elem) { return elem == node_ptr->operands[0]; });
if (all_inputs_are_same) {
ReplaceNodeWith(node_ptr, node_ptr->operands[0]);
changed = true;
continue;
}
// Find a closure of inter-connected phis and one non-phi node. Replace
// all phis with that non-phi node.
//
// def A = phi(B, C)
// def B = phi(C, D)
// def C = phi(A, B)
// def D = non-phi
// Replace A, B, and C with D:
// A = phi(B, C) => D
// B = phi(C, D) => D
// C = phi(A, B) => D
if (checked_for_closure.contains(node_ptr)) {
continue;
}
// Keeps track of nodes in the current closure being tested.
absl::flat_hash_set<Node*> workset;
std::queue<Node*> worklist;
Node* non_phi = nullptr;
worklist.push(node_ptr);
while (!worklist.empty()) {
Node* todo = worklist.front();
worklist.pop();
if (workset.contains(todo)) {
continue;
}
checked_for_closure.insert(todo);
workset.insert(todo);
for (Node* operand : todo->operands) {
worklist.push(operand);
}
if (!todo->is_phi) {
if (non_phi != nullptr && non_phi != todo) {
// We see distinct non-phi nodes in the closure, can't apply the
// optimization.
non_phi = nullptr;
// Break the while loop non_phi setting to nullptr, signaling that
// the optimization can't be applied.
break;
} else {
// This is the non_phi node we are seeing so far.
non_phi = todo;
}
}
}
if (non_phi != nullptr) {
// Replace all phi nodes in the closure/workset with the non_phi node.
for (Node* node : workset) {
if (!node->is_phi) {
CHECK_EQ(node, non_phi);
continue;
}
ReplaceNodeWith(node, non_phi);
changed = true;
}
}
}
}
}
} // namespace xla