blob: 008a289cfd6726cbfb88f86b4f1337cb3d42992e [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace grappler {
void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) {
const string& x = fused_node.input(0);
string scale = fused_node.input(1);
string offset = fused_node.input(2);
string mean = fused_node.input(3);
string variance = fused_node.input(4);
if (fused_node.attr().at("data_format").s() == "NCHW") {
// Need to reshape the last 4 inputs
NodeDef* new_shape = optimized_graph->add_node();
new_shape->set_name(AddPrefixToNodeName("NCHWShape", fused_node.name()));
new_shape->set_op("Const");
new_shape->set_device(fused_node.device());
*new_shape->add_input() = AsControlDependency(scale);
(*new_shape->mutable_attr())["dtype"].set_type(DT_INT32);
Tensor t(DT_INT32, {4});
t.flat<int32>()(0) = 1;
t.flat<int32>()(1) = -1;
t.flat<int32>()(2) = 1;
t.flat<int32>()(3) = 1;
t.AsProtoTensorContent(
(*new_shape->mutable_attr())["value"].mutable_tensor());
NodeDef* reshaped_scale = optimized_graph->add_node();
reshaped_scale->set_name(
AddPrefixToNodeName("NCHWShapedScale", fused_node.name()));
reshaped_scale->set_op("Reshape");
reshaped_scale->set_device(fused_node.device());
*reshaped_scale->add_input() = scale;
*reshaped_scale->add_input() = new_shape->name();
(*reshaped_scale->mutable_attr())["T"] = fused_node.attr().at("T");
(*reshaped_scale->mutable_attr())["Tshape"].set_type(DT_INT32);
scale = reshaped_scale->name();
NodeDef* reshaped_offset = optimized_graph->add_node();
reshaped_offset->set_name(
AddPrefixToNodeName("NCHWShapedOffset", fused_node.name()));
reshaped_offset->set_op("Reshape");
reshaped_offset->set_device(fused_node.device());
*reshaped_offset->add_input() = offset;
*reshaped_offset->add_input() = new_shape->name();
(*reshaped_offset->mutable_attr())["T"] = fused_node.attr().at("T");
(*reshaped_offset->mutable_attr())["Tshape"].set_type(DT_INT32);
offset = reshaped_offset->name();
NodeDef* reshaped_mean = optimized_graph->add_node();
reshaped_mean->set_name(
AddPrefixToNodeName("NCHWShapedMean", fused_node.name()));
reshaped_mean->set_op("Reshape");
reshaped_mean->set_device(fused_node.device());
*reshaped_mean->add_input() = mean;
*reshaped_mean->add_input() = new_shape->name();
(*reshaped_mean->mutable_attr())["T"] = fused_node.attr().at("T");
(*reshaped_mean->mutable_attr())["Tshape"].set_type(DT_INT32);
mean = reshaped_mean->name();
NodeDef* reshaped_variance = optimized_graph->add_node();
reshaped_variance->set_name(
AddPrefixToNodeName("NCHWShapedVariance", fused_node.name()));
reshaped_variance->set_op("Reshape");
reshaped_variance->set_device(fused_node.device());
*reshaped_variance->add_input() = variance;
*reshaped_variance->add_input() = new_shape->name();
(*reshaped_variance->mutable_attr())["T"] = fused_node.attr().at("T");
(*reshaped_variance->mutable_attr())["Tshape"].set_type(DT_INT32);
variance = reshaped_variance->name();
}
float epsilon = 0.0f;
if (fused_node.attr().count("epsilon")) {
epsilon = fused_node.attr().at("epsilon").f();
}
DataType dtype = fused_node.attr().at("T").type();
Tensor value(dtype, TensorShape());
value.scalar<float>()() = epsilon;
NodeDef* variance_epsilon = optimized_graph->add_node();
TF_CHECK_OK(ConstantFolding::CreateNodeDef(
AddPrefixToNodeName("Const", fused_node.name()), &value,
variance_epsilon));
variance_epsilon->set_device(fused_node.device());
NodeDef* variance_plus_epsilon = optimized_graph->add_node();
variance_plus_epsilon->set_name(
AddPrefixToNodeName("VarPlusEpsilon", fused_node.name()));
variance_plus_epsilon->set_op("Add");
(*variance_plus_epsilon->mutable_attr())["T"].set_type(dtype);
variance_plus_epsilon->set_device(fused_node.device());
*variance_plus_epsilon->add_input() = variance;
*variance_plus_epsilon->add_input() = variance_epsilon->name();
NodeDef* inv = optimized_graph->add_node();
inv->set_name(AddPrefixToNodeName("Inv", fused_node.name()));
inv->set_op("Rsqrt");
inv->set_device(fused_node.device());
(*inv->mutable_attr())["T"].set_type(dtype);
*inv->add_input() = variance_plus_epsilon->name();
NodeDef* scaled = optimized_graph->add_node();
scaled->set_name(AddPrefixToNodeName("Scaled", fused_node.name()));
scaled->set_op("Mul");
scaled->set_device(fused_node.device());
(*scaled->mutable_attr())["T"].set_type(dtype);
*scaled->add_input() = inv->name();
*scaled->add_input() = scale;
NodeDef* a = optimized_graph->add_node();
a->set_name(AddPrefixToNodeName("Mul", fused_node.name()));
a->set_op("Mul");
a->set_device(fused_node.device());
(*a->mutable_attr())["T"].set_type(dtype);
*a->add_input() = x;
*a->add_input() = scaled->name();
NodeDef* b = optimized_graph->add_node();
b->set_name(AddPrefixToNodeName("Mul2", fused_node.name()));
b->set_op("Mul");
b->set_device(fused_node.device());
(*b->mutable_attr())["T"].set_type(dtype);
*b->add_input() = mean;
*b->add_input() = scaled->name();
NodeDef* c = optimized_graph->add_node();
c->set_name(AddPrefixToNodeName("Offset", fused_node.name()));
c->set_op("Sub");
c->set_device(fused_node.device());
(*c->mutable_attr())["T"].set_type(dtype);
*c->add_input() = offset;
*c->add_input() = b->name();
NodeDef* r = optimized_graph->add_node();
r->set_name(fused_node.name());
r->set_op("Add");
r->set_device(fused_node.device());
(*r->mutable_attr())["T"].set_type(dtype);
*r->add_input() = a->name();
*r->add_input() = c->name();
}
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
GraphDef* optimized_graph) {
GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically(false));
GraphView graph(const_cast<GraphDef*>(&item.graph));
// During inference, most of the inputs to FusedBatchNorm are constant, and we
// can therefore replace the op with a much cheaper set of primitives.
for (const NodeDef& node : item.graph.node()) {
if (node.op() == "FusedBatchNorm" || node.op() == "FusedBatchNormV2") {
bool optimizable = (node.attr().count("T") == 0 ||
node.attr().at("T").type() == DT_FLOAT);
optimizable &= (node.attr().count("is_training") == 0 ||
!node.attr().at("is_training").b());
if (optimizable) {
int const_inputs = 0;
const auto& props = properties.GetInputProperties(node.name());
for (const auto& prop : props) {
if (prop.has_value()) {
const_inputs += 1;
}
}
// TODO(bsteiner): use the cost model to compare the cost of fused batch
// norm against that of the optimized form.
optimizable = (const_inputs >= 4);
}
if (optimizable) {
for (GraphView::Edge edge : graph.GetFanoutEdges(node, false)) {
if (edge.src.port_id != 0) {
// The optimized version only generates the first output.
optimizable = false;
break;
}
}
}
if (optimizable) {
VLOG(1) << "Optimizing fused batch norm node " << node.DebugString();
AddBatchNormNodes(optimized_graph, node);
continue;
}
}
*optimized_graph->add_node() = node;
}
*optimized_graph->mutable_library() = item.graph.library();
*optimized_graph->mutable_versions() = item.graph.versions();
return Status::OK();
}
void Remapper::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,
const GraphDef& /*optimized_graph*/,
double /*result*/) {
// Nothing to do for RemapperOptimizer.
}
} // namespace grappler
} // namespace tensorflow