| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <algorithm> |
| #include <unordered_map> |
| #include <unordered_set> |
| |
| #include "absl/strings/match.h" |
| #include "tensorflow/c/c_api_internal.h" |
| #include "tensorflow/core/framework/attr_value_util.h" |
| #include "tensorflow/core/framework/function.pb.h" |
| #include "tensorflow/core/framework/node_def.pb.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/tensor.pb.h" // NOLINT |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/lib/strings/base64.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| |
| using tensorflow::errors::InvalidArgument; |
| |
| namespace tensorflow { |
| namespace { |
| |
| // Class that maintains a one-to-one original node name -> new node name |
| // mapping. We normalize the names used as input and output arguments to match |
| // regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name. |
| // Once we rename them, we risk creating a name collision with the other |
| // node names, so if necessary we add a suffix to make |
| // names unique. If we have an input named "A" and a node in the function |
| // body named "a", they will be renamed to "a" and "a_0". |
| class NodeNameMapping { |
| public: |
| NodeNameMapping() = default; |
| |
| // Normalize the input name and make it unique. This is the same as the |
| // function for output, expect that it adds a name mapping for the name. |
| string GetInputName(const string& name); |
| |
| // Normalize the output name and make it unique. |
| string GetOutputName(const string& name); |
| |
| // Make the node name unique. |
| string Uniquify(const string& name); |
| |
| // Records name as a used name. If this name is already used, |
| // returns an error status. |
| Status UseOutputName(const string& name); |
| |
| // Look up how a node name was previously normalized/uniquified. |
| // Returns empty if name was never seen. |
| string Lookup(const string& name) const; |
| |
| private: |
| string UniquifyHelper(const string& name) const; |
| static string Normalize(string name); |
| |
| // The normalized/uniquified names already used as |
| // input names (in signature), output names (in signature), and node names |
| // (in node_def). |
| // This is a superset of values in name_mapping_. |
| std::unordered_set<string> used_names_; |
| // Mapping from original node name from the graph to the normalized |
| // and uniquified version of it. |
| std::unordered_map<string, string> name_mapping_; |
| }; |
| |
| string NodeNameMapping::Normalize(string name) { |
| // Convert letters to lowercase and non-alphanumeric characters to '_'. |
| if (name.empty()) return "unknown"; |
| const int n = name.size(); |
| for (int i = 0; i < n; ++i) { |
| char c = name[i]; |
| if (isalnum(c)) { |
| if (isupper(c)) { |
| name[i] = tolower(c); |
| } |
| } else { |
| name[i] = '_'; |
| } |
| } |
| |
| // Find the first letter and start with it. |
| int i = 0; |
| for (; i < n; ++i) { |
| if (isalpha(name[i])) break; |
| } |
| |
| // Return "unknown" if none of the name's chars were letters. |
| return i == n ? "unknown" : name.substr(i); |
| } |
| |
| string NodeNameMapping::UniquifyHelper(const string& name) const { |
| // If the name hasn't been used yet, use it as-is. |
| if (used_names_.find(name) == used_names_.end()) return name; |
| // Add a suffix to name to make it unique. |
| for (int i = 0;; ++i) { |
| const string candidate = strings::StrCat(name, "_", i); |
| if (used_names_.find(candidate) == used_names_.end()) return candidate; |
| } |
| } |
| |
| string NodeNameMapping::GetInputName(const string& name) { |
| const string& input_name = GetOutputName(name); |
| name_mapping_[name] = input_name; |
| return input_name; |
| } |
| |
| string NodeNameMapping::GetOutputName(const string& name) { |
| const string& input_name = UniquifyHelper(Normalize(name)); |
| // Record that we used this name, but don't add it to name_mapping_ |
| // since this name is not for a node. |
| used_names_.insert(input_name); |
| return input_name; |
| } |
| |
| string NodeNameMapping::Uniquify(const string& name) { |
| const string uniqued = UniquifyHelper(name); |
| name_mapping_[name] = uniqued; |
| used_names_.insert(uniqued); |
| return uniqued; |
| } |
| |
| Status NodeNameMapping::UseOutputName(const string& name) { |
| const auto& iter = used_names_.find(name); |
| if (iter != used_names_.end()) { |
| return InvalidArgument("Cannot have duplicate output names. Name '", name, |
| "' appears more than once in 'output_names' array."); |
| } |
| used_names_.insert(iter, name); |
| return Status::OK(); |
| } |
| |
| string NodeNameMapping::Lookup(const string& name) const { |
| const auto iter = name_mapping_.find(name); |
| if (iter == name_mapping_.end()) return string(); |
| return iter->second; |
| } |
| |
| Status ValidateNonRefOutput(const Node* node, int idx) { |
| const DataType& dt = node->output_type(idx); |
| return IsRefType(dt) |
| ? InvalidArgument("Output ", idx, " of node '", node->name(), |
| "' has a reference type ", DataTypeString(dt)) |
| : Status::OK(); |
| } |
| |
| Status FillFunctionBody( |
| const string& fn_name, const NodeNameMapping& node_names, |
| const std::vector<const Node*>& body_nodes, |
| const std::unordered_map<string, string>& tensor_renaming, |
| FunctionDef* fdef) { |
| std::unordered_set<string> func_attr_names; |
| for (const auto& func_attr : fdef->signature().attr()) { |
| func_attr_names.insert(func_attr.name()); |
| } |
| |
| std::vector<const Edge*> in_edges; |
| std::vector<const Edge*> control_edges; |
| for (const Node* node : body_nodes) { |
| NodeDef* node_def = fdef->add_node_def(); |
| // First, copy the node_def as is. We will patch it next. |
| *node_def = node->def(); |
| if (!node->assigned_device_name().empty()) { |
| node_def->set_device(node->assigned_device_name()); |
| } |
| node_def->set_name(node_names.Lookup(node->name())); |
| |
| // Input names must be set based on nested names in tensor_renaming. |
| // Clear the flat input names we got from the original node_def |
| // from the graph. |
| node_def->clear_input(); |
| |
| // Collect regular and control inputs. Regular inputs are indexed |
| // by the index at which they come into the `node`. Control inputs |
| // don't follow any order, and we sort control inputs to make sure generated |
| // NodeDef is deterministic. |
| in_edges.clear(); |
| in_edges.resize(node->num_inputs(), nullptr); |
| control_edges.clear(); |
| for (const Edge* edge : node->in_edges()) { |
| if (edge->src()->IsSource()) continue; |
| if (edge->IsControlEdge()) { |
| control_edges.push_back(edge); |
| } else { |
| in_edges[edge->dst_input()] = edge; |
| } |
| } |
| std::sort(control_edges.begin(), control_edges.end(), |
| [](const Edge* a, const Edge* b) { |
| return a->src()->name() < b->src()->name(); |
| }); |
| |
| // Add regular inputs. |
| for (size_t i = 0; i < in_edges.size(); ++i) { |
| const Edge* edge = in_edges[i]; |
| string original_input_name; |
| if (edge == nullptr) { |
| // A backedge might not appear as a regular Edge, but be only present |
| // in the node_def. Such edges are referred to as requested_inputs(). |
| if (i >= node->requested_inputs().size()) { |
| return InvalidArgument( |
| "Graph to be converted to function appears to be malformed. ", |
| "Node ", node->name(), " is missing input edge ", i); |
| } |
| original_input_name = |
| ParseTensorName(node->requested_inputs()[i]).ToString(); |
| } else { |
| original_input_name = |
| strings::StrCat(edge->src()->name(), ":", edge->src_output()); |
| } |
| |
| const auto iter = tensor_renaming.find(original_input_name); |
| if (iter == tensor_renaming.end()) { |
| return InvalidArgument( |
| "Input ", i, ", '", original_input_name, "', of node '", |
| node->name(), "' in function '", fn_name, |
| "' is not available. You might need to include it in inputs " |
| "or include its source node in the body"); |
| } |
| node_def->add_input(iter->second); |
| } |
| |
| // Add control inputs. |
| for (const Edge* edge : control_edges) { |
| // Add this control input only if the src node is in the body or a part of |
| // the inputs. |
| const string normalized = node_names.Lookup(edge->src()->name()); |
| // If we did not find a name for the source of control edge, this |
| // source must be outside of the body, and not an input. Raise an error. |
| if (normalized.empty()) { |
| return InvalidArgument( |
| "The source of control edge ", edge->DebugString(), |
| " is not in the body. Encountered while creating function '", |
| fn_name, "'"); |
| } |
| node_def->add_input(strings::StrCat("^", normalized)); |
| } |
| |
| // A function is stateful if any of its nodes are stateful. |
| if (node->op_def().is_stateful()) { |
| fdef->mutable_signature()->set_is_stateful(true); |
| } |
| |
| // If this node has any attributes with placeholder value, add the |
| // attribute to FunctionDef signature. |
| for (const auto& iter : node->attrs()) { |
| if (iter.second.placeholder().empty()) { |
| continue; |
| } |
| |
| // If we already added the attribute, skip it. |
| string func_attr_name = iter.second.placeholder(); |
| if (func_attr_names.find(func_attr_name) != func_attr_names.end()) { |
| continue; |
| } |
| |
| // This node's attribute is a placeholder value, so it does not have type |
| // information. We check node's OpDef for attribute type. |
| string node_attr_name = iter.first; |
| const OpDef::AttrDef* node_attr_def = nullptr; |
| for (const auto& node_attr : node->op_def().attr()) { |
| if (node_attr.name() == node_attr_name) { |
| node_attr_def = &node_attr; |
| } |
| } |
| if (!node_attr_def) { |
| #ifdef TENSORFLOW_LITE_PROTOS |
| return errors::Unimplemented( |
| "Placeholder value is not supported for attributes not in OpDef. " |
| "Attribute: ", |
| node_attr_name); |
| #else |
| return errors::Unimplemented( |
| "Placeholder value is not supported for attributes not in OpDef. " |
| "Attribute: ", |
| node_attr_name, ", OpDef: ", node->op_def().DebugString()); |
| #endif |
| } |
| OpDef::AttrDef* attr_def = fdef->mutable_signature()->add_attr(); |
| attr_def->set_name(func_attr_name); |
| attr_def->set_type(node_attr_def->type()); |
| |
| func_attr_names.insert(func_attr_name); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| // Graph to FunctionDef conversion. This code is closely modeled on the Python |
| // function graph_to_function_def(), which is located in |
| // tensorflow/python/framework/graph_to_function_def.py. |
| Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, |
| bool append_hash_to_fn_name, |
| const std::vector<const Node*>& body_nodes, |
| const std::vector<OutputTensor>& inputs, |
| const std::vector<OutputTensor>& outputs, |
| const std::vector<string>& output_names, |
| const std::vector<const Node*>& control_outputs, |
| const std::vector<string>& control_output_names, |
| const char* description, FunctionDef* fdef) { |
| if (!output_names.empty()) { |
| DCHECK_EQ(output_names.size(), outputs.size()); |
| } |
| |
| if (description != nullptr) { |
| fdef->mutable_signature()->set_description(description); |
| } |
| |
| // Keep track of names we used and how we normalized them. |
| NodeNameMapping node_names; |
| |
| // Mapping from original names of tensors (i.e. "<node_name>:<idx>") to the |
| // name we used in the function: |
| // - For input tensors: |
| // {flat_tensor_name -> normalized_name_of_src_node} |
| // e.g. {In:3 -> in} |
| // - For tensors produced by nodes in function's body: |
| // {flat_tensor_name -> nested_tensor_name} |
| // e.g. {Add:3 -> add_0:z:1} |
| std::unordered_map<string, string> tensor_renaming; |
| |
| // Fill outputs in function's signature. |
| // We fill the outputs first to prevent output_names from colliding |
| // with the input names we pick below. With this order, no names are used in |
| // node_names yet, and output_names won't collide with anything (except |
| // potentially with themselves). |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| const Node* node = outputs[i].node; |
| int idx = outputs[i].index; |
| OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg(); |
| argdef->set_type(node->output_type(idx)); |
| if (!output_names.empty()) { |
| TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i])); |
| argdef->set_name(output_names[i]); |
| } else { |
| argdef->set_name(node_names.GetOutputName(node->name())); |
| } |
| } |
| |
| // Fill inputs in function's signature. |
| for (size_t i = 0; i < inputs.size(); ++i) { |
| const Node* node = inputs[i].node; |
| int idx = inputs[i].index; |
| OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg(); |
| argdef->set_type(node->output_type(idx)); |
| const string& input_name = node_names.GetInputName(node->name()); |
| argdef->set_name(input_name); |
| auto& arg_attrs = (*fdef->mutable_arg_attr())[i]; |
| for (const auto& attr : node->attrs()) { |
| // Only copy internal attributes. These attributes will be applied to |
| // _Arg/Placeholder nodes when this FunctionDef is converted to graph, and |
| // normal attributes for nodes cannot be applied to those _Arg/Placeholder |
| // nodes. |
| if (absl::StartsWith(attr.first, "_")) { |
| arg_attrs.mutable_attr()->insert(attr); |
| } |
| } |
| tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name; |
| } |
| |
| // Populate tensor_renaming and node_names. |
| // Generate the new output names for every node in the function. |
| // The NodeDefs in FunctionDefs use a different naming scheme for |
| // their inputs than the NodeDefs in a graph (see the comment for |
| // FunctionDef.node_def in function.proto). We do the |
| // graph tensor name -> function tensor name conversion for every |
| // possible input (i.e. every node's outputs) and store the result |
| // in tensor_renaming. |
| for (const Node* node : body_nodes) { |
| // Make sure node_name does not collide with an input or output name. |
| const string& node_name = node_names.Uniquify(node->name()); |
| // For each output_arg in the op_def, the output_ranges |
| // map will have [start, end] range of indices that this arg produces |
| // among all the output tensors of this op. |
| NameRangeMap output_ranges; |
| TF_RETURN_IF_ERROR( |
| NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges)); |
| for (const auto& output : output_ranges) { |
| const StringPiece& output_name = output.first; |
| int index_start = output.second.first; |
| int index_end = output.second.second; |
| for (int i = index_start; i < index_end; ++i) { |
| const string& original_name = strings::StrCat(node->name(), ":", i); |
| const string& new_name = |
| strings::StrCat(node_name, ":", output_name, ":", i - index_start); |
| // Record the mapping if this tensor is not already mapped. |
| // Tensor can be already mapped if it is used as an input. |
| if (tensor_renaming.find(original_name) == tensor_renaming.end()) { |
| tensor_renaming[original_name] = new_name; |
| } |
| } |
| } |
| } |
| |
| TF_RETURN_IF_ERROR( |
| FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef)); |
| |
| // Remap return values. |
| for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { |
| const string& ret_name = fdef->signature().output_arg(r).name(); |
| // We convert this flat tensor name to the nested value |
| // (e.g. `add:z:1`) that we stored in tensor_renaming. |
| const string& return_value = |
| strings::StrCat(outputs[r].node->name(), ":", outputs[r].index); |
| const auto iter = tensor_renaming.find(return_value); |
| if (iter == tensor_renaming.end()) { |
| return InvalidArgument( |
| "TF_Output ", return_value, " is neither in the function body ", |
| "nor among function inputs. Encountered while creating function '", |
| fn_name, "'"); |
| } |
| (*fdef->mutable_ret())[ret_name] = iter->second; |
| } |
| |
| if (append_hash_to_fn_name) { |
| const uint64 hash = FunctionDefHash(*fdef); |
| string encoded; |
| TF_RETURN_IF_ERROR(Base64Encode( |
| StringPiece(reinterpret_cast<const char*>(&hash), sizeof(hash)), |
| &encoded)); |
| // Besides letters and digits our Base64 encoding uses '_' and '-'. |
| // Dash is invalid in operation names and multiple underscores in random |
| // places look strange. Since we never need to decode the hash back, |
| // replace these chars with with 'a' and 'A'. Replacing with different |
| // letters keeps more entropy. |
| std::replace(encoded.begin(), encoded.end(), '-', 'a'); |
| std::replace(encoded.begin(), encoded.end(), '_', 'A'); |
| fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_", encoded)); |
| } else { |
| fdef->mutable_signature()->set_name(fn_name); |
| } |
| |
| if (!control_output_names.empty() && |
| (control_outputs.size() != control_output_names.size())) { |
| return InvalidArgument( |
| "Expected number of control outputs (", control_outputs.size(), |
| ") and the number of control output names (", |
| control_output_names.size(), ") to match but they do not."); |
| } |
| std::set<string> control_output_names_set; |
| for (int i = 0; i < control_outputs.size(); ++i) { |
| string signature_name; |
| if (!control_output_names.empty()) { |
| signature_name = control_output_names[i]; |
| } else { |
| signature_name = control_outputs[i]->name(); |
| } |
| if (signature_name.empty()) { |
| return errors::InvalidArgument("Control output name must be not empty"); |
| } |
| if (!control_output_names_set.insert(signature_name).second) { |
| return errors::InvalidArgument("Repeated control output name: ", |
| signature_name); |
| } |
| const string control_output_node = |
| node_names.Lookup(control_outputs[i]->name()); |
| if (control_output_node.empty()) { |
| return errors::InvalidArgument( |
| "Control output node name must be not empty"); |
| } |
| (*fdef->mutable_control_ret())[signature_name] = control_output_node; |
| } |
| for (const string& control_output : control_output_names_set) { |
| fdef->mutable_signature()->add_control_output(control_output); |
| } |
| |
| return Status::OK(); |
| } |
| |
| // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and |
| // does various checks while doing so. `input_nodes` will contain the same |
| // information as input_tensors just in a different structure to make |
| // following processing easier. TODO(iga): Simplify this nested structure. |
| Status ProcessInputs( |
| const TF_Graph* fn_body, const char* fn_name, int ninputs, |
| const TF_Output* inputs, std::vector<OutputTensor>* input_tensors, |
| std::unordered_map<const Node*, std::vector<int>>* input_nodes) |
| EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { |
| input_tensors->reserve(ninputs); |
| for (int i = 0; i < ninputs; ++i) { |
| Node* node = &inputs[i].oper->node; |
| int idx = inputs[i].index; |
| |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| fn_body->graph.IsValidOutputTensor(node, idx), |
| "Encountered while processing input ", i, " into function '", fn_name, |
| "'"); |
| TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), |
| "Encountered while processing input ", i, |
| " into function '", fn_name, "'"); |
| |
| input_tensors->emplace_back(node, idx); |
| |
| const auto& iter = input_nodes->find(node); |
| if (iter == input_nodes->end()) { |
| input_nodes->insert({node, {idx}}); |
| } else { |
| auto& indices = iter->second; |
| if (std::find(indices.begin(), indices.end(), idx) != indices.end()) { |
| return InvalidArgument("TF_Output ", node->name(), ":", idx, |
| " appears more than once in the input list"); |
| } |
| indices.push_back(idx); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| // Converts `noutputs` and `outputs` into `outputs_tensors` and does various |
| // checks while doing so. |
| Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, |
| int noutputs, const TF_Output* outputs, |
| std::vector<OutputTensor>* output_tensors) |
| EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { |
| output_tensors->reserve(noutputs); |
| for (int i = 0; i < noutputs; ++i) { |
| Node* node = &outputs[i].oper->node; |
| int idx = outputs[i].index; |
| TF_RETURN_WITH_CONTEXT_IF_ERROR( |
| fn_body->graph.IsValidOutputTensor(node, idx), |
| "Encountered while processing output ", i, " from function '", fn_name, |
| "'"); |
| TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), |
| "Encountered while creating function '", |
| fn_name, "'"); |
| output_tensors->emplace_back(node, idx); |
| } |
| return Status::OK(); |
| } |
| |
| // Populates `body_nodes` with the nodes that will become function's body. |
| // Performs various checks. |
| Status ComputeBodyNodes( |
| const TF_Graph* fn_body, const char* fn_name, int num_opers, |
| const TF_Operation* const* opers, |
| const std::unordered_map<const Node*, std::vector<int>>& input_nodes, |
| std::vector<const Node*>* body_nodes) |
| EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { |
| if (num_opers == -1) { |
| for (const Node* node : fn_body->graph.op_nodes()) { |
| const auto& iter = input_nodes.find(node); |
| if (iter == input_nodes.end()) { |
| // This node is not referenced in inputs. Add it to the body. |
| body_nodes->push_back(node); |
| } else { |
| // This node is referenced in inputs. Currently, we place an |
| // artificial restriction and require that when num_opers=-1, such |
| // nodes must have a single output. |
| if (node->num_outputs() != 1) { |
| return InvalidArgument( |
| "When `num_opers` is set to -1, nodes referenced in `inputs` " |
| "must have a single output. Node ", |
| node->name(), " has ", node->num_outputs(), |
| " outputs. Encountered while creating function '", fn_name, "'"); |
| } |
| } |
| } |
| } else { |
| body_nodes->reserve(num_opers); |
| for (int i = 0; i < num_opers; ++i) { |
| const Node* node = &opers[i]->node; |
| body_nodes->push_back(node); |
| } |
| } |
| return Status::OK(); |
| } |
| |
| } // namespace |
| } // namespace tensorflow |
| |
| using tensorflow::Node; |
| using tensorflow::string; |
| |
| TF_Function* TF_GraphToFunctionWithControlOutputs( |
| const TF_Graph* fn_body, const char* fn_name, |
| unsigned char append_hash_to_fn_name, int num_opers, |
| const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, |
| int noutputs, const TF_Output* outputs, const char* const* output_names, |
| int ncontrol_outputs, const TF_Operation* const* control_outputs, |
| const char* const* control_output_names, const TF_FunctionOptions* opts, |
| const char* description, TF_Status* status) { |
| tensorflow::mutex_lock l(*const_cast<tensorflow::mutex*>(&fn_body->mu)); |
| |
| // Process inputs. |
| std::vector<tensorflow::OutputTensor> input_tensors; |
| std::unordered_map<const Node*, std::vector<int>> input_nodes; |
| status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs, |
| &input_tensors, &input_nodes); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| |
| // Process outputs. |
| std::vector<tensorflow::OutputTensor> output_tensors; |
| status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs, |
| outputs, &output_tensors); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| |
| // Process output names. |
| std::vector<string> output_names_vec; |
| if (output_names) { |
| output_names_vec.reserve(noutputs); |
| for (int i = 0; i < noutputs; ++i) { |
| output_names_vec.push_back(string(output_names[i])); |
| } |
| } |
| |
| // Process control output names. |
| std::vector<string> control_output_names_vec; |
| if (control_output_names) { |
| control_output_names_vec.reserve(ncontrol_outputs); |
| for (int i = 0; i < ncontrol_outputs; ++i) { |
| control_output_names_vec.push_back(string(output_names[i])); |
| } |
| } |
| |
| // Compute body nodes. |
| std::vector<const Node*> body_nodes; |
| status->status = tensorflow::ComputeBodyNodes( |
| fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes); |
| if (TF_GetCode(status) != TF_OK) return nullptr; |
| |
| // Compute body nodes. |
| std::vector<const Node*> control_output_nodes; |
| for (int i = 0; i < ncontrol_outputs; ++i) { |
| control_output_nodes.push_back(&control_outputs[i]->node); |
| } |
| |
| // Do the actual function creation. |
| TF_Function* tf_function = new TF_Function(); |
| DCHECK(append_hash_to_fn_name <= 1); |
| status->status = tensorflow::GraphToFunctionDef( |
| fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes, |
| input_tensors, output_tensors, output_names_vec, control_output_nodes, |
| control_output_names_vec, description, &tf_function->fdef); |
| if (TF_GetCode(status) != TF_OK) { |
| TF_DeleteFunction(tf_function); |
| return nullptr; |
| } |
| return tf_function; |
| } |
| |
| TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, |
| unsigned char append_hash_to_fn_name, |
| int num_opers, const TF_Operation* const* opers, |
| int ninputs, const TF_Output* inputs, |
| int noutputs, const TF_Output* outputs, |
| const char* const* output_names, |
| const TF_FunctionOptions* opts, |
| const char* description, TF_Status* status) { |
| return TF_GraphToFunctionWithControlOutputs( |
| fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs, |
| inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts, |
| description, status); |
| } |
| |
| const char* TF_FunctionName(TF_Function* func) { |
| return func->fdef.signature().name().c_str(); |
| } |
| |
| void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, |
| const TF_Function* grad, TF_Status* status) { |
| if (func == nullptr) { |
| status->status = InvalidArgument( |
| "'func' argument to TF_GraphCopyFunction cannot be null"); |
| return; |
| } |
| |
| // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph |
| // to avoid the extra copy here. |
| tensorflow::FunctionDefLibrary fdef_lib; |
| *fdef_lib.add_function() = func->fdef; |
| if (grad) { |
| *fdef_lib.add_function() = grad->fdef; |
| tensorflow::GradientDef* gdef = fdef_lib.add_gradient(); |
| gdef->set_function_name(func->fdef.signature().name()); |
| gdef->set_gradient_func(grad->fdef.signature().name()); |
| } |
| |
| tensorflow::mutex_lock l(g->mu); |
| status->status = g->graph.AddFunctionLibrary(fdef_lib); |
| } |
| |
| int TF_GraphNumFunctions(TF_Graph* g) { |
| tensorflow::mutex_lock l(g->mu); |
| return g->graph.flib_def().num_functions(); |
| } |
| |
| int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func, |
| TF_Status* status) { |
| tensorflow::FunctionDefLibrary lib; |
| { |
| tensorflow::mutex_lock l(g->mu); |
| lib = g->graph.flib_def().ToProto(); |
| } |
| const auto len = std::min(max_func, static_cast<int>(lib.function_size())); |
| for (int i = 0; i < len; ++i) { |
| TF_Function* func = new TF_Function(); |
| func->fdef = lib.function(i); |
| funcs[i] = func; |
| } |
| status->status = tensorflow::Status::OK(); |
| return len; |
| } |
| |
| void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def, |
| TF_Status* status) { |
| status->status = MessageToBuffer(func->fdef, output_func_def); |
| } |
| |
| TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len, |
| TF_Status* status) { |
| TF_Function* func = new TF_Function(); |
| if (!func->fdef.ParseFromArray(proto, proto_len)) { |
| status->status = InvalidArgument( |
| "Invalid FunctionDef given to TF_FunctionImportFunctionDef"); |
| TF_DeleteFunction(func); |
| return nullptr; |
| } |
| status->status = tensorflow::Status::OK(); |
| return func; |
| } |
| |
| void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name, |
| const void* proto, size_t proto_len, |
| TF_Status* status) { |
| tensorflow::AttrValue attr_value; |
| if (!attr_value.ParseFromArray(proto, proto_len)) { |
| status->status = InvalidArgument( |
| "Unparseable AttrValue proto passed to " |
| "TF_FunctionSetAttrValueProto"); |
| return; |
| } |
| (*func->fdef.mutable_attr())[string(attr_name)] = attr_value; |
| status->status = tensorflow::Status::OK(); |
| } |
| |
| void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name, |
| TF_Buffer* output_attr_value, |
| TF_Status* status) { |
| const auto& it = func->fdef.attr().find(attr_name); |
| if (it == func->fdef.attr().end()) { |
| status->status = |
| InvalidArgument("Function '", func->fdef.signature().name(), |
| "' has no attr named '", attr_name, "'."); |
| return; |
| } |
| status->status = MessageToBuffer(it->second, output_attr_value); |
| } |
| |
| void TF_DeleteFunction(TF_Function* func) { delete func; } |