blob: e72e613c9e3e44262e75a9fff74879c436507cd1 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "absl/types/optional.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace grappler {
namespace {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeAndType;
using shape_inference::ShapeHandle;
using TensorVector = gtl::InlinedVector<TensorValue, 4>;
template <typename Handle>
struct HashHandle {
std::size_t operator()(const Handle& h) const { return h.Handle(); }
};
template <typename Handle>
struct CompareHandle {
bool operator()(const Handle& h1, const Handle& h2) const {
return h1.SameHandle(h2);
}
};
template <typename Handle>
struct HandleToObject {};
template <>
struct HandleToObject<ShapeHandle> {
typedef ShapeHandle Object;
static ShapeHandle Unknown() { return ShapeHandle(); }
};
template <>
struct HandleToObject<DimensionHandle> {
typedef int64 Object;
static int64 Unknown() { return -1; }
};
template <typename Handle>
struct Processor {};
template <>
struct Processor<ShapeHandle> {
// Extract the shape or dim denoted by the handle.
void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
// Merge the shapes or dims.
Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
if (InferenceContext::RankKnown(*result)) {
// The result was initialized in a previous merge to a shape of known
// rank, make sure we preserve that information.
return Status::OK();
}
if (InferenceContext::RankKnown(h1)) {
*result = h1;
} else {
*result = h2;
}
return Status::OK();
}
};
template <>
struct Processor<DimensionHandle> {
// Assign a negative id to unknown dimensions, starting at -2 (the -1 id
// reserved by TensorFlow).
void ExtractValue(DimensionHandle d, int64* result) {
if (!InferenceContext::ValueKnown(d)) {
*result = -counter;
counter++;
} else {
int64 val = InferenceContext::Value(d);
if (val >= 0) {
*result = val;
} else {
// A shape inference function generated an invalid dimension handle.
// Use a symbolic dimension to encode this.
*result = -counter;
counter++;
}
}
}
// Merge the dimensions d1 and d2. Return the known shape if there is one,
// otherwise look for a symbolic shape. If there is no symbolic shape and no
// known shape, the shape if fully unknown so return -1.
Status Merge(DimensionHandle d1, DimensionHandle d2, int64* result) {
const int64 dim1 = InferenceContext::Value(d1);
const int64 dim2 = InferenceContext::Value(d2);
if (dim1 >= 0 && dim2 >= 0) {
CHECK_EQ(dim1, dim2);
return RefineDim(dim1, result);
} else if (dim1 >= 0 && dim2 < 0) {
return RefineDim(dim1, result);
} else if (dim1 < 0 && dim2 >= 0) {
return RefineDim(dim2, result);
} else if (dim1 < -1) {
return RefineDim(dim1, result);
} else if (dim2 < -1) {
return RefineDim(dim2, result);
} else {
CHECK_EQ(dim1, dim2);
CHECK_EQ(-1, dim1);
return RefineDim(-1, result);
}
return Status::OK();
}
private:
Status RefineDim(int64 dim, int64* result) {
if (*result >= 0) {
if (!(*result == dim || dim < 0)) {
return errors::InvalidArgument("Inconsistent dimensions detected");
}
} else if (dim >= 0) {
*result = dim;
} else if (dim < *result) {
*result = dim;
}
return Status::OK();
}
int64 counter = 2;
};
// Traditional Disjoint-Set datastructure with path compression.
// (https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
template <typename Handle>
class DisjointSet {
public:
DisjointSet() {}
~DisjointSet() {
for (auto rep : nodes_) {
delete rep.second;
}
}
Status Merge(Handle x, Handle y);
const typename HandleToObject<Handle>::Object GetMergedValue(Handle value);
private:
// All the handles that belong to the same set are part of the same tree, and
// utimately represented by the root of that tree.
struct Rep {
// Parent in the tree used to encode the set.
Rep* parent;
// Rank in the tree, used to figure out how to compress the path to the root
// of the tree.
int rank;
// The handle.
typename HandleToObject<Handle>::Object value;
};
// Create a new set for the value if none exists, or return its representative
// node otherwise.
Rep* Find(Handle value);
private:
Processor<Handle> processor_;
std::unordered_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
nodes_;
};
template <typename Handle>
const typename HandleToObject<Handle>::Object
DisjointSet<Handle>::GetMergedValue(Handle value) {
Rep* rep = Find(value);
if (!rep) {
// We don't know anything about this handle.
return HandleToObject<Handle>::Unknown();
}
return rep->value;
}
template <typename Handle>
Status DisjointSet<Handle>::Merge(Handle x, Handle y) {
Rep* x_root = Find(x);
Rep* y_root = Find(y);
// x and y are already in the same set
if (x_root == y_root) {
return Status::OK();
}
// x and y are not in same set, so we merge them
// Use the occasion to strengthen what we know about the handle by merging the
// information about the 2 subsets.
if (x_root->rank < y_root->rank) {
TF_RETURN_IF_ERROR(processor_.Merge(y, x, &y_root->value));
x_root->parent = y_root;
} else if (x_root->rank > y_root->rank) {
TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
y_root->parent = x_root;
} else {
TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
// Arbitrarily make one root the new parent
y_root->parent = x_root;
x_root->rank = x_root->rank + 1;
}
return Status::OK();
}
template <typename Handle>
typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
auto it = nodes_.find(value);
if (it == nodes_.end()) {
// This is the first time we process this handle, create an entry for it.
Rep* node = new Rep;
node->parent = node;
node->rank = 0;
processor_.ExtractValue(value, &node->value);
nodes_[value] = node;
return node;
}
// Return the representative for the set, which is the root of the tree. Apply
// path compression to speedup future queries.
Rep* node = it->second;
Rep* root = node->parent;
while (root != root->parent) {
root = root->parent;
}
while (node->parent != root) {
Rep* next = node->parent;
node->parent = root;
node = next;
}
return root;
}
// TODO(dyoon): Move many helper functions in this file (including those within
// SymbolicShapeRefiner class) to shared utils.
bool IsEnqueue(const NodeDef& n) {
return (n.op().find("Enqueue") != string::npos &&
n.op().find("EnqueueMany") == string::npos);
}
bool IsDequeue(const NodeDef& n) {
return (n.op().find("Dequeue") != string::npos &&
n.op().find("DequeueMany") == string::npos);
}
bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
if (proto.unknown_rank()) {
return true;
}
for (const auto& dim : proto.dim()) {
if (dim.size() < 0) {
return true;
}
}
return false;
}
// This really should be done in an external debugging tool
void VerboseLogUnknownDimensionSources(
const GraphDef& graph,
const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>&
input_properties_map,
const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>&
output_properties_map) {
if (!VLOG_IS_ON(2)) {
return;
}
VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
// Find all nodes in the graph for which we
// do not have any unknown dimensions in their inputs, but
// we have some unknown dimensions in their outputs.
std::map<string, int> op_to_count;
for (const NodeDef& node : graph.node()) {
const auto& input_properties = input_properties_map.at(node.name());
const auto& output_properties = output_properties_map.at(node.name());
bool has_unknown_inputs = false;
for (const auto& input_prop : input_properties) {
if (HasAnyUnknownDimensions(input_prop.shape())) {
has_unknown_inputs = true;
break;
}
}
if (has_unknown_inputs) {
continue;
}
for (const auto& output_prop : output_properties) {
if (HasAnyUnknownDimensions(output_prop.shape())) {
string inputs = "input_shapes=[";
for (const auto& input_prop : input_properties) {
inputs += PartialTensorShape::DebugString(input_prop.shape());
}
inputs += "]";
string outputs = "output_shapes=[";
for (const auto& output_prop : output_properties) {
outputs += PartialTensorShape::DebugString(output_prop.shape());
}
outputs += "]";
VLOG(2) << "Node: " << node.name() << ", Op: " << node.op() << ", "
<< inputs << ", " << outputs;
op_to_count[node.op()]++;
// don't log again for this node
break;
}
}
}
VLOG(2) << "Op types with known inputs, but with unknown output dimensions "
<< "(format: <op_type> (<count>)):";
for (const auto& p : op_to_count) {
VLOG(2) << p.first << " (" << p.second << ")";
}
}
bool IsShapeFullyDefinedIntegerVectorOrScalar(
InferenceContext* ic, const ShapeHandle& shape,
const ShapeHandle& tensor_as_shape, const DataType& dtype) {
if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
!ic->FullyDefined(tensor_as_shape) ||
(dtype != DT_INT32 && dtype != DT_INT64)) {
return false;
}
return true;
}
// Returned tensor's shape is like `shape`, and its values and dtype are from
// `tensor_as_shape` and `dtype`.
TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
const ShapeHandle& shape,
const ShapeHandle& tensor_as_shape,
const DataType& dtype) {
TensorProto tensor_proto;
tensor_proto.set_dtype(dtype);
auto* shape_proto = tensor_proto.mutable_tensor_shape();
if (ic->Rank(shape) == 1) {
shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
}
// For a scalar tensor, tensor_shape field will be left empty; no dim.
for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
int64 value = ic->Value(ic->Dim(tensor_as_shape, i));
if (dtype == DT_INT32) {
tensor_proto.add_int_val(value);
} else {
tensor_proto.add_int64_val(value);
}
}
return tensor_proto;
}
// Returns a Const NodeDef with tensor `tensor_proto` and dtype = `dtype`.
NodeDef MakeConstNodeDefFromTensorProto(InferenceContext* ic,
const TensorProto& tensor_proto,
const DataType& dtype) {
NodeDef const_node;
const_node.set_name("const_from_shape");
const_node.set_op("Const");
auto* attr = const_node.mutable_attr();
(*attr)["dtype"].set_type(dtype);
auto* tensor = (*attr)["value"].mutable_tensor();
*tensor = tensor_proto;
return const_node;
}
// Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
// and dtype = `dtype`.
NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
const ShapeHandle& shape,
const ShapeHandle& tensor_as_shape,
const DataType& dtype) {
return MakeConstNodeDefFromTensorProto(
ic, MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype), dtype);
}
} // namespace
// Queue of nodes to process. Nodes can be enqueued in any order, but will be
// dequeued in (roughly) topological order. Propagating shapes following a
// topological ordering isn't required for correctness but helps speed things up
// since it avoids processing the same node multiple times as its inputs
// information is refined.
class TopoQueue {
public:
explicit TopoQueue(const std::vector<const NodeDef*>& topo_order)
: topo_order_(TopoOrder(topo_order)) {}
void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
const NodeDef* pop() {
CHECK(!empty());
auto it = queue_.begin();
const NodeDef* n = it->first;
queue_.erase(it);
return n;
}
bool empty() const { return queue_.empty(); }
std::size_t size() const { return queue_.size(); }
private:
using NodeAndId = std::pair<const NodeDef*, int>;
// Graph nodes are created in (roughly) topological order. Therefore we can
// use their id to ensure they're sorted topologically.
struct OrderByIdAscending {
bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
return lhs.second < rhs.second;
}
};
const std::unordered_map<const NodeDef*, int> TopoOrder(
const std::vector<const NodeDef*>& topo_order) const {
std::unordered_map<const NodeDef*, int> map;
map.reserve(topo_order.size());
for (int i = 0; i < topo_order.size(); ++i) {
map.emplace(topo_order[i], i);
}
return map;
}
const std::unordered_map<const NodeDef*, int> topo_order_;
std::set<NodeAndId, OrderByIdAscending> queue_;
};
bool IsNumericType(const DataType dtype) {
static const gtl::FlatSet<DataType>* const kRealNumberTypes =
CHECK_NOTNULL((new gtl::FlatSet<DataType>{
// Floating point.
DT_BFLOAT16,
DT_HALF,
DT_FLOAT,
DT_DOUBLE,
// Int / UInt.
DT_INT8,
DT_INT16,
DT_INT32,
DT_INT64,
DT_UINT8,
DT_UINT16,
DT_UINT32,
DT_UINT64,
// Quantized Int.
DT_QINT8,
DT_QUINT8,
DT_QINT16,
DT_QUINT16,
DT_QINT32,
// Bool.
DT_BOOL,
}));
return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
}
bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) {
static const gtl::FlatSet<string>* const kOpTpeWhitelist =
CHECK_NOTNULL((new gtl::FlatSet<string>{
// Unary arithmetic ops
"Floor",
"Round",
"Sqrt",
"Square",
"Sign",
// Binary arithmetic ops
"Add",
"Div",
"FloorDiv",
"FloorMod",
"Greater",
"GreaterEqual",
"Less",
"LessEqual",
"LogicalAnd",
"LogicalNot",
"LogicalOr",
"Maximum",
"Minimum",
"Mod",
"Mul",
"NotEqual",
"QuantizedAdd",
"QuantizedMul",
"SquareDifference",
"Sub",
"TruncateDiv",
"TruncateMod",
"RealDiv",
// N-ary arithemtic ops
"AddN",
// Others
"StridedSlice",
"OnesLike",
"ZerosLike",
"Concat",
"ConcatV2",
"Split",
"Range",
"Fill",
"Cast",
}));
return kOpTpeWhitelist->find(op_type) != kOpTpeWhitelist->end();
}
// Processes symbolic shapes.
// Each symbolic shape or dimension is represented by a handle. Unlike the TF
// shape refiner which creates new handles every time it processes an unknown
// shape/dimension, the symbolic shape refiner assigns a specific handle to each
// unknown shape/dimension of a given node.
class SymbolicShapeRefiner {
public:
explicit SymbolicShapeRefiner(
const GraphView& graph,
const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
const bool aggressive_shape_inference)
: graph_(graph),
function_library_(OpRegistry::Global(), graph.graph()->library()),
fed_ports_(fed_ports),
aggressive_shape_inference_(aggressive_shape_inference) {
graph_def_version_ = graph.graph()->versions().producer();
node_to_context_.reserve(graph.graph()->node_size());
}
const GraphView& graph() const { return graph_; }
struct NodeContext {
const OpRegistrationData* op_data;
DataTypeVector input_types;
DataTypeVector output_types;
std::unique_ptr<InferenceContext> inference_context;
// Additional info for propagating tensor values and tensor shapes.
std::vector<const TensorProto*> input_tensor_protos;
std::vector<const TensorProto*> output_tensor_protos;
std::vector<ShapeHandle> output_tensors_as_shapes;
// Output shapes incompatible between annotation and shape inference.
bool shape_incompatible = false;
};
NodeContext* GetNodeContext(const NodeDef* node) {
auto it = node_to_context_.find(node);
if (it == node_to_context_.end()) {
return nullptr;
}
return &it->second;
}
InferenceContext* GetContext(const NodeDef* node) {
auto it = node_to_context_.find(node);
if (it == node_to_context_.end()) {
return nullptr;
}
return it->second.inference_context.get();
}
// Forward the shapes from the function input nodes to
// the argument nodes (which are Placeholder nodes), then
// perform shape inference on the function body.
//
// Propagate shape information of final function body node
// to function node `function_node`.
//
// In the event of an error, UpdateNode will simply set `function_node`'s
// output shape to be Unknown.
Status UpdateFunction(const NodeDef* function_node) {
auto it = fun_to_grappler_function_item_.find(function_node->op());
if (it == fun_to_grappler_function_item_.end()) {
return errors::InvalidArgument(
function_node->op(),
" was not previously added to SymbolicShapeRefiner.");
}
const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
it->second;
if (!maybe_grappler_function_item.has_value()) {
VLOG(3) << "Skip failed to instantiate function call: function_name="
<< function_node->op();
auto* ctx = GetNodeContext(function_node);
auto* ic = ctx->inference_context.get();
for (int i = 0; i < ic->num_outputs(); ++i) {
TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i));
}
return Status::OK();
}
// Copy (not reference) so that changes we make here (e.g., replacing
// _Arg with Const and _Retval with Identity) don't affect one in
// fun_to_grappler_function_item_.
GrapplerFunctionItem grappler_function_item = *maybe_grappler_function_item;
MutableGraphView gv(&grappler_function_item.graph);
// Forward shapes from function input nodes to argument nodes.
for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
auto& fun_input = grappler_function_item.input(i);
NodeDef* fun_node = gv.GetNode(fun_input.node_name);
const TensorId input_tensor = ParseTensorName(function_node->input(i));
if (IsControlInput(input_tensor)) {
return errors::FailedPrecondition(
"Function inputs should not contain control nodes.");
}
const NodeDef* input_node = graph_.GetNode(input_tensor.node());
if (input_node == nullptr) {
return errors::FailedPrecondition(input_tensor.node(),
" was not found in the graph.");
}
InferenceContext* input_ic = GetContext(input_node);
if (input_ic == nullptr) {
return errors::FailedPrecondition(
"Inference context has not been created for ", input_tensor.node());
}
int output_port_num = input_tensor.index();
AttrValue attr_output_shape;
TensorShapeProto proto;
const auto& handle = input_ic->output(output_port_num);
input_ic->ShapeHandleToProto(handle, &proto);
// There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
for (int i = 0; i < proto.dim_size(); i++) {
if (proto.dim(i).size() < -1) {
proto.mutable_dim(i)->set_size(-1);
}
}
// Turn _Arg node into a Placeholder. _Arg node is a system op without a
// valid shape function.
*attr_output_shape.mutable_shape() = proto;
fun_node->set_op("Placeholder");
(*fun_node->mutable_attr())["dtype"] = (*fun_node->mutable_attr())["T"];
(*fun_node->mutable_attr()).erase("index");
(*fun_node->mutable_attr()).erase("T");
(*fun_node->mutable_attr())["shape"] = attr_output_shape;
}
// Replace input nodes with Consts, if values are known. Note that
// we don't check exceptions here as it's done in the above loop.
auto* ctx = GetNodeContext(function_node);
auto* ic = ctx->inference_context.get();
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
const string& input = function_node->input(i);
const string& node_name = NodeName(input);
const NodeDef* input_node = graph_.GetNode(node_name);
if (IsConstant(*input_node)) {
TF_CHECK_OK(
ReplaceInputWithConst(*input_node, i, &grappler_function_item));
} else if (ctx->input_tensor_protos.size() > i &&
ctx->input_tensor_protos[i] != nullptr) {
NodeDef const_input_node = MakeConstNodeDefFromTensorProto(
ic, *ctx->input_tensor_protos[i], ctx->input_types[i]);
TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
&grappler_function_item));
} else if (ic->input_tensors_as_shapes().size() > i &&
IsShapeFullyDefinedIntegerVectorOrScalar(
ic, ic->input(i), ic->input_tensors_as_shapes()[i],
ctx->input_types[i])) {
// We have fully defined input_tensors_as_shapes for this input; use it
// as a const input to the function node.
NodeDef const_input_node = MakeConstNodeDefFromShape(
ic, ic->input(i), ic->input_tensors_as_shapes()[i],
ctx->input_types[i]);
TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
&grappler_function_item));
}
}
// Replace output _Retval nodes with Identity nodes. _Retval is a system op
// without outputs and registered shape function.
for (const auto& output_arg : grappler_function_item.outputs()) {
NodeDef* output_node = gv.GetNode(output_arg.node_name);
DCHECK_EQ(output_node->op(), "_Retval");
output_node->set_op("Identity");
output_node->mutable_attr()->erase("index");
}
// Perform inference on function body.
GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(
/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/aggressive_shape_inference_,
/*include_tensor_values=*/true));
// Add return nodes for output shapes.
int output = 0;
ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
nullptr);
for (auto const& out_arg : grappler_function_item.outputs()) {
// It is guaranteed that output_tensors does not contain any control
// inputs, so port_id >= 0.
TensorId out_tensor = ParseTensorName(out_arg.node_name);
const NodeDef* retnode = gv.GetNode(out_tensor.node());
if (retnode == nullptr) {
return errors::FailedPrecondition(
"Unable to find return function_node ", out_tensor.node(), " for ",
function_node->name());
}
auto output_properties = gp.GetOutputProperties(retnode->name());
if (out_tensor.index() >= output_properties.size()) {
return errors::InvalidArgument(
out_tensor.ToString(), " has invalid position ", out_tensor.index(),
" (output_properties.size() = ", output_properties.size(), ").");
}
auto const& outprop = output_properties[out_tensor.index()];
const TensorShapeProto& shape = outprop.shape();
ShapeHandle out;
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
ic->set_output(output, out);
if (outprop.has_value()) {
// Forward tensor value to output_tensors_as_shape.
MaybeTensorProtoToShape(ic, outprop.value(),
&ctx->output_tensors_as_shapes[output]);
const_tensors_to_propagate_.push_back(outprop.value());
ctx->output_tensor_protos[output] = &const_tensors_to_propagate_.back();
}
output++;
}
return Status::OK();
}
// Prepares input shapes/values/handles, then runs shape inference, and
// finally sets output shapes/values/handles.
Status UpdateNode(const NodeDef* node, bool* refined) {
NodeContext* ctx = GetNodeContext(node);
if (ctx == nullptr) {
TF_RETURN_IF_ERROR(AddNode(node));
ctx = CHECK_NOTNULL(GetNodeContext(node));
*refined = true;
}
// Check if the shapes of the nodes in the fan-in of this node have changed,
// and if they have, update the node input shapes.
InferenceContext* ic = ctx->inference_context.get();
std::vector<ShapeHandle> input_tensors_as_shapes(ic->num_inputs());
ctx->input_tensor_protos.resize(ic->num_inputs(), nullptr);
for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
const GraphView::InputPort port(node, dst_input);
const GraphView::OutputPort fanin = graph_.GetRegularFanin(port);
int src_output = fanin.port_id;
const NodeDef* src = fanin.node;
NodeContext* src_ctx = GetNodeContext(src);
if (src_ctx == nullptr) {
return errors::FailedPrecondition(
"Input ", dst_input, " for '", node->name(),
"' was not previously added to SymbolicShapeRefiner.");
}
InferenceContext* src_ic = src_ctx->inference_context.get();
if (src_output >= src_ic->num_outputs()) {
return errors::OutOfRange("src_output = ", src_output,
", but num_outputs is only ",
src_ic->num_outputs());
}
// Propagate input node's NodeContext info to the current node's
// NodeContext:
// output_tensor_protos to input_tensor_protos and input_tensors, and
// output_tensors_as_shapes to input_tensors_as_shapes.
if (src_ctx->output_tensors_as_shapes.size() > src_output) {
input_tensors_as_shapes[dst_input] =
src_ctx->output_tensors_as_shapes[src_output];
}
if (src_ctx->output_tensor_protos.size() > src_output) {
const auto* tensor_proto = src_ctx->output_tensor_protos[src_output];
if (tensor_proto != nullptr) {
ctx->input_tensor_protos[dst_input] = tensor_proto;
if (!ic->FullyDefined(input_tensors_as_shapes[dst_input])) {
// Tensorflow uses '-1' to encode unknown shape or dimension:
//
// -1 : unknown shape
// [-1] : vector of unknown size
// [-1, -1] : matrix of unknown size
//
// For example `tf.reshape(x, [-1])` will reshape an arbitrary
// tensor x to a vector.
//
// It's possible that the same Const with -1 is used in many
// places, but that doesn't mean the resultant shapes are
// identical. e.g., x1 = Reshape(x, c) and y1 = Reshape(y, c),
// where c is [-1]. In this case, shape inference yields both x1 and
// y1 as rank 1, size unknown, but still the shapes of x1 and y1
// can be different. (even if we use different Const([-1]) for x1
// and x2, graph optimizer may merge them to single Const through
// duplicate removal.)
// If we reuse output_tensors_as_shapes to input_tensors_as_shapes
// by copying ShapeHandle, they share the same Shape object, and
// SymbolicShapeManager, later in InferStatically(), assigns the
// same symbolic dim value (unique value < -1); in the above
// Reshape example, the shapes of x1 and y1 become, for example,
// [-278] and graph optimizer may yield incorrect output 'cause it
// assumes x1 and y1 have the same shape.
// To prevent this, we re-create a ShapeHandle from the Const
// tensor, instead of reusing output_tensors_as_shapes (so that
// ShapeHandles of the const fanouts have the same values,
// but different Shape objects -- SymbolicShapeManager assigns
// different symbol id to each fanout shape).
// TODO(dyoon): clean up the way values are propagated.
MaybeTensorProtoToShape(ic, *tensor_proto,
&input_tensors_as_shapes[dst_input]);
}
}
}
// NOTE: we check only shape is refined; we do not (yet) check whether
// tensor value is refined.
if (!*refined &&
!ic->input(dst_input).SameHandle(src_ic->output(src_output))) {
*refined = true;
}
ic->SetInput(dst_input, src_ic->output(src_output));
if (!*refined && ic->requested_input_tensor_as_partial_shape(dst_input)) {
// The input value may have changed. Since we have no way to know if
// that's indeed the case, err on the safe side.
*refined = true;
}
// Also propagate handle shape and dtype of edges which are carrying
// resource handles.
if (ctx->input_types[dst_input] == DT_RESOURCE) {
auto* outputs = src_ic->output_handle_shapes_and_types(src_output);
if (!outputs) continue;
auto* inputs = ic->input_handle_shapes_and_types(dst_input);
if (!inputs || !EquivalentShapesAndTypes(*outputs, *inputs))
*refined = true;
ic->set_input_handle_shapes_and_types(dst_input, *outputs);
}
}
// Make sure we schedule the fanout of resources (which have no input)
// whenever the resources are updated.
*refined |= ic->num_inputs() == 0;
if (!*refined) {
// No input shape has changed, we're done.
return Status::OK();
}
// Notice: UpdateFunction only uses input_tensors_as_shapes, so for function
// nodes, we dont' perform the conversion from TensorProtos to Tensors for
// constant inputs here.
ic->set_input_tensors_as_shapes(input_tensors_as_shapes);
// Properly handle function nodes.
if (ctx->op_data && ctx->op_data->is_function_op) {
// TODO(jmdecker): Detect if the input shapes have changed for this
// function. Note that when we hit a function call node, refined will be
// true, as the updates to the call node will have changed, even if it's
// the same function being called twice with the same input shapes.
// Example: simple_function.pbtxt
auto s = UpdateFunction(node);
if (s.ok()) {
return Status::OK();
} else {
VLOG(1) << "UpdateFunction failed for " << node->op()
<< ". Defaulting to ShapeUnknown.\n"
<< s.ToString();
}
}
// Construct Tensors for constant inputs used by shape functions.
std::vector<Tensor> const_values(ic->num_inputs());
std::vector<const Tensor*> input_tensors(ic->num_inputs(), nullptr);
for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
const TensorProto* tensor_proto = ctx->input_tensor_protos[dst_input];
if (tensor_proto != nullptr &&
const_values[dst_input].FromProto(*tensor_proto)) {
input_tensors[dst_input] = &const_values[dst_input];
}
}
ic->set_input_tensors(input_tensors);
// Update the shapes of the outputs.
return InferShapes(*node, ctx);
}
Status SetUnknownShape(const NodeDef* node, int output_port) {
shape_inference::ShapeHandle shape =
GetUnknownOutputShape(node, output_port);
InferenceContext* ctx = GetContext(node);
if (ctx == nullptr) {
return errors::InvalidArgument("Missing context");
}
ctx->set_output(output_port, shape);
return Status::OK();
}
struct ShapeId {
const NodeDef* node;
int port_id;
bool operator==(const ShapeId& other) const {
return node == other.node && port_id == other.port_id;
}
};
struct HashShapeId {
std::size_t operator()(const ShapeId& shp) const {
return std::hash<const NodeDef*>{}(shp.node) + shp.port_id;
}
};
struct DimId {
const NodeDef* node;
int port_id;
int dim_index;
bool operator==(const DimId& other) const {
return node == other.node && port_id == other.port_id &&
dim_index == other.dim_index;
}
};
struct HashDimId {
std::size_t operator()(const DimId& dim) const {
return std::hash<const NodeDef*>{}(dim.node) + dim.port_id +
dim.dim_index;
}
};
// 'port_index' as the union of shape1 and shape2.
ShapeHandle OutputAsUnion(const NodeDef* node, int port_index,
ShapeHandle shape1, ShapeHandle shape2) {
if (shape1.SameHandle(shape2)) {
return shape1;
}
InferenceContext* ctx = GetContext(node);
ShapeHandle relaxed = shape1;
const int rank = ctx->Rank(shape1);
if (!ctx->RankKnown(shape2) || ctx->Rank(shape2) != rank) {
relaxed = GetUnknownOutputShape(node, port_index);
} else {
for (int d = 0; d < rank; ++d) {
if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
int64 val1 = ctx->Value(ctx->Dim(shape1, d));
int64 val2 = ctx->Value(ctx->Dim(shape2, d));
if (val1 != val2 || (val1 < 0 && val2 < 0)) {
DimensionHandle new_dim = GetUnknownOutputDim(node, port_index, d);
TF_CHECK_OK(ctx->ReplaceDim(relaxed, d, new_dim, &relaxed));
}
}
}
}
return relaxed;
}
bool EquivalentShapes(ShapeHandle s1, ShapeHandle s2) const {
if (s1.SameHandle(s2)) {
return true;
}
if (InferenceContext::Rank(s1) != InferenceContext::Rank(s2)) {
return false;
}
if (!InferenceContext::RankKnown(s1) && !InferenceContext::RankKnown(s2)) {
return true;
}
const int rank = InferenceContext::Rank(s1);
for (int i = 0; i < rank; ++i) {
if (!InferenceContext::DimKnownRank(s1, i).SameHandle(
InferenceContext::DimKnownRank(s2, i))) {
int64 val1 =
InferenceContext::Value(InferenceContext::DimKnownRank(s1, i));
int64 val2 =
InferenceContext::Value(InferenceContext::DimKnownRank(s2, i));
if (val1 >= 0 && val2 >= 0 && val1 == val2) {
continue;
}
return false;
}
}
return true;
}
// Return true if the annotated shape is compatible with shape inference
// result. Examples:
// Inferred shape: ?, annotated shape: [10, 10] -> true;
// Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
// Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
// Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
bool CompatibleShapes(ShapeHandle inferred_shape,
ShapeHandle annotated_shape) const {
if (inferred_shape.SameHandle(annotated_shape)) {
return true;
}
if (!InferenceContext::RankKnown(inferred_shape)) {
return true;
}
if (InferenceContext::Rank(inferred_shape) !=
InferenceContext::Rank(annotated_shape)) {
return false;
}
const int rank = InferenceContext::Rank(inferred_shape);
for (int i = 0; i < rank; ++i) {
if (!InferenceContext::DimKnownRank(inferred_shape, i)
.SameHandle(
InferenceContext::DimKnownRank(annotated_shape, i))) {
int64 val1 = InferenceContext::Value(
InferenceContext::DimKnownRank(inferred_shape, i));
int64 val2 = InferenceContext::Value(
InferenceContext::DimKnownRank(annotated_shape, i));
if (val1 >= 0 && val1 != val2) {
return false;
}
}
}
return true;
}
bool SameShapes(ShapeHandle inferred_shape,
ShapeHandle annotated_shape) const {
if (inferred_shape.SameHandle(annotated_shape)) {
return true;
}
if (InferenceContext::Rank(inferred_shape) !=
InferenceContext::Rank(annotated_shape)) {
return false;
}
const int rank = InferenceContext::Rank(inferred_shape);
for (int i = 0; i < rank; ++i) {
int64 val1 = InferenceContext::Value(
InferenceContext::DimKnownRank(inferred_shape, i));
int64 val2 = InferenceContext::Value(
InferenceContext::DimKnownRank(annotated_shape, i));
if (val1 != val2) {
return false;
}
}
return true;
}
bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
const std::vector<ShapeAndType>& st2) const {
if (st1.size() != st2.size()) {
return false;
}
for (int i = 0; i < st1.size(); ++i) {
const ShapeAndType& s1 = st1[i];
const ShapeAndType& s2 = st2[i];
if (s1.dtype != s2.dtype) {
return false;
}
if (!EquivalentShapes(s1.shape, s2.shape)) {
return false;
}
}
return true;
}
Status AddFunction(const NodeDef* function_node) {
auto it = fun_to_grappler_function_item_.find(function_node->op());
if (it != fun_to_grappler_function_item_.end()) {
return Status::OK();
}
const FunctionDef* function_def =
CHECK_NOTNULL(function_library_.Find(function_node->op()));
GrapplerFunctionItem grappler_function_item;
Status function_instantiated =
MakeGrapplerFunctionItem(*function_def, function_library_,
graph_def_version_, &grappler_function_item);
// If function instantiation failed we will skip it during shape inference.
if (!function_instantiated.ok()) {
VLOG(3) << "Failed to instantiate a function. Error: "
<< function_instantiated.error_message();
fun_to_grappler_function_item_[function_def->signature().name()] =
absl::nullopt;
return Status::OK();
}
if (grappler_function_item.inputs().size() > function_node->input_size()) {
return errors::FailedPrecondition(
"Function input size should be smaller than node input size.");
}
for (int i = grappler_function_item.inputs().size();
i < function_node->input_size(); ++i) {
const string& input = function_node->input(i);
if (!IsControlInput(input)) {
return errors::FailedPrecondition(
"Found regular input (", input,
") instead of control nodes for node ", function_node->name());
}
}
fun_to_grappler_function_item_[function_def->signature().name()] =
grappler_function_item;
return Status::OK();
}
Status AddNode(const NodeDef* node) {
NodeContext& node_ctx = node_to_context_[node];
TF_RETURN_IF_ERROR(function_library_.LookUp(node->op(), &node_ctx.op_data));
if (node_ctx.op_data->is_function_op) {
TF_RETURN_IF_ERROR(AddFunction(node));
}
TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
&node_ctx.input_types,
&node_ctx.output_types));
// Create the inference context for this node.
const int num_inputs = node_ctx.input_types.size();
std::vector<ShapeHandle> input_shapes(num_inputs);
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
input_handle_shapes_and_types(num_inputs);
std::vector<const Tensor*> input_tensors(num_inputs, nullptr);
std::vector<ShapeHandle> input_tensors_as_shapes;
node_ctx.inference_context.reset(new InferenceContext(
graph_def_version_, node, node_ctx.op_data->op_def, input_shapes,
input_tensors, input_tensors_as_shapes,
std::move(input_handle_shapes_and_types)));
const Status s = node_ctx.inference_context->construction_status();
if (!s.ok()) {
node_ctx.inference_context.reset(nullptr);
}
return s;
}
private:
// Return the one ShapeHandle used to denote a fully unknown shape for a node
// output.
ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) {
ShapeId id{node, index};
auto it = unknown_shapes_.find(id);
if (it != unknown_shapes_.end()) {
return it->second;
}
InferenceContext* c = GetContext(node);
ShapeHandle shp = c->UnknownShape();
unknown_shapes_[id] = shp;
return shp;
}
// Return the one ShapeHandle used to denote a fully unknown dimension for a
// node output.
DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index,
int dim_id) {
DimId id{node, index, dim_id};
auto it = unknown_dims_.find(id);
if (it != unknown_dims_.end()) {
return it->second;
}
InferenceContext* c = GetContext(node);
DimensionHandle dim = c->UnknownDim();
unknown_dims_[id] = dim;
return dim;
}
// Returns true if all the output tensors have known values.
bool AllOutputValuesKnown(NodeContext* c) {
InferenceContext* ic = c->inference_context.get();
if (c->output_tensors_as_shapes.size() < ic->num_outputs() &&
c->output_tensor_protos.size() < ic->num_outputs()) {
return false;
} else {
// Checks if we can get output value via either output_tensor_proto or
// output_tensors_as_shapes.
for (int i = 0; i < ic->num_outputs(); i++) {
if (c->output_tensor_protos.size() > i &&
c->output_tensor_protos[i] != nullptr) {
continue;
}
if (c->output_tensors_as_shapes.size() > i &&
ic->FullyDefined(c->output_tensors_as_shapes[i])) {
continue;
}
// Unknown for output[i].
return false;
}
}
return true;
}
// Returns true if we can infer output tensors' values -- we know values of
// all the input tensors.
bool AllInputValuesKnown(NodeContext* c) {
InferenceContext* ic = c->inference_context.get();
// Check inputs are fully defined and values are known.
for (int i = 0; i < ic->num_inputs(); i++) {
const Tensor* tensor = ic->input_tensor(i);
// Note that we don't check c->input_tensor_protos[i], as UpdateNode()
// already converted it to ic->input_tensor(i);
const ShapeHandle& input_tensors_as_shape =
ic->input_tensors_as_shapes()[i];
// Either input_tensor is valid or input_tensors_as_shape, which has
// value of input tensors as shape format, should be fully defined.
if (tensor == nullptr && !ic->FullyDefined(input_tensors_as_shape)) {
return false;
}
}
return true;
}
// Returns true if we want to update output shapes and values with running
// EvaluateNode() for this op, based on op type, data type, and size.
bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64 max_size) {
InferenceContext* ic = c->inference_context.get();
// Due to the cost of running EvaluateNode(), we limit only to white listed
// op types.
if (!IsWhiteListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
return false;
}
// Check input dtypes are number types.
for (const auto& input_type : c->input_types) {
if (!IsNumericType(input_type)) {
return false;
}
}
// Check output dtypes are number types.
for (const auto& output_type : c->output_types) {
if (!IsNumericType(output_type)) {
return false;
}
}
// Check if the number of elements of each of input tensor is no larger than
// the given max size.
for (int i = 0; i < ic->num_inputs(); i++) {
const Tensor* tensor = ic->input_tensor(i);
const ShapeHandle& input_shape_handle = ic->input(i);
if (tensor != nullptr) {
if (tensor->NumElements() > max_size) {
return false;
}
} else if (ic->Value(ic->NumElements(input_shape_handle)) > max_size) {
return false;
}
}
// Check if we know the shape of each output tensor, and the number of
// elements is larger than the given max size.
for (int i = 0; i < ic->num_outputs(); i++) {
const ShapeHandle& shape_handle = ic->output(i);
if (!ic->FullyDefined(shape_handle) ||
ic->Value(ic->NumElements(shape_handle)) > max_size) {
return false;
}
}
return true;
}
// Create input tensors from the NodeConext.
void CreateInputTensors(NodeContext* c,
std::vector<Tensor>* input_tensor_vector,
TensorVector* inputs) {
InferenceContext* ic = c->inference_context.get();
for (int i = 0; i < ic->num_inputs(); i++) {
if (ic->input_tensor(i)) {
input_tensor_vector->at(i) = *ic->input_tensor(i);
inputs->emplace_back(&input_tensor_vector->at(i));
// Note that we don't check c->input_tensor_protos[i], as UpdateNode()
// already converted it to ic->input_tensor(i);
} else {
// Create Tensor from input_tensors_as_shapes, and then emplace it
// back to inputs.
// Note that input_tensors_as_shapes is scalar or vector.
const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
const DataType& data_type = c->input_types[i];
int32 rank = ic->Rank(shape_handle);
if (rank < 1) {
input_tensor_vector->at(i) = Tensor(data_type, {});
} else {
input_tensor_vector->at(i) = Tensor(data_type, {rank});
}
auto* tensor = &input_tensor_vector->at(i);
if (data_type == DT_INT32) {
auto flat = tensor->flat<int32>();
for (int j = 0; j < rank; j++) {
int32 dim = ic->Value(ic->Dim(shape_handle, j));
flat(j) = dim;
}
} else {
auto flat = tensor->flat<int64>();
for (int j = 0; j < rank; j++) {
int64 dim = ic->Value(ic->Dim(shape_handle, j));
flat(j) = dim;
}
}
inputs->emplace_back(tensor);
}
}
}
// Run a node to infer output shapes and values, and add it to the
// NodeContext.
Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
InferenceContext* ic = c->inference_context.get();
// Input to EvaluateNode()
TensorVector inputs;
// Container for temporaily created tensor object.
std::vector<Tensor> input_tensor_vector(ic->num_inputs());
CreateInputTensors(c, &input_tensor_vector, &inputs);
// Output for EvaluateNode() and output tensor clean up object.
TensorVector outputs;
auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
for (const auto& output : outputs) {
if (output.tensor) {
delete output.tensor;
}
}
});
TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, /*cpu_device=*/nullptr,
&resource_mgr_, &outputs));
c->output_tensors_as_shapes.resize(outputs.size());
c->output_tensor_protos.resize(outputs.size(), nullptr);
for (int k = 0; k < outputs.size(); k++) {
const auto& t = outputs[k];
// Override output shape.
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(
ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
if (ic->FullyDefined(ic->output(k)) &&
!EquivalentShapes(ic->output(k), output_shape)) {
LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
<< ", inferred output shape "
<< "doesn't match for k=" << k << ": "
<< "ic->output(k): " << ic->DebugString(ic->output(k))
<< ", output_shape: " << ic->DebugString(output_shape)
<< " -- " << node.DebugString();
}
ic->set_output(k, output_shape);
// Set output_tensors_as_shape.
MaybeTensorValueToShape(ic, *t.tensor, &c->output_tensors_as_shapes[k]);
// Set output_tensor_protos.
TensorProto tensor_proto;
t->AsProtoTensorContent(&tensor_proto);
const_tensors_to_propagate_.push_back(tensor_proto);
c->output_tensor_protos[k] = &const_tensors_to_propagate_.back();
}
return Status::OK();
}
// Update output shapes with annotated information.
// Currently only handle nodes with static shapes, i.e. shapes do not change
// during execution.
// TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
NodeContext* c) const {
const auto& attr = node.attr();
if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
attr.count(kOutputShapes) == 0)
return Status::OK();
InferenceContext* ic = c->inference_context.get();
int output_size = attr.at(kOutputShapes).list().shape_size();
for (int i = 0; i < ic->num_outputs(); i++) {
// Annotated Switch node has only one output. Propagate the shape to all
// the outputs.
int shape_index = IsSwitch(node) ? 0 : i;
if (shape_index >= output_size) {
LOG(WARNING)
<< "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
<< node.name() << ", inferred output shape size "
<< ic->num_outputs() << ", annotated output shape size "
<< output_size;
break;
}
const TensorShapeProto& shape =
attr.at(kOutputShapes).list().shape(shape_index);
if (shape.dim().empty()) continue;
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
// Check if annotated shapes are incompatible with inferred shapes.
if ((ic->FullyDefined(ic->output(i)) &&
!SameShapes(ic->output(i), output_shape)) ||
(!ic->FullyDefined(ic->output(i)) &&
!CompatibleShapes(ic->output(i), output_shape))) {
LOG(WARNING)
<< "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
<< node.name() << ", inferred output shape "
<< "doesn't match for i=" << i << ": "
<< "ic->output(k): " << ic->DebugString(ic->output(i))
<< ", annotated output shape: " << ic->DebugString(output_shape)
<< " -- " << node.DebugString();
c->shape_incompatible = true;
}
// Only use annotated shapes if the inference shape is unknown and
// compatible with annotated shapes.
if (!ic->FullyDefined(ic->output(i)) &&
CompatibleShapes(ic->output(i), output_shape)) {
VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
<< node.name() << ", inferred output shape " << i << ": "
<< "ic->output(i): " << ic->DebugString(ic->output(i))
<< ", annotated output shape: " << ic->DebugString(output_shape)
<< " -- " << node.ShortDebugString();
ic->set_output(i, output_shape);
}
}
return Status::OK();
}
Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
NodeContext* c) {
// Propagate tensors and shape tensors unless the node is fed.
// TODO(bsteiner) We should still propagate the shapes to the ports that
// aren't fed in the case of a ShapeN node.
InferenceContext* ic = c->inference_context.get();
if (!is_fed) {
if (IsConstant(node)) {
c->output_tensor_protos.resize(1);
const TensorProto& tensor_proto = node.attr().at("value").tensor();
c->output_tensor_protos[0] = &tensor_proto;
c->output_tensors_as_shapes.resize(1);
MaybeTensorProtoToShape(ic, tensor_proto,
&c->output_tensors_as_shapes[0]);
} else if (IsRank(node)) {
if (ic->RankKnown(ic->input(0))) {
// Propagate rank value.
int32 rank = ic->Rank(ic->input(0));
const_tensors_to_propagate_.push_back(
MakeIntegerScalarTensorProto(DT_INT32, rank));
c->output_tensor_protos.resize(1);
c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
}
} else if (IsSize(node)) {
DimensionHandle size = ic->NumElements(ic->input(0));
if (ic->ValueKnown(size)) {
// Propagate size value.
int64 sz = ic->Value(size);
bool valid = false;
if (node.attr().at("out_type").type() == DT_INT32) {
if (sz < std::numeric_limits<int32>::max()) {
const_tensors_to_propagate_.push_back(
MakeIntegerScalarTensorProto(DT_INT32, sz));
valid = true;
}
} else {
const_tensors_to_propagate_.push_back(
MakeIntegerScalarTensorProto(DT_INT64, sz));
valid = true;
}
if (valid) {
c->output_tensor_protos.resize(1);
c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
}
}
} else if (IsShape(node)) {
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = c->inference_context->input(0);
} else if (IsShapeN(node)) {
c->output_tensors_as_shapes.resize(c->inference_context->num_inputs());
for (int i = 0; i < c->inference_context->num_inputs(); ++i) {
c->output_tensors_as_shapes[i] = c->inference_context->input(i);
}
} else if (node.op() == "ConcatV2") {
bool valid = true;
ShapeHandle result;
for (int i = 0; i < ic->num_inputs() - 1; ++i) {
ShapeHandle input = ic->input_tensors_as_shapes()[i];
if (!ic->RankKnown(input)) {
valid = false;
break;
} else if (i == 0) {
result = input;
} else {
TF_RETURN_IF_ERROR(ic->Concatenate(result, input, &result));
}
}
if (valid) {
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = result;
}
} else if (IsPack(node)) {
// A Pack node concatenating scalars is often used to generate a shape.
std::vector<DimensionHandle> dims;
bool valid = true;
for (int i = 0; i < ic->num_inputs(); ++i) {
const Tensor* t = ic->input_tensor(i);
if (t) {
if (t->dims() != 0 ||
(t->dtype() != DT_INT32 && t->dtype() != DT_INT64)) {
valid = false;
break;
}
int64 size = t->dtype() == DT_INT32 ? t->scalar<int32>()()
: t->scalar<int64>()();
dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size));
} else {
// Don't have tensor value, but use input_tensors_as_shapes, if
// possible.
const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
ic->ValueKnown(ic->Dim(shape_handle, 0))) {
dims.push_back(ic->Dim(shape_handle, 0));
} else {
dims.push_back(ic->UnknownDim());
}
}
}
if (valid) {
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
}
} else if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0];
if (c->input_tensor_protos[0] != nullptr) {
c->output_tensor_protos.resize(1);
c->output_tensor_protos[0] = c->input_tensor_protos[0];
}
} else if (IsSlice(node)) {
ShapeHandle input = ic->input_tensors_as_shapes()[0];
bool valid = ic->RankKnown(input);
const Tensor* slice_offset = ic->input_tensor(1);
valid &= slice_offset != nullptr && slice_offset->NumElements() == 1;
const Tensor* slice_size = ic->input_tensor(2);
valid &= slice_size != nullptr && slice_size->NumElements() == 1;
if (valid) {
int64 start = slice_offset->dtype() == DT_INT32
? slice_offset->flat<int32>()(0)
: slice_offset->flat<int64>()(0);
int64 size =
(slice_size->dtype() == DT_INT32 ? slice_size->flat<int32>()(0)
: slice_size->flat<int64>()(0));
ShapeHandle result;
if (size == -1) {
TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result));
} else {
int64 end = start + size;
TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
}
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = result;
}
} else if (IsStridedSlice(node)) {
ShapeHandle input = ic->input_tensors_as_shapes()[0];
bool valid = ic->RankKnown(input);
const Tensor* slice_begin = ic->input_tensor(1);
valid &= slice_begin != nullptr && slice_begin->NumElements() == 1;
const Tensor* slice_end = ic->input_tensor(2);
valid &= slice_end != nullptr && slice_end->NumElements() == 1;
const Tensor* slice_stride = ic->input_tensor(3);
valid &= slice_stride != nullptr && slice_stride->NumElements() == 1;
if (node.attr().count("ellipsis_mask") > 0 &&
node.attr().at("ellipsis_mask").i() != 0) {
valid = false;
}
if (node.attr().count("new_axis_mask") > 0 &&
node.attr().at("new_axis_mask").i() != 0) {
valid = false;
}
if (node.attr().count("shrink_axis_mask") > 0 &&
node.attr().at("shrink_axis_mask").i() != 0) {
valid = false;
}
int begin_mask = 0;
if (node.attr().count("begin_mask") > 0) {
begin_mask = node.attr().at("begin_mask").i();
}
int end_mask = 0;
if (node.attr().count("end_mask") > 0) {
end_mask = node.attr().at("end_mask").i();
}
if (begin_mask < 0 || begin_mask > 1 || end_mask < 0 || end_mask > 1) {
valid = false;
}
if (valid) {
int64 begin = 0;
if (begin_mask == 0) {
begin = slice_begin->dtype() == DT_INT32
? slice_begin->flat<int32>()(0)
: slice_begin->flat<int64>()(0);
}
int64 end = std::numeric_limits<int64>::max();
if (end_mask == 0) {
end =
(slice_end->dtype() == DT_INT32 ? slice_end->flat<int32>()(0)
: slice_end->flat<int64>()(0));
}
int64 stride = slice_stride->dtype() == DT_INT32
? slice_stride->flat<int32>()(0)
: slice_stride->flat<int64>()(0);
ShapeHandle result;
TF_RETURN_IF_ERROR(ic->Subshape(input, begin, end, stride, &result));
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = result;
}
}
}
if (aggressive_shape_inference_) {
// Update output shapes with annotated information. This is optional.
UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
// Update output tensor values using EvaluateNode() if we can.
// Due to the cost of EvaluateNode(), we run it only for certain op types
// (white listed) and small integer tensors.
const int max_element_size = 17; // Max up to 4x4 matrix or similar.
if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
!ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
return Status::OK();
}
UpdateOutputShapesAndValues(node, c).IgnoreError(); // This is optional.
}
return Status::OK();
}
Status InferShapes(const NodeDef& node, NodeContext* c) {
// Infer the shapes of output tensors.
if (!c->op_data || c->op_data->shape_inference_fn == nullptr ||
!c->inference_context->Run(c->op_data->shape_inference_fn).ok()) {
// Annotate outputs with unknown shapes. Update output shapes with
// annotated information later on if available.
// Note that shape inference function may return an error, but we ignore
// it, and use UnknownShape in that case.
TF_RETURN_IF_ERROR(
c->inference_context->Run(shape_inference::UnknownShape));
}
Status status = Status::OK();
auto it = fed_ports_.find(node.name());
const bool is_fed = it != fed_ports_.end();
if (is_fed) {
// It is possible to feed node output ports with tensors of any shape: as
// a result, the shape of a fed port is completely unknown.
for (const int output_port : it->second) {
status.Update(SetUnknownShape(&node, output_port));
}
}
// Update NodeContext output fields after shape inference function runs.
status.Update(MaybeUpdateNodeContextOutput(node, is_fed, c));
return status;
}
private:
bool IsIntegerVector(const Tensor& tensor) {
if (tensor.dims() == 1 &&
(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
return true;
}
return false;
}
bool IsIntegerScalar(const Tensor& tensor) {
if (tensor.dims() == 0 &&
(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
tensor.NumElements() == 1) {
return true;
}
return false;
}
TensorProto MakeIntegerScalarTensorProto(const DataType dtype,
const int64 val) {
TensorProto tensor_proto;
tensor_proto.set_dtype(dtype);
// Scalar TensorProto has an empty tensor_shape; no dim, no dim.size.
tensor_proto.mutable_tensor_shape();
if (dtype == DT_INT32) {
tensor_proto.add_int_val(val);
} else if (dtype == DT_INT64) {
tensor_proto.add_int64_val(val);
}
return tensor_proto;
}
bool MaybeTensorProtoToShape(InferenceContext* ic,
const TensorProto& tensor_proto,
ShapeHandle* tensors_as_shapes) {
// Skip if dtype is not integer.
if (tensor_proto.dtype() != DT_INT32 && tensor_proto.dtype() != DT_INT64) {
return false;
}
// Skip if shape is neither scalar nor vector.
if (tensor_proto.tensor_shape().unknown_rank() ||
tensor_proto.tensor_shape().dim_size() > 1) {
return false;
}
Tensor tensor;
if (!tensor.FromProto(tensor_proto)) {
return false;
}
return MaybeTensorValueToShape(ic, tensor, tensors_as_shapes);
}
bool MaybeTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
ShapeHandle* tensors_as_shapes) {
// Integer tensors of rank one can also be interpreted as a shape
// provided all their values are >= -1.
if (IsIntegerVector(tensor)) {
bool has_values_smaller_than_minus_1 = false;
std::vector<DimensionHandle> dims;
for (int i = 0; i < tensor.NumElements(); i++) {
int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(i)
: tensor.flat<int64>()(i);
has_values_smaller_than_minus_1 |= (value < -1);
dims.push_back(value < 0 ? ic->UnknownDim() : ic->MakeDim(value));
}
if (!has_values_smaller_than_minus_1) {
*tensors_as_shapes = ic->MakeShape(dims);
}
} else if (IsIntegerScalar(tensor)) {
// Scalar constant.
int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
: tensor.flat<int64>()(0);
if (value == -1) {
// Scalar value -1 represents an unknown shape. If we would try to
// MakeShape(MakeDim) with it, we would get vector of unknown size.
*tensors_as_shapes = ic->UnknownShape();
return true;
} else if (value >= 0) {
// Ideally, values can be < -1, but MakeDim() fails with a value < -1.
// It's a limitation as we use ShapeHandle as a means to pass values.
*tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
return true;
}
}
return false;
}
const GraphView& graph_;
int graph_def_version_;
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
// Store function instantiations only for valid function. If function
// instantiation failed it will have an `absl::nullopt`.
std::unordered_map<string, absl::optional<GrapplerFunctionItem>>
fun_to_grappler_function_item_;
FunctionLibraryDefinition function_library_;
const std::unordered_map<string, std::unordered_set<int>>& fed_ports_;
// Store TensorProtos for tensor value propagation. Note that we use list, not
// vector, as we use pointers to the TensorProtos in this container. Vector
// may resize and copy the objects into a new buffer, then the existing
// pointers become dangling pointers.
std::list<TensorProto> const_tensors_to_propagate_;
// For more aggressive shape and value inference.
bool aggressive_shape_inference_;
ResourceMgr resource_mgr_;
};
// Keep track of shapes and dimensions in a graph.
// In particular, use disjoint sets to track equivalence between shapes and
// dims, and consolidate the information globally.
class SymbolicShapeManager {
public:
SymbolicShapeManager() {}
Status Merge(ShapeHandle s1, ShapeHandle s2) {
if (!s1.IsSet() || !s2.IsSet()) {
return Status::OK();
}
TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2));
if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) {
CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2));
for (int i = 0; i < InferenceContext::Rank(s1); ++i) {
TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i),
InferenceContext::DimKnownRank(s2, i)));
}
}
return Status::OK();
}
Status Merge(DimensionHandle d1, DimensionHandle d2) {
if (!d1.IsSet() || !d2.IsSet()) {
return Status::OK();
}
return dims_.Merge(d1, d2);
}
void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
OpInfo::TensorProperties* properties) {
properties->set_dtype(type);
ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
if (!InferenceContext::RankKnown(actual_shape)) {
properties->mutable_shape()->set_unknown_rank(true);
} else {
for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
shape_inference::DimensionHandle dim =
InferenceContext::DimKnownRank(actual_shape, j);
int64 d = dims_.GetMergedValue(dim);
properties->mutable_shape()->add_dim()->set_size(d);
}
}
}
private:
DisjointSet<shape_inference::ShapeHandle> shapes_;
DisjointSet<shape_inference::DimensionHandle> dims_;
};
Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
const std::vector<ShapeAndType>& shapes_and_types,
std::vector<ShapeAndType>* queue_shapes_and_types) {
if (shapes_and_types.size() != queue_shapes_and_types->size()) {
return errors::InvalidArgument(
"Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
" vs ", queue_shapes_and_types->size());
}
for (size_t i = 0; i < shapes_and_types.size(); ++i) {
const ShapeAndType& a = shapes_and_types[i];
ShapeAndType& b = (*queue_shapes_and_types)[i];
if (a.dtype != b.dtype) {
return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
i, ": ", DataTypeString(a.dtype), " vs ",
DataTypeString(b.dtype));
}
b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape);
}
return Status::OK();
}
// Compute the output shape of the merge node as the union of the available
// input shapes.
Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner,
const NodeDef* node,
bool* new_shapes) const {
InferenceContext* ic = shape_refiner->GetContext(node);
if (!ic) {
// Now we can run shape inference
TF_RETURN_IF_ERROR(shape_refiner->AddNode(node));
ic = CHECK_NOTNULL(shape_refiner->GetContext(node));
*new_shapes = true;
// Infer the shape of the second output once and for all since it never
// changes.
ShapeHandle out1 = ic->Scalar();
ic->set_output(1, out1);
}
ShapeHandle out;
const std::vector<ShapeAndType>* out_handle = nullptr;
bool out_initialized = false;
for (const GraphView::Edge fanin : shape_refiner->graph().GetFaninEdges(
*node, /*include_controlling_edges=*/false)) {
InferenceContext* src_ic = shape_refiner->GetContext(fanin.src.node);
if (!src_ic) {
// Handling a loop for the first time, the back edge won't have any shape
// info.
continue;
}
ShapeHandle input = src_ic->output(fanin.src.port_id);
ic->SetInput(fanin.dst.port_id, input);
auto* input_handle =
src_ic->output_handle_shapes_and_types(fanin.src.port_id);
if (input_handle)
ic->set_input_handle_shapes_and_types(fanin.dst.port_id, *input_handle);
if (!out_initialized) {
out_initialized = true;
out = input;
out_handle = input_handle;
} else {
// Note here only out, not out_handle, is modified.
out = shape_refiner->OutputAsUnion(node, 0, input, out);
}
}
if (*new_shapes || !shape_refiner->EquivalentShapes(out, ic->output(0))) {
ic->set_output(0, out);
if (out_handle) ic->set_output_handle_shapes_and_types(0, *out_handle);
*new_shapes = true;
}
return Status::OK();
}
// Manually propagate the input shape for Enter nodes.
Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
const NodeDef* node, bool* new_shapes) {
InferenceContext* ic = shape_refiner->GetContext(node);
if (!ic) {
TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, new_shapes));
ic = shape_refiner->GetContext(node);
}
GraphView::InputPort port(node, 0);
GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(port);
InferenceContext* src_ic = shape_refiner->GetContext(fanin.node);
ShapeHandle input = src_ic->output(fanin.port_id);
if (!ic->output(0).SameHandle(input)) {
ic->SetInput(0, input);
ic->set_output(0, input);
*new_shapes = true;
}
auto* outputs = src_ic->output_handle_shapes_and_types(fanin.port_id);
if (outputs) {
ic->set_input_handle_shapes_and_types(0, *outputs);
ic->set_output_handle_shapes_and_types(0, *outputs);
*new_shapes = true;
}
return Status::OK();
}
Status GraphProperties::UpdateShapes(
SymbolicShapeRefiner* shape_refiner,
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
const NodeDef* n, bool* new_shapes) const {
if (IsEnter(*n)) {
// The Enter shape function always forwards an UnknownShape, so do the right
// thing here.
TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, new_shapes));
} else if (IsMerge(*n)) {
// Properly handle merge nodes.
TF_RETURN_IF_ERROR(UpdateMerge(shape_refiner, n, new_shapes));
} else if (IsEnqueue(*n)) {
// Make sure the shapes of enqueued tensors are propagated to the queue
// itself.
TF_RETURN_IF_ERROR(
UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
} else if (IsQueue(*n)) {
// Set shapes and types of Queue ops, if needed.
TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
} else {
// Rely on regular TF shape refinement for all the other nodes.
// UpdateNode calls UpdateFunction if a function node is detected.
TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
}
return Status::OK();
}
// Propagates the shapes in the transitive fan-out of <new_shapes>.
Status GraphProperties::PropagateShapes(
SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
int num_loops) const {
// Limit the number of iterations to prevent infinite loops in the presence of
// incorrect shape functions. The algorithm should converge in at most
// num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
// The same applies to resources.
VLOG(1) << "Propagating " << new_shapes->size() << " new shapes through "
<< num_loops << " loops and " << resource_handles.size()
<< " resources" << std::endl;
const int64 max_loop_length = item_.graph.node_size();
const int64 max_rank = 4;
const int64 max_loop_iterations =
max_rank * max_loop_length * std::max<int64>(1, num_loops * num_loops);
const int64 num_queues = resource_handles.size();
const int64 max_resource_iterations = num_queues * num_queues * max_rank;
int64 num_resource_iterations = 0;
do {
int64 num_loop_iterations = 0;
while (!new_shapes->empty() &&
num_loop_iterations++ < max_loop_iterations) {
const NodeDef* n = new_shapes->pop();
bool updated = false;
TF_RETURN_IF_ERROR(
UpdateShapes(shape_refiner, resource_handles, n, &updated));
if (updated) {
for (const auto& fanout : shape_refiner->graph().GetFanouts(
*n, /*include_controlled_nodes=*/false)) {
new_shapes->push(fanout.node);
}
// Make sure the corresponding queue nodes are (re)processed.
if (IsEnqueue(*n)) {
auto it = resource_handles.find(n);
if (it != resource_handles.end()) {
new_shapes->push(it->second);
}
}
}
}
} while (!new_shapes->empty() &&
num_resource_iterations++ < max_resource_iterations);
if (!new_shapes->empty()) {
return errors::Internal("Shape inference failed to converge");
}
return Status::OK();
}
Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
SymbolicShapeRefiner* shape_refiner,
bool* new_shapes) {
auto* ctx = shape_refiner->GetNodeContext(queue_node);
if (!ctx) {
TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
}
auto* ic = ctx->inference_context.get();
auto* outputs = ic->output_handle_shapes_and_types(0);
if (outputs) {
// Shapes and types are already set, presumably by Enqueue ops.
return shape_refiner->UpdateNode(queue_node, new_shapes);
}
if (queue_node->attr().count("shapes") <= 0 ||
queue_node->attr().count("component_types") <= 0 ||
queue_node->attr().at("shapes").list().shape_size() !=
queue_node->attr().at("component_types").list().type_size()) {
// Errors in shapes and component_types attr.
return shape_refiner->UpdateNode(queue_node, new_shapes);
}
// Extract types and shapes from Queue attr.
const auto& shapes = queue_node->attr().at("shapes").list().shape();
const auto& types = queue_node->attr().at("component_types").list().type();
std::vector<ShapeAndType> shapes_and_types;
for (int i = 0; i < types.size(); i++) {
const auto& shape = shapes[i];
ShapeHandle shape_handle;
TF_RETURN_IF_ERROR(
ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
DataType data_type =
queue_node->attr().at("component_types").list().type(i);
ShapeAndType shape_and_type(shape_handle, data_type);
shapes_and_types.push_back(shape_and_type);
}
ic->set_output_handle_shapes_and_types(0, shapes_and_types);
// Queue node is updated with output_handle_shapes_and_types, so set
// new_shapes and ignore it from UpdateNoe().
*new_shapes = true;
bool dummy_new_shapes = false;
return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
}
Status GraphProperties::UpdateEnqueue(
const NodeDef* enqueue_node,
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
auto ctx = shape_refiner->GetNodeContext(enqueue_node);
if (!ctx) {
TF_RETURN_IF_ERROR(shape_refiner->AddNode(enqueue_node));
ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(enqueue_node));
}
auto it = resource_handles.find(enqueue_node);
if (it == resource_handles.end()) {
// The corresponding queue was not found, there isn't much we can do.
return Status::OK();
}
const NodeDef* qnode = it->second;
auto qctx = shape_refiner->GetContext(qnode);
if (!qctx) {
return Status::OK();
}
auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
// TODO(bsteiner): handle EnqueueMany as well.
std::vector<ShapeAndType> shapes_and_types;
for (int i = 1; i < ctx->input_types.size(); ++i) {
GraphView::InputPort inp(enqueue_node, i);
GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(inp);
InferenceContext* in = shape_refiner->GetContext(fanin.node);
ShapeHandle input = in->output(fanin.port_id);
ctx->inference_context->SetInput(i, input);
shapes_and_types.push_back({input, ctx->input_types[i]});
}
if (queue_handle_data == nullptr) {
qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
*new_shapes = true;
} else {
TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
shape_refiner, qnode, *queue_handle_data, &shapes_and_types));
*new_shapes |= !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
shapes_and_types);
qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
}
return Status::OK();
}
Status GraphProperties::InferStatically(bool assume_valid_feeds,
bool aggressive_shape_inference,
bool include_input_tensor_values,
bool include_output_tensor_values) {
FunctionLibraryDefinition function_library(OpRegistry::Global(),
item_.graph.library());
std::unordered_map<string, std::unordered_set<int>> fed_ports;
if (!assume_valid_feeds) {
for (const auto& feed : item_.feed) {
SafeTensorId tensor_id = ParseTensorName(feed.first);
fed_ports[tensor_id.node()].insert(tensor_id.index());
}
}
GraphView graph_view(&item_.graph);
// List the resources and the nodes using them. Also collect the Merge nodes,
// fed nodes, and primary inputs.
std::unordered_map<const NodeDef*,
std::pair<std::unordered_set<const NodeDef*>,
std::unordered_set<const NodeDef*>>>
resources;
std::unordered_set<const NodeDef*> merge_nodes;
std::unordered_set<const NodeDef*> fed_nodes;
std::unordered_set<const NodeDef*> primary_inputs;
int num_loops = 0;
for (const NodeDef& node : item_.graph.node()) {
if (IsQueue(node)) {
for (const GraphView::InputPort& fanout :
graph_view.GetFanouts(node, false)) {
if (IsEnter(*fanout.node)) {
const NodeDef& enter = *fanout.node;
for (const GraphView::InputPort& fanout :
graph_view.GetFanouts(enter, false)) {
if (IsEnqueue(*fanout.node)) {
resources[&node].first.insert(fanout.node);
} else if (IsDequeue(*fanout.node)) {
resources[&node].second.insert(fanout.node);
}
}
} else {
if (IsEnqueue(*fanout.node)) {
resources[&node].first.insert(fanout.node);
} else if (IsDequeue(*fanout.node)) {
resources[&node].second.insert(fanout.node);
}
}
}
}
if (NumNonControlInputs(node) == 0) {
primary_inputs.insert(&node);
} else if (IsMerge(node)) {
merge_nodes.insert(&node);
} else if (IsNextIteration(node)) {
++num_loops;
}
if (fed_ports.find(node.name()) != fed_ports.end()) {
fed_nodes.insert(&node);
}
}
std::unordered_map<const NodeDef*, const NodeDef*> resource_handles;
std::vector<TopologicalDependency> extra_deps;
for (const auto& resource : resources) {
for (const NodeDef* src : resource.second.first) {
resource_handles[src] = resource.first;
for (const NodeDef* dst : resource.second.second) {
// Add control edges from enqueue to dequeue nodes to ensure they are
// processed in their logical order.
extra_deps.emplace_back(src, dst);
}
}
}
std::vector<const NodeDef*> topo_order;
Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
if (!s.ok()) {
if (extra_deps.empty()) {
return s;
} else {
// There is a loop between queues: we'll just use the graph topological
// order. This will make the shape inference less precise but since this
// isn't common it's not worth to figure out where to break the loop and
// do a proper relaxation.
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
}
}
// Heap-allocate SymbolicShapeRefiner in order to not consume a large amount
// of stack space.
auto refiner = absl::make_unique<SymbolicShapeRefiner>(
graph_view, fed_ports, aggressive_shape_inference);
TopoQueue new_shapes(topo_order);
// Also seed the propagation of shapes in the fanout of primary inputs.
for (const NodeDef* node : primary_inputs) {
new_shapes.push(node);
}
// Also seed the propagation of shapes in the fanout of fed nodes.
for (const NodeDef* node : fed_nodes) {
new_shapes.push(node);
}
// Propagate shapes normally.
TF_RETURN_IF_ERROR(
PropagateShapes(refiner.get(), &new_shapes, resource_handles, num_loops));
// Track shapes globally across the graph.
std::unique_ptr<SymbolicShapeManager> shape_manager =
absl::make_unique<SymbolicShapeManager>();
bool found_error = false;
for (const NodeDef& node : item_.graph.node()) {
auto node_ctx = refiner->GetContext(&node);
if (!node_ctx) {
continue;
}
// Skip any information that comes from fed nodes.
if (fed_ports.find(node.name()) != fed_ports.end()) {
VLOG(2) << "Skipping feed node shape: " << node.name();
continue;
}
for (const auto& merged_shapes : node_ctx->MergedShapes()) {
if (!shape_manager->Merge(merged_shapes.first, merged_shapes.second)
.ok()) {
found_error = true;
break;
}
}
for (const auto& merged_dims : node_ctx->MergedDims()) {
if (!shape_manager->Merge(merged_dims.first, merged_dims.second).ok()) {
found_error = true;
break;
}
}
if (found_error) {
// The shapes aren't consistent, we can't infer safely: discard all the
// information discovered so far.
shape_manager = absl::make_unique<SymbolicShapeManager>();
break;
}
}
for (const NodeDef& node : item_.graph.node()) {
VLOG(3) << "Filling in graph properties for node: " << node.name();
auto ctx = refiner->GetNodeContext(&node);
if (!ctx) {
continue;
}
auto* ic = ctx->inference_context.get();
// Fill input properties.
{
auto& input_properties = input_properties_[node.name()];
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(input_properties.size(), 0);
input_properties.resize(ic->num_inputs());
GraphView::InputPort input(&node, -1);
for (int i = 0; i < ic->num_inputs(); ++i) {
shape_manager->AsTensorProperties(ic->input(i), ctx->input_types[i],
&input_properties[i]);
input.port_id = i;
GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
if (include_input_tensor_values) {
// Export tensor value to input_properties.value.
if (IsConstant(*fanin.node)) {
const TensorProto& raw_val =
fanin.node->attr().at("value").tensor();
*input_properties[i].mutable_value() = raw_val;
} else if (ctx->input_tensor_protos.size() > i &&
ctx->input_tensor_protos[i] != nullptr) {
*input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
} else if (ic->input_tensors_as_shapes().size() > i &&
IsShapeFullyDefinedIntegerVectorOrScalar(
ic, ic->input(i), ic->input_tensors_as_shapes()[i],
ctx->input_types[i])) {
*input_properties[i].mutable_value() = MakeTensorProtoFromShape(
ic, ic->input(i), ic->input_tensors_as_shapes()[i],
ctx->input_types[i]);
}
}
}
}
// Fill output properties.
{
auto& output_properties = output_properties_[node.name()];
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(output_properties.size(), 0);
output_properties.resize(ic->num_outputs());
for (int i = 0; i < ic->num_outputs(); ++i) {
shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
&output_properties[i]);
if (include_output_tensor_values) {
// Export tensor value to output_properties.value.
if (IsConstant(node)) {
// TODO(rmlarsen): Eliminate this copy.
const TensorProto& raw_val = node.attr().at("value").tensor();
*output_properties[i].mutable_value() = raw_val;
} else if (ctx->output_tensor_protos.size() > i &&
ctx->output_tensor_protos[i] != nullptr) {
*output_properties[i].mutable_value() =
*ctx->output_tensor_protos[i];
} else if (ctx->output_tensors_as_shapes.size() > i &&
IsShapeFullyDefinedIntegerVectorOrScalar(
ic, ic->output(i), ctx->output_tensors_as_shapes[i],
ctx->output_types[i])) {
*output_properties[i].mutable_value() = MakeTensorProtoFromShape(
ic, ic->output(i), ctx->output_tensors_as_shapes[i],
ctx->output_types[i]);
}
}
}
}
if (aggressive_shape_inference && ctx->shape_incompatible)
incompatible_shape_nodes_.insert(node.name());
}
if (aggressive_shape_inference && !incompatible_shape_nodes_.empty())
LOG(WARNING) << incompatible_shape_nodes_.size()
<< " nodes have incompatible output shapes.";
// Help trace the unknown dimensions to their origins.
VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
output_properties_);
return Status::OK();
}
Status GraphProperties::InferDynamically(Cluster* cluster) {
TF_RETURN_IF_ERROR(cluster->Initialize(item_));
// Runs the model once to collect the shapes in the cost model.
RunMetadata metadata;
TF_RETURN_IF_ERROR(
cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
return InferFromCostGraph(metadata.cost_graph());
}
Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) const {
*output_graph_def = item_.graph;
for (int i = 0; i < output_graph_def->node_size(); i++) {
auto node = output_graph_def->mutable_node(i);
AttrValue attr_output_shape;
auto tensor_properties = GetOutputProperties(node->name());
for (const auto& tensor_property : tensor_properties) {
*attr_output_shape.mutable_list()->add_shape() = tensor_property.shape();
}
(*node->mutable_attr())["_output_shapes"] = attr_output_shape;
}
return Status::OK();
}
Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
if (cost_graph.node_size() == 0) {
LOG(WARNING) << "cost_graph is empty: nothing can be inferred!";
}
std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
std::unordered_map<string, const NodeDef*> name_to_node; // Empty
for (auto& node : cost_graph.node()) {
name_to_cost[node.name()] = &node;
std::vector<OpInfo::TensorProperties> output_properties;
for (const auto& out : node.output_info()) {
OpInfo::TensorProperties properties;
properties.set_dtype(out.dtype());
*properties.mutable_shape() = out.shape();
output_properties.push_back(properties);
}
output_properties_[node.name()] = output_properties;
}
for (const auto& node : item_.graph.node()) {
// Skip the nodes that are not in the cost graph: these are nodes that
// aren't run, because they aren't in the intersection of transitive fan-in
// of a fetch node and the transitive fan-out of an input, or nodes that
// were optimized away by the optimizer.
auto it = name_to_cost.find(node.name());
if (it == name_to_cost.end()) {
continue;
}
std::vector<OpInfo::TensorProperties> inputs =
FindInputFeatures(node, name_to_cost, name_to_node);
input_properties_[node.name()] = inputs;
}
return Status::OK();
}
bool GraphProperties::HasInputProperties(const string& node_name) const {
return input_properties_.find(node_name) != input_properties_.end();
}
bool GraphProperties::HasOutputProperties(const string& node_name) const {
return output_properties_.find(node_name) != output_properties_.end();
}
const std::vector<OpInfo::TensorProperties>&
GraphProperties::GetInputProperties(const string& node_name) const {
auto it = input_properties_.find(node_name);
if (it != input_properties_.end()) {
return it->second;
}
return missing_properties_;
}
const std::vector<OpInfo::TensorProperties>&
GraphProperties::GetOutputProperties(const string& node_name) const {
auto it = output_properties_.find(node_name);
if (it != output_properties_.end()) {
return it->second;
}
return missing_properties_;
}
void GraphProperties::ClearInputProperties(const string& node_name) {
input_properties_.erase(node_name);
}
void GraphProperties::ClearOutputProperties(const string& node_name) {
output_properties_.erase(node_name);
}
} // end namespace grappler
} // end namespace tensorflow