blob: fb9c6d3576cdd8c114a4878d1ba2a32c8e10df83 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/function.h"
#include <map>
#include <unordered_map>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb_text.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
/* static */ constexpr const char* const FunctionLibraryDefinition::kArgOp;
/* static */ constexpr const char* const
FunctionLibraryDefinition::kDeviceArgOp;
/* static */ constexpr const char* const FunctionLibraryDefinition::kRetOp;
/* static */ constexpr const char* const
FunctionLibraryDefinition::kDeviceRetOp;
/* static */ constexpr const char* const
FunctionLibraryDefinition::kIntsOnDeviceAttr;
/* static */ constexpr const char* const FunctionLibraryDefinition::kGradientOp;
/* static */ constexpr const char* const FunctionLibraryDefinition::kFuncAttr;
// Extracts the actual type from "attr_values" based on its definition
// "arg_def".
//
// If "arg_def" is a N*T type, *is_type_list is set to false, and
// *dtypes is set to be a vector of size N and each element is T.
//
// If "arg_def" is a list(type), *is_type_list is set to true, and
// *dtypes is set to be a vector of types specified in attrs for
// arg_def.
//
// Otherwise (arg_def is a simple type T), *is_type_list is set to
// false, and *dtypes is set to a single element vector, whose only
// element is T.
Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
bool* is_type_list, DataTypeVector* dtypes) {
dtypes->clear();
if (!arg_def.type_list_attr().empty()) {
const AttrValue* v = attrs.Find(arg_def.type_list_attr());
if (v == nullptr) {
return errors::NotFound("type attr not found: ",
arg_def.type_list_attr());
}
*is_type_list = true;
for (int i = 0; i < v->list().type_size(); ++i) {
dtypes->push_back(v->list().type(i));
}
return Status::OK();
}
*is_type_list = false;
int num = 1;
if (!arg_def.number_attr().empty()) {
const AttrValue* v = attrs.Find(arg_def.number_attr());
if (v == nullptr) {
return errors::NotFound("type attr not found: ", arg_def.type_attr());
}
num = v->i();
}
DataType dtype;
if (arg_def.type() != DT_INVALID) {
dtype = arg_def.type();
} else if (arg_def.type_attr().empty()) {
dtype = DT_INVALID;
} else {
const AttrValue* v = attrs.Find(arg_def.type_attr());
if (v == nullptr) {
return errors::NotFound("type attr not found: ", arg_def.type_attr());
}
dtype = v->type();
}
dtypes->resize(num, dtype);
return Status::OK();
}
namespace {
template <typename T>
void AddAttr(const string& name, const T& val, NodeDef* ndef) {
SetAttrValue(val, &((*ndef->mutable_attr())[name]));
}
Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) {
// attr_values should specify all attrs defined in fdef.
for (const auto& a : sig.attr()) {
const AttrValue* v = attr_values.Find(a.name());
if (!v) {
return errors::NotFound("Attr ", a.name(), " is not found from ",
SummarizeOpDef(sig));
}
Status status = AttrValueHasType(*v, a.type());
if (!status.ok()) {
errors::AppendToMessage(&status, "for attr '", a.name(), "'");
return status;
}
}
// TODO(josh11b): Enable this code once it works with function gradients.
// Right now the C++ function gradient code assumes it can pass
// all the attrs of the function to the gradient, and any attrs that
// the gradient doesn't care about will be ignored.
#if 0
if (attr_values.size() != sig.attr_size()) {
for (const auto& a : attr_values) {
// TODO(josh11b): Possibly should ignore attrs that start with "_" here?
bool found = false;
for (const auto& s : sig.attr()) {
if (a.first == s.name()) {
found = true;
break;
}
}
if (!found) {
return errors::NotFound("Attr ", a.first, " is not found in ",
SummarizeOpDef(sig));
}
}
}
#endif
return Status::OK();
}
// A helper class for instantiating functions. This contains shared information
// like the resulting graph and node name index.
class FunctionInstantiationHelper {
public:
FunctionInstantiationHelper(GetFunctionSignature get_function,
InstantiationResult* result)
: get_function_(std ::move(get_function)), result_(*result) {
result_.nodes.clear();
}
// Builds index for nodes that can be used as node's input arguments.
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
bool ints_on_device) {
bool is_type_list;
DataTypeVector dtypes;
TF_RETURN_IF_ERROR(
ArgNumType(attr_values, arg_def, &is_type_list, &dtypes));
CHECK_GE(dtypes.size(), size_t{1});
int arg_index = result_.nodes.size();
TF_RETURN_IF_ERROR(
AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes}));
// Creates dtypes.size() nodes in the graph.
for (size_t i = 0; i < dtypes.size(); ++i) {
TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i),
{true, arg_index, 0, false, {dtypes[i]}}));
DCHECK_EQ(arg_index, result_.nodes.size());
string name = arg_def.name();
if (dtypes.size() > 1) {
strings::StrAppend(&name, "_", i);
}
NodeDef* gnode = AddNode(name);
if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
} else {
gnode->set_op(FunctionLibraryDefinition::kArgOp);
}
DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
AddAttr("T", dtype, gnode);
AddAttr("index", arg_index, gnode);
result_.arg_types.push_back(dtypes[i]);
++arg_index;
}
return Status::OK();
}
Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs,
const int arg_index) {
const OpDef* node_sig = nullptr;
TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig));
if (node_sig->output_arg_size() == 0) {
return AddItem(node.name(), {false, arg_index, 0, false, {}});
}
const int num_retval = node_sig->output_arg_size();
int start = 0;
bool is_type_list;
DataTypeVector dtypes;
for (int i = 0; i < num_retval; ++i) {
TF_RETURN_IF_ERROR(
ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes));
// Note that we rely on the backwards-compatibility test enforcing
// that output_arg(*).name() doesn't change here.
const string base_name =
strings::StrCat(node.name(), ":", node_sig->output_arg(i).name());
TF_RETURN_IF_ERROR(
AddItem(base_name, {false, arg_index, start, is_type_list, dtypes}));
for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) {
TF_RETURN_IF_ERROR(
AddItem(strings::StrCat(base_name, ":", j),
{false, arg_index, start + j, false, {dtypes[j]}}));
}
start += dtypes.size();
}
return Status::OK();
}
Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) {
const OpDef* fnode_sig = nullptr;
TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig));
NodeDef* gnode = AddNode(fnode.name());
gnode->set_op(fnode.op());
gnode->set_device(fnode.device());
int gnode_idx = nodes_.size() - 1;
// Input
const int num_args = fnode_sig->input_arg_size();
bool is_type_list; // ignored
DataTypeVector dtypes;
int fnode_arg_index = 0;
for (int i = 0; i < num_args; ++i) {
TF_RETURN_IF_ERROR(
ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes));
// Consume inputs (indexed by fnode_arg_index) until we have
// matched each element of dtypes (indexed by j).
for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) {
if (fnode_arg_index >= fnode.input_size()) {
// Should never happen if we computed dtypes correctly.
return errors::InvalidArgument(
"Attempt to access beyond input size: ", fnode_arg_index,
" >= ", fnode.input_size());
}
// Look up the next input.
const string& input_name = fnode.input(fnode_arg_index);
const auto* item = GetItemOrNull(input_name);
if (item == nullptr) {
return errors::InvalidArgument(
"input ", input_name,
" is not found: ", FormatNodeDefForError(fnode));
}
if (item->dtypes.size() > dtypes.size() - j) {
return errors::InvalidArgument("Input ", input_name, " too long for ",
fnode_sig->input_arg(i).name());
}
// Match up all the elements of this input (indexed by k) with
// elements of dtypes (advancing j).
for (int k = 0; k < item->dtypes.size(); ++k, ++j) {
if (item->dtypes[k] != dtypes[j]) {
return errors::InvalidArgument(
"input ", fnode_sig->input_arg(i).name(), "[", j,
"] expected type ", DataTypeString(dtypes[j]),
" != ", DataTypeString(item->dtypes[k]), ", the type of ",
input_name, "[", k, "]");
}
if (item->is_func_arg) {
AddInput(gnode_idx, item->nid + k, 0);
} else {
AddInput(gnode_idx, item->nid, item->idx + k);
}
}
}
}
// Control deps.
for (int i = fnode_arg_index; i < fnode.input_size(); ++i) {
const string& input = fnode.input(i);
if (input.empty() || input[0] != '^') {
return errors::InvalidArgument("Expected input[", i, "] == '", input,
"' to be a control input.");
}
int nid = -1;
const string node_name = input.substr(1);
const string node_colon = node_name + ":";
const string node_colon_bound = node_name + ";";
// index_ is a map sorted lexicographically, so the key we are looking for
// must lie in the range [node_name, node_colon_bound).
auto it = index_.lower_bound(node_name);
while (it != index_.end() && it->first <= node_colon_bound) {
if (it->first == node_name || absl::StartsWith(it->first, node_colon)) {
nid = it->second.nid;
break;
}
++it;
}
if (nid == -1) {
return errors::InvalidArgument("input[", i, "] == '", input,
"', is not found.");
}
AddDep(gnode_idx, nid);
}
// Attrs.
for (const auto& p : attrs) {
(*gnode->mutable_attr())[p.first] = p.second;
}
return Status::OK();
}
Status AddReturnNode(
const OpDef::ArgDef& ret_def, AttrSlice attrs,
const ::tensorflow::protobuf::Map<string, string>& ret_map,
bool ints_on_device, int* ret_index) {
auto ret_iter = ret_map.find(ret_def.name());
if (ret_iter == ret_map.end()) {
return errors::InvalidArgument("Return ", ret_def.name(), " missing.");
}
bool is_type_list;
DataTypeVector dtypes;
TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes));
CHECK_GE(dtypes.size(), size_t{1});
const auto* item = GetItemOrNull(ret_iter->second);
if (item == nullptr) {
return errors::InvalidArgument("Return ", ret_def.name(), " -> ",
ret_iter->second, " is not found.");
}
if (dtypes != item->dtypes) {
return errors::InvalidArgument("Invalid ret types ", ret_def.name(),
" : ", DataTypeVectorString(dtypes),
" vs. ",
DataTypeVectorString(item->dtypes));
}
for (size_t i = 0; i < dtypes.size(); ++i) {
string name = strings::StrCat(ret_def.name(), "_RetVal");
if (dtypes.size() > 1) {
strings::StrAppend(&name, "_", i);
}
NodeDef* gnode = AddNode(name);
if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp);
} else {
gnode->set_op(FunctionLibraryDefinition::kRetOp);
}
AddInput(nodes_.size() - 1, item->nid, item->idx + i);
DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i];
AddAttr("T", dtype, gnode);
AddAttr("index", (*ret_index)++, gnode);
result_.ret_types.push_back(dtypes[i]);
}
return Status::OK();
}
// Adds the actual node inputs to the result graph by converting indexes to
// the node names.
void AddNodeInputs() {
for (int i = 0; i < result_.nodes.size(); i++) {
NodeInfo& node_info = nodes_[i];
for (const auto& p : node_info.data_inputs) {
result_.nodes[i].add_input(Name(p.first, p.second));
}
for (int index : node_info.control_inputs) {
result_.nodes[i].add_input(Dep(index));
}
}
}
private:
// This is used to build a small index for all names that can be used as a
// node's input arguments.
//
// If is_func_arg is true, the name is a function's argument. In
// this case, the produced graph def has node[nid:nid + dtype.size()].
//
// Otherwise, the name is a function body's node return value. In
// this case, the produced graph def has one node node[nid] and
// the node's output index [idx ... idx + num) corresponds to the
// named outputs.
//
// In all cases, "dtype" specifies the data type.
struct NameInfoItem {
bool is_func_arg;
int nid;
int idx;
bool is_type_list;
DataTypeVector dtypes;
};
// Adds an item into the input name index.
Status AddItem(const string& name, const NameInfoItem& item) {
if (!index_.insert({name, item}).second) {
return errors::InvalidArgument(
strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret",
" name: "),
name);
}
return Status::OK();
}
const NameInfoItem* GetItemOrNull(const string& name) const {
return gtl::FindOrNull(index_, name);
}
string Dep(int node_index) const {
return strings::StrCat("^", Name(node_index));
}
string Name(int node_index) const {
CHECK_LT(node_index, nodes_.size());
return nodes_[node_index].name;
}
string Name(int node_index, int output_index) const {
if (output_index == 0) {
return Name(node_index);
} else {
return strings::StrCat(Name(node_index), ":", output_index);
}
}
NodeDef* AddNode(const string& name) {
result_.nodes.emplace_back();
NodeDef* gnode = &result_.nodes.back();
gnode->set_name(name);
nodes_.push_back({name, {}, {}});
CHECK_EQ(result_.nodes.size(), nodes_.size());
return gnode;
}
void AddInput(int node_index, int output_node, int output_index) {
CHECK_LT(node_index, nodes_.size());
nodes_[node_index].data_inputs.push_back(
std::make_pair(output_node, output_index));
}
void AddDep(int node_index, int dep_index) {
CHECK_LT(node_index, nodes_.size());
nodes_[node_index].control_inputs.push_back(dep_index);
}
GetFunctionSignature get_function_;
InstantiationResult& result_;
// A small index for all names that can be used as a node's input arguments.
std::map<string, NameInfoItem> index_;
// This contains information about a node in the new graph including the node
// names and input nodes' indexes.
struct NodeInfo {
string name;
// Data inputs where <n, k> means arg k of node n.
std::vector<std::pair<int, int>> data_inputs;
// Control inputs (dependencies).
std::vector<int> control_inputs;
};
// nodes_[i] is the information about result_.nodes[i].
std::vector<NodeInfo> nodes_;
};
// Various helpers Print(proto) to print relevant protos to ascii.
string Print(const OpDef::ArgDef& arg) {
string out;
strings::StrAppend(&out, arg.name(), ":");
if (arg.is_ref()) strings::StrAppend(&out, "Ref(");
if (!arg.number_attr().empty()) {
strings::StrAppend(&out, arg.number_attr(), "*");
}
if (arg.type() != DT_INVALID) {
strings::StrAppend(&out, DataTypeString(arg.type()));
} else {
strings::StrAppend(&out, arg.type_attr());
}
if (arg.is_ref()) strings::StrAppend(&out, ")");
return out;
}
// TODO(josh11b): Merge this with SummarizeAttrValue().
string Print(const AttrValue& attr_value) {
if (attr_value.value_case() == AttrValue::kType) {
return DataTypeString(attr_value.type());
} else if ((attr_value.value_case() == AttrValue::kList) &&
(attr_value.list().type_size() > 0)) {
string ret = "{";
for (int i = 0; i < attr_value.list().type_size(); ++i) {
if (i > 0) strings::StrAppend(&ret, ", ");
strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i)));
}
strings::StrAppend(&ret, "}");
return ret;
} else if (attr_value.value_case() == AttrValue::kFunc) {
if (attr_value.func().attr_size() == 0) {
return attr_value.func().name();
}
std::vector<string> entries;
for (auto p : attr_value.func().attr()) {
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
}
std::sort(entries.begin(), entries.end());
return strings::StrCat(attr_value.func().name(), "[",
absl::StrJoin(entries, ", "), "]");
}
return SummarizeAttrValue(attr_value);
}
// TODO(josh11b): Merge this with SummarizeNodeDef().
string Print(const NodeDef& n) {
string out;
strings::StrAppend(&out, n.name(), " = ", n.op());
if (n.attr_size() > 0) {
std::vector<string> entries;
for (auto& a : n.attr()) {
entries.push_back(strings::StrCat(a.first, "=", Print(a.second)));
}
std::sort(entries.begin(), entries.end());
// Add a short device string at the end of all attributes.
if (!n.device().empty()) {
DeviceNameUtils::ParsedName parsed;
if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
entries.push_back(
strings::StrCat("device=", parsed.type, ":", parsed.id));
} else {
entries.push_back("device=<FAILED_TO_PARSE>");
}
}
strings::StrAppend(&out, "[", absl::StrJoin(entries, ", "), "]");
}
strings::StrAppend(&out, "(");
std::vector<StringPiece> dat;
std::vector<string> dep;
for (StringPiece s : n.input()) {
if (absl::ConsumePrefix(&s, "^")) {
dep.emplace_back(s);
} else {
dat.push_back(s);
}
}
strings::StrAppend(&out, absl::StrJoin(dat, ", "), ")");
if (!dep.empty()) {
strings::StrAppend(&out, " @ ", absl::StrJoin(dep, ", "));
}
return out;
}
string Print(const FunctionDef& fdef) {
string out;
const OpDef& sig = fdef.signature();
strings::StrAppend(&out, "\n", sig.name());
if (sig.attr_size() > 0) {
strings::StrAppend(&out, "[");
for (int i = 0; i < sig.attr_size(); ++i) {
const auto& a = sig.attr(i);
if (i > 0) strings::StrAppend(&out, ", ");
if (a.type() == "type") {
strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values()));
} else {
strings::StrAppend(&out, a.name(), ":", a.type());
}
}
strings::StrAppend(&out, "]");
}
strings::StrAppend(&out, "(");
for (int i = 0; i < sig.input_arg_size(); ++i) {
if (i > 0) strings::StrAppend(&out, ", ");
strings::StrAppend(&out, Print(sig.input_arg(i)));
}
strings::StrAppend(&out, ") -> (");
for (int i = 0; i < sig.output_arg_size(); ++i) {
if (i > 0) strings::StrAppend(&out, ", ");
strings::StrAppend(&out, Print(sig.output_arg(i)));
}
strings::StrAppend(&out, ") {\n");
for (const auto& n : fdef.node_def()) {
strings::StrAppend(&out, " ", Print(n), "\n");
}
for (const auto& cr : fdef.control_ret()) {
strings::StrAppend(&out, " @return ", cr.first, " = ", cr.second, "\n");
}
for (const auto& r : fdef.ret()) {
strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n");
}
strings::StrAppend(&out, "}\n");
return out;
}
string Print(gtl::ArraySlice<const NodeDef*> nodes) {
std::vector<const NodeDef*> arg;
std::vector<const NodeDef*> ret;
std::vector<const NodeDef*> body;
for (const NodeDef* n : nodes) {
if (n->op() == FunctionLibraryDefinition::kArgOp ||
n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
arg.push_back(n);
} else if (n->op() == FunctionLibraryDefinition::kRetOp ||
n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
ret.push_back(n);
} else {
body.push_back(n);
}
}
auto comp = [](const NodeDef* x, const NodeDef* y) {
int xi;
TF_CHECK_OK(GetNodeAttr(*x, "index", &xi));
int yi;
TF_CHECK_OK(GetNodeAttr(*y, "index", &yi));
return xi < yi;
};
std::sort(arg.begin(), arg.end(), comp);
std::sort(ret.begin(), ret.end(), comp);
string out;
strings::StrAppend(&out, "\n(");
auto get_type_and_device = [](const NodeDef& n) {
DataType dt;
if (!TryGetNodeAttr(n, "T", &dt)) {
dt = DT_INVALID;
}
if (!n.device().empty()) {
DeviceNameUtils::ParsedName parsed;
if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) {
return strings::StrCat(DataTypeString(dt), "@", parsed.type, ":",
parsed.id);
} else {
LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in "
<< n.op() << ":" << n.name();
return strings::StrCat(DataTypeString(dt), "@",
"<FAILED_TO_PARSE_DEVICE>");
}
}
return DataTypeString(dt);
};
for (size_t i = 0; i < arg.size(); ++i) {
const NodeDef* n = arg[i];
if (i > 0) strings::StrAppend(&out, ", ");
CHECK_GE(n->attr_size(), 2);
strings::StrAppend(&out, n->name(), ":", get_type_and_device(*n));
}
strings::StrAppend(&out, ") -> (");
for (size_t i = 0; i < ret.size(); ++i) {
const NodeDef* n = ret[i];
if (i > 0) strings::StrAppend(&out, ", ");
CHECK_LE(2, n->attr_size());
// The _RetVal op should have a unique non-control input. We assert that
// here and add it to the output.
bool found_non_control_input = false;
for (const string& input : n->input()) {
if (!input.empty() && input[0] != '^') {
DCHECK_EQ(found_non_control_input, false)
<< "RetVal node has more than one non-control input: "
<< absl::StrJoin(n->input(), ", ");
strings::StrAppend(&out, n->input(0), ":", get_type_and_device(*n));
found_non_control_input = true;
}
}
DCHECK_EQ(found_non_control_input, true)
<< "RetVal did not have any non-control inputs: "
<< absl::StrJoin(n->input(), ", ");
}
strings::StrAppend(&out, ") {\n");
for (size_t i = 0; i < body.size(); ++i) {
strings::StrAppend(&out, " ", Print(*body[i]), "\n");
}
strings::StrAppend(&out, "}\n");
return out;
}
Status AddDefaultAttrs(const string& op,
const GetFunctionSignature& get_function,
AttrValueMap* attrs) {
const OpDef* op_def = nullptr;
TF_RETURN_IF_ERROR(get_function(op, &op_def));
AttrSlice attr_slice(attrs);
for (const auto& attr_def : op_def->attr()) {
if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) {
if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) {
return errors::Internal("Somehow duplicated: ", attr_def.name());
}
}
}
return Status::OK();
}
} // end namespace
// TODO(shikharagarwal): Transmit original node names correctly in file.
Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
GetFunctionSignature get_function,
InstantiationResult* result) {
VLOG(4) << "Instantiation Function: " << Print(fdef);
const OpDef& sig = fdef.signature();
TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
bool ints_on_device =
fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
FunctionInstantiationHelper helper(get_function, result);
Status s;
for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
s = helper.BuildInputArgIndex(arg_def, attr_values, ints_on_device);
if (!s.ok()) {
errors::AppendToMessage(&s, "In ", Print(arg_def));
return s;
}
}
auto substitute = [attr_values](StringPiece name, AttrValue* val) {
if (const AttrValue* v = attr_values.Find(name)) {
*val = *v;
return true;
}
return false;
};
// Makes a copy of all attrs in fdef and substitutes placeholders.
// After this step, every attr is bound to a concrete value.
std::vector<AttrValueMap> node_attrs;
node_attrs.resize(fdef.node_def_size());
for (int i = 0; i < fdef.node_def_size(); ++i) {
for (auto attr : fdef.node_def(i).attr()) {
if (!SubstitutePlaceholders(substitute, &attr.second)) {
return errors::InvalidArgument("Failed to bind all placeholders in ",
SummarizeAttrValue(attr.second));
}
if (!node_attrs[i].insert(attr).second) {
return errors::Internal("Somehow duplicated: ", attr.first);
}
}
TF_RETURN_IF_ERROR(
AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i]));
}
for (int i = 0; i < fdef.node_def_size(); ++i) {
s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]),
result->nodes.size() + i);
if (!s.ok()) {
errors::AppendToMessage(&s, "In ",
FormatNodeDefForError(fdef.node_def(i)));
return s;
}
}
// Emits one node for each fdef.node_def.
for (int i = 0; i < fdef.node_def_size(); ++i) {
s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i]));
if (!s.ok()) {
errors::AppendToMessage(&s, "In ",
FormatNodeDefForError(fdef.node_def(i)));
return s;
}
}
// Emits nodes for the function's return values.
int ret_index = 0;
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device,
&ret_index);
if (!s.ok()) {
errors::AppendToMessage(&s, "In function output ", Print(ret_def));
return s;
}
}
// Adds the actual node inputs using the input indexes.
helper.AddNodeInputs();
return Status::OK();
}
string DebugString(const FunctionDef& func_def) { return Print(func_def); }
string DebugString(const GraphDef& instantiated_func_def) {
std::vector<const NodeDef*> ptrs;
for (const NodeDef& n : instantiated_func_def.node()) {
ptrs.push_back(&n);
}
return Print(ptrs);
}
string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) {
std::vector<const NodeDef*> ptrs;
for (const NodeDef& n : instantiated_func_nodes) {
ptrs.push_back(&n);
}
return Print(ptrs);
}
string DebugStringWhole(const GraphDef& gdef) {
string ret;
for (const auto& fdef : gdef.library().function()) {
strings::StrAppend(&ret, Print(fdef));
}
strings::StrAppend(&ret, "\n");
for (const auto& ndef : gdef.node()) {
strings::StrAppend(&ret, Print(ndef), "\n");
}
return ret;
}
namespace {
// Returns the name -> attr mapping of fdef's attrs that have a value set. In
// Python, it's possible to access unset attrs, which returns a default value
// and adds an unset attr to the map.
std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) {
std::map<string, AttrValue> set_attrs;
for (auto pair : fdef.attr()) {
if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) {
set_attrs[pair.first] = pair.second;
}
}
return set_attrs;
}
} // end namespace
bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
if (!OpDefEqual(f1.signature(), f2.signature())) return false;
std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1);
std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2);
if (f1_attrs.size() != f2_attrs.size()) return false;
for (auto iter1 : f1_attrs) {
auto iter2 = f2_attrs.find(iter1.first);
if (iter2 == f2_attrs.end()) return false;
if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
}
if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
return false;
}
std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
if (ret1 != ret2) return false;
std::map<string, string> control_ret1(f1.control_ret().begin(),
f1.control_ret().end());
std::map<string, string> control_ret2(f2.control_ret().begin(),
f2.control_ret().end());
if (control_ret1 != control_ret2) return false;
return true;
}
uint64 FunctionDefHash(const FunctionDef& fdef) {
// signature
uint64 h = OpDefHash(fdef.signature());
// attrs
std::map<string, AttrValue> attrs = GetSetAttrs(fdef);
for (const auto& p : attrs) {
h = Hash64(p.first.data(), p.first.size(), h);
h = Hash64Combine(AttrValueHash(p.second), h);
}
// node defs
h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h);
// output names
std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end());
for (const auto& p : ret) {
h = Hash64(p.first.data(), p.first.size(), h);
h = Hash64(p.second.data(), p.second.size(), h);
}
// control output names
std::map<string, string> control_ret(fdef.control_ret().begin(),
fdef.control_ret().end());
for (const auto& p : control_ret) {
h = Hash64(p.first.data(), p.first.size(), h);
h = Hash64(p.second.data(), p.second.size(), h);
}
return h;
}
static constexpr const char* const kExecutorAttr = "_executor";
/* static */
string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options,
AttrSlice attrs) {
if (!options.executor_type.empty()) {
return options.executor_type;
} else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) {
return executor_attr->s();
} else {
return string();
}
}
string Canonicalize(const string& funcname, AttrSlice attrs,
const FunctionLibraryRuntime::InstantiateOptions& options) {
std::vector<string> entries;
entries.reserve(attrs.size() + static_cast<int>(options.target.empty()) +
options.input_devices.size());
for (auto p : attrs) {
if (p.first != kExecutorAttr) {
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
}
}
if (!options.target.empty()) {
entries.push_back(
strings::StrCat("_target", "=", absl::CEscape(options.target)));
}
for (int i = 0; i < options.input_devices.size(); ++i) {
entries.push_back(strings::StrCat("_input_dev", i, "=",
absl::CEscape(options.input_devices[i])));
}
for (int i = 0; i < options.output_devices.size(); ++i) {
entries.push_back(strings::StrCat(
"_output_dev", i, "=", absl::CEscape(options.output_devices[i])));
}
for (const auto& iter : options.input_tensor_shapes) {
entries.push_back(
strings::StrCat("_input_tensor_shape", iter.first, "=",
absl::CEscape(iter.second.DebugString())));
}
for (const auto& iter : options.input_resource_dtypes_and_shapes) {
entries.push_back(strings::StrCat("_input_resource_dtype", iter.first, "=",
DataTypeString(iter.second.dtype)));
entries.push_back(
strings::StrCat("_input_resource_shape", iter.first, "=",
absl::CEscape(iter.second.shape.DebugString())));
}
if (options.lib_def) {
entries.push_back(strings::StrCat(
"_lib_def", "=", reinterpret_cast<uintptr_t>(options.lib_def)));
}
if (!options.state_handle.empty()) {
entries.push_back(
strings::StrCat("_state_handle", "=", options.state_handle));
}
string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs);
if (!executor_type.empty()) {
entries.push_back(strings::StrCat(kExecutorAttr, "=", executor_type));
}
string config_proto_serialized;
options.config_proto.SerializeToString(&config_proto_serialized);
if (!config_proto_serialized.empty()) {
entries.push_back(strings::StrCat("_config_proto", "=",
absl::CEscape(config_proto_serialized)));
}
std::sort(entries.begin(), entries.end());
return strings::StrCat(funcname, "[", absl::StrJoin(entries, ","), "]");
}
FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types,
DataTypeSlice ret_types)
: arg_types_(arg_types.begin(), arg_types.end()),
ret_types_(ret_types.begin(), ret_types.end()) {
args_.resize(arg_types_.size());
rets_.resize(ret_types_.size());
}
FunctionCallFrame::~FunctionCallFrame() {}
Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) {
// Input type checks.
if (args.size() != arg_types_.size()) {
return errors::InvalidArgument("Expects ", arg_types_.size(),
" arguments, but ", args.size(),
" is provided");
}
for (size_t i = 0; i < args.size(); ++i) {
if (arg_types_[i] != args[i].dtype()) {
return errors::InvalidArgument(
"Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ",
DataTypeString(args[i].dtype()), " is provided");
}
args_[i] = args[i];
}
return Status::OK();
}
Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
rets->clear();
rets->reserve(rets_.size());
for (size_t i = 0; i < rets_.size(); ++i) {
const auto& item = rets_[i];
if (item.has_val) {
rets->push_back(item.val);
} else {
return errors::Internal("Retval[", i, "] does not have value");
}
}
return Status::OK();
}
Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets,
bool allow_dead_tensors) {
rets->clear();
rets->reserve(rets_.size());
for (size_t i = 0; i < rets_.size(); ++i) {
if (rets_[i].has_val) {
rets->emplace_back(std::move(rets_[i].val));
} else if (allow_dead_tensors) {
rets->emplace_back();
} else {
return errors::Internal("Retval[", i, "] does not have value");
}
}
return Status::OK();
}
Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
args_.size(), ")");
}
*val = args_[index];
return Status::OK();
}
Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
if (index < 0 || static_cast<size_t>(index) >= rets_.size()) {
return errors::InvalidArgument("SetRetval ", index, " is not within [0, ",
rets_.size(), ")");
}
if (val.dtype() != ret_types_[index]) {
return errors::InvalidArgument(
"Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]),
", but ", DataTypeString(val.dtype()), " is provided.");
}
Retval* item = &rets_[index];
if (!item->has_val) {
item->has_val = true;
item->val = val;
} else {
return errors::Internal("Retval[", index, "] has already been set.");
}
return Status::OK();
}
FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
: fdef(fdef_in),
// Exact shape inference for functions is handled by ShapeRefiner.
// Here we pass a dummy shape inference function for legacy code paths.
op_registration_data(fdef.signature(), shape_inference::UnknownShape,
true /* is_function */) {}
FunctionLibraryDefinition::FunctionLibraryDefinition(
const FunctionLibraryDefinition& other)
: default_registry_(other.default_registry_) {
tf_shared_lock l(other.mu_);
function_defs_ = other.function_defs_;
func_grad_ = other.func_grad_;
}
FunctionLibraryDefinition::FunctionLibraryDefinition(
const OpRegistryInterface* default_registry,
const FunctionDefLibrary& def_lib)
: default_registry_(default_registry),
function_defs_(def_lib.function_size()) {
for (const auto& fdef : def_lib.function()) {
// The latter function definition wins.
auto& ptr = function_defs_[fdef.signature().name()];
ptr.reset(new FunctionDefAndOpRegistration(fdef));
}
for (const auto& grad : def_lib.gradient()) {
func_grad_[grad.function_name()] = grad.gradient_func();
}
}
FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
bool FunctionLibraryDefinition::Contains(const string& func) const {
tf_shared_lock l(mu_);
return function_defs_.find(func) != function_defs_.end();
}
const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
tf_shared_lock l(mu_);
auto result = FindHelper(func);
if (result) {
return &result->fdef;
} else {
return nullptr;
}
}
std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration>
FunctionLibraryDefinition::FindHelper(const string& func) const {
auto iter = function_defs_.find(func);
if (iter == function_defs_.end()) {
return nullptr;
} else {
return iter->second;
}
}
Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
mutex_lock l(mu_);
bool added;
return AddFunctionDefHelper(fdef, &added);
}
Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
bool* added) {
*added = false;
std::shared_ptr<FunctionDefAndOpRegistration>& entry =
function_defs_[fdef.signature().name()];
if (entry) {
if (!FunctionDefsEqual(entry->fdef, fdef)) {
return errors::InvalidArgument(
"Cannot add function '", fdef.signature().name(),
"' because a different function with the same name already "
"exists.");
}
// Ignore duplicate FunctionDefs.
return Status::OK();
}
const OpDef* op_def;
if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
return errors::InvalidArgument(
"Cannot add function '", fdef.signature().name(),
"' because an op with the same name already exists.");
}
entry = std::make_shared<FunctionDefAndOpRegistration>(fdef);
*added = true;
return Status::OK();
}
Status FunctionLibraryDefinition::AddHelper(
std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) {
*added = false;
std::shared_ptr<FunctionDefAndOpRegistration>& entry =
function_defs_[registration->fdef.signature().name()];
if (entry) {
if (!FunctionDefsEqual(entry->fdef, registration->fdef)) {
return errors::InvalidArgument(
"Cannot add function '", registration->fdef.signature().name(),
"' because a different function with the same name already "
"exists.");
}
// Ignore duplicate FunctionDefs.
return Status::OK();
}
const OpDef* op_def;
if (default_registry_
->LookUpOpDef(registration->fdef.signature().name(), &op_def)
.ok()) {
return errors::InvalidArgument(
"Cannot add function '", registration->fdef.signature().name(),
"' because an op with the same name already exists.");
}
entry = std::move(registration);
*added = true;
return Status::OK();
}
Status FunctionLibraryDefinition::CopyFunctionDefFrom(
const string& func, const FunctionLibraryDefinition& other) {
if (default_registry_ != other.default_registry_) {
return errors::InvalidArgument(
"Cannot copy function '", func,
"' because CopyFunctionDefFrom() requires that both libraries have the "
"same default registry.");
}
std::shared_ptr<FunctionDefAndOpRegistration> function_def;
{
tf_shared_lock l(other.mu_);
function_def = other.FindHelper(func);
}
if (!function_def) {
return errors::InvalidArgument(
"Cannot copy function '", func,
"' because no function with that name exists in the other library.");
}
{
mutex_lock l(mu_);
std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func];
if (entry) {
if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) {
return errors::InvalidArgument(
"Cannot copy function '", func,
"' because a different function with the same name already "
"exists.");
}
} else {
entry = std::move(function_def);
}
}
return Status::OK();
}
Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
mutex_lock l(mu_);
bool added;
return AddGradientDefHelper(grad, &added);
}
Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
bool* added) {
*added = false;
string* entry = &func_grad_[grad.function_name()];
if (!entry->empty()) {
if (*entry != grad.gradient_func()) {
return errors::InvalidArgument(
"Cannot assign gradient function '", grad.gradient_func(), "' to '",
grad.function_name(), "' because it already has gradient function ",
"'", *entry, "'");
}
// Ignore duplicate GradientDefs
return Status::OK();
}
*entry = grad.gradient_func();
*added = true;
return Status::OK();
}
Status FunctionLibraryDefinition::AddLibrary(
const FunctionLibraryDefinition& other) {
// Clone `other` to ensure thread-safety (grabbing `other`'s lock for
// the duration of the function could lead to deadlock).
FunctionLibraryDefinition clone(other);
mutex_lock l(mu_);
mutex_lock l2(clone.mu_);
// Remember the funcs and grads that we added successfully so that
// we can roll them back on error.
std::vector<string> funcs;
std::vector<string> funcs_with_grads;
Status s;
bool added;
for (auto iter : clone.function_defs_) {
s = AddHelper(iter.second, &added);
if (!s.ok()) {
Remove(funcs, funcs_with_grads);
return s;
}
if (added) {
funcs.push_back(iter.second->fdef.signature().name());
}
}
for (auto iter : clone.func_grad_) {
GradientDef grad;
grad.set_function_name(iter.first);
grad.set_gradient_func(iter.second);
s = AddGradientDefHelper(grad, &added);
if (!s.ok()) {
Remove(funcs, funcs_with_grads);
return s;
}
if (added) {
funcs_with_grads.push_back(grad.function_name());
}
}
return Status::OK();
}
Status FunctionLibraryDefinition::AddLibrary(
const FunctionDefLibrary& lib_def) {
// Remember the funcs and grads that we added successfully so that
// we can roll them back on error.
mutex_lock l(mu_);
std::vector<string> funcs;
std::vector<string> funcs_with_grads;
Status s;
bool added;
for (const FunctionDef& fdef : lib_def.function()) {
s = AddFunctionDefHelper(fdef, &added);
if (!s.ok()) {
Remove(funcs, funcs_with_grads);
return s;
}
if (added) {
funcs.push_back(fdef.signature().name());
}
}
for (const GradientDef& grad : lib_def.gradient()) {
s = AddGradientDefHelper(grad, &added);
if (!s.ok()) {
Remove(funcs, funcs_with_grads);
return s;
}
if (added) {
funcs_with_grads.push_back(grad.function_name());
}
}
return Status::OK();
}
Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
const FunctionDef& fdef) {
mutex_lock l(mu_);
bool added;
TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, &added));
return Status::OK();
}
Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
mutex_lock l(mu_);
bool added;
TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
return Status::OK();
}
Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RemoveFunctionHelper(func));
return Status::OK();
}
Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) {
const auto& i = function_defs_.find(func);
if (i == function_defs_.end()) {
return errors::InvalidArgument("Tried to remove non-existent function '",
func, "'.");
}
function_defs_.erase(i);
return Status::OK();
}
Status FunctionLibraryDefinition::RemoveGradient(const string& func) {
const auto& i = func_grad_.find(func);
if (i == func_grad_.end()) {
return errors::InvalidArgument("Tried to remove non-existent gradient '",
func, "'.");
}
func_grad_.erase(i);
return Status::OK();
}
void FunctionLibraryDefinition::Remove(
const std::vector<string>& funcs,
const std::vector<string>& funcs_with_grads) {
for (const string& f : funcs) {
Status s = RemoveFunctionHelper(f);
DCHECK(s.ok());
}
for (const string& f : funcs_with_grads) {
Status s = RemoveGradient(f);
DCHECK(s.ok());
}
}
string FunctionLibraryDefinition::FindGradient(const string& func) const {
tf_shared_lock l(mu_);
return gtl::FindWithDefault(func_grad_, func, "");
}
string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
return gtl::FindWithDefault(func_grad_, func, "");
}
Status FunctionLibraryDefinition::LookUp(
const string& op, const OpRegistrationData** op_reg_data) const {
tf_shared_lock l(mu_);
auto iter = function_defs_.find(op);
if (iter != function_defs_.end()) {
*op_reg_data = &iter->second->op_registration_data;
return Status::OK();
}
return default_registry_->LookUp(op, op_reg_data);
}
string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
tf_shared_lock l(mu_);
int index = 0;
string name = strings::StrCat(prefix, index);
while (function_defs_.find(name) != function_defs_.end()) {
++index;
name = strings::StrCat(prefix, index);
}
return name;
}
const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
const NodeDef& ndef) const {
if (ndef.op() != kGradientOp) {
// If 'ndef' calls a function and the function's def has the attr,
// returns it.
return Find(ndef.op());
}
// If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
// Foo's attributes.
const NameAttrList* forward_func_attrs;
if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) {
return nullptr;
}
const string& func_name = forward_func_attrs->name();
{
tf_shared_lock l(mu_);
const string& grad_name = FindGradientHelper(func_name);
// If 'func' has a user-defined gradient function, uses the grad
// function's attrs to see if noinline is specified. Otherwise,
// uses func's attrs.
if (!grad_name.empty()) {
return &(FindHelper(grad_name)->fdef);
}
return &(FindHelper(func_name)->fdef);
}
}
std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
std::vector<string> function_names;
tf_shared_lock l(mu_);
function_names.reserve(function_defs_.size());
for (const auto& it : function_defs_) {
function_names.emplace_back(it.first);
}
return function_names;
}
FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
FunctionDefLibrary lib;
tf_shared_lock l(mu_);
for (const auto& f : function_defs_) {
*lib.add_function() = f.second->fdef;
}
for (const auto& g : func_grad_) {
GradientDef* gd = lib.add_gradient();
gd->set_function_name(g.first);
gd->set_gradient_func(g.second);
}
return lib;
}
template <typename T>
Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
const string& attr, T* value) const {
const FunctionDef* fdef = GetAttrImpl(ndef);
if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) {
return Status::OK();
}
return errors::InvalidArgument("Attr ", attr, " is not defined.");
}
template <typename T>
Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr,
T* value) const {
return GetAttr(node.def(), attr, value);
}
#define GET_ATTR(T) \
template Status FunctionLibraryDefinition::GetAttr(const Node&, \
const string&, T*) const; \
template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \
const string&, T*) const;
GET_ATTR(string)
GET_ATTR(bool)
#undef GET_ATTR
namespace {
constexpr char kApiImplements[] = "api_implements";
std::set<string> ReachableFunctions(
const FunctionLibraryDefinition& flib,
const protobuf::RepeatedPtrField<NodeDef>& nodes) {
// Functions that are reachable from the graph.
std::set<string> reachable_funcs;
// For any functions, if it has attribute "api_implements" =
// "some_interface" and it is reachable, then it means any other
// function with same attribute name and value could also be potentially
// reachable, eg via implementation_selector swapping the nodedef.
absl::flat_hash_set<string> reachable_api_interface;
// Functions might be reachable from the nested function calls, so we keep a
// queue of functions that we have to check.
gtl::InlinedVector<const FunctionDef*, 4> func_queue;
// Add reachable and not already processed functions to the functions queue.
const auto add_to_func_queue = [&](const string& func_name) {
const FunctionDef* func = flib.Find(func_name);
if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) {
func_queue.push_back(func);
}
};
// If any function with certain API name is reachable, all the other functions
// with same API name should also be checked.
const auto add_function_with_api_interface = [&](const string& api_name) {
if (!reachable_api_interface.contains(api_name)) {
reachable_api_interface.insert(api_name);
for (const auto& func_name : flib.ListFunctionNames()) {
const auto& func_def = flib.Find(func_name);
const auto attr_it = func_def->attr().find(kApiImplements);
if (attr_it != func_def->attr().end() &&
attr_it->second.s() == api_name) {
add_to_func_queue(func_name);
}
}
}
};
// Add all the functions that are reachable from the given node to the queue.
const auto process_node = [&](const NodeDef& node) {
// Node itself can be a call to the function.
add_to_func_queue(node.op());
// Or node can have an attribute referencing a function.
for (const auto& attr : node.attr()) {
const auto& attr_value = attr.second;
// 1. AttrValue.func
if (attr_value.has_func()) {
add_to_func_queue(attr_value.func().name());
}
// 2. AttrValue.ListValue.func
if (attr_value.has_list()) {
for (const auto& func : attr_value.list().func()) {
add_to_func_queue(func.name());
}
}
}
};
// Add all functions that are directly called from the optimized graph.
std::for_each(nodes.begin(), nodes.end(), process_node);
// Process all reachable functions.
while (!func_queue.empty()) {
const FunctionDef* func = func_queue.back();
func_queue.pop_back();
const string& func_name = func->signature().name();
reachable_funcs.insert(func_name);
const auto attr_it = func->attr().find(kApiImplements);
if (attr_it != func->attr().end()) {
add_function_with_api_interface(attr_it->second.s());
}
// Find all the functions called from the function body.
const auto& func_body = func->node_def();
std::for_each(func_body.begin(), func_body.end(), process_node);
// Check if the function has a registered gradient.
const string grad_func_name = flib.FindGradient(func_name);
if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
}
return reachable_funcs;
}
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
const FunctionLibraryDefinition& flib,
const protobuf::RepeatedPtrField<NodeDef>& nodes) {
std::set<string> reachable_funcs = ReachableFunctions(flib, nodes);
FunctionLibraryDefinition reachable_flib(flib.default_registry(),
FunctionDefLibrary());
for (const string& func_name : reachable_funcs) {
// This should never fail, because we copy functions from a valid flib and
// use the same default registry.
Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib);
TF_DCHECK_OK(added);
const string grad_func_name = flib.FindGradient(func_name);
if (!grad_func_name.empty()) {
GradientDef grad;
grad.set_function_name(func_name);
grad.set_gradient_func(grad_func_name);
// It can only fail if function already has a gradient function.
const Status added_grad = reachable_flib.AddGradientDef(grad);
TF_DCHECK_OK(added_grad);
}
}
return reachable_flib;
}
string AllocatorAttributesToString(
const std::vector<AllocatorAttributes>& attrs) {
string result("[");
// AllocatorAttribute::DebugString produces around 85 bytes now.
result.reserve(100 * attrs.size());
for (const AllocatorAttributes& attr : attrs) {
result.append(attr.DebugString());
result.append(", ");
}
if (!attrs.empty()) {
result.resize(result.size() - 2);
}
result.append("]");
return result;
}
const char* IsSet(void* ptr) { return ptr == nullptr ? "unset" : "set"; }
} // namespace
FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
const GraphDef& graph) const {
return ReachableFunctionLibraryDefinition(*this, graph.node());
}
FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions(
const FunctionDef& func) const {
return ReachableFunctionLibraryDefinition(*this, func.node_def());
}
string FunctionLibraryRuntime::Options::DebugString() const {
return absl::StrCat(
"FLR::Options(step_id=", step_id, " rendezvous=", IsSet(rendezvous),
" cancellation_manager=", IsSet(cancellation_manager),
" collective_executor=", IsSet(collective_executor),
" step_container=", IsSet(step_container),
" stats_collector=", IsSet(stats_collector), " runner=", IsSet(runner),
" remote_execution=", remote_execution, " source_device=", source_device,
" create_rendezvous=", create_rendezvous,
" allow_dead_tensors=", allow_dead_tensors,
" args_alloc_attrs=", AllocatorAttributesToString(args_alloc_attrs),
" rets_alloc_attrs=", AllocatorAttributesToString(rets_alloc_attrs), ")");
}
void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) {
if (val.size() >= 2 && val[0] == '$') {
proto.set_placeholder(val.data() + 1, val.size() - 1);
} else {
SetAttrValue(val, &proto);
}
}
FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef(
const string& name,
gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) {
AttrValueWrapper ret;
ret.proto.mutable_func()->set_name(name);
for (const auto& a : attrs) {
ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto});
}
return ret;
}
NodeDef FunctionDefHelper::Node::ToNodeDef() const {
NodeDef n;
n.set_op(this->op);
n.set_name(this->ret[0]);
for (const auto& a : this->attr) {
n.mutable_attr()->insert({a.first, a.second.proto});
}
for (const string& a : this->arg) {
n.add_input(a);
}
for (const string& d : this->dep) {
n.add_input(strings::StrCat("^", d));
}
if (!this->device.empty()) {
n.set_device(this->device);
}
return n;
}
/* static */
FunctionDef FunctionDefHelper::Create(
const string& function_name, gtl::ArraySlice<string> in_def,
gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
gtl::ArraySlice<Node> node_def,
gtl::ArraySlice<std::pair<string, string>> ret_def,
gtl::ArraySlice<std::pair<string, string>> control_ret_def) {
FunctionDef fdef;
// Signature
OpDefBuilder b(function_name);
for (const auto& i : in_def) b.Input(i);
for (const auto& o : out_def) b.Output(o);
for (const auto& a : attr_def) b.Attr(a);
for (const auto& c : control_ret_def) b.ControlOutput(c.first);
OpRegistrationData op_reg_data;
TF_CHECK_OK(b.Finalize(&op_reg_data));
fdef.mutable_signature()->Swap(&op_reg_data.op_def);
// Function body
for (const auto& n : node_def) {
*(fdef.add_node_def()) = n.ToNodeDef();
}
// Returns
for (const auto& r : ret_def) {
fdef.mutable_ret()->insert({r.first, r.second});
}
// Control returns
for (const auto& cr : control_ret_def) {
fdef.mutable_control_ret()->insert({cr.first, cr.second});
}
auto* op_def_registry = OpRegistry::Global();
// Check if any op is stateful.
for (const auto& n : node_def) {
const OpDef* op_def = nullptr;
auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
// Lookup can fail if e.g. we are calling a function that was not yet
// defined. If it happens, conservatively assume the op is stateful.
if (!status.ok() || op_def->is_stateful()) {
fdef.mutable_signature()->set_is_stateful(true);
}
}
return fdef;
}
/* static */
FunctionDef FunctionDefHelper::Create(
const string& function_name, gtl::ArraySlice<string> in_def,
gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def,
gtl::ArraySlice<Node> node_def,
gtl::ArraySlice<std::pair<string, string>> ret_def) {
return Create(function_name, in_def, out_def, attr_def, node_def, ret_def,
/*control_ret_def=*/{});
}
/* static */
FunctionDef FunctionDefHelper::Define(const string& name,
gtl::ArraySlice<string> arg_def,
gtl::ArraySlice<string> ret_def,
gtl::ArraySlice<string> attr_def,
gtl::ArraySlice<Node> node_def) {
FunctionDef fdef;
OpDefBuilder b(name);
for (const auto& a : arg_def) b.Input(a);
for (const auto& r : ret_def) b.Output(r);
for (const auto& a : attr_def) b.Attr(a);
OpRegistrationData op_reg_data;
TF_CHECK_OK(b.Finalize(&op_reg_data));
fdef.mutable_signature()->Swap(&op_reg_data.op_def);
// Mapping from legacy output names to NodeDef outputs.
std::unordered_map<string, string> ret_index;
for (const auto& a : fdef.signature().input_arg()) {
ret_index[a.name()] = a.name();
}
// For looking up OpDefs
auto* op_def_registry = OpRegistry::Global();
// Function body
for (const auto& src : node_def) {
NodeDef* n = fdef.add_node_def();
n->set_op(src.op);
n->set_name(src.ret[0]);
for (const auto& a : src.attr) {
n->mutable_attr()->insert({a.first, a.second.proto});
}
for (const string& a : src.arg) {
const auto iter = ret_index.find(a);
CHECK(iter != ret_index.end())
<< "Node input '" << a << "' in '" << src.ret[0] << "' of " << name;
n->add_input(iter->second);
}
for (const string& d : src.dep) {
n->add_input(strings::StrCat("^", d));
}
// Add the outputs of this node to ret_index.
const OpDef* op_def = nullptr;
TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op();
CHECK(op_def != nullptr) << n->op();
NameRangeMap output_names;
TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names));
for (const auto& o : output_names) {
CHECK_LE(o.second.second, src.ret.size())
<< "Missing ret for output '" << o.first << "' in '" << src.ret[0]
<< "' of " << name;
for (int i = o.second.first; i < o.second.second; ++i) {
ret_index[src.ret[i]] =
strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
}
}
if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
}
// Returns
for (const auto& r : fdef.signature().output_arg()) {
const auto iter = ret_index.find(r.name());
CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name;
fdef.mutable_ret()->insert({r.name(), iter->second});
}
return fdef;
}
FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def,
gtl::ArraySlice<string> ret_def,
gtl::ArraySlice<string> attr_def,
gtl::ArraySlice<Node> node_def) {
return Define("_", arg_def, ret_def, attr_def, node_def);
}
namespace gradient {
typedef std::unordered_map<string, Creator> OpGradFactory;
OpGradFactory* GetOpGradFactory() {
static OpGradFactory* factory = new OpGradFactory;
return factory;
}
bool RegisterOp(const string& op, Creator func) {
CHECK(GetOpGradFactory()->insert({op, func}).second)
<< "Duplicated gradient for " << op;
return true;
}
Status GetOpGradientCreator(const string& op, Creator* creator) {
auto fac = GetOpGradFactory();
auto iter = fac->find(op);
if (iter == fac->end()) {
return errors::NotFound("No gradient defined for op: ", op);
}
*creator = iter->second;
return Status::OK();
}
} // end namespace gradient
} // namespace tensorflow