blob: 8f58607701b540542252b7e2d81f8663e37c3c64 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
#include <string>
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
class Node;
struct NodeDebugInfo;
// We forward declare protos so that kernels don't need to depend on them
class NodeDef;
class OpDef;
class AttrSlice;
// Name of the attribute used to encode node colocation constraints.
//
// Nodes can be co-located on the same device. Desire for explicit co-location
// is described by list(string) attribute containing the name of colocation
// groups.
extern const char* const kColocationAttrName;
// String prefix applied to the operation name for colocation constraints.
extern const char* const kColocationGroupPrefix;
// Produce a human-readable version of a Node or NodeDef that is more concise
// than a text-format proto.
string SummarizeNode(const Node& node);
string SummarizeNodeDef(const NodeDef& node_def);
string SummarizeAttrs(const NodeDef& node_def);
string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
// Produces a formatted string pattern from the node which can uniquely identify
// this node upstream to produce an informative error message. The pattern
// followed is: {{node <node_name>}}
string FormatNodeForError(const Node& node);
string FormatNodeDefForError(const NodeDef& node_def);
string FormatNodeDefForError(
StringPiece node_name, bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
// Merges the original node names from the debug information of 'from' to the
// debug information of 'to'.
void MergeDebugInfo(const NodeDebugInfo& from, Node* to);
void MergeDebugInfo(const NodeDebugInfo& from, NodeDef* to);
void MergeDebugInfo(const NodeDef& from, NodeDef* to);
typedef protobuf::Map<string, AttrValue> AttrValueMap;
// Adds an attr with name <name> and value <value> to *node_def.
// The type of the attr is based on the type of value.
void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, int32 value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, int64 value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, float value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, double value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, const PartialTensorShape& value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def);
void AddNodeAttr(StringPiece name, const NameAttrList& value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<StringPiece> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<const char*> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<string> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<int32> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<int64> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<float> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<bool> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, const std::vector<bool>& value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<DataType> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShape> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<PartialTensorShape> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShapeProto> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<Tensor> value,
NodeDef* node_def);
void AddNodeAttr(StringPiece name, gtl::ArraySlice<NameAttrList> value,
NodeDef* node_def);
// Version to workaround C++'s "perfect" forwarding not being able to
// forward {...} initialization.
template <class T>
void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
NodeDef* node_def) {
AddNodeAttr(name, gtl::ArraySlice<T>(value), node_def);
}
// Adds an attr to an attr value map.
void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map);
void AddAttr(StringPiece name, bool value, AttrValueMap* map);
class AttrSlice {
public:
AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit)
AttrSlice(); // Empty
explicit AttrSlice(const AttrValueMap* a);
int size() const { return attrs_->size(); }
// Returns the attr with attr_name if found. Otherwise, returns
// nullptr.
const AttrValue* Find(StringPiece attr_name) const;
// Returns the attr_value for attr_name if found. Otherwise, returns a
// NotFound status.
Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
// Helper class to avoid allocations in EqualAttrs.
// TODO(irving): Will go away once NodeInfo is used.
struct Scratch {
string a;
string b;
};
// Check if all attrs and attr values match. Does not take defaults into
// account.
//
// TODO(irving): There is a bug in this routine inherited from its
// OptimizerCSE::EqualAttrs precedecessor. The same tensor attr can be
// represented in more than one way as an AttrValue, since TensorProto is
// not 1-1. This bug will go away once I replace everything with NodeInfo,
// which stores a Tensor object directly. The Scratch object will also go
// away.
bool EqualAttrs(AttrSlice other, Scratch* scratch) const;
// If this AttrSlice has an attached NodeDef, summarize it. This is for
// error messages only: we intentionally do not provide direct access to the
// NodeDef, since it is not always there.
string SummarizeNode() const;
// Iteration over all attrs
AttrValueMap::const_iterator begin() const { return attrs_->begin(); }
AttrValueMap::const_iterator end() const { return attrs_->end(); }
string DebugString() const;
private:
const NodeDef* ndef_;
const AttrValueMap* attrs_;
};
// Return true if the attr with the name attr_name is defined in node_def.
bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
// Look up the attr with name attr_name and set *value to its value. If no
// attr with attr_name is found in node_def, or the attr does not have
// a matching type, a non-ok status will be returned.
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
string* value); // type: "string"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
int64* value); // type: "int"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
int32* value); // type: "int"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
float* value); // type: "float"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
bool* value); // type: "bool"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
DataType* value); // type: "type"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
TensorShapeProto* value); // type: "shape"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
TensorShape* value); // type: "shape"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
PartialTensorShape* value); // type: "shape"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
Tensor* value); // type: "tensor"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<string>* value); // type "list(string)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int64>* value); // type "list(int)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int32>* value); // type "list(int)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<float>* value); // type "list(float)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<bool>* value); // type "list(bool)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<DataType>* value); // type "list(type)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
DataTypeVector* value); // type "list(type)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<TensorShapeProto>* value); // type "list(shape)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<TensorShape>* value); // type "list(shape)"
Status GetNodeAttr(
const AttrSlice& attrs, StringPiece attr_name,
std::vector<PartialTensorShape>* value); // type "list(shape)"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<Tensor>* value); // type: "list(tensor)"
// This version avoids copying the TensorProto.
// REQUIRES: Must not use *value beyond the lifetime of node_def.
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const TensorProto** value); // type: "tensor"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const TensorProto** value); // type: "tensor"
// This version avoids copying the NameAttrList.
// REQUIRES: Must not use *value beyond the lifetime of node_def.
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const NameAttrList** value); // type: "func"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
const NameAttrList** value); // type: "func"
// These versions copies the NameAttrList(s).
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
NameAttrList* value); // type: "func"
Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<NameAttrList>* value); // type: "list(func)"
// Look up the attr with name attr_name and set *value to its value. If no
// attr with attr_name is found in node_def, or the attr does not have
// a matching type, false is returned.
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
string* value); // type: "string"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
int64* value); // type: "int"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int64>* value); // type: "int"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
int32* value); // type: "int"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
float* value); // type: "float"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
bool* value); // type: "bool"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
DataType* value); // type: "type"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
TensorShape* value); // type: "shape"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<string>* value); // type: "list(string)"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<int32>* value); // type: "list(int)"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<float>* value); // type: "list(float)"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<bool>* value); // type: "list(bool)"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<DataType>* value); // type: "list(type)"
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<TensorShape> value); // type: "shape"
// Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute
// values.
bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
std::vector<const string*>* value); // type: "list(string)"
bool TryGetNodeAttr(
const AttrSlice& attrs, StringPiece attr_name,
std::vector<const TensorShapeProto*>* value); // type: "list(shape)"
// Look up the attr with name attr_name and return a reference to its value.
// If no attr with attr_name is found in node_def, or the attr does not have
// a matching type, a reference to an empty string is returned.
// REQUIRES: Must not use the returned value beyond the lifetime of node_def.
const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name);
// Computes the input type for a specific node input.
// REQUIRES: ValidateOpDef(op_def).ok()
Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int input_port, DataType* input_type);
// Computes the input types for a specific node.
// REQUIRES: ValidateOpDef(op_def).ok()
Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs);
// Computes the output type for a specific node output.
// REQUIRES: ValidateOpDef(op_def).ok()
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int output_port, DataType* output_type);
// Computes the output types for a specific node.
// REQUIRES: ValidateOpDef(op_def).ok()
Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* outputs);
Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
DataTypeVector* outputs);
// Computes the input and output types for a specific node.
// REQUIRES: ValidateOpDef(op_def).ok()
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs);
// Computes the number of outputs for a specific node.
// REQUIRES: ValidateOpDef(op_def).ok()
Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
int* num_outputs);
// Validates that the NodeDef:
// * Defines all expected attrs from the OpDef.
// * All attrs satisfies constraints from the OpDef.
// * Has a signature matching SignatureForNode().
// etc.
Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
// Computes the mapping from input/output argument name to the
// corresponding input/output index range. For example,
// input "foo" corresponds to input indices
// [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
// NOTE(mrry): To reduce allocations when the map is used and save
// space, the returned `NameRangeMap` objects borrow the input/output
// argument names from `op_def`. The `op_def` must outlive the
// returned `NameRangeMap` objects.
typedef gtl::FlatMap<StringPiece, std::pair<int, int>, hash<StringPiece>>
NameRangeMap;
Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
NameRangeMap* inputs, NameRangeMap* outputs);
Status NameRangesForNode(const Node& node, const OpDef& op_def,
NameRangeMap* inputs, NameRangeMap* outputs);
// Adds default values to *node_def for unspecified attrs from op_def.
void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
// Validates the syntax of a NodeDef provided externally.
//
// The following is an EBNF-style syntax for NodeDef objects. Note that
// Node objects are actually specified as tensorflow::NodeDef protocol buffers,
// which contain many other fields that are not (currently) validated.
//
// Node = NodeName, Inputs
// Inputs = ( DataInput * ), ( ControlInput * )
// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ?
// ControlInput = "^", NodeName
// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] *
Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
// Returns "status" with formatted NodeDef attached as additional text
// in the error message. If 'allow_multiple_formatted_node' is false and there
// is already a formatted NodeDef present in 'status', we simply attach the name
// of the NodeDef instead of the formatted string.
Status AttachDef(const Status& status, const NodeDef& node_def,
bool allow_multiple_formatted_node = false);
Status AttachDef(const Status& status, const Node& node,
bool allow_multiple_formatted_node = false);
// Appends the given prefix and suffix to the original node name in order to
// make the name unique. If it's an "Enter" node and uniquify_frame_name is
// true, use the same way to reset attribute "frame_name".
Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
NodeDef* node_def,
bool uniquify_frame_name = true);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_