blob: 361d48fd88f4230be2dcb3fe370f8b138d446f8d [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
#include <deque>
#include <string>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/lite/delegates/gpu/common/model.h"
namespace tflite {
namespace gpu {
bool ModelTransformer::Apply(const std::string& name,
SequenceTransformation* transformation) {
// Seed transformations with starting node. Each node may start a chain of
// transformations.
for (auto input : graph_->inputs()) {
for (auto node : graph_->FindConsumers(input->id)) {
AddNodeToProcess(node);
}
}
while (!to_process_.empty()) {
auto node = graph_->GetNode(to_process_.front());
if (node) {
if (!ApplyStartingWithNode(name, transformation, node)) {
return false;
}
}
to_process_.pop_front();
}
processed_.clear();
return true;
}
bool ModelTransformer::Apply(const std::string& name,
NodeTransformation* transformation) {
// Apply a transformation only to nodes that are present in the graph before
// transformation.
std::vector<NodeId> nodes;
for (auto node : graph_->nodes()) {
nodes.push_back(node->id);
}
for (auto node_id : nodes) {
auto node = graph_->GetNode(node_id);
if (!node) {
continue;
}
auto result = transformation->ApplyToNode(node, graph_);
last_transformation_message_ = result.message;
if (result.status == TransformStatus::INVALID) {
return false;
}
}
return true;
}
const std::string& ModelTransformer::last_transformation_message() const {
return last_transformation_message_;
}
bool ModelTransformer::ApplyStartingWithNode(
const std::string& name, SequenceTransformation* transformation,
Node* begin) {
int expected_sequence_length = transformation->ExpectedSequenceLength();
std::deque<NodeId> sequence;
std::vector<Node*> nodes;
nodes.reserve(transformation->ExpectedSequenceLength());
sequence.push_back(begin->id);
// Go over nodes with sequence sliding window of size
// expected_sequence_length until a node with multiple dependents is found.
while (true) {
// Apply transformation if possible.
if (sequence.size() == expected_sequence_length) {
nodes.clear();
for (NodeId id : sequence) {
// Nodes present in sequence should be present in a graph. If they are
// not, then this transformation changes a graph but didn't say it.
Node* node = graph_->GetNode(id);
if (node == nullptr) {
return false;
}
nodes.push_back(node);
}
NodeId first_in_sequence = sequence.front();
auto preceding_node =
graph_->FindProducer(graph_->FindInputs(first_in_sequence)[0]->id);
auto result = transformation->ApplyToNodesSequence(nodes, graph_);
last_transformation_message_ = result.message;
if (result.status == TransformStatus::INVALID) {
// graph is broken now.
return false;
}
if (result.status == TransformStatus::APPLIED) {
// Also remove first node of a sequence from a set of processed node.
// Out of all nodes in a sequence only first one may have been added
// to "processed" set because other nodes do not have more than one
// dependent. However, if a sequence is changed, then processing needs
// to be restarted again.
processed_.erase(first_in_sequence);
// Transformation was successful. Restart sequence from the node that
// precedes current sequence.
if (preceding_node) {
processed_.erase(preceding_node->id);
AddNodeToProcess(preceding_node);
} else {
// This is the first node in the graph. Re-seed transformation.
for (auto input : graph_->inputs()) {
for (auto node : graph_->FindConsumers(input->id)) {
AddNodeToProcess(node);
}
}
}
return true;
}
}
// Try to extend current sequence.
Node* next_node_in_sequence = nullptr;
bool has_multiple_children = false;
// Check that all outputs from last node are consumed by a single node.
for (auto output_value : graph_->FindOutputs(sequence.back())) {
for (auto dependent : graph_->FindConsumers(output_value->id)) {
if (has_multiple_children) {
AddNodeToProcess(dependent);
} else if (next_node_in_sequence == nullptr) {
next_node_in_sequence = dependent;
} else if (next_node_in_sequence != dependent) {
// There are more than two nodes depend on the output from end node,
// therefore here a sequence stops and new will start. Push all such
// nodes.
has_multiple_children = true;
AddNodeToProcess(dependent);
AddNodeToProcess(next_node_in_sequence);
}
}
}
// Now check that next node has inputs only produced by the last node.
if (!has_multiple_children && next_node_in_sequence) {
for (auto input : graph_->FindInputs(next_node_in_sequence->id)) {
auto producer = graph_->FindProducer(input->id);
if (producer == nullptr || producer->id != sequence.back()) {
has_multiple_children = true;
AddNodeToProcess(next_node_in_sequence);
break;
}
}
}
if (has_multiple_children || next_node_in_sequence == nullptr) {
// reached end of this transformation sequence.
return true;
}
sequence.push_back(next_node_in_sequence->id);
// Decrease sequence until it matches expected length.
if (sequence.size() > expected_sequence_length) {
sequence.pop_front();
}
}
return true;
}
} // namespace gpu
} // namespace tflite