blob: 6fa21fa6204dcc9446081d07e2a59ccace216713 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
// TODO(phawkins) add a canonical copy of these operator names and refactor
// everything to use it.
const char* const kArgOp = "_Arg";
const char* const kRetValOp = "_Retval";
// Class that maintains a one-to-one original node name -> new name mapping.
// We have to normalize the names used as input and output arguments to
// match regexp "[a-z][a-z0-9_]*". Once we rename them, we risk creating
// a name collision with the other node names, so if necessary we add
// a suffix to make names unique. So if we have an input named "A" and a
// node in the function body named "a", they will be renamed to "a" and "a_0".
class NodeNameMapping {
public:
NodeNameMapping() = default;
// Normalize the input/output name and then make it unique.
string Normalize(const string& name);
// Make the node name unique.
string Uniquify(const string& name);
// Look up how a node name was previously normalized/uniquified.
// Returns empty if name was never seen.
string Renormalize(const string& name) const;
private:
string NormalizeHelper(string name) const;
string UniquifyHelper(string name);
std::unordered_set<string> used_names_;
std::unordered_map<string, string> name_mapping_;
};
string NodeNameMapping::NormalizeHelper(string name) const {
// Convert letters to lowercase and non-alphanumeric characters to '_'.
if (name.empty()) name = "unknown";
const int n = name.size();
for (int i = 0; i < n; i++) {
char c = name[i];
if (isalnum(c)) {
if (isupper(c)) {
name[i] = tolower(c);
}
} else {
name[i] = '_';
}
}
return name;
}
string NodeNameMapping::UniquifyHelper(string name) {
// If the name hasn't been used yet, use it as-is.
if (used_names_.insert(name).second) return name;
// Add a suffix to name to make it unique.
for (int i = 0;; ++i) {
const string candidate = strings::StrCat(name, "_", i);
if (used_names_.insert(candidate).second) return candidate;
}
}
string NodeNameMapping::Normalize(const string& name) {
const string normalized = UniquifyHelper(NormalizeHelper(name));
name_mapping_[name] = normalized;
return normalized;
}
string NodeNameMapping::Uniquify(const string& name) {
const string uniqued = UniquifyHelper(name);
name_mapping_[name] = uniqued;
return uniqued;
}
string NodeNameMapping::Renormalize(const string& name) const {
const auto iter = name_mapping_.find(name);
if (iter == name_mapping_.end()) return string();
return iter->second;
}
} // anonymous namespace
// Graph to FunctionDef conversion. This code is closely modeled on the Python
// code in third_party/tensorflow/python/framework/function.py.
Status GraphToFunctionDef(const Graph& graph, const string& name,
FunctionDef* fdef) {
fdef->mutable_signature()->set_name(name);
std::unordered_map<string, string> tensor_renaming;
std::unordered_map<string, string> return_values;
NodeNameMapping node_names;
for (Node const* node : graph.op_nodes()) {
if (node->type_string() == kArgOp) {
int index;
DataType type;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
while (fdef->signature().input_arg_size() <= index) {
fdef->mutable_signature()->add_input_arg();
}
OpDef::ArgDef* argdef =
fdef->mutable_signature()->mutable_input_arg(index);
argdef->set_type(type);
const string normalized = node_names.Normalize(node->name());
argdef->set_name(normalized);
tensor_renaming[strings::StrCat(node->name(), ":0")] = normalized;
continue;
}
if (node->type_string() == kRetValOp) {
int index;
DataType type;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &type));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "index", &index));
while (fdef->signature().output_arg_size() <= index) {
fdef->mutable_signature()->add_output_arg();
}
OpDef::ArgDef* argdef =
fdef->mutable_signature()->mutable_output_arg(index);
argdef->set_type(type);
const string normalized = node_names.Normalize(node->name());
argdef->set_name(normalized);
Edge const* edge;
TF_CHECK_OK(node->input_edge(0, &edge));
return_values[normalized] =
strings::StrCat(edge->src()->name(), ":", edge->src_output());
continue;
}
NodeDef* node_def = fdef->add_node_def();
*node_def = node->def();
if (!node->assigned_device_name().empty()) {
node_def->set_device(node->assigned_device_name());
}
node_def->set_name(node_names.Uniquify(node->name()));
// Reset input names based on graph rather than the NodeDef.
node_def->clear_input();
// Edges, indexed by dst_input.
std::vector<const Edge*> in_edges;
std::vector<const Edge*> control_edges;
for (Edge const* edge : node->in_edges()) {
if (edge->src()->IsSource()) continue;
if (edge->IsControlEdge()) {
control_edges.push_back(edge);
} else {
if (in_edges.size() <= edge->dst_input()) {
in_edges.resize(edge->dst_input() + 1);
}
in_edges[edge->dst_input()] = edge;
}
}
// Add regular inputs
for (std::vector<const Edge*>::size_type i = 0; i < in_edges.size(); ++i) {
const Edge* edge = in_edges[i];
if (edge == nullptr) {
return errors::InvalidArgument(
"Nonconsecutive input edges; missing "
"input edge ",
i, " for node ", node->name());
}
node_def->add_input(
strings::StrCat(edge->src()->name(), ":", edge->src_output()));
}
// Add control inputs
for (const Edge* edge : control_edges) {
node_def->add_input(strings::StrCat("^", edge->src()->name()));
}
// Populate tensor_renaming.
NameRangeMap output_ranges;
TF_RETURN_IF_ERROR(
NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges));
for (const auto& output : output_ranges) {
for (int i = output.second.first; i < output.second.second; ++i) {
const string tensor_name = strings::StrCat(
node_def->name(), ":", output.first, ":", i - output.second.first);
tensor_renaming[strings::StrCat(node->name(), ":", i)] = tensor_name;
}
}
}
// Detect missing function inputs.
for (int i = 0; i < fdef->signature().input_arg_size(); ++i) {
const string& input_name = fdef->signature().input_arg(i).name();
if (input_name.empty()) {
return errors::InvalidArgument("Missing input ", i, " to function ",
name);
}
}
// Remap input names. We do this as a second pass to allow the nodes to be in
// any order.
for (int n_index = 0; n_index < fdef->node_def_size(); ++n_index) {
NodeDef* node_def = fdef->mutable_node_def(n_index);
for (int i = 0; i < node_def->input_size(); ++i) {
if (StringPiece(node_def->input(i)).starts_with("^")) {
// Control input
const string normalized =
node_names.Renormalize(node_def->input(i).substr(1));
if (normalized.empty()) {
return errors::InvalidArgument(
"Could not remap control input ", i, ", '", node_def->input(i),
"', of node '", node_def->name(), "' in function ", name);
}
*node_def->mutable_input(i) = strings::StrCat("^", normalized);
} else {
const auto iter = tensor_renaming.find(node_def->input(i));
if (iter == tensor_renaming.end()) {
return errors::InvalidArgument(
"Could not remap input ", i, ", '", node_def->input(i),
"', of node '", node_def->name(), "' in function ", name);
}
*node_def->mutable_input(i) = iter->second;
}
}
}
// Remap return values.
for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
const string& ret_name = fdef->signature().output_arg(r).name();
if (ret_name.empty()) {
return errors::InvalidArgument("Missing output ", r, " to function ",
name);
}
const string& return_value = return_values[ret_name];
const auto iter = tensor_renaming.find(return_value);
if (iter == tensor_renaming.end()) {
return errors::InvalidArgument("Could not remap return value ", r, ", '",
ret_name, "', of '", return_value,
"' in function ", name);
}
(*fdef->mutable_ret())[ret_name] = iter->second;
}
return Status::OK();
}
} // namespace tensorflow