blob: 77a3f22efd4b90ea5d21481fcfe065781e5486ea [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include <cmath>
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/setround.h"
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/bcast.h"
#include "tensorflow/core/util/saved_tensor_slice_util.h"
namespace tensorflow {
namespace grappler {
using TensorVector = gtl::InlinedVector<TensorValue, 4>;
// We only fold/materialize constants smaller than 100kB.
const int64_t kMaxConstantSize = 100 * 1024;
namespace {
template <typename T>
bool AllValuesAre(const TensorProto& proto, const T& value) {
Tensor tensor;
if (!tensor.FromProto(proto)) {
return false;
}
auto values = tensor.flat<T>();
for (int i = 0; i < tensor.NumElements(); ++i) {
if (values(i) != value) {
return false;
}
}
return true;
}
// Add new_input as a control input to node if it does not already depend on it.
// TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
// clean up code that should be using them.
bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node,
GraphDef* graph, NodeMap* node_map) {
bool already_exists = false;
for (const string& input : node->input()) {
if (input == ctrl_input || AsControlDependency(input) == ctrl_input) {
already_exists = true;
break;
}
}
if (!already_exists) {
const string ctrl_dep =
ConstantFolding::AddControlDependency(ctrl_input, graph, node_map);
node->add_input(ctrl_dep);
node_map->AddOutput(NodeName(ctrl_input), node->name());
}
return !already_exists;
}
// Remove old_input as a control input to node.
bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
GraphDef* graph, NodeMap* node_map) {
bool removed_input = false;
bool update_node_map = true;
const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input));
for (int i = 0; i < node->input_size(); ++i) {
const string& input = node->input(i);
if (old_input_ctrl_dep == input) {
if (IsControlInput(input)) {
node->mutable_input()->SwapElements(i, node->input_size() - 1);
node->mutable_input()->RemoveLast();
removed_input = true;
} else {
// There is a non-control input from the same node.
// Don't remove the output from the NodeMap.
update_node_map = false;
}
}
}
if (update_node_map) {
node_map->RemoveOutput(NodeName(old_input), node->name());
}
return removed_input;
}
bool HasTPUAttributes(const NodeDef& node) {
AttrSlice attrs(node);
for (const auto& attr : attrs) {
if (attr.first.find("_tpu_") != attr.first.npos) {
return true;
}
}
return false;
}
template <typename T>
bool PackedValuesNotEqual(T a, T b) {
return a != b;
}
template <>
bool PackedValuesNotEqual(float a, float b) {
return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
}
template <>
bool PackedValuesNotEqual(double a, double b) {
return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
}
float QuantizedTypeMinAsFloat(DataType data_type) {
switch (data_type) {
case DT_QINT8:
return Eigen::NumTraits<qint8>::lowest();
case DT_QUINT8:
return Eigen::NumTraits<quint8>::lowest();
case DT_QINT16:
return Eigen::NumTraits<qint16>::lowest();
case DT_QUINT16:
return Eigen::NumTraits<quint16>::lowest();
case DT_QINT32:
return Eigen::NumTraits<qint32>::lowest();
default:
return 0.0f;
}
}
float QuantizedTypeMaxAsFloat(DataType data_type) {
switch (data_type) {
case DT_QINT8:
return Eigen::NumTraits<qint8>::highest();
case DT_QUINT8:
return Eigen::NumTraits<quint8>::highest();
case DT_QINT16:
return Eigen::NumTraits<qint16>::highest();
case DT_QUINT16:
return Eigen::NumTraits<quint16>::highest();
case DT_QINT32:
return Eigen::NumTraits<qint32>::highest();
default:
return 0.0f;
}
}
} // namespace
ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
DeviceBase* cpu_device,
bool disable_compressed_tensor_optimization,
bool fold_quantization_emulation)
: opt_level_(opt_level),
cpu_device_(cpu_device),
disable_compressed_tensor_optimization_(
disable_compressed_tensor_optimization),
fold_quantization_emulation_(fold_quantization_emulation) {
resource_mgr_.reset(new ResourceMgr());
}
ConstantFolding::ConstantFolding(DeviceBase* cpu_device,
bool disable_compressed_tensor_optimization,
bool fold_quantization_ops)
: ConstantFolding(RewriterConfig::ON, cpu_device,
disable_compressed_tensor_optimization,
fold_quantization_ops) {}
// static
string ConstantFolding::AddControlDependency(const string& input_name,
GraphDef* graph,
NodeMap* node_map) {
if (IsControlInput(input_name)) {
return input_name;
}
const NodeDef* node = node_map->GetNode(input_name);
// Sanity check for missing node.
if (!node) {
return input_name;
}
if (!IsSwitch(*node)) {
return AsControlDependency(*node);
} else {
// We can't anchor control dependencies directly on the switch node: unlike
// other nodes only one of the outputs of the switch node will be generated
// when the switch node is executed, and we need to make sure the control
// dependency is only triggered when the corresponding output is triggered.
// We start by looking for an identity node connected to the output of the
// switch node, and use it to anchor the control dependency.
for (const NodeDef* output : node_map->GetOutputs(node->name())) {
if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
if (IsSameInput(node->input(0), input_name)) {
return AsControlDependency(*output);
}
}
}
// We haven't found an existing node where we can anchor the control
// dependency: add a new identity node.
int port = 0;
string ctrl_dep_name = ParseNodeName(input_name, &port);
strings::StrAppend(&ctrl_dep_name, "_", port);
ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
const DataType output_type = node->attr().at("T").type();
NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
if (added_node == nullptr) {
added_node = graph->add_node();
added_node->set_name(ctrl_dep_name);
added_node->set_op("Identity");
added_node->set_device(node->device());
(*added_node->mutable_attr())["T"].set_type(output_type);
*added_node->add_input() = input_name;
node_map->AddNode(added_node->name(), added_node);
node_map->AddOutput(node->name(), added_node->name());
}
return AsControlDependency(*added_node);
}
}
// Forward inputs at the given indices to outputs and add a control dependency
// on node.
bool ConstantFolding::ForwardInputs(NodeDef* node,
absl::Span<const int> inputs_to_forward) {
for (int input_idx : inputs_to_forward) {
if (input_idx < 0 || input_idx >= node->input_size()) {
return false;
}
}
const auto& tmp = node_map_->GetOutputs(node->name());
const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
bool updated_graph = false;
for (int input_idx : inputs_to_forward) {
const string& input = node->input(input_idx);
if (IsControlInput(input) && consumers.size() > 1) {
continue;
}
const NodeDef* input_node = node_map_->GetNode(NodeName(input));
if (input_node == nullptr) {
LOG(ERROR) << "Bad input: " << input;
break;
}
// Update each consumer.
for (NodeDef* consumer : consumers) {
bool add_dep = false;
for (int consumer_input_idx = 0;
consumer_input_idx < consumer->input_size(); ++consumer_input_idx) {
const string& consumer_input = consumer->input(consumer_input_idx);
if (IsControlInput(consumer_input)) {
break;
}
// It is illegal to add control dependencies to _Retval nodes, so we
// can't bypass value producing `node` and forward inputs to `consumer`.
if (IsRetval(*consumer)) {
break;
}
int output_idx;
const string input_node_name =
ParseNodeName(consumer_input, &output_idx);
if (input_node_name == node->name() && output_idx == input_idx) {
consumer->set_input(consumer_input_idx, input);
// We will keep the input from the node through a control
// dependency, so we only need to add the consumer as an output
// for the input node.
node_map_->AddOutput(NodeName(input), consumer->name());
add_dep = true;
}
}
if (add_dep) {
consumer->add_input(AsControlDependency(node->name()));
updated_graph = true;
}
}
}
if (updated_graph) {
for (NodeDef* consumer : consumers) {
DedupControlInputs(consumer);
}
}
return updated_graph;
}
// Puts the given value into the tensor at the given "flat" index.
static Status PutValueIntoTensor(const int64_t value, const DataType& type,
const int index, Tensor* tensor) {
if (type == DT_INT32) {
if (value >= INT_MAX) {
return Status(error::INVALID_ARGUMENT, "int32 overflow");
}
tensor->flat<int32>()(index) = static_cast<int32>(value);
} else {
tensor->flat<int64_t>()(index) = value;
}
return Status::OK();
}
// Writes the given tensor shape into the given tensor.
// Op is assumed to be Shape, ShapeN, Size or Rank.
static Status ConvertShapeToConstant(const string& op, const DataType& type,
const PartialTensorShape& shp,
Tensor* tensor) {
if (op == "Shape" || op == "ShapeN") {
*tensor = Tensor(type, TensorShape({shp.dims()}));
for (int i = 0; i < shp.dims(); ++i) {
TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
}
} else if (op == "Size") {
int64_t size = 1;
for (int i = 0; i < shp.dims(); ++i) {
size *= shp.dim_size(i);
}
*tensor = Tensor(type, TensorShape({}));
TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
} else {
CHECK_EQ(op, "Rank");
*tensor = Tensor(type, TensorShape({}));
TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
}
return Status::OK();
}
// TODO(rmlarsen): Perhaps we should move this to the GraphOptimizer base class.
bool ConstantFolding::OptimizedNodeExists(const NodeDef& node,
StringPiece suffix) const {
return node_map_->NodeExists(OptimizedNodeName(node, suffix));
}
string ConstantFolding::OptimizedNodeName(const NodeDef& node,
StringPiece suffix) const {
return AddPrefixToNodeName(strings::StrCat(node.name(), suffix),
kConstantFoldingConst);
}
bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
if (!IsConstant(node)) {
return false;
}
// If the node is fed it's not constant anymore.
return feed_nodes_.find(node.name()) == feed_nodes_.end();
}
// TODO(rmlarsen): Refactor to shared util.
bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input,
Tensor* tensor) {
const NodeDef* node = node_map_->GetNode(node_name_or_input);
return node != nullptr && IsReallyConstant(*node) &&
CheckAttrExists(*node, "value").ok() &&
tensor->FromProto(node->attr().at("value").tensor());
}
// Materialize the shapes using constants whenever possible.
Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
// We may add some nodes to the graph to encode control dependencies and hold
// the materialized shapes: there is no need to process these added nodes, so
// only iterate over the nodes of the input graph.
const int node_count = graph_->node_size();
for (int node_idx = 0; node_idx < node_count; ++node_idx) {
NodeDef* node = graph_->mutable_node(node_idx);
const string op = node->op();
if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
op != "TensorArraySizeV3") {
continue;
}
const std::vector<OpInfo::TensorProperties>& output =
properties.GetOutputProperties(node->name());
const std::vector<OpInfo::TensorProperties>& input =
properties.GetInputProperties(node->name());
if (input.empty() || output.empty()) {
continue;
}
if (op == "Shape" || op == "Size" || op == "Rank") {
CHECK_EQ(1, output.size());
CHECK_EQ(1, input.size());
const DataType type = output[0].dtype();
CHECK(type == DT_INT32 || type == DT_INT64);
const PartialTensorShape shape(input[0].shape());
if ((op != "Rank" && !shape.IsFullyDefined()) ||
(op == "Rank" && shape.unknown_rank())) {
continue;
}
Tensor constant_value(type);
if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
continue;
}
// TODO(rmlarsen): Remove this workaround for b/150861569
// The bug involves an expression of the form Shape(ExpandDims(x)
// with an incorrectly inferred zero-size first dimension.
if (op == "Shape") {
if (shape.dims() > 0 && shape.dim_size(0) == 0) continue;
}
// Repurpose the existing node to be the constant.
// Device placement is preserved.
graph_modified_ = true;
node->set_op("Const");
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["dtype"].set_type(type);
constant_value.AsProtoTensorContent(
(*node->mutable_attr())["value"].mutable_tensor());
// Turn the data input into a control dependency: this is needed to
// ensure that the constant value will only be run in the
// cases where the shape/rank/size would have been run in
// the original graph.
string ctrl_dep =
AddControlDependency(node->input(0), graph_, node_map_.get());
node_map_->UpdateInput(node->name(), node->input(0), ctrl_dep);
node->set_input(0, ctrl_dep);
// Done with the Shape/Size/Rank node, move to the next node.
continue;
}
if (op == "TensorArraySizeV3") {
const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
if (array->input_size() == 0 ||
(array->attr().count("dynamic_size") != 0 &&
array->attr().at("dynamic_size").b())) {
continue;
}
const NodeDef* array_size =
CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
if (IsReallyConstant(*array_size)) {
// Don't materialize 0 sizes to avoid triggering incorrect static
// checks. A 0 sized array that can't grow isn't useful anyway.
if (array_size->attr().count("value") == 0) {
continue;
}
const TensorProto& raw_val = array_size->attr().at("value").tensor();
if (raw_val.dtype() != DT_INT32) {
continue;
}
Tensor value(raw_val.dtype(), raw_val.tensor_shape());
if (!value.FromProto(raw_val)) {
continue;
}
if (value.flat<int32>()(0) == 0) {
continue;
}
graph_modified_ = true;
node->set_op("Const");
*node->mutable_attr() = array_size->attr();
node->set_input(0, AsControlDependency(NodeName(node->input(0))));
node->set_input(1, AddControlDependency(NodeName(node->input(1)),
graph_, node_map_.get()));
}
continue;
}
// Handle ShapeN materialization case.
// It's possible that not all input tensors have known shapes.
CHECK_EQ(op, "ShapeN");
CHECK_EQ(input.size(), output.size());
const NodeDef* const shape_n_node = node;
for (int port_idx = 0, idx_limit = output.size(); port_idx < idx_limit;
++port_idx) {
const DataType type = output[port_idx].dtype();
CHECK(type == DT_INT32 || type == DT_INT64);
const PartialTensorShape shape(input[port_idx].shape());
if (!shape.IsFullyDefined()) {
continue;
}
Tensor constant_value(type);
auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
if (!status.ok()) {
continue;
}
// We make a copy because we mutate the nodes.
auto fanouts = node_map_->GetOutputs(shape_n_node->name());
// Find all nodes consuming this shape and connect them through the new
// constant node instead.
for (NodeDef* output : fanouts) {
// Track whether there are any direct edges left between shape_n_node
// and this output node after the transformation.
bool direct_edges_exist = false;
for (int k = 0; k < output->input_size(); ++k) {
int port;
const string node_name = ParseNodeName(output->input(k), &port);
if (node_name == shape_n_node->name() && port == port_idx) {
// Create a const node as ShapeN's output if not already.
const string const_name = OptimizedNodeName(
*shape_n_node, strings::StrCat("-matshapes-", port_idx));
if (node_map_->GetNode(const_name) == nullptr) {
NodeDef* added_node = graph_->add_node();
added_node->set_name(const_name);
added_node->set_op("Const");
added_node->set_device(shape_n_node->device());
node_map_->AddNode(added_node->name(), added_node);
(*added_node->mutable_attr())["dtype"].set_type(type);
constant_value.AsProtoTensorContent(
(*added_node->mutable_attr())["value"].mutable_tensor());
// We add a control dependency to the original ShapeN node,
// so that the node will only be run if all inputs of the
// original ShapeN node are run.
string ctrl_dep = AddControlDependency(shape_n_node->name(),
graph_, node_map_.get());
*added_node->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
}
*output->mutable_input(k) = const_name;
node_map_->AddOutput(const_name, output->name());
graph_modified_ = true;
}
if (node_name == shape_n_node->name() && port != port_idx) {
direct_edges_exist = true;
}
}
if (!direct_edges_exist) {
node_map_->RemoveOutput(node->name(), output->name());
}
}
}
}
return Status::OK();
}
namespace {
bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
BCast::Vec* shape, int64_t* min_id) {
if (shape_node.op() == "Shape") {
const std::vector<OpInfo::TensorProperties>& prop1 =
properties.GetInputProperties(shape_node.name());
if (prop1.size() != 1) {
return false;
}
const TensorShapeProto& shp = prop1[0].shape();
if (shp.unknown_rank()) {
return false;
}
for (const auto& dim : shp.dim()) {
shape->push_back(dim.size());
*min_id = std::min<int64_t>(*min_id, dim.size());
}
} else {
if (shape_node.attr().count("value") == 0) {
return false;
}
const TensorProto& raw_val = shape_node.attr().at("value").tensor();
if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
return false;
}
Tensor value(raw_val.dtype(), raw_val.tensor_shape());
if (!value.FromProto(raw_val)) {
return false;
}
for (int j = 0; j < value.NumElements(); ++j) {
if (raw_val.dtype() == DT_INT64) {
shape->push_back(value.vec<int64_t>()(j));
} else {
shape->push_back(value.vec<int>()(j));
}
}
}
return true;
}
} // namespace
Status ConstantFolding::MaterializeBroadcastGradientArgs(
const NodeDef& node, const GraphProperties& properties) {
const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
if (shape_node1 == nullptr ||
(shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
shape_node2 == nullptr ||
(shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
return Status::OK();
}
// Don't optimize this again if it was already optimized and folded.
if (OptimizedNodeExists(node, "-folded-1") ||
OptimizedNodeExists(node, "-folded-2")) {
return Status::OK();
}
int64_t min_id = 0;
BCast::Vec shape1;
if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
return Status::OK();
}
BCast::Vec shape2;
if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
return Status::OK();
}
// A value of -1 means we don't known anything about the dimension. Replace
// the -1 values with unique dimension ids since we don't want two '-1'
// dimensions to be considered equal.
for (auto& id : shape1) {
if (id == -1) {
id = --min_id;
}
}
for (auto& id : shape2) {
if (id == -1) {
id = --min_id;
}
}
// Beware: the reduction dimensions computed by the BCast class are valid iff
// we assume that two distinct symbolic dimensions can't be equal and a
// symbolic dimension can't be equal to 1. This is often but not always true,
// so to make this optimization safe we filter out these cases.
const int common_dims = std::min(shape1.size(), shape2.size());
for (int i = 0; i < common_dims; ++i) {
if (shape1[i] >= 0 && shape2[i] >= 0) {
continue;
}
if (shape1[i] != shape2[i]) {
// We're either dealing with 2 different symbolic dimensions or a symbolic
// and a know dimensions. We can't be sure whether both are equal or not,
// so we can't be sure whether we'll be broadcasting or not.
return Status::OK();
}
}
// These extra dims could be equal to 1, in which case there is no
// broadcasting. It could also be greater than 1, in which case there would
// be broadcasting. Since we don't know, we'll just punt.
for (int i = common_dims, end = shape1.size(); i < end; ++i) {
if (shape1[i] < 0) {
return Status::OK();
}
}
for (int i = common_dims, end = shape2.size(); i < end; ++i) {
if (shape2[i] < 0) {
return Status::OK();
}
}
BCast bcast(shape1, shape2);
if (!bcast.IsValid()) {
return Status::OK();
}
BCast::Vec reduce_dims[2];
reduce_dims[0] = bcast.grad_x_reduce_idx();
reduce_dims[1] = bcast.grad_y_reduce_idx();
TF_RETURN_IF_ERROR(CheckAttrExists(node, "T"));
const DataType type = node.attr().at("T").type();
NodeDef* out[2];
for (int j = 0; j < 2; ++j) {
int reduction_indices = reduce_dims[j].size();
Tensor value(type, TensorShape({reduction_indices}));
for (int i = 0; i < reduction_indices; ++i) {
if (type == DT_INT32) {
value.vec<int32>()(i) = reduce_dims[j][i];
} else {
value.vec<int64_t>()(i) = reduce_dims[j][i];
}
}
string const_name =
OptimizedNodeName(node, strings::StrCat("-bcastargs-", j));
out[j] = node_map_->GetNode(const_name);
if (out[j] == nullptr) {
out[j] = graph_->add_node();
TF_RETURN_IF_ERROR(
CreateNodeDef(const_name, TensorValue(&value), out[j]));
out[j]->set_device(node.device());
node_map_->AddNode(const_name, out[j]);
string ctrl_dep =
AddControlDependency(node.name(), graph_, node_map_.get());
*out[j]->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
}
}
// We make a copy here since we might mutate the set.
const auto outputs = node_map_->GetOutputs(node.name());
for (NodeDef* output : outputs) {
for (int k = 0; k < output->input_size(); ++k) {
int port;
string node_name = ParseNodeName(output->input(k), &port);
if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
*output->mutable_input(k) = out[port]->name();
node_map_->UpdateInput(output->name(), node_name, out[port]->name());
}
}
}
return Status::OK();
}
Status ConstantFolding::MaterializeReductionIndices(
NodeDef* node, const GraphProperties& properties) {
if (node->input_size() < 2) {
return Status::OK();
}
const NodeDef* indices = node_map_->GetNode(node->input(1));
if (!indices || IsReallyConstant(*indices)) {
// The reduction indices are already constant, there's nothing to do.
return Status::OK();
}
const std::vector<OpInfo::TensorProperties>& input_props =
properties.GetInputProperties(node->name());
if (input_props.size() != 2) {
return Status::OK();
}
const OpInfo::TensorProperties& input_prop = input_props[0];
if (input_prop.shape().unknown_rank()) {
// We can't do anything if we don't know the rank of the input.
return Status::OK();
}
const int input_rank = input_prop.shape().dim_size();
if (input_rank < 1) {
// Unexpected graph, don't try to change it.
return Status::OK();
}
const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
DataType dtype = reduction_indices_prop.dtype();
if (dtype != DT_INT32 && dtype != DT_INT64) {
return Status::OK();
}
PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
const int num_reduction_indices = reduction_indices_shape.num_elements();
const std::vector<OpInfo::TensorProperties>& output_props =
properties.GetOutputProperties(node->name());
if (output_props.size() != 1) {
return Status::OK();
}
const OpInfo::TensorProperties& output_prop = output_props[0];
const int output_rank =
output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
if (!full_reduction) {
// A full reduction will generate a tensor of one of the shapes
// [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
// elements in the output of the reduction, we may deduce it from reshape
// nodes following it.
for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
full_reduction = false;
if (!IsReshape(*fanout)) {
return Status::OK();
}
const std::vector<OpInfo::TensorProperties>& reshape_props =
properties.GetOutputProperties(fanout->name());
if (reshape_props.size() != 1) {
return Status::OK();
}
const OpInfo::TensorProperties& reshape_prop = reshape_props[0];
PartialTensorShape shape(reshape_prop.shape());
if (shape.num_elements() != 1) {
return Status::OK();
} else {
full_reduction = true;
}
}
if (!full_reduction) {
return Status::OK();
}
}
// We know it's a full reduction. We can generate the full set of indices to
// reduce as a constant node.
string const_name = OptimizedNodeName(*node, "-reduction_indices");
if (node_map_->GetNode(const_name)) {
return Status::OK();
}
NodeDef* reduction_indices = graph_->add_node();
Tensor value(dtype, TensorShape({input_rank}));
for (int i = 0; i < input_rank; ++i) {
if (dtype == DT_INT32) {
value.vec<int32>()(i) = i;
} else {
value.vec<int64_t>()(i) = i;
}
}
TF_RETURN_IF_ERROR(
CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
reduction_indices->set_device(node->device());
string ctrl_dep =
AddControlDependency(node->input(1), graph_, node_map_.get());
*reduction_indices->add_input() = ctrl_dep;
node_map_->AddNode(const_name, reduction_indices);
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
node->set_input(1, reduction_indices->name());
node_map_->UpdateInput(node->name(), indices->name(),
reduction_indices->name());
return Status::OK();
}
Status ConstantFolding::MaterializeConstantValuedNode(
NodeDef* node, const GraphProperties& properties) {
if (disable_compressed_tensor_optimization_) {
return Status::OK();
}
// Nodes that generate constant-valued outputs can be represented compactly in
// compressed format, regardless of their shape.
const std::vector<OpInfo::TensorProperties>& output_props =
properties.GetOutputProperties(node->name());
if (output_props.size() != 1) return Status::OK();
const auto& output_shape = output_props[0].shape();
if (!PartialTensorShape(output_shape).IsFullyDefined()) {
return Status::OK();
}
if (IsFill(*node)) {
const auto output_dtype = output_props[0].dtype();
NodeDef* input_node = nullptr;
for (int i = 0; i < 2; ++i) {
input_node = node_map_->GetNode(NodeName(node->input(i)));
if (input_node == nullptr || !IsReallyConstant(*input_node)) {
return Status::OK();
}
}
TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
// Copy the input tensor to the fill node, set the output shape and data
// type, and change the node type to Const.
TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
const TensorProto& input_tensor = input_node->attr().at("value").tensor();
if (!input_tensor.tensor_content().empty()) {
// Convert the value to repeated field format, so we can use the
// decompression mechanism to store only a single value in the constant
// node, even if the shape specified in the original Fill is large.
Tensor t;
if (!t.FromProto(input_tensor)) {
return errors::InvalidArgument(
"Could not construct Tensor form TensorProto in node: ",
input_node->name());
}
tensor->clear_tensor_content();
t.AsProtoField(tensor);
} else {
*tensor = input_tensor;
}
*(tensor->mutable_tensor_shape()) = output_shape;
(*node->mutable_attr())["dtype"].set_type(output_dtype);
node->mutable_attr()->erase("T");
node->mutable_attr()->erase("index_type");
node->set_op("Const");
for (int i = 0; i < 2; i++) {
// Change inputs to a control inputs.
const string ctrl_dep = AsControlDependency(node->input(i));
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep);
}
graph_modified_ = true;
} else {
double value =
(IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
if (value >= 0) {
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
value, properties, output_shape, node, graph_));
}
}
return Status::OK();
}
// Materialize output values inferred by the shape inference.
Status ConstantFolding::MaterializeOutputValues(
NodeDef* node, const GraphProperties& properties) {
const std::vector<OpInfo::TensorProperties>& output =
properties.GetOutputProperties(node->name());
if (output.size() != 1 || !output[0].has_value() ||
!IsFoldable(*node, &properties)) {
return Status::OK();
}
// If this is a trivial Identity node with a constant input, just route the
// input around it.
if (IsIdentity(*node)) {
NodeDef* input = node_map_->GetNode(node->input(0));
if (IsReallyConstant(*input)) {
std::vector<int> inputs_to_forward;
std::iota(inputs_to_forward.begin(), inputs_to_forward.end(), 0);
graph_modified_ = ForwardInputs(node, inputs_to_forward);
return Status::OK();
}
}
// Repurpose the existing node to be the constant.
// Device placement is preserved.
TensorProto value_copy = output[0].value();
return ReplaceOperationWithConstantTensor(output[0].dtype(), &value_copy,
node, graph_);
}
Status ConstantFolding::MaterializeConstants(
const GraphProperties& properties) {
const int node_count = graph_->node_size();
for (int i = 0; i < node_count; ++i) {
NodeDef& node = *graph_->mutable_node(i);
const string& op = node.op();
if (op == "BroadcastGradientArgs") {
TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
} else if (IsReduction(node)) {
TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
} else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) {
TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties));
} else {
TF_RETURN_IF_ERROR(MaterializeOutputValues(&node, properties));
}
}
return Status::OK();
}
bool ConstantFolding::IsFoldable(const NodeDef& node,
const GraphProperties* properties) {
string key = strings::StrCat(node.name(), "/", node.op());
auto it = maybe_foldable_nodes_.find(key);
if (it == maybe_foldable_nodes_.end()) {
it = maybe_foldable_nodes_
.emplace(std::move(key), MaybeFoldable(node, properties))
.first;
}
if (!it->second) {
return false;
} else {
return IsFoldableUncached(node, properties);
}
}
bool ConstantFolding::IsFoldableUncached(
const NodeDef& node, const GraphProperties* properties) const {
// Folding not applicable to ops with no inputs.
if (node.input().empty()) {
return false;
}
// We can only fold nodes if all their inputs are known statically, except in
// the case of a merge node that propagate the first inputs that becomes
// available, and therefore only requires a single constant input to be
// foldable.
bool merge_has_constant_input = false;
const bool is_merge = IsMerge(node);
for (const auto& input : node.input()) {
if (IsControlInput(input)) {
continue;
}
const NodeDef* input_node = node_map_->GetNode(input);
if (!input_node) {
return false;
}
bool is_const = IsReallyConstant(*input_node);
if (is_const) {
// Don't fold strings constants for now since this causes problems with
// checkpointing.
if (input_node->attr().count("dtype") == 0 ||
input_node->attr().at("dtype").type() == DT_STRING) {
return false;
}
// Special case: If a Merge node has at least one constant input that
// does not depend on a control input, we can fold it.
merge_has_constant_input |= !HasControlInputs(*input_node);
} else if (!is_merge) {
return false;
}
}
if (is_merge && !merge_has_constant_input) return false;
if (disable_compressed_tensor_optimization_ &&
(IsFill(node) || IsZerosLike(node) || IsOnesLike(node)))
return false;
// If we know the output shapes, make sure that the outputs are small enough
// to materialize.
if (properties != nullptr && properties->HasOutputProperties(node.name())) {
const std::vector<OpInfo::TensorProperties>& input_props =
properties->GetInputProperties(node.name());
const std::vector<OpInfo::TensorProperties>& output_props =
properties->GetOutputProperties(node.name());
// Compute total size of inputs.
int64_t input_size_bytes = 0;
for (const auto& input_prop : input_props) {
const PartialTensorShape input_shape(input_prop.shape());
if (input_shape.IsFullyDefined()) {
input_size_bytes +=
input_shape.num_elements() * DataTypeSize(input_prop.dtype());
}
}
for (const auto& output_prop : output_props) {
PartialTensorShape output_shape;
if (!PartialTensorShape::BuildPartialTensorShape(output_prop.shape(),
&output_shape)
.ok()) {
return false;
}
if (output_shape.IsFullyDefined()) {
const int64_t num_bytes =
output_shape.num_elements() * DataTypeSize(output_prop.dtype());
if (num_bytes > input_size_bytes && num_bytes > kMaxConstantSize) {
// Do not fold nodes if the in-memory size of output is too large.
// Notice that this is not exactly the same check used in
// CreateNodeDef() where the actual encoded size is checked.
return false;
}
}
}
}
return true;
}
bool ConstantFolding::MaybeFoldable(const NodeDef& node,
const GraphProperties* properties) const {
// Skip constants, they're already folded
if (IsConstant(node)) {
return false;
}
// Don't fold stateful ops such as TruncatedNormal.
if (!IsFreeOfSideEffect(node)) {
return false;
}
// Skips nodes that must be preserved except allowlisted nodes.
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
return false;
}
// Skip control flow nodes, they can't be folded.
if (ModifiesFrameInfo(node)) {
return false;
}
// Skips ops that don't benefit from folding.
if (IsPlaceholder(node)) {
return false;
}
// `FakeParam` op is used as a placeholder in If branch function. It doesn't
// have a valid output when executed.
if (IsFakeParam(node)) {
return false;
}
if (node.op() == "AccumulateNV2") {
return false;
}
// Removing LoopCond nodes can screw up the partitioner.
if (node.op() == "LoopCond") {
return false;
}
if (!fold_quantization_emulation_ && IsQuantizationEmulation(node)) {
return false;
}
const string& op = node.op();
if (op.find("Save") != string::npos || op.find("Restore") != string::npos ||
op.find("Reader") != string::npos) {
return false;
}
if (op.find("Quantized") != string::npos || absl::StartsWith(op, "Sparse")) {
return false;
}
// Don't fold nodes that contain TPU attributes.
// TODO(rmlarsen): We should be able to fold many of these nodes as long as we
// properly forward custom attributes, b/119051778.
if (HasTPUAttributes(node)) {
return false;
}
const OpDef* op_def = nullptr;
Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
if (!status.ok()) {
return false;
}
// Don't fold ops without outputs.
if (op_def->output_arg_size() == 0) {
return false;
}
// Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
// TODO(rmlarsen): Only do this for XLA_* devices.
for (const OpDef::ArgDef& output_arg : op_def->output_arg()) {
if (output_arg.type() == DT_VARIANT) {
return false;
}
}
// Don't fold nodes that have no outgoing edges except allowlisted nodes.
// Such nodes could be introduced by an earlier constant folding pass and are
// preserved in case users want to fetch their values; re-processing them
// would lead to an error of adding a duplicated node to graph.
const auto& outputs = node_map_->GetOutputs(node.name());
if (outputs.empty() &&
nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
return false;
}
return true;
}
namespace {
#define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME) \
case DTYPE: \
t->add_##NAME##_val(static_cast<TYPE>(value)); \
break;
Status CreateConstantTensorAttrValue(DataType type, double value,
const TensorShapeProto& shape,
AttrValue* attr_tensor) {
TensorProto* t = attr_tensor->mutable_tensor();
t->set_dtype(type);
*t->mutable_tensor_shape() = shape;
switch (type) {
case DT_HALF:
t->add_half_val(
Eigen::numext::bit_cast<uint16>(static_cast<Eigen::half>(value)));
break;
case DT_BFLOAT16:
t->add_half_val(
Eigen::numext::bit_cast<uint16>(static_cast<bfloat16>(value)));
break;
SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
SET_TENSOR_VAL_CASE(DT_INT64, int64_t, int64);
SET_TENSOR_VAL_CASE(DT_UINT64, int64_t, int64);
SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
SET_TENSOR_VAL_CASE(DT_QINT32, int32, int);
SET_TENSOR_VAL_CASE(DT_QINT16, int32, int);
SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int);
SET_TENSOR_VAL_CASE(DT_QINT8, int32, int);
SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int);
SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
default:
return errors::InvalidArgument(
"Unsupported type in CreateConstantTensorAttrValue: ",
DataTypeString(type));
}
return Status::OK();
}
#undef SET_TENSOR_CAL_CASE
DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
const GraphProperties& properties) {
DataType dtype = DT_INVALID;
if (node.attr().count("T") == 1) {
dtype = node.attr().at("T").type();
} else if (node.attr().count("dtype") == 1) {
dtype = node.attr().at("dtype").type();
} else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
dtype = DT_BOOL;
} else {
auto output_props = properties.GetOutputProperties(node.name());
if (!output_props.empty()) {
dtype = output_props[0].dtype();
}
}
return dtype;
}
// Checks whether the shape of the const input of the Mul op is valid to perform
// the MulConvPushDown optimization.
bool IsValidConstShapeForMulConvPushDown(
const string& data_format, const TensorShapeProto& filter_shape,
const TensorShapeProto& mul_const_input_shape) {
// If the const is a scalar, or it has fewer or same number of dimensions
// than the filter and it only has single element, the optimization should
// work.
if (mul_const_input_shape.dim_size() <=
static_cast<int>(data_format.size()) &&
TensorShape(mul_const_input_shape).num_elements() == 1) {
return true;
}
// Otherwise, check the eligibility according to data format.
if (data_format == "NHWC" || data_format == "NDHWC") {
TensorShapeProto new_filter_shape;
if (!ShapeAfterBroadcast(filter_shape, mul_const_input_shape,
&new_filter_shape)) {
return false;
}
if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
return false;
}
// Only the last dimension could be larger than one, since broadcasting over
// the last dimension (the output channel) will result in invalid filter.
for (int i = 0; i < mul_const_input_shape.dim_size() - 1; ++i) {
if (mul_const_input_shape.dim(i).size() > 1) return false;
}
return true;
} else if (data_format == "NCHW" || data_format == "NCDHW") {
// TODO(laigd): support NCHW and NCDHW (b/111214513).
return false;
}
return false;
}
} // namespace
// static
Status ConstantFolding::CreateNodeDef(const string& name,
const TensorValue& tensor, NodeDef* node,
size_t original_size) {
node->set_name(name);
node->set_op("Const");
AttrValue attr_type;
attr_type.set_type(tensor->dtype());
node->mutable_attr()->insert({"dtype", attr_type});
AttrValue attr_tensor;
TensorProto* t = attr_tensor.mutable_tensor();
bool optimized = false;
size_t encoded_size;
// Use the packed representation whenever possible to avoid generating large
// graphdefs. Moreover, avoid repeating the last values if they're equal.
if (tensor->NumElements() > 4) {
#define POPULATE_TENSOR_PROTO(tensor, t, TYPE, FIELDTYPE) \
{ \
const auto* val_ptr = tensor->flat<TYPE>().data(); \
auto last = *val_ptr; \
int64_t last_index = 0; \
for (int64_t i = 0; i < tensor->NumElements(); ++i) { \
TYPE cur = *val_ptr++; \
if (PackedValuesNotEqual(cur, last)) { \
last = cur; \
last_index = i; \
} \
} \
encoded_size = (last_index + 1) * sizeof(FIELDTYPE); \
if (encoded_size < kint32max) { \
optimized = true; \
t->mutable_##FIELDTYPE##_val()->Reserve(last_index + 1); \
const auto* src_ptr = tensor->flat<TYPE>().data(); \
auto* dst_ptr = \
t->mutable_##FIELDTYPE##_val()->AddNAlreadyReserved(last_index + 1); \
std::copy(src_ptr, src_ptr + last_index + 1, dst_ptr); \
} \
} \
break
switch (tensor->dtype()) {
case DT_FLOAT:
POPULATE_TENSOR_PROTO(tensor, t, float, float);
case DT_DOUBLE:
POPULATE_TENSOR_PROTO(tensor, t, double, double);
case DT_INT64:
POPULATE_TENSOR_PROTO(tensor, t, int64_t, int64);
case DT_UINT64:
POPULATE_TENSOR_PROTO(tensor, t, uint64, uint64);
case DT_INT32:
POPULATE_TENSOR_PROTO(tensor, t, int32_t, int);
case DT_UINT32:
POPULATE_TENSOR_PROTO(tensor, t, uint32, uint32);
case DT_INT16:
POPULATE_TENSOR_PROTO(tensor, t, int16_t, int);
case DT_UINT16:
POPULATE_TENSOR_PROTO(tensor, t, uint16, int);
case DT_INT8:
POPULATE_TENSOR_PROTO(tensor, t, int8_t, int);
case DT_UINT8:
POPULATE_TENSOR_PROTO(tensor, t, uint8, int);
case DT_BOOL:
POPULATE_TENSOR_PROTO(tensor, t, bool, bool);
default:
/* Do nothing. */
break;
}
}
if (optimized) {
// Also specify type and shape.
t->set_dtype(tensor->dtype());
tensor->shape().AsProto(t->mutable_tensor_shape());
} else {
// DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8,
// DT_QUINT8
tensor->AsProtoTensorContent(t);
encoded_size = t->tensor_content().size();
}
node->mutable_attr()->insert({"value", attr_tensor});
if (encoded_size > original_size && encoded_size >= kMaxConstantSize) {
return errors::InvalidArgument(
strings::StrCat("Can't fold ", name, ", its size would be too large (",
encoded_size, " >= ", kMaxConstantSize, " bytes)"));
}
return Status::OK();
}
Status ConstantFolding::EvaluateNode(const NodeDef& node,
const TensorVector& inputs,
TensorVector* output) const {
return ::tensorflow::grappler::EvaluateNode(node, inputs, cpu_device_,
resource_mgr_.get(), output);
}
Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
std::vector<NodeDef>* outputs,
bool* result_too_large) {
TensorVector inputs;
TensorVector output_tensors;
auto inputs_cleanup = gtl::MakeCleanup([&inputs, &output_tensors] {
for (const auto& input : inputs) {
delete input.tensor;
}
for (const auto& output : output_tensors) {
if (output.tensor) {
delete output.tensor;
}
}
});
size_t total_inputs_size = 0;
for (const auto& input : node.input()) {
const TensorId input_tensor = ParseTensorName(input);
if (input_tensor.index() < 0) {
// Control dependency
break;
}
const NodeDef* input_node = node_map_->GetNode(input);
if (!IsReallyConstant(*input_node)) {
return Status(error::INVALID_ARGUMENT,
strings::StrCat("Can't fold ", node.name(), ", its ", input,
" isn't constant"));
}
TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
const TensorProto& raw_val = input_node->attr().at("value").tensor();
if (raw_val.dtype() == DT_INVALID) {
return Status(
error::INVALID_ARGUMENT,
strings::StrCat("A tensor in the input node, with TensorId of ",
input_tensor.ToString(),
" has a dtype of DT_INVALID."));
}
if (IsRefType(raw_val.dtype())) {
return errors::InvalidArgument(
"Not allowed to construct a tensor with reference dtype, got ",
DataTypeString(raw_val.dtype()));
}
Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
if (!value->FromProto(raw_val)) {
delete (value);
return errors::InvalidArgument("Unable to make Tensor from proto for ",
node.name(), " with shape ",
raw_val.tensor_shape().DebugString());
}
inputs.emplace_back(value);
total_inputs_size += value->TotalBytes();
}
TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
if (output_tensors.empty()) {
return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
}
outputs->resize(output_tensors.size());
for (size_t i = 0; i < output_tensors.size(); i++) {
string node_name = OptimizedNodeName(node, "-folded");
if (output_tensors.size() > 1) {
node_name = strings::StrCat(node_name, "-", i);
}
if (output_tensors[i].tensor) {
Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i),
total_inputs_size);
if (!s.ok()) {
*result_too_large = true;
return s;
}
} else {
// Create an empty NodeDef to identify dead outputs (e.g. the output of a
// switch that's not selected by the switch predicate).
outputs->at(i) = NodeDef();
}
}
return Status::OK();
}
Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
// Merge nodes are special, in the sense that they execute as soon as one of
// their input is ready. We can therefore fold a merge node iff it has at
// least one constant input without control dependency.
// We still need to ensure that the nodes in the fanin of the merge node are
// scheduled. We'll therefore add a control dependency from the merge node
// to the folded constant. We end up with:
// * the merge node and its inputs are preserved as is
// * a new constant node C1, driven by the merge node through a control
// dependency, initialized to the value of the folded input
// * a new constant node C2, driven by the merge node through a control
// dependency, initialized to the index of the folded input
// * the fanout of the merge nodes is rewired to be driven by either C1 or
// C2.
for (int input_index = 0; input_index < node->input_size(); ++input_index) {
const auto& input = node->input(input_index);
if (IsControlInput(input)) {
// Try the next input.
continue;
}
NodeDef* input_node = node_map_->GetNode(input);
if (!IsReallyConstant(*input_node)) {
continue;
}
bool valid_input = true;
for (const string& fanin_of_input : input_node->input()) {
if (IsControlInput(fanin_of_input)) {
valid_input = false;
break;
}
}
if (!valid_input) {
// Try the next input
continue;
}
string const_out_name = OptimizedNodeName(*node, "_const");
string const_index_name = OptimizedNodeName(*node, "_index");
if (node_map_->GetNode(const_out_name) ||
node_map_->GetNode(const_index_name)) {
// Intended name already exists.
return errors::AlreadyExists(
strings::StrCat(const_out_name, " or ", const_index_name,
" already present in the graph"));
}
NodeDef* const_out = output_graph->add_node();
*const_out = *input_node;
const_out->set_name(const_out_name);
const_out->set_device(node->device());
*const_out->add_input() = AsControlDependency(*node);
node_map_->AddNode(const_out->name(), const_out);
node_map_->AddOutput(node->name(), const_out->name());
NodeDef* const_index = output_graph->add_node();
const_index->set_op("Const");
Tensor index(DT_INT32, TensorShape({}));
index.flat<int32>()(0) = input_index;
(*const_index->mutable_attr())["dtype"].set_type(DT_INT32);
index.AsProtoTensorContent(
(*const_index->mutable_attr())["value"].mutable_tensor());
const_index->set_name(const_index_name);
const_index->set_device(node->device());
*const_index->add_input() = AsControlDependency(*node);
node_map_->AddNode(const_index->name(), const_index);
node_map_->AddOutput(node->name(), const_index->name());
// We make a copy because we mutate the nodes.
auto outputs = node_map_->GetOutputs(node->name());
for (NodeDef* output : outputs) {
for (int i = 0; i < output->input_size(); i++) {
int port;
string node_name = ParseNodeName(output->input(i), &port);
if (node_name == node->name()) {
if (port == 0) {
*output->mutable_input(i) = const_out->name();
node_map_->AddOutput(const_out->name(), output->name());
} else if (port == 1) {
*output->mutable_input(i) = const_index->name();
node_map_->AddOutput(const_index->name(), output->name());
} else {
// This is a control dependency (or an invalid edge since the
// merge node has only 2 outputs): preserve them.
}
}
}
}
return Status::OK();
}
return Status::OK();
}
Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph,
bool* result_too_large) {
*result_too_large = false;
if (IsMerge(*node)) {
return FoldMergeNode(node, output_graph);
}
std::vector<NodeDef> const_nodes;
TF_RETURN_IF_ERROR(
EvaluateOneFoldable(*node, &const_nodes, result_too_large));
VLOG(2) << "Folded node: " << SummarizeNodeDef(*node);
NodeDef* constant_output = nullptr;
for (int i = 0, end = const_nodes.size(); i < end; i++) {
NodeDef* const_node = &const_nodes[i];
VLOG(3) << "Generated constant node: " << SummarizeNodeDef(*const_node);
if (const_node->name().empty()) {
// Dead output: we can't create a constant to encode its value, so we'll
// just skip it. We'll preserve the edges that originate from that
// output below to preserve the overall behavior of the graph wrt dead
// edges.
continue;
}
// Returns `true` iff `const_node` already has control input named `input`.
const auto is_duplicate_control_input = [&](const string& input) -> bool {
auto it = absl::c_find(const_node->input(), input);
return it != const_node->input().end();
};
// Forward control dependencies.
for (const string& input : node->input()) {
// Forward control dependencies from folded node.
if (IsControlInput(input)) {
if (!is_duplicate_control_input(input)) {
*const_node->add_input() = input;
}
}
// Forward control dependencies from constant inputs to folded node.
if (!IsControlInput(input)) {
NodeDef* input_node = node_map_->GetNode(input);
for (const string& fanin_of_input : input_node->input()) {
if (!is_duplicate_control_input(fanin_of_input)) {
*const_node->add_input() = fanin_of_input;
}
}
}
}
// We rewrite the existing node if it only has a single output, and
// create new nodes otherwise.
if (const_nodes.size() == 1) {
node->set_op("Const");
// Note we need to clear the inputs in NodeMap before we clear the inputs
// in the node, otherwise NodeMap would see empty inputs and effectively
// does nothing.
node_map_->RemoveInputs(node->name());
node->clear_input();
*node->mutable_input() = const_node->input();
for (const auto& input : node->input()) {
node_map_->AddOutput(NodeName(input), node->name());
}
*node->mutable_attr() = const_node->attr();
break;
} else {
if (node_map_->GetNode(const_node->name())) {
// Intended name already exists.
return errors::AlreadyExists(strings::StrCat(
const_node->name(), " already present in the graph"));
}
NodeDef* added_node = output_graph->add_node();
*added_node = *const_node;
added_node->set_device(node->device());
node_map_->AddNode(added_node->name(), added_node);
for (const auto& input : added_node->input()) {
node_map_->AddOutput(NodeName(input), added_node->name());
}
// All the constant nodes encoding output values have the same control
// dependencies (since these are the control dependencies of the node
// we're trying to fold). Record one such constant node.
constant_output = added_node;
}
}
if (const_nodes.size() > 1) {
// We make a copy because we mutate the nodes.
auto outputs = node_map_->GetOutputs(node->name());
for (NodeDef* output : outputs) {
for (int i = 0; i < output->input_size(); i++) {
int port;
string node_name = ParseNodeName(output->input(i), &port);
if (node_name == node->name()) {
if (port < 0) {
// Propagate control dependencies if possible. If not, we'll just
// preserve the existing control dependencies.
if (constant_output != nullptr) {
node_map_->UpdateInput(node_name, NodeName(output->input(i)),
constant_output->name());
*output->mutable_input(i) = AsControlDependency(*constant_output);
}
} else if (port < static_cast<int>(const_nodes.size()) &&
!const_nodes[port].name().empty()) {
// Replace alive outputs with the corresponding constant.
node_map_->UpdateInput(output->name(), NodeName(output->input(i)),
const_nodes[port].name());
*output->mutable_input(i) = const_nodes[port].name();
} else {
// Leave this edge alone.
VLOG(3) << "Preserving edge from " << node->name() << ":" << port
<< "[" << node->op() << "] to " << output->name() << ":"
<< i << "[" << output->op() << "]";
}
}
}
}
outputs = node_map_->GetOutputs(node->name());
if (outputs.empty() && has_fetch_ &&
nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) {
node_map_->RemoveInputs(node->name());
node->clear_input();
}
}
return Status::OK();
}
Status ConstantFolding::FoldGraph(
const GraphProperties& properties, GraphDef* optimized_graph,
absl::flat_hash_set<string>* nodes_to_not_simplify) {
// We build a new optimized_graph by inserting the folded nodes into it, then
// copy other nodes that might be needed at the end of this function.
absl::flat_hash_set<string> processed_nodes;
std::deque<NodeDef*> queue;
for (int i = 0; i < graph_->node_size(); i++) {
const NodeDef& node = graph_->node(i);
if (IsFoldable(node, &properties) &&
!nodes_to_not_simplify->count(node.name())) {
queue.push_back(graph_->mutable_node(i));
}
}
while (!queue.empty()) {
NodeDef* node = queue.front();
queue.pop_front();
if (processed_nodes.count(node->name())) {
continue;
}
// We need to record a copy of output nodes before FoldNode() modifies it.
// We also need to ensure that the fanout is sorted deterministically.
std::vector<NodeDef*> fanout =
node_map_->GetOutputsOrderedByNodeName(node->name());
bool result_too_large = false;
Status s = FoldNode(node, optimized_graph, &result_too_large);
processed_nodes.insert(node->name());
if (!s.ok()) {
VLOG(1) << "Failed to fold node " << node->DebugString()
<< "\nError message: " << s;
if (result_too_large) {
nodes_to_not_simplify->emplace(node->name());
}
} else {
for (auto& fanout_node : fanout) {
if (IsFoldable(*fanout_node, &properties) &&
!nodes_to_not_simplify->count(fanout_node->name())) {
queue.push_back(fanout_node);
}
}
}
}
// Delete the newly created nodes that don't feed anything.
std::vector<int> nodes_to_delete;
for (int i = 0; i < optimized_graph->node_size(); i++) {
const auto& fanout = node_map_->GetOutputs(optimized_graph->node(i).name());
if (fanout.empty()) nodes_to_delete.push_back(i);
}
EraseNodesFromGraph(std::move(nodes_to_delete), optimized_graph);
for (int i = 0; i < graph_->node_size(); ++i) {
NodeDef* node = graph_->mutable_node(i);
// If no fetch nodes is provided, we conservatively
// move all nodes in the original graph to the output, in case users need
// to fetch their values.
const auto& fanout = node_map_->GetOutputs(node->name());
if (!fanout.empty() || !has_fetch_ ||
nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end()) {
*(optimized_graph->add_node()) = std::move(*node);
}
}
return Status::OK();
}
Status ConstantFolding::IsSimplifiableReshape(
const NodeDef& node, const GraphProperties& properties) const {
if (!IsReshape(node)) {
return errors::Internal("Node ", node.name(), " is not a Reshape node");
}
if (2 > node.input_size()) {
return errors::Internal("Node ", node.name(),
" must have at most 2 inputs but has ",
node.input_size());
}
const NodeDef* new_shape = node_map_->GetNode(node.input(1));
if (!IsReallyConstant(*new_shape)) {
return errors::Internal("Node ", node.name(), " has shape ",
new_shape->DebugString(),
" which is not a constant");
}
TensorVector outputs;
auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
for (const auto& output : outputs) {
delete output.tensor;
}
});
Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
if (!s.ok()) {
return errors::Internal("Could not evaluate node ", node.name());
}
if (outputs.size() != 1) {
return errors::Internal("Node ", node.name(),
" must have exactly 1 output but has ",
outputs.size());
}
const std::vector<OpInfo::TensorProperties>& props =
properties.GetInputProperties(node.name());
if (props.empty()) {
return errors::Internal("Node ", node.name(), " has no properties");
}
const OpInfo::TensorProperties& prop = props[0];
if (prop.dtype() == DT_INVALID) {
return errors::Internal("Node ", node.name(), " has property ",
prop.DebugString(), " with invalid dtype");
}
const PartialTensorShape shape(prop.shape());
if (!shape.IsFullyDefined()) {
return errors::Internal("Node ", node.name(), " has property ",
prop.DebugString(), " with shape ",
shape.DebugString(), " which is not fully defined");
}
PartialTensorShape new_dims;
if (outputs[0]->dtype() == DT_INT32) {
std::vector<int32> shp;
for (int i = 0; i < outputs[0]->NumElements(); ++i) {
int32_t dim = outputs[0]->flat<int32>()(i);
shp.push_back(dim);
}
s = TensorShapeUtils::MakeShape(shp, &new_dims);
if (!s.ok()) return s;
} else {
std::vector<int64_t> shp;
for (int i = 0; i < outputs[0]->NumElements(); ++i) {
int64_t dim = outputs[0]->flat<int64_t>()(i);
shp.push_back(dim);
}
s = TensorShapeUtils::MakeShape(shp, &new_dims);
if (!s.ok()) return s;
}
if (!shape.IsCompatibleWith(new_dims)) {
return errors::Internal("Expected shape ", shape.DebugString(),
"to be compatible with ", new_dims.DebugString());
}
return Status::OK();
}
#define IS_VALUE_CASE(DTYPE, VALUE) \
case DTYPE: \
return AllValuesAre<EnumToDataType<DTYPE>::Type>( \
node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE))
#define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1)
#define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0)
bool ConstantFolding::IsOnes(const NodeDef& node) const {
if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
return false;
}
if (IsOnesLike(node)) return true;
if (IsZerosLike(node)) return false;
if (node.op() == "Fill") {
NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
return values != nullptr && IsOnes(*values);
}
if (node.op() != "Const") return false;
if (node.attr().count("dtype") == 0) return false;
const auto dtype = node.attr().at("dtype").type();
switch (dtype) {
IS_ONES_CASE(DT_BOOL);
IS_ONES_CASE(DT_HALF);
IS_ONES_CASE(DT_BFLOAT16);
IS_ONES_CASE(DT_FLOAT);
IS_ONES_CASE(DT_DOUBLE);
IS_ONES_CASE(DT_COMPLEX64);
IS_ONES_CASE(DT_COMPLEX128);
IS_ONES_CASE(DT_UINT8);
IS_ONES_CASE(DT_INT8);
IS_ONES_CASE(DT_UINT16);
IS_ONES_CASE(DT_INT16);
IS_ONES_CASE(DT_INT32);
IS_ONES_CASE(DT_INT64);
IS_ONES_CASE(DT_QINT32);
IS_ONES_CASE(DT_QINT16);
IS_ONES_CASE(DT_QUINT16);
IS_ONES_CASE(DT_QINT8);
IS_ONES_CASE(DT_QUINT8);
default:
VLOG(1) << "Unsupported type " << DataTypeString(dtype);
return false;
}
return false;
}
bool ConstantFolding::IsZeros(const NodeDef& node) const {
if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
return false;
}
if (IsOnesLike(node)) return false;
if (IsZerosLike(node)) return true;
if (node.op() == "Fill") {
NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
return values != nullptr && IsZeros(*values);
}
if (!IsConstant(node)) return false;
if (node.attr().count("dtype") == 0) return false;
const auto dtype = node.attr().at("dtype").type();
switch (dtype) {
IS_ZEROS_CASE(DT_BOOL);
IS_ZEROS_CASE(DT_HALF);
IS_ZEROS_CASE(DT_BFLOAT16);
IS_ZEROS_CASE(DT_FLOAT);
IS_ZEROS_CASE(DT_DOUBLE);
IS_ZEROS_CASE(DT_COMPLEX64);
IS_ZEROS_CASE(DT_COMPLEX128);
IS_ZEROS_CASE(DT_UINT8);
IS_ZEROS_CASE(DT_INT8);
IS_ZEROS_CASE(DT_UINT16);
IS_ZEROS_CASE(DT_INT16);
IS_ZEROS_CASE(DT_INT32);
IS_ZEROS_CASE(DT_INT64);
IS_ZEROS_CASE(DT_QINT32);
IS_ZEROS_CASE(DT_QINT16);
IS_ZEROS_CASE(DT_QUINT16);
IS_ZEROS_CASE(DT_QINT8);
IS_ZEROS_CASE(DT_QUINT8);
default:
VLOG(1) << "Unsupported type " << DataTypeString(dtype);
return false;
}
return false;
}
bool ConstantFolding::ReplaceOperationWithBroadcastTo(
int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) {
return false;
}
const PartialTensorShape shape(
properties.GetOutputProperties(node->name())[0].shape());
if (!shape.IsFullyDefined()) {
return false;
}
// Create constant node with shape.
const string const_name = OptimizedNodeName(
*node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
if (node_map_->GetNode(const_name) != nullptr) {
return false;
}
Tensor shape_t;
if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) {
return false;
}
NodeDef tmp;
if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) {
return false;
}
NodeDef* const_node = graph->add_node();
const_node->Swap(&tmp);
const_node->set_device(node->device());
node_map_->AddNode(const_name, const_node);
for (int i = 0; i < node->input_size(); ++i) {
if (i != input_to_broadcast) {
// Add a control input on the unused input.
string ctrl_dep = AddControlDependency(NodeName(node->input(i)), graph,
node_map_.get());
*const_node->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
}
}
// Rewrite `node` in-place to BroadcastTo.
node->set_op("BroadcastTo");
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["T"].set_type(dtype);
(*node->mutable_attr())["Tidx"].set_type(DT_INT32);
// Set the designated input to BroadcastTo.
node->mutable_input()->SwapElements(0, input_to_broadcast);
// Keep all other inputs as control dependencies.
for (int i = 1; i < node->input_size(); ++i) {
if (IsControlInput(node->input(i))) {
break;
}
const string ctrl_dep =
AddControlDependency(node->input(i), graph, node_map_.get());
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep);
}
// Add the shape argument.
*node->add_input() = const_node->name();
node_map_->AddOutput(const_name, node->name());
node->mutable_input()->SwapElements(1, node->input_size() - 1);
return true;
}
// Replace an operation with Identity.
void ConstantFolding::ReplaceOperationWithIdentity(
int input_to_forward, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) {
if (input_to_forward < 0 || input_to_forward >= node->input_size()) return;
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return;
node->set_op("Identity");
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["T"].set_type(dtype);
// Propagate the designated input through the identity.
node->mutable_input()->SwapElements(0, input_to_forward);
// Add all other inputs as control dependencies.
for (int i = 1; i < node->input_size(); ++i) {
if (IsControlInput(node->input(i))) {
break;
}
const string ctrl_dep =
AddControlDependency(node->input(i), graph, node_map_.get());
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep);
}
graph_modified_ = true;
}
void ConstantFolding::ReplaceOperationWithSnapshot(
int input_to_forward, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) {
// If the graph contains no ops that mutate their inputs, we can
// use Identity instead of Snapshot.
if (!graph_contains_assign_or_inplace_op_) {
ReplaceOperationWithIdentity(input_to_forward, properties, node, graph);
return;
}
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return;
node->set_op("Snapshot");
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["T"].set_type(dtype);
// Propagate the designated input through the Snapshot.
node->mutable_input()->SwapElements(0, input_to_forward);
// Add all other inputs as control dependencies.
for (int i = 1; i < node->input_size(); ++i) {
if (IsControlInput(node->input(i))) {
break;
}
const string ctrl_dep =
AddControlDependency(node->input(i), graph, node_map_.get());
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep);
}
graph_modified_ = true;
}
// Replace a node with NoOp. Change all inputs to control dependencies.
// If the node has non-control outputs, no change will be performed.
void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node,
GraphProperties* properties,
GraphDef* graph) {
if (HasRegularOutputs(*node, *node_map_)) return;
node->set_op("NoOp");
EraseRegularNodeAttributes(node);
EraseNodeOutputAttributes(node);
// Erase attributes that describe output properties.
properties->ClearOutputProperties(node->name());
// Change all inputs to control dependencies.
for (int i = 0; i < node->input_size(); ++i) {
if (IsControlInput(node->input(i))) {
break;
}
const string ctrl_dep =
AddControlDependency(node->input(i), graph, node_map_.get());
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep);
}
DedupControlInputs(node);
graph_modified_ = true;
}
void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) {
if (!ReplaceOperationWithBroadcastTo(input_to_broadcast, properties, node,
graph)) {
return;
}
graph_modified_ = true;
}
void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
GraphDef* graph) {
node->set_op("Reciprocal");
node->mutable_input()->SwapElements(0, 1);
const string ctrl_dep =
AddControlDependency(node->input(1), graph, node_map_.get());
node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
node->set_input(1, ctrl_dep);
graph_modified_ = true;
}
void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
GraphDef* graph) {
node->set_op("Neg");
node->mutable_input()->SwapElements(0, 1);
const string ctrl_dep =
AddControlDependency(node->input(1), graph, node_map_.get());
node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
node->set_input(1, ctrl_dep);
graph_modified_ = true;
}
Status ConstantFolding::ReplaceOperationWithConstantTensor(DataType dtype,
TensorProto* value,
NodeDef* node,
GraphDef* graph) {
if (dtype == DT_VARIANT) return Status::OK();
node->set_op("Const");
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["dtype"].set_type(dtype);
(*node->mutable_attr())["value"].mutable_tensor()->Swap(value);
// Convert all inputs to control dependencies.
for (int i = 0; i < node->input_size(); ++i) {
if (IsControlInput(node->input(i))) {
break;
}
const string ctrl_dep =
AddControlDependency(node->input(i), graph, node_map_.get());
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep);
}
DedupControlInputs(node);
graph_modified_ = true;
return Status::OK();
}
Status ConstantFolding::ReplaceOperationWithConstant(
double value, const GraphProperties& properties,
const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_VARIANT) return Status::OK();
AttrValue tensor_attr;
Status s = CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr);
if (!s.ok()) {
// Fail gracefully without mutating the graph.
VLOG(1) << "Failed to replace node " << node->name() << " of type "
<< DataTypeString(dtype) << " with constant tensor of value "
<< value;
return Status::OK();
}
return ReplaceOperationWithConstantTensor(dtype, tensor_attr.mutable_tensor(),
node, graph);
}
Status ConstantFolding::SimplifyGraph(
GraphDef* optimized_graph, GraphProperties* properties,
absl::flat_hash_set<string>* nodes_to_not_simplify) {
for (int i = 0; i < optimized_graph->node_size(); ++i) {
NodeDef* node = optimized_graph->mutable_node(i);
// TODO(lyandy): Move nodes to not simplify check into SimplifyNode and
// generalize to only restrict certain simplifications.
if (nodes_to_not_simplify->find(node->name()) ==
nodes_to_not_simplify->end()) {
if (HasTPUAttributes(*node)) {
nodes_to_not_simplify->insert(node->name());
continue;
}
TF_RETURN_IF_ERROR(SimplifyNode(node, optimized_graph, properties));
}
}
return Status::OK();
}
#define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
TF_RETURN_IF_ERROR(EXPR); \
if (graph_modified_) return Status::OK()
#define SET_AND_RETURN_IF_MODIFIED(EXPR) \
graph_modified_ = EXPR; \
if (graph_modified_) return Status::OK()
#define RETURN_IF_MODIFIED(EXPR) \
EXPR; \
if (graph_modified_) return Status::OK()
Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
GraphProperties* properties) {
bool graph_modified_cached = graph_modified_;
graph_modified_ = false;
bool use_shape_info = properties->has_properties();
RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
*properties, use_shape_info, optimized_graph, node));
RETURN_IF_MODIFIED(
RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
RemoveReverse(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
SimplifySlice(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
SimplifyTile(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
SimplifyPad(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_MODIFIED(
SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
SimplifyReduction(optimized_graph, *properties, node));
SET_AND_RETURN_IF_MODIFIED(
SimplifyReshape(*properties, use_shape_info, node));
RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
*properties, use_shape_info, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
ConstantPushDown(properties, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
MulConvPushDown(optimized_graph, node, *properties));
SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
SET_AND_RETURN_IF_MODIFIED(
PartialAssocOpConstFolding(optimized_graph, properties, node));
SET_AND_RETURN_IF_MODIFIED(
MergeConcat(use_shape_info, properties, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
PartialConcatConstFolding(optimized_graph, properties, node));
SET_AND_RETURN_IF_MODIFIED(
ConstantPushDownBiasAdd(properties, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
SimplifySelect(*properties, optimized_graph, node));
RETURN_IF_MODIFIED(
RemoveRedundantVariableUpdates(properties, optimized_graph, node));
graph_modified_ = graph_modified_cached;
return Status::OK();
}
void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph,
NodeDef* node) {
if (node->attr().count("num_split") == 0) return;
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
}
if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
Status ConstantFolding::RemoveShuffleOrTranspose(
const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) {
if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node)))
return Status::OK();
Tensor permutation_tensor;
if (GetTensorFromConstNode(node->input(1), &permutation_tensor) &&
properties.HasInputProperties(node->name())) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
std::vector<int> permutation;
for (int j = 0; j < permutation_tensor.NumElements(); ++j) {
if (permutation_tensor.dtype() == DT_INT64) {
permutation.push_back(permutation_tensor.vec<int64_t>()(j));
} else {
permutation.push_back(permutation_tensor.vec<int>()(j));
}
}
int permutation_size = permutation.size();
if (permutation_size != shape.dim_size()) {
// Number of elements in perm should be same as dim_size. Skip if not.
return Status::OK();
}
// The node is replaceable iff
// dim_size == 0 || all dims have size 1 ||
// all dims with > 1 size are not permuted.
bool replaceable = true;
for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
}
void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,
NodeDef* node) {
if (use_shape_info && IsRandomShuffle(*node) &&
!properties.GetInputProperties(node->name()).empty()) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
// unknown_rank == false && (dim_size == 0 || first dim is of size 1)
if (!shape.unknown_rank() &&
(shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
}
Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,
NodeDef* node) {
if (!use_shape_info || node->op() != "ReverseV2") return Status::OK();
Tensor axis;
if (properties.HasInputProperties(node->name()) &&
GetTensorFromConstNode(node->input(1), &axis)) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
if (shape.unknown_rank()) return Status::OK();
std::set<int> target_axes;
for (int j = 0; j < axis.NumElements(); ++j) {
// value of axis can be negative.
if (axis.dtype() == DT_INT64) {
target_axes.insert((axis.vec<int64_t>()(j) + shape.dim_size()) %
shape.dim_size());
} else {
target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
shape.dim_size());
}
}
// The node is replaceable iff
// unknown_rank == false &&
// (dim_size == 0 || all dims have size 1 ||
// all dims with > 1 size are not in target_axes)
bool replaceable = true;
for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
replaceable &=
shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end();
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
}
Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,
NodeDef* node) {
if (!use_shape_info || !IsSlice(*node)) return Status::OK();
Tensor begin;
Tensor size;
if (properties.HasInputProperties(node->name()) &&
GetTensorFromConstNode(node->input(1), &begin) &&
GetTensorFromConstNode(node->input(2), &size)) {
const auto& input = properties.GetInputProperties(node->name())[0];
// The node is replaceable iff unknown_rank == false &&
// begin == 0 && (size == -1 || size == input_shape) for all dimensions
bool replaceable = !input.shape().unknown_rank();
for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
if (begin.dtype() == DT_INT32) {
replaceable &= begin.vec<int>()(j) == 0;
} else {
replaceable &= begin.vec<int64_t>()(j) == 0;
}
if (size.dtype() == DT_INT32) {
replaceable &= (size.vec<int>()(j) == -1 ||
size.vec<int>()(j) == input.shape().dim(j).size());
} else {
replaceable &= (size.vec<int64_t>()(j) == -1 ||
size.vec<int64_t>()(j) == input.shape().dim(j).size());
}
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
}
Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,
NodeDef* node) {
if (use_shape_info && IsStridedSlice(*node) &&
properties.GetInputProperties(node->name()).size() == 4) {
TF_RETURN_IF_ERROR(
CheckAttrsExist(*node, {"new_axis_mask", "shrink_axis_mask"}));
if (node->attr().at("new_axis_mask").i() != 0 ||
node->attr().at("shrink_axis_mask").i() != 0) {
// Skip nodes with new/shrink axis mask, since they involve dimension
// changes.
return Status::OK();
}
const auto& input = properties.GetInputProperties(node->name())[0];
for (int j = 0; j < input.shape().dim_size(); ++j) {
// Skip if input shape is not fully determined.
if (input.shape().dim(j).size() < 0) {
return Status::OK();
}
}
std::vector<Tensor> input_tensors(3);
for (int i = 1; i < 4; ++i) {
if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) {
return Status::OK();
}
}
const Tensor& begin = input_tensors[0];
const Tensor& end = input_tensors[1];
const Tensor& strides = input_tensors[2];
TF_RETURN_IF_ERROR(
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
int begin_mask = node->attr().at("begin_mask").i();
int end_mask = node->attr().at("end_mask").i();
std::set<int> expanded_ellipsis_indices;
int ellipsis_index = -1;
for (int j = 0; j < input.shape().dim_size(); ++j) {
// find the ellipsis_mask. If not found, insert one in the end if
// necessary.
if (node->attr().at("ellipsis_mask").i() & 1 << j ||
(ellipsis_index == -1 && j >= strides.NumElements())) {
ellipsis_index = j;
}
// insert the indices that are immediately after ellipsis_index if
// necessary.
if (ellipsis_index != -1 &&
input.shape().dim_size() >
strides.NumElements() + j - ellipsis_index) {
expanded_ellipsis_indices.insert(j);
}
}
// The node is replaceable iff unknown_rank == false &&
// ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
// && strides == 1) for all dimensions.
bool replaceable = !input.shape().unknown_rank();
for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
if (expanded_ellipsis_indices.find(j) !=
expanded_ellipsis_indices.end()) {
// ellipsis_mask is effective on current dimension.
continue;
}
// when we have ellipsis_mask in between, input.shape().dim_size() will
// be greater than strides.NumElements(), since we will insert
// as many as expanded_ellipsis_indices.size() axes during computation.
// We need to subtract this number from j.
int i = j;
int expanded_ellipsis_indices_size = expanded_ellipsis_indices.size();
if (ellipsis_index != -1 &&
j >= ellipsis_index + expanded_ellipsis_indices_size) {
i = j - expanded_ellipsis_indices_size;
}
int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
: begin.vec<int64_t>()(i);
int e =
end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64_t>()(i);
int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
: strides.vec<int64_t>()(i);
replaceable &= (begin_mask & 1 << i || b == 0) &&
(end_mask & 1 << i || e == input.shape().dim(j).size()) &&
s == 1;
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
}
Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) {
Tensor multiplies;
if (use_shape_info && IsTile(*node) &&
GetTensorFromConstNode(node->input(1), &multiplies)) {
// The node is replaceable iff all values in multiplies are 1.
bool replaceable = true;
if (multiplies.dtype() == DT_INT32) {
for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
replaceable &= multiplies.vec<int>()(j) == 1;
}
} else {
for (int j = 0; replaceable && j < multiplies.vec<int64_t>().size();
++j) {
replaceable &= multiplies.vec<int64_t>()(j) == 1;
}
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
}
Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) {
if (!use_shape_info || !IsPad(*node)) return Status::OK();
Tensor paddings;
if (GetTensorFromConstNode(node->input(1), &paddings)) {
// The node is replaceable iff all values in paddings are 0.
bool replaceable = true;
if (paddings.dtype() == DT_INT32) {
const auto flatten = paddings.flat<int32>();
for (int j = 0; replaceable && j < flatten.size(); ++j) {
replaceable &= flatten(j) == 0;
}
} else {
const auto flatten = paddings.flat<int64_t>();
for (int j = 0; replaceable && j < flatten.size(); ++j) {
replaceable &= flatten(j) == 0;
}
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
}
void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,
NodeDef* node) {
if (use_shape_info && IsSqueeze(*node) &&
!properties.GetInputProperties(node->name()).empty()) {
// https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
// error to squeeze a dimension that is not 1, so we only need to check
// whether the input has > 1 size for each dimension.
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
// The node is replaceable iff
// unknown_rank == false && (dim_size == 0 || all dims have size > 1)
bool replaceable = !shape.unknown_rank();
for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
replaceable &= shape.dim(j).size() > 1;
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
}
bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
const string axis_node_name = OptimizedNodeName(*node, "_const_axis");
if (!IsPack(*node) || NumNonControlInputs(*node) != 1 ||
node_map_->NodeExists(axis_node_name)) {
return false;
}
// It's unsafe to add a control dependency on the feed node, because it might
// have been never executed otherwiwise.
if (feed_nodes_.find(NodeName(node->input(0))) != feed_nodes_.end()) {
return false;
}
// Create constant axis node.
Tensor axis_t(DT_INT32, TensorShape({}));
const int axis =
node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
NodeDef new_node;
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
!CreateNodeDef(axis_node_name, TensorValue(&axis_t), &new_node).ok()) {
return false;
}
NodeDef* axis_node = optimized_graph->add_node();
*axis_node = std::move(new_node);
axis_node->set_name(axis_node_name);
node_map_->AddNode(axis_node->name(), axis_node);
// Add a control dependency to make sure axis_node is in the right frame.
const string ctrl_dep = ConstantFolding::AddControlDependency(
node->input(0), optimized_graph, node_map_.get());
axis_node->add_input(ctrl_dep);
axis_node->set_device(node->device());
node_map_->AddOutput(NodeName(node->input(0)), axis_node->name());
node->set_op("ExpandDims");
if (node->attr().count("axis") != 0) {
node->mutable_attr()->erase("axis");
}
if (node->attr().count("N") != 0) {
node->mutable_attr()->erase("N");
}
(*node->mutable_attr())["Tdim"].set_type(DT_INT32);
node->add_input(axis_node->name());
node_map_->AddOutput(axis_node->name(), node->name());
if (node->input_size() > 2) {
node->mutable_input()->SwapElements(1, node->input_size() - 1);
}
return true;
}
bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
if (node->op() != "Case") return false;
const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
if (output_idx_node == nullptr ||
!CheckAttrExists(*output_idx_node, "value").ok()) {
return false;
}
Tensor output_idx_t;
if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
return false;
int output_idx = output_idx_t.scalar<int>()();
const auto& func_list = node->attr().at("branches").list();
if (output_idx < 0 || output_idx >= func_list.func_size()) return false;
NodeDef call_node = *node;
call_node.set_op("PartitionedCall");
call_node.clear_input();
for (int i = 1; i < node->input_size(); ++i) {
call_node.add_input(node->input(i));
}
auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
*new_func = func_list.func(output_idx);
// Move the output shape of the branch to _output_shapes if it is known.
const auto& output_shape_list =
(*node->mutable_attr())["output_shapes"].list();
if (output_shape_list.shape_size() > output_idx) {
TensorShapeProto* new_output_shape =
(*call_node.mutable_attr())["_output_shapes"]
.mutable_list()
->add_shape();
*new_output_shape =
std::move(node->attr().at("output_shapes").list().shape(output_idx));
}
call_node.mutable_attr()->erase("output_shapes");
call_node.mutable_attr()->erase("branches");
*node = std::move(call_node);
return true;
}
bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node) {
if (!IsSelect(*node)) return false;
const std::vector<OpInfo::TensorProperties>& input_props =
properties.GetInputProperties(node->name());
if (input_props.size() < 3) return false;
const NodeDef* predicate_node = node_map_->GetNode(node->input(0));
const bool is_all_true = IsOnes(*predicate_node);
const bool is_all_false = IsZeros(*predicate_node);
if (!is_all_true && !is_all_false) {
return false;
}
const int live_input_idx = is_all_true ? 1 : 2;
const int ignored_input_idx = is_all_true ? 2 : 1;
const TensorShapeProto& predicate_shape = input_props[0].shape();
const bool predicate_is_scalar =
!predicate_shape.unknown_rank() && predicate_shape.dim_size() == 0;
if (ShapesSymbolicallyEqual(input_props[1], input_props[2]) &&
(ShapesSymbolicallyEqual(input_props[0], input_props[1]) ||
predicate_is_scalar)) {
// Replace node with Identity if no broadcasting is involved.
node->set_op("Identity");
*node->mutable_input(0) =
AddControlDependency(node->input(0), optimized_graph, node_map_.get());
*node->mutable_input(ignored_input_idx) = AddControlDependency(
node->input(ignored_input_idx), optimized_graph, node_map_.get());
node->mutable_input()->SwapElements(0, live_input_idx);
} else if (!ReplaceOperationWithBroadcastTo(live_input_idx, properties, node,
optimized_graph)) {
return false;
}
DedupControlInputs(node);
return true;
}
void ConstantFolding::RemoveRedundantVariableUpdates(
GraphProperties* properties, GraphDef* optimized_graph, NodeDef* node) {
static const absl::flat_hash_set<string>* kVariableReadOps =
new absl::flat_hash_set<string>{"AssignAddVariableOp",
"AssignSubVariableOp",
"AssignAdd",
"AssignSub",
"ScatterAdd",
"ScatterSub",
"ScatterMul",
"ScatterDiv",
"ScatterNdAdd",
"ScatterNdSub",
"ScatterNdMul",
"ScatterNdDiv",
"ResourceScatterAdd",
"ResourceScatterSub",
"ResourceScatterMul",
"ResourceScatterDiv",
"ResourceScatterNdAdd",
"ResourceScatterNdSub",
"ResourceScatterNdMul",
"ResourceScatterNdDiv"};
if (kVariableReadOps == nullptr ||
kVariableReadOps->find(node->op()) == kVariableReadOps->end())
return;
const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1;
const NodeDef* delta_node = node_map_->GetNode(node->input(value_index));
if (delta_node == nullptr) return;
const bool is_add_or_sub = absl::StrContains(node->op(), "Add") ||
absl::StrContains(node->op(), "Sub");
if ((is_add_or_sub && IsZeros(*delta_node)) ||
(!is_add_or_sub && IsOnes(*delta_node))) {
VLOG(1) << "Removing redundant variable update: " << node->DebugString();
if (absl::StrContains(node->op(), "Variable") ||
absl::StrContains(node->op(), "Resource")) {
ReplaceOperationWithNoOp(node, properties, optimized_graph);
} else {
ReplaceOperationWithIdentity(0 /* input_to_forward */, *properties, node,
optimized_graph);
}
}
}
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
NodeDef* node) {
if (!IsEnter(*node) || node->input_size() == 0 ||
node->attr().count("is_constant") == 0 ||
!node->attr().at("is_constant").b()) {
return false;
}
const string& node_name = node->name();
const NodeDef* input = node_map_->GetNode(node->input(0));
if (input == nullptr || !IsReallyConstant(*input) ||
OptimizedNodeExists(*input, "_enter")) {
return false;
}
// Find non-constant nodes that consume the output of *node.
std::vector<NodeDef*> consumers;
for (const NodeDef* fanout : node_map_->GetOutputs(node_name)) {
if (!IsConstant(*fanout)) {
for (int i = 0; i < fanout->input_size(); ++i) {
if (fanout->input(i) == node_name) {
consumers.push_back(const_cast<NodeDef*>(fanout));
break;
}
}
}
}
if (consumers.empty()) {
return false;
}
graph_modified_ = true;
NodeDef* new_node = optimized_graph->add_node();
*new_node = *input;
new_node->set_name(OptimizedNodeName(*input, "_enter"));
new_node->set_device(node->device());
new_node->clear_input();
new_node->add_input(AsControlDependency(node_name));
node_map_->AddNode(new_node->name(), new_node);
node_map_->AddOutput(node_name, new_node->name());
for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) {
if (NodeName(consumer->input(i)) == node_name) {
node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
consumer->set_input(i, new_node->name());
}
}
}
return true;
}
bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
if (node->op() == "Switch" && node->input(0) == node->input(1) &&
!OptimizedNodeExists(*node, "_const_false") &&
!OptimizedNodeExists(*node, "_const_true")) {
bool already_optimized = true;
// If the optimization was already applied, the switch would have exactly
// one Identity node consuming each of its outputs, each without any
// non-control outputs.
const auto& fanouts = node_map_->GetOutputs(node->name());
if (fanouts.size() == 2) {
for (const NodeDef* fanout : fanouts) {
if ((!IsIdentity(*fanout) && !IsIdentityNSingleInput(*fanout)) ||
HasRegularOutputs(*fanout, *node_map_)) {
already_optimized = false;
break;
}
}
}
Tensor false_t(DT_BOOL, TensorShape({}));
Tensor true_t(DT_BOOL, TensorShape({}));
// Make sure we don't proceed if this switch node was already optimized.
if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
SetTensorValue(DT_BOOL, false, &false_t).ok()) {
// Copy the set of consumers of the switch as they will be manipulated
// below.
std::vector<NodeDef*> consumers =
node_map_->GetOutputsOrderedByNodeName(node->name());
// Create constant false & true nodes.
NodeDef tmp_false_node;
tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
&tmp_false_node)
.ok()) {
return false;
}
tmp_false_node.set_device(node->device());
NodeDef tmp_true_node;
tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
&tmp_true_node)
.ok()) {
return false;
}
tmp_true_node.set_device(node->device());
// Add const nodes to graph.
NodeDef* false_node = optimized_graph->add_node();
false_node->Swap(&tmp_false_node);
NodeDef* true_node = optimized_graph->add_node();
true_node->Swap(&tmp_true_node);
// Add controls from the switch ports to the constants, and connect the
// constants to the original switch outputs.
const string false_port = node->name();
const string true_port = strings::StrCat(node->name(), ":1");
const string false_ctrl_dep =
AddControlDependency(false_port, optimized_graph, node_map_.get());
false_node->add_input(false_ctrl_dep);
const string true_ctrl_dep =
AddControlDependency(true_port, optimized_graph, node_map_.get());
true_node->add_input(true_ctrl_dep);
node_map_->AddNode(false_node->name(), false_node);
node_map_->AddNode(true_node->name(), true_node);
node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) {
const string& input = consumer->input(i);
if (input == false_port) {
consumer->set_input(i, false_node->name());
node_map_->UpdateInput(consumer->name(), false_port,
false_node->name());
} else if (input == true_port) {
consumer->set_input(i, true_node->name());
node_map_->UpdateInput(consumer->name(), true_port,
true_node->name());
}
}
}
return true;
}
}
return false;
}
bool ConstantFolding::IsReductionWithConstantIndices(
const NodeDef& node, bool* indices_is_empty) const {
// Ensure its an appropriate Reduce node.
if (!IsReduction(node) || node.input_size() < 2) {
return false;
}
// Ensure that the axes to reduce by are constant.
NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
if (!IsReallyConstant(*reductions_indices) ||
!reductions_indices->attr().count("value")) {
return false;
}
const TensorShapeProto& reduction_indices_shape =
reductions_indices->attr().at("value").tensor().tensor_shape();
*indices_is_empty = TensorShape(reduction_indices_shape).num_elements() == 0;
return true;
}
bool ConstantFolding::IsReductionCandidateForSimplification(
const NodeDef& node, const GraphProperties& properties,
TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
bool* is_single_element_op) const {
// Get the properties of the input & output tensors and check if they both
// contain a single element.
if (!properties.HasInputProperties(node.name()) ||
!properties.HasOutputProperties(node.name())) {
return false;
}
const auto& input_props = properties.GetInputProperties(node.name())[0];
const auto& output_props = properties.GetOutputProperties(node.name())[0];
if (!input_props.has_shape() || input_props.shape().unknown_rank() ||
!output_props.has_shape() || output_props.shape().unknown_rank()) {
return false;
}
*input_tensor_shape = input_props.shape();
*output_tensor_shape = output_props.shape();
for (int i = 0; i < input_tensor_shape->dim_size(); ++i) {
if (input_tensor_shape->dim(i).size() < 0) {
return false;
}
}
for (int i = 0; i < output_tensor_shape->dim_size(); ++i) {
if (output_tensor_shape->dim(i).size() < 0) {
return false;
}
}
const int input_num_elements =
TensorShape(*input_tensor_shape).num_elements();
const int output_num_elements =
TensorShape(*output_tensor_shape).num_elements();
*is_single_element_op = input_num_elements == 1 && output_num_elements == 1;
return true;
}
bool ConstantFolding::IsReductionSimplifiableToIdentity(
const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
const TensorVector& reduction_indices_vector) const {
int output_size = reduction_indices_vector[0]->NumElements();
if (output_size == 0) {
return true;
}
if (!keep_dims) {
return false;
}
bool simplifiable = true;
for (int i = 0; i < output_size; ++i) {
int64_t dim;
if (reduction_indices_vector[0]->dtype() == DT_INT32) {
dim = reduction_indices_vector[0]->flat<int32>()(i);
} else {
dim = reduction_indices_vector[0]->flat<int64_t>()(i);
}
if (dim < 0) {
dim += input_shape.dim_size();
}
if (dim < 0 || dim >= input_shape.dim_size() ||
input_shape.dim(dim).size() != 1) {
simplifiable = false;
break;
}
}
return simplifiable;
}
bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
// Replace the reduction node with an identity node, that can be further
// optimized by other passes.
DataType output_type;
if (node->attr().count("T") != 0) {
output_type = node->attr().at("T").type();
} else if (IsAny(*node) || IsAll(*node)) {
output_type = DT_BOOL;
} else {
return false;
}
node->set_op("Identity");
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["T"].set_type(output_type);
*node->mutable_input(1) = AsControlDependency(node->input(1));
return true;
}
bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
const GraphProperties& properties,
NodeDef* node) {
bool indices_is_empty = false;
if (!IsReductionWithConstantIndices(*node, &indices_is_empty)) {
return false;
}
if (indices_is_empty) {
return ReplaceReductionWithIdentity(node);
}
bool is_single_element_op = false;
TensorShapeProto input_tensor_shape, output_tensor_shape;
if (!IsReductionCandidateForSimplification(
*node, properties, &input_tensor_shape, &output_tensor_shape,
&is_single_element_op)) {
return false;
}
// Get the reduction indices.
string reduction_indices_input = node->input(1);
NodeDef* reduction_indices = node_map_->GetNode(reduction_indices_input);
TensorVector reduction_indices_vector;
auto outputs_cleanup = gtl::MakeCleanup([&reduction_indices_vector] {
for (const auto& out : reduction_indices_vector) {
delete out.tensor;
}
});
if (!EvaluateNode(*reduction_indices, TensorVector(),
&reduction_indices_vector)
.ok() ||
reduction_indices_vector.size() != 1) {
return false;
}
bool keep_dims =
node->attr().count("keep_dims") > 0 && node->attr().at("keep_dims").b();
bool simplifiable_to_reshape =
is_single_element_op && !keep_dims && (node->attr().count("T") > 0);
bool simplifiable_to_identity = IsReductionSimplifiableToIdentity(
*node, input_tensor_shape, keep_dims, reduction_indices_vector);
if (simplifiable_to_reshape) {
// Const node to output shape.
const int new_num_dimensions = output_tensor_shape.dim_size();
Tensor tensor(DT_INT32, TensorShape({new_num_dimensions}));
for (int i = 0; i < new_num_dimensions; i++) {
tensor.flat<int>()(i) = 1;
}
TensorValue shape_value(&tensor);
NodeDef* shape_node = optimized_graph->add_node();
if (!CreateNodeDef(OptimizedNodeName(*node, "_shape_const"), shape_value,
shape_node)
.ok()) {
return false;
}
shape_node->set_device(node->device());
node_map_->AddNode(shape_node->name(), shape_node);
// Control dependency to ensure shape_node is in the correct frame.
shape_node->add_input(AsControlDependency(reduction_indices_input));
node_map_->AddOutput(NodeName(reduction_indices_input), shape_node->name());
// Optimize node to Reshape.
node->set_op("Reshape");
node_map_->UpdateInput(node->name(), node->input(1), shape_node->name());
node->set_input(1, shape_node->name());
node->mutable_attr()->erase("keep_dims");
node->mutable_attr()->erase("Tidx");
AttrValue attr_type_indices;
attr_type_indices.set_type(DT_INT32);
(*node->mutable_attr())["Tshape"] = attr_type_indices;
return true;
} else if (simplifiable_to_identity) {
return ReplaceReductionWithIdentity(node);
}
return false;
}
bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
bool use_shape_info, NodeDef* node) {
if (!use_shape_info || node->attr().count("T") == 0 ||
!IsSimplifiableReshape(*node, properties).ok()) {
return false;
}
DataType output_type = node->attr().at("T").type();
node->set_op("Identity");
EraseRegularNodeAttributes(node);
(*node->mutable_attr())["T"].set_type(output_type);
*node->mutable_input(1) = AsControlDependency(node->input(1));
return true;
}
Status ConstantFolding::SimplifyArithmeticOperations(
const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) {
const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsAnyMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
const bool is_sub = IsSub(*node);
const bool is_any_div = IsAnyDiv(*node) && !IsFloorDiv(*node);
// Simplify arithmetic operations with ones or zeros.
if (use_shape_info &&
(is_mul || is_matmul || is_add || is_sub || is_any_div) &&
properties.HasInputProperties(node->name()) &&
properties.HasOutputProperties(node->name())) {
const NodeDef* x = node_map_->GetNode(node->input(0));
const NodeDef* y = node_map_->GetNode(node->input(1));
if (x == nullptr || y == nullptr) {
return errors::InvalidArgument("Invalid inputs to node: ",
node->DebugString());
}
const TensorShapeProto& output_shape =
properties.GetOutputProperties(node->name())[0].shape();
// Simplify element-wise multiplication by ones or addition/subtraction
// of zeros.
const TensorShapeProto& y_shape =
properties.GetInputProperties(node->name())[1].shape();
const TensorShapeProto& x_shape =
properties.GetInputProperties(node->name())[0].shape();
const bool y_matches_output_shape =
ShapesSymbolicallyEqual(output_shape, y_shape);
const bool x_matches_output_shape =
ShapesSymbolicallyEqual(output_shape, x_shape);
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
// 1 * y = y or 0 + y = y.
if (y_matches_output_shape) {
ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
} else if (x_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
optimized_graph);
}
return Status::OK();
}
if (y_matches_output_shape && (is_sub && x_is_zero)) {
// Replace 0 - y with Neg(y).
ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
return Status::OK();
}
// Replace 1 / y with Reciprocal op.
if (y_matches_output_shape && is_any_div && x_is_one) {
TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
DataType type = node->attr().at("T").type();
if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
return Status::OK();
}
}
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
if (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero)) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
if (x_matches_output_shape) {
ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
} else if (y_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
optimized_graph);
}
return Status::OK();
}
// x OR true = true OR y = true.
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
1, properties, output_shape, node, optimized_graph));
return Status::OK();
}
// Simplify multiplication and matmul by zeros.
// Also optimize zeros divided by a tensor, but only if we are in
// aggressive mode, since we might get rid of divisions by zero.
const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
if ((x_is_zero || y_is_zero) &&
(is_mul || is_matmul || optimize_zeros_divided_by_y)) {
if (shp.IsFullyDefined()) {
bool is_quantized = IsQuantizedMatMul(*node);
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
0, properties, output_shape, node, optimized_graph));
if (is_quantized && graph_modified_) {
TF_RETURN_IF_ERROR(
AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
}
return Status::OK();
}
// Even if an input shape is only partially known, we may known that it
// matches the output shape and thus forward or broadcast the
// corresponding zero input.
if ((is_mul || is_any_div) && x_is_zero) {
if (x_matches_output_shape) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
} else if (y_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
optimized_graph);
}
return Status::OK();
} else if (is_mul && y_is_zero) {
if (y_matches_output_shape) {
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
} else if (x_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
optimized_graph);
}
return Status::OK();
}
}
}
return Status::OK();
}
bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
NodeDef* node) {
// Strength reduce floating point division by a constant Div(x, const) to
// multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
// will be constant folded to Mul(x, 1.0/const).
if (node->input_size() >= 2 &&
(IsDiv(*node) || IsRealDiv(*node) || IsXdivy(*node))) {
const string& const_input = node->input(1);
const NodeDef* denom = node_map_->GetNode(const_input);
CHECK(denom != nullptr);
if (!IsReallyConstant(*denom)) {
return false;
}
if (node->attr().count("T") == 0) {
return false;
}
DataType type = node->attr().at("T").type();
// Skip integer division.
if (IsDiv(*node) &&
!(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
return false;
}
// Insert new reciprocal op and change node from Div to Mul.
NodeDef* reciprocal_node = optimized_graph->add_node();
reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
reciprocal_node->set_op("Reciprocal");
reciprocal_node->set_device(node->device());
reciprocal_node->add_input(const_input);
(*reciprocal_node->mutable_attr())["T"].set_type(type);
// Re-wire inputs and outputs.
if (IsXdivy(*node)) {
node->set_op("MulNoNan");
node->set_input(1, node->input(0));
node->set_input(0, reciprocal_node->name());
} else {
node->set_op("Mul");
node->set_input(1, reciprocal_node->name());
}
node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
return true;
}
return false;
}
bool ConstantFolding::PrepareConstantPushDown(
const NodeDef& parent, const GraphProperties& properties,
bool must_have_properties, ConstantPushDownContext* ctx) const {
if (ctx == nullptr || !has_fetch_ || NumNonControlInputs(parent) != 2) {
return false;
}
NodeDef* left_child = node_map_->GetNode(parent.input(0));
NodeDef* right_child = node_map_->GetNode(parent.input(1));
// Sanity check for missing children.
if (left_child == nullptr || right_child == nullptr) {
return false;
}
ctx->left_child_is_const = IsReallyConstant(*left_child);
ctx->right_child_is_const = IsReallyConstant(*right_child);
ctx->op_child = ctx->left_child_is_const ? right_child : left_child;
ctx->const_child = ctx->left_child_is_const ? left_child : right_child;
// Nothing to do unless the parent has a constant child node.
if (!ctx->left_child_is_const && !ctx->right_child_is_const) {
return false;
}
// Don't move nodes across devices.
if (parent.device() != ctx->op_child->device() ||
parent.device() != ctx->const_child->device()) {
return false;
}
// Make sure that it is safe to change the value of the child node result.
if (ctx->op_child->input_size() < 2 ||
nodes_to_preserve_.find(ctx->op_child->name()) !=
nodes_to_preserve_.end() ||
NumNonControlOutputs(*ctx->op_child, *node_map_) > 1) {
return false;
}
// Don't apply reassociation to floating point types of low precision.
// The danger of significant numerical changes is too high.
if (!CheckAttrExists(parent, "T").ok()) return false;
DataType dtype = parent.attr().at("T").type();
if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
return false;
}
// Don't rewrite the tree if it might create cycles.
// TODO(rmlarsen): Add back handling of control dependency from op to C.
const auto& child_output = node_map_->GetOutputs(ctx->op_child->name());
if (child_output.find(ctx->const_child) != child_output.end()) {
return false;
}
// Get leaf nodes.
ctx->left_leaf = node_map_->GetNode(ctx->op_child->input(0));
ctx->right_leaf = node_map_->GetNode(ctx->op_child->input(1));
ctx->left_leaf_is_const = IsReallyConstant(*ctx->left_leaf);
ctx->right_leaf_is_const = IsReallyConstant(*ctx->right_leaf);
if (ctx->left_leaf_is_const && ctx->right_leaf_is_const) {
// Child is already foldable, leave it alone.
return false;
}
// Don't move nodes across devices.
if (parent.device() != ctx->left_leaf->device() ||
parent.device() != ctx->right_leaf->device()) {
return false;
}
// Get shape and type information.
ctx->parent_input_props = &properties.GetInputProperties(parent.name());
ctx->op_child_input_props =
&properties.GetInputProperties(ctx->op_child->name());
if (must_have_properties && (ctx->parent_input_props == nullptr ||
ctx->parent_input_props->size() < 2 ||
ctx->op_child_input_props == nullptr ||
ctx->op_child_input_props->size() < 2)) {
return false;
}
VLOG(1) << "\n++++++++ PushDown for node " << parent.name() << ": "
<< parent.op() << "(" << left_child->op() << ", " << right_child->op()
<< ")";
return true;
}
bool ConstantFolding::ConstantPushDownBiasAdd(GraphProperties* properties,
GraphDef* optimized_graph,
NodeDef* node) {
// This implements constant push-down for BiasAdd. In the following "CV" is a
// constant vector (tensor of rank 1), "V" is a (possibly) non-constant
// vector, "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
// non-constant matrix, and "BA" is BiasAdd.
// For a valid input graph, the following 4 rewrites are legal:
//
// 1) + +
// / \ / \
// BA CV -- > BA V
// / \ / \
// M V M CV
//
// 2) + +
// / \ / \
// BA CM -- > BA M
// / \ / \
// M V CM V
//
// 3) BA BA
// / \ / \
// + CV -- > + V
// / \ / \
// M V M CV
//
// 4) BA BA = parent
// / \ / \
// BA CV -- > BA V = children
// / \ / \
// M V M CV = leaves
//
// Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
const bool parent_is_bias_add = IsBiasAdd(*node);
if (!parent_is_bias_add && !IsAdd(*node)) return false;
ConstantPushDownContext ctx;
if (!PrepareConstantPushDown(*node, *properties,
/*must_have_properties=*/true, &ctx)) {
return false;
}
// Special case for BiasAdd: Since the left argument to BiasAdd must be rank
// >= 2 and the leaves must be vectors, we cannot swap them.
if (ctx.left_child_is_const && parent_is_bias_add) return false;
const bool child_is_bias_add = IsBiasAdd(*ctx.op_child);
if (!child_is_bias_add && !IsAdd(*ctx.op_child)) return false;
// Get properties to validate rank and dtype constraints.
if (ctx.parent_input_props->empty() || ctx.op_child_input_props->empty() ||
(*ctx.parent_input_props)[0].shape().unknown_rank() ||
(*ctx.parent_input_props)[1].shape().unknown_rank() ||
(*ctx.op_child_input_props)[0].shape().unknown_rank() ||
(*ctx.op_child_input_props)[1].shape().unknown_rank()) {
return false;
}
// Now get the ranks and types of the 3 leaf nodes.
const int left_leaf_rank = (*ctx.op_child_input_props)[0].shape().dim_size();
const int right_leaf_rank = (*ctx.op_child_input_props)[1].shape().dim_size();
// At least one leaf must be a vector.
if (left_leaf_rank != 1 && right_leaf_rank != 1) return false;
const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
const int matrix_idx = 1 - vector_idx;
const auto& vector_prop = (*ctx.op_child_input_props)[vector_idx];
const int vector_rank = vector_idx == 0 ? left_leaf_rank : right_leaf_rank;
if (vector_rank != 1) return false; // this should never happen.
const DataType vector_type = vector_prop.dtype();
const auto& matrix_prop = (*ctx.op_child_input_props)[matrix_idx];
const int matrix_rank = matrix_prop.shape().dim_size();
const DataType matrix_type = matrix_prop.dtype();
const int const_idx = ctx.left_child_is_const ? 0 : 1;
const auto& const_prop = (*ctx.parent_input_props)[const_idx];
const int const_rank = const_prop.shape().dim_size();
const DataType const_type = const_prop.dtype();
int input_to_swap = -1;
if (!parent_is_bias_add && child_is_bias_add && const_rank == matrix_rank &&
const_type == matrix_type) {
// Case 2:
input_to_swap = matrix_idx;
} else if (const_rank == 1 && const_type == vector_type) {
// Case 1, 3, and, 4:
input_to_swap = vector_idx;
}
if (input_to_swap == -1) return false;
const NodeDef* leaf_to_swap =
node_map_->GetNode(ctx.op_child->input(input_to_swap));
if (IsConstant(*leaf_to_swap)) return false;
node_map_->UpdateInput(node->name(), node->input(const_idx),
ctx.op_child->input(input_to_swap));
node_map_->AddOutput(node->input(const_idx), ctx.op_child->name());
if (ctx.op_child->input(input_to_swap) !=
ctx.op_child->input(1 - input_to_swap)) {
node_map_->RemoveOutput(ctx.op_child->input(input_to_swap),
ctx.op_child->name());
}
std::swap(*node->mutable_input(const_idx),
*ctx.op_child->mutable_input(input_to_swap));
properties->ClearInputProperties(node->name());
properties->ClearInputProperties(ctx.op_child->name());
return true;
}
bool ConstantFolding::ConstantPushDown(GraphProperties* properties,
GraphDef* optimized_graph,
NodeDef* node) {
// Consider the transformation
//
// + + = parent
// / \ / \
// C + -- > X + = children
// / \ / \
// X Y C Y = leaves
//
// where C is constant, X is non-constant, Y may be constant or non-constant,
// and '+' denotes an associative and commutative operator like addition or
// multiplication. This optimization pushes constants down in the tree to
// canonicalize it. Moreover, in cases where the child node has a second
// constant input Y we will create a leaf node that can be folded, e.g.
//
// Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
//
// We also handle the non-commutative cases of subtraction and division
// by rotating the tree locally, e.g.
// Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
// Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
// Get parent op type.
const bool is_add = IsAdd(*node);
const bool is_mul = IsMul(*node);
const bool is_sub = IsSub(*node);
const bool is_div = IsDiv(*node);
if (!(is_add || is_sub || is_mul || is_div)) return false;
const bool is_symmetric = is_add || is_mul;
ConstantPushDownContext ctx;
if (!PrepareConstantPushDown(*node, *properties,
/*must_have_properties=*/false, &ctx)) {
return false;
}
// Get child op type.
const bool is_child_add = IsAdd(*ctx.op_child);
const bool is_child_mul = IsMul(*ctx.op_child);
const bool is_child_sub = IsSub(*ctx.op_child);
const bool is_child_div = IsDiv(*ctx.op_child);
const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
if (!is_add_sub && !is_mul_div) {
return false;
}
const bool is_child_symmetric = is_child_add || is_child_mul;
if (!CheckAttrExists(*node, "T").ok()) return false;
DataType dtype = node->attr().at("T").type();
if (!(is_symmetric && is_child_symmetric) &&
!(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
return false;
}
const NodeDef* y_node =
ctx.left_leaf_is_const ? ctx.left_leaf : ctx.right_leaf;
if (!IsReallyConstant(*y_node) && !ctx.parent_input_props->empty() &&
!ctx.op_child_input_props->empty()) {
// If we know the shapes of the nodes being swapped, make sure we don't push
// down a larger node and create more work by broadcasting earlier in the
// expressions tree.
const PartialTensorShape c_shape(
(*ctx.parent_input_props)[ctx.left_child_is_const ? 0 : 1].shape());
const PartialTensorShape x_shape(
(*ctx.op_child_input_props)[ctx.left_leaf_is_const ? 0 : 1].shape());
if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
c_shape.num_elements() > x_shape.num_elements()) {
return false;
} else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
c_shape.dims() > 0) {
for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims()); ++idx) {
if (x_shape.dim_size(idx) >= 0 &&
c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
return false;
}
}
}
}
// Get the node names corresponding to X, Y, and C.
const string input_x =
ctx.left_leaf_is_const ? ctx.op_child->input(1) : ctx.op_child->input(0);
const string input_y = input_x == ctx.op_child->input(0)
? ctx.op_child->input(1)
: ctx.op_child->input(0);
const string input_c =
ctx.left_child_is_const ? node->input(0) : node->input(1);
const string input_op =
ctx.left_child_is_const ? node->input(1) : node->input(0);
VLOG(1) << "input_c = " << input_c << "\ninput_x = " << input_x;
// Now we have identified the nodes to swap, update the nodemap accordingly.
node_map_->UpdateInput(node->name(), input_c, input_x);
node_map_->AddOutput(input_c, ctx.op_child->name());
if (input_x != input_y) {
node_map_->RemoveOutput(input_x, ctx.op_child->name());
}
properties->ClearInputProperties(node->name());
properties->ClearInputProperties(ctx.op_child->name());
if (is_symmetric && is_child_symmetric) {
// Easy case (only commutative ops). We always write this as one of
// +
// / \
// X +
// / \
// C Y
node->set_input(0, input_x);
node->set_input(1, input_op);
ctx.op_child->set_input(0, input_c);
ctx.op_child->set_input(1, input_y);
} else {
// More complicated case: When there are non-commutative operations like
// subtractions or divisions involved, we may have to rotate the tree
// and/or change op types. There are 6 non-trivial cases depending on
// the effective generalized "sign" of each of the three terms C, Y, and X.
// Here are the final trees we want to generate for those 6 cases:
//
// (CYX signs): ++- +-- -+- --+ +-+ -++
//
// - - - - + +
// / \ / \ / \ / \ / \ / \
// + X - X - X X + X - X -
// / \ / \ / \ / \ / \ / \
// C Y C Y Y C Y C C Y Y C
//
// First, let's determine the effective sign of each term in the original
// expression
auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
bool leaf_negated = !is_child_symmetric && is_right_leaf;
bool child_negated = !is_symmetric && (ctx.left_child_is_const);
return leaf_negated != child_negated;
};
const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
bool neg_c = !is_symmetric && !ctx.left_child_is_const;
bool neg_x = is_leaf_negated(ctx.left_leaf_is_const);
bool neg_y = is_leaf_negated(!ctx.left_leaf_is_const);
// Rewrite the parent node.
node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
node->set_input(0, neg_x ? input_op : input_x);
node->set_input(1, neg_x ? input_x : input_op);
// Rewrite the child node.
ctx.op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
ctx.op_child->set_input(0, neg_c ? input_y : input_c);
ctx.op_child->set_input(1, neg_c ? input_c : input_y);
}
return true;
}
bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
const GraphProperties& properties) {
// Push down multiplication on ConvND.
// * ConvND
// / \ / \
// ConvND C2 -- > X *
// / \ / \
// X C1 C1 C2
//
// where C1 and C2 are constants and X is non-constant.
//
// TODO(rmlarsen): Use PrepareConstantPushDown() to simplify this code.
if (!IsAnyMul(*node) || NumNonControlInputs(*node) != 2) return false;
NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
NodeDef* mul_right_child = node_map_->GetNode(node->input(1));
if (mul_left_child == nullptr || mul_right_child == nullptr) {
return false;
}
// One child must be constant, and the second must be Conv op.
const bool left_child_is_constant = IsReallyConstant(*mul_left_child);
const bool right_child_is_constant = IsReallyConstant(*mul_right_child);
if (!left_child_is_constant && !right_child_is_constant) {
return false;
}
NodeDef* conv_node =
left_child_is_constant ? mul_right_child : mul_left_child;
if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) {
return false;
}
if (node->device() != mul_left_child->device() ||
node->device() != mul_right_child->device()) {
return false;
}
// Make sure that it is safe to change the value of the convolution
// output.
if (conv_node->input_size() < 2 ||
NumNonControlOutputs(*conv_node, *node_map_) > 1 ||
nodes_to_preserve_.find(conv_node->name()) != nodes_to_preserve_.end()) {
return false;
}
// Identify the nodes to swap.
NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0));
NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1));
const bool conv_left_is_constant = IsReallyConstant(*conv_left_child);
const bool conv_right_is_constant = IsReallyConstant(*conv_right_child);
if (!conv_left_is_constant && !conv_right_is_constant) {
// At least one of the convolution inputs should be constant.
return false;
}
if (conv_left_is_constant && conv_right_is_constant) {
// Leverage regular constant folding to handle this.
return false;
}
const auto& mul_props = properties.GetOutputProperties(node->name());
const auto& conv_props = properties.GetOutputProperties(conv_node->name());
if (mul_props.empty() || conv_props.empty()) {
return false;
}
const auto& mul_shape = mul_props[0].shape();
const auto& conv_shape = conv_props[0].shape();
if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) {
return false;
}
const auto& input_props = properties.GetInputProperties(conv_node->name());
if (input_props.size() < 2) {
return false;
}
const auto& filter_shape = input_props[1].shape();
NodeDef* const_node =
left_child_is_constant ? mul_left_child : mul_right_child;
const auto& const_props = properties.GetOutputProperties(const_node->name());
if (const_props.empty()) {
return false;
}
const auto& const_shape = const_props[0].shape();
if (!IsValidConstShapeForMulConvPushDown(
conv_node->attr().at("data_format").s(), filter_shape, const_shape)) {
return false;
}
string mul_new_name = AddPrefixToNodeName("merged_input", conv_node->name());
if (node_map_->NodeExists(mul_new_name)) {
return false;
}
// Make sure we don't introduce loops in the graph by removing control
// dependencies from the conv2d node to c2.
string conv_const_input =
conv_left_is_constant ? conv_node->input(0) : conv_node->input(1);
if (MaybeRemoveControlInput(conv_node->name(), const_node, optimized_graph,
node_map_.get())) {
// Add a control dep from c1 to c2 to ensure c2 is in the right frame
MaybeAddControlInput(conv_const_input, const_node, optimized_graph,
node_map_.get());
}
conv_node->set_name(node->name());
node->set_name(mul_new_name);
if (conv_left_is_constant) {
node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name);
conv_node->set_input(0, mul_new_name);
} else {
node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
conv_node->set_input(1, mul_new_name);
}
NodeDef* conv_const_node =
conv_left_is_constant ? conv_left_child : conv_right_child;
if (left_child_is_constant) {
node->set_input(1, conv_const_node->name());
} else {
node->set_input(0, conv_const_node->name());
}
node_map_->AddNode(mul_new_name, node);
return true;
}
bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
// Partial constant propagation through IdentityN.
if (!(IsIdentityN(*node) || IsIdentityNSingleInput(*node)) ||
!HasRegularInputs(*node))
return false;
std::vector<int> inputs_to_forward;
for (int input_idx = 0; input_idx < node->input_size(); ++input_idx) {
const string& input = node->input(input_idx);
if (IsControlInput(input)) {
return false;
}
const NodeDef* input_node = node_map_->GetNode(NodeName(input));
if (input_node == nullptr) {
LOG(ERROR) << "Bad input: " << input;
return false;
}
// Forward constant inputs to outputs and add a control dependency on
// the IdentityN node.
if (IsReallyConstant(*input_node)) {
inputs_to_forward.push_back(input_idx);
}
}
return ForwardInputs(node, inputs_to_forward);
}
bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
GraphProperties* properties,
NodeDef* node) {
// Partial constant folding for associative operators:
// Split AddN/AccumulateNV2 to enable partial
// folding of ops when more than one but not all inputs are constant.
// For AddN and AccumulateNV2, we may furthermore reorder inputs, since
// addition is commutative.
if (!IsAggregate(*node) || !IsCommutative(*node)) return false;
const int num_non_control_inputs = NumNonControlInputs(*node);
if (num_non_control_inputs <= 2) return false;
const int num_control_inputs = node->input_size() - num_non_control_inputs;
std::vector<int> const_inputs;
std::vector<int> nonconst_inputs;
for (int i = 0; i < node->input_size(); ++i) {
const string& input = node->input(i);
const NodeDef* input_node = node_map_->GetNode(NodeName(input));
if (input_node == nullptr) return false;
if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
const_inputs.push_back(i);
} else {
// Non-const and control inputs.
nonconst_inputs.push_back(i);
}
}
// Promote AccumulateNV2 with all constant inputs to AddN, since it is
// a fake node that cannot be constant folded by itself.
int const_inputs_size = const_inputs.size();
if (const_inputs_size == num_non_control_inputs &&
node->op() == "AccumulateNV2") {
node->set_op("AddN");
node->mutable_attr()->erase("shape");
return true;
}
const string new_node_name = OptimizedNodeName(
*node, strings::StrCat("_partial_split_", const_inputs_size));
if (const_inputs_size > 1 && const_inputs_size < num_non_control_inputs &&
!node_map_->NodeExists(new_node_name)) {
NodeDef* added_node = optimized_graph->add_node();
*added_node = *node;
// Always use AddN for the constant node, since AccumulateNV2 is a fake
// node that cannot be constant folded, since it does not have a kernel.
added_node->set_op("AddN");
added_node->mutable_attr()->erase("shape");
added_node->set_name(new_node_name);
node_map_->AddNode(added_node->name(), added_node);
added_node->clear_input();
for (int i : const_inputs) {
added_node->add_input(node->input(i));
node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
added_node->name());
}
// Overwrite the first const input with the added node.
node->set_input(const_inputs[0], added_node->name());
node_map_->AddOutput(added_node->name(), node->name());
nonconst_inputs.push_back(const_inputs[0]);
// Compact the remaining inputs to the original node.
std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
int idx = 0;
for (int i : nonconst_inputs) {
if (idx != i) {
node->set_input(idx, node->input(i));
}
++idx;
}
node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
const_inputs.size() - 1);
(*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs);
properties->ClearInputProperties(node->name());
(*added_node->mutable_attr())["N"].set_i(const_inputs.size());
return true;
}
return false;
}
bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
GraphProperties* properties,
NodeDef* node) {
// Partial constant folding for Concat which is not commutative, so
// we have to preserve order and can only push consecutive runs of constant
// inputs into sub-nodes.
if (!IsConcat(*node) ||
node->name().rfind("_partial_split_") != string::npos) {
return false;
}
const int num_non_control_inputs = NumNonControlInputs(*node);
if (num_non_control_inputs <= 3) return false;
int axis_arg = -1;
int begin = 0;
int end = num_non_control_inputs;
if (node->op() == "Concat") {
begin = 1;
axis_arg = 0;
} else if (node->op() == "ConcatV2") {
end = num_non_control_inputs - 1;
axis_arg = num_non_control_inputs - 1;
} else {
return false;
}
// We search for consecutive runs of constant inputs in the range
// [begin:end[ and push then down into child nodes.
std::vector<std::pair<int, int>> constant_input_runs;
int first = begin;
int last = begin;
while (last < end) {
while (first < end && !IsReallyConstant(*node_map_->GetNode(
NodeName(node->input(first))))) {
++first;
}
// Invariant: node[first] is constant || first >= end.
last = first + 1;
while (last < end &&
IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) {
++last;
}
// Invariant: node[last] is not constant || last >= end
// Discard intervals shorter than 2 elements.
if (first < end && (last - first) > 1) {
constant_input_runs.emplace_back(first, last);
}
first = last;
}
// Skip if all inputs are constant, and let constant folding take over.
if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
constant_input_runs[0].first == begin &&
constant_input_runs[0].second == end)) {
return false;
}
std::set<int> inputs_to_delete;
for (auto interval : constant_input_runs) {
// Push the constant inputs in the interval to a child node than can be
// constant folded.
string new_node_name = OptimizedNodeName(*node, "_partial_split");
do {
new_node_name += strings::StrCat("_", interval.first);
} while (node_map_->NodeExists(new_node_name));
NodeDef* added_node = optimized_graph->add_node();
*added_node = *node;
added_node->set_op("ConcatV2");
added_node->set_name(new_node_name);
node_map_->AddNode(added_node->name(), added_node);
added_node->clear_input();
for (int i = interval.first; i < interval.second; ++i) {
added_node->add_input(node->input(i));
node_map_->UpdateInput(node->name(), node->input(i), added_node->name());
if (i != interval.first) {
inputs_to_delete.insert(i);
}
}
added_node->add_input(node->input(axis_arg));
(*added_node->mutable_attr())["N"].set_i(interval.second - interval.first);
node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
// Overwrite the first constant input with the result of the added
// child node.
node->set_input(interval.first, added_node->name());
}
if (!inputs_to_delete.empty()) {
// Fix up the inputs to the original node.
protobuf::RepeatedPtrField<string> tmp;
tmp.Swap(node->mutable_input());
for (int i = 0; i < tmp.size(); ++i) {
if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
node->add_input(tmp.Get(i));
}
}
(*node->mutable_attr())["N"].set_i(node->input_size() - 1);
properties->ClearInputProperties(node->name());
}
return true;
}
bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
if (node.op() != "ConcatV2") {
return false;
}
int axis_idx = node.input_size() - 1;
while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) {
--axis_idx;
}
if (axis_idx <= 0) {
return false;
}
Tensor axis_tensor;
if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) {
return false;
}
*axis = axis_tensor.dtype() == DT_INT64
? static_cast<int>(axis_tensor.scalar<int64_t>()())
: axis_tensor.scalar<int32>()();
return true;
}
bool ConstantFolding::MergeConcat(bool use_shape_info,
GraphProperties* properties,
GraphDef* optimized_graph, NodeDef* node) {
// We only optimize for ConcatV2.
int axis;
if (!use_shape_info || !GetConcatAxis(*node, &axis) ||
nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
node_map_->GetOutputs(node->name()).size() != 1) {
return false;
}
// If all inputs are constant, don't merge and let folding take case of it.
const int num_regular_inputs = NumNonControlInputs(*node);
bool all_inputs_are_const = true;
for (int i = 0; i < num_regular_inputs - 1; ++i) {
const NodeDef* input_node = node_map_->GetNode(node->input(i));
if (!IsReallyConstant(*input_node)) {
all_inputs_are_const = false;
break;
}
}
if (all_inputs_are_const) return false;
NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
int parent_axis;
if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) {
return false;
}
// Make a pass over the parent inputs to see if any of them have explicit
// device() fields set, and if different inputs are on different tasks. If
// so, this concat of concats may have been carefully constructed to be a
// two-stage concat, and we don't want to undo that here.
string task, device;
absl::flat_hash_set<string> unique_input_tasks;
const int n_parent_inputs = NumNonControlInputs(*parent);
// Iterate over the real inputs to concatenate [0..n_parent_inputs - 1). The
// input at n_parent_inputs - 1 is the concat axis argument for a ConcatV2
// node, which we don't want to consider here.
for (int i = 0; i < n_parent_inputs - 1; ++i) {
const NodeDef* input_node = node_map_->GetNode(parent->input(i));
if (!input_node->device().empty() &&
tensorflow::DeviceNameUtils::SplitDeviceName(input_node->device(),
&task, &device)) {
unique_input_tasks.insert(task);
if (unique_input_tasks.size() >= 2) {
// More than one input task represented in the device specifications
// of the parent's input nodes. Don't mess with this.
return false;
}
}
}
protobuf::RepeatedPtrField<string> parent_inputs;
parent_inputs.Swap(parent->mutable_input());
// TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
// collapse it into the parent multiple times? Probably not.
for (const auto& input : parent_inputs) {
if (IsSameInput(input, node->name())) {
for (int j = 0; j < num_regular_inputs - 1; ++j) {
// Add tensor inputs to first child concat tensors (except the final
// axis input) to the parent's inputs.
parent->add_input(node->input(j));
node_map_->UpdateInput(parent->name(), node->name(), node->input(j));
}
} else {
parent->add_input(input);
}
}
// Forward Add control inputs
const int num_inputs = node->input_size();
for (int i = num_inputs - 1; i >= num_regular_inputs; --i) {
parent->add_input(node->input(i));
node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
node->mutable_input()->RemoveLast();
}
(*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
DedupControlInputs(parent);
ReplaceOperationWithNoOp(node, properties, optimized_graph);
return true;
}
Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
NodeDef* node, GraphDef* optimized_graph) {
auto add_quantized_out = [this, node, optimized_graph](
const string& out_const_name, int index) {
NodeDef* out_node = optimized_graph->add_node();
graph_modified_ = true;
Tensor value(DT_FLOAT, TensorShape({}));
const bool is_min = index == 1;
const DataType type_attr = node->attr().at("dtype").type();
value.flat<float>()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr)
: QuantizedTypeMaxAsFloat(type_attr);
TF_RETURN_IF_ERROR(
CreateNodeDef(out_const_name, TensorValue(&value), out_node));
node_map_->AddNode(out_const_name, out_node);
out_node->set_device(node->device());
// Copy all inputs from node.
out_node->mutable_input()->CopyFrom(node->input());
for (const string& input : out_node->input()) {
node_map_->AddOutput(NodeName(input), out_const_name);
}
// Update output nodes consuming node:index to new const node.
string old_input = absl::StrCat(node->name(), ":", index);
int old_node_count = 0;
// We make a copy since the set might change.
auto outputs = node_map_->GetOutputs(node->name());
for (const auto& output : outputs) {
for (int i = 0; i < output->input_size(); ++i) {
if (output->input(i) == old_input) {
output->set_input(i, out_const_name);
node_map_->AddOutput(out_const_name, output->name());
} else if (NodeName(output->input(i)) == node->name()) {
++old_node_count;
}
}
if (old_node_count == 0) {
node_map_->RemoveOutput(node->name(), output->name());
}
}
return Status::OK();
};
const string min_out_const_name =
OptimizedNodeName(*node, "-quantized_matmul_min_out");
const string max_out_const_name =
OptimizedNodeName(*node, "-quantized_matmul_max_out");
if (node_map_->GetNode(min_out_const_name) == nullptr &&
node_map_->GetNode(max_out_const_name) == nullptr) {
TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1));
TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2));
} else {
return errors::Internal(absl::Substitute(
"Can't create Const for QuantizedMatMul min_out/max_out of "
"node '$0' because of node name conflict",
node->name()));
}
return Status::OK();
}
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
GrapplerItem* item,
GraphProperties* properties,
GraphDef* optimized_graph) {
optimized_graph->Clear();
graph_ = &item->graph;
node_map_.reset(new NodeMap(graph_));
nodes_allowlist_.clear();
// Fold fetch nodes iff it has a single fanout. Note that if a fetch node
// has a single fanout, it would be rewritten as a constant with the same
// node name, and therefore users are still able to fetch it. This is not
// the case if the node has multiple fanouts, and constant folding would
// replace the node with multiple constants (each for one fanout) with
// new names, and as a result users would not be able to fetch the node any
// more with the original node name.
for (const auto& fetch : item->fetch) {
const NodeDef* fetch_node = node_map_->GetNode(fetch);
if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
nodes_allowlist_.insert(fetch_node->name());
}
}
absl::flat_hash_set<string> nodes_to_not_simplify;
if (properties->has_properties()) {
TF_RETURN_IF_ERROR(MaterializeShapes(*properties));
TF_RETURN_IF_ERROR(MaterializeConstants(*properties));
TF_RETURN_IF_ERROR(
FoldGraph(*properties, optimized_graph, &nodes_to_not_simplify));
} else {
*optimized_graph = *graph_;
}
node_map_.reset(new NodeMap(optimized_graph));
TF_RETURN_IF_ERROR(
SimplifyGraph(optimized_graph, properties, &nodes_to_not_simplify));
return Status::OK();
}
Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
// TensorFlow flushes denormals to zero and rounds to nearest, so we do
// the same here.
port::ScopedFlushDenormal flush;
port::ScopedSetRound round(FE_TONEAREST);
nodes_to_preserve_ = item.NodesToPreserve();
for (const auto& feed : item.feed) {
feed_nodes_.insert(NodeName(feed.first));
}
if (cpu_device_ == nullptr) {
owned_device_.reset(new DeviceSimple());
cpu_device_ = owned_device_.get();
}
graph_contains_assign_or_inplace_op_ = false;
for (const NodeDef& node : item.graph.node()) {
if (ModifiesInputsInPlace(node) || HasRefInput(node)) {
graph_contains_assign_or_inplace_op_ = true;
break;
}
}
has_fetch_ = !item.fetch.empty();
GrapplerItem item_to_optimize = item;
GraphProperties properties(item_to_optimize);
// It's possible to feed a placeholder with a tensor of any shape: make sure
// that the shape inference deals with this conservatively unless we're in
// aggressive mode.
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
if (!properties
.InferStatically(assume_valid_feeds,
/*aggressive_shape_inference=*/false,
/*include_input_tensor_values=*/false,
/*include_output_tensor_values=*/true)
.ok()) {
properties.Clear();
}
*optimized_graph = GraphDef();
item_to_optimize.graph.Swap(optimized_graph);
int64_t node_count;
do {
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
graph_modified_ = false;
item_to_optimize.graph.Swap(optimized_graph);
node_count = item_to_optimize.graph.node_size();
TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, &item_to_optimize,
&properties, optimized_graph));
} while (graph_modified_ || optimized_graph->node_size() != node_count);
*optimized_graph->mutable_library() = item.graph.library();
*optimized_graph->mutable_versions() = item.graph.versions();
return Status::OK();
}
} // namespace grappler
} // namespace tensorflow