blob: 9b1d2b8e2709f965f5dd4355cb8191ef66c7c755 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/lower_if_op.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
namespace {
using NodeOut = NodeBuilder::NodeOut;
constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
// Convenience builder to make it easy to construct a conditional with a single
// function call in the then and else branch. This first converts the if node
// into switches (for inputs) and merges (for outputs) around a function call
// per branch.
class CondBuilder {
public:
enum Branch { kElseBranch = 0, kThenBranch = 1 };
// Create a CondBuilder to create the lowered form of `if_op` with then and
// else functions `then_fn` and `else_fn` respectively in the `graph`. The
// functions should be available in `flib`.
CondBuilder(Node* if_op, const NameAttrList& then_fn,
const NameAttrList& else_fn, bool keep_node_fetchable,
Graph* graph);
// Constructs the basic conditional control flow using switch and merge nodes.
Status CreatePivotNodes();
// Adds the inputs from the if node to the merge nodes of the lowered if.
Status AddInputs();
// Adds the outputs from the if node to the merge nodes of the lowered if.
// Note: no inputs can be added once outputs are added as the then and else
// nodes are finalized while adding outputs.
Status AddOutputs();
// Builds an identity node with the same outputs as If.
Status BuildLoweredIfOutput();
private:
// Returns unique name containing the name of the If op being rewritten
// (name_), infix and a suffix to ensure it is unique within the graph.
string NewName(const string& infix);
// Adds input to both the then and else nodes from src:src_output.
Status AddInput(Node* src, int src_output);
// The merged outputs of the then and else nodes.
std::vector<NodeOut> outputs_;
// The node that dominates all execution of the then and else body nodes.
Node* control_predecessor_;
// The original If op.
Node* if_op_;
// The node with the same name as the original If op:
// (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'
// and if the original If op had non-zero data outputs.
// (b) NoOp node with control edge from 'branch_executed_node_' otherwise.
Node* lowered_if_output_;
// The predicate of the conditional.
OutputTensor pred_;
// Node corresponding to pivot_f branch of predicate switch which is
// the pivot node that dominates all nodes in the false/else branch.
Node* pivot_f_;
// Node corresponding to pivot_t branch of predicate switch which is
// the pivot node that dominates all nodes in the true/then branch.
Node* pivot_t_;
Node* then_call_node_;
Node* else_call_node_;
// Merge node that has inputs from [pivot_t, pivot_f] and control edges from
// [^then_call_node_, ^else_call_node_]. This node will guarantee that even
// when then/else branch functions do not have outputs, they still will be
// executed for the side effects.
Node* branch_executed_node_;
Graph* graph_;
string name_;
bool keep_node_fetchable_;
NodeDebugInfo debug_info_;
NodeBuilder then_call_builder_;
NodeBuilder else_call_builder_;
};
CondBuilder::CondBuilder(Node* if_op, const NameAttrList& then_fn,
const NameAttrList& else_fn,
bool keep_node_fetchable, Graph* graph)
: if_op_(if_op),
graph_(graph),
name_(if_op->name()),
keep_node_fetchable_(keep_node_fetchable),
debug_info_(*if_op_),
then_call_builder_(NewName("then"), then_fn.name(), graph->op_registry(),
&debug_info_),
else_call_builder_(NewName("else"), else_fn.name(), graph->op_registry(),
&debug_info_) {
TF_CHECK_OK(if_op_->input_tensor(0, &pred_));
then_call_builder_.Device(if_op_->requested_device());
then_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
for (const auto& i : then_fn.attr()) {
then_call_builder_.Attr(i.first, i.second);
}
else_call_builder_.Device(if_op_->requested_device());
else_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
for (const auto& i : else_fn.attr()) {
else_call_builder_.Attr(i.first, i.second);
}
}
Status CondBuilder::CreatePivotNodes() {
// Construct the basic cond body (consisting of feeding in the predicate to
// create pivot nodes).
Node* switch_pred;
TF_RETURN_IF_ERROR(NodeBuilder(NewName("switch_pred"), "Switch",
graph_->op_registry(), &debug_info_)
.Input(NodeOut(pred_))
.Input(NodeOut(pred_))
.Device(if_op_->requested_device())
.Finalize(graph_, &switch_pred));
control_predecessor_ = switch_pred;
TF_RETURN_IF_ERROR(NodeBuilder(NewName("pivot_f"), "Identity",
graph_->op_registry(), &debug_info_)
.Input(switch_pred, kElseBranch)
.Device(if_op_->requested_device())
.Finalize(graph_, &pivot_f_));
TF_RETURN_IF_ERROR(NodeBuilder(NewName("pivot_t"), "Identity",
graph_->op_registry(), &debug_info_)
.Input(switch_pred, kThenBranch)
.Device(if_op_->requested_device())
.Finalize(graph_, &pivot_t_));
return Status::OK();
}
string CondBuilder::NewName(const string& infix) {
return graph_->NewName(strings::StrCat(name_, "/", infix));
}
Status CondBuilder::AddInput(Node* src, int src_output) {
Node* input;
NodeDebugInfo debug_info(*src);
// Colocate the Switch node with the `src` node.
//
// This is to avoid unnecessary Host<->Device copies between src and the
// Switch node. This aligns with the implementation of legacy tf.cond in
// control_flow_ops.py. The legacy impl colocates the Switch with the
// input tensor which resets the device stack and forces the Switch to have
// the same device as the input node (if set) and sets the colocation _class
// attr. It also ignores the existing colocation constraints on the input node
// using colocate_with(ignore_existing=True).
TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "Switch",
graph_->op_registry(), &debug_info)
.Input(src, src_output)
.Input(pred_)
.Device(src->requested_device())
.Attr("_class", {src->name()})
.Finalize(graph_, &input));
then_call_builder_.Input(input, kThenBranch);
else_call_builder_.Input(input, kElseBranch);
return Status::OK();
}
Status CondBuilder::AddInputs() {
// Add input data edges.
std::vector<const Edge*> edges;
TF_RETURN_IF_ERROR(if_op_->input_edges(&edges));
// Start at index 1 as the first input is the predicate.
for (int i = 1; i < edges.size(); ++i) {
const Edge* e = edges[i];
TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output()));
}
// Add input control edges.
for (const Edge* e : if_op_->in_edges()) {
if (e->IsControlEdge()) {
graph_->AddControlEdge(e->src(), control_predecessor_);
}
}
return Status::OK();
}
Status CondBuilder::AddOutputs() {
// Construct the then and else nodes.
TF_RETURN_IF_ERROR(then_call_builder_.Finalize(graph_, &then_call_node_));
graph_->AddControlEdge(pivot_t_, then_call_node_);
TF_RETURN_IF_ERROR(else_call_builder_.Finalize(graph_, &else_call_node_));
graph_->AddControlEdge(pivot_f_, else_call_node_);
// Add Merge node for each data output of the If node.
std::vector<Node*> merges(then_call_node_->num_outputs());
outputs_.resize(merges.size());
for (int i = 0; i < then_call_node_->num_outputs(); ++i) {
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("output"), "Merge", graph_->op_registry(),
&debug_info_)
.Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)})
.Device(if_op_->requested_device())
.Finalize(graph_, &merges[i]));
outputs_[i] = NodeOut(merges[i], 0);
}
// Add a Merge node that will be used as a control dependency source for the
// lowered output node. This Merge node will guarantee that lowered else/then
// function calls will be executed even if they do not have data outputs.
//
// Furthermore it will guarantee that all function side effects will be
// executed, if the function will be inlined into the graph. Having data
// outputs is not enough, because they might become unused after inlining.
//
// We will use this node to rewrite outgoing control edges from lowered 'If'
// node. All data edges will read tensors directly from Merge nodes.
TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_executed"), "Merge",
graph_->op_registry(), &debug_info_)
.Input({pivot_t_, pivot_f_})
.ControlInputs({then_call_node_, else_call_node_})
.Device(if_op_->requested_device())
.Finalize(graph_, &branch_executed_node_));
TF_RETURN_IF_ERROR(BuildLoweredIfOutput());
// Add outputs.
for (const Edge* e : if_op_->out_edges()) {
if (e->IsControlEdge()) {
graph_->AddControlEdge(branch_executed_node_, e->dst());
} else {
// Feed the outputs directly from the merge nodes so that downstream ops
// can start before all the outputs have been computed.
graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input());
}
}
return Status::OK();
}
Status CondBuilder::BuildLoweredIfOutput() {
// If outputs are empty, it means that we might have only output control
// edges (already connected to the `branch_executed_node`). Furthermore it's
// illegal to have an IdentityN with empty inputs.
//
// We still must keep lowered If node as a valid source of control edges,
// because it might be a part of function control output set.
NodeBuilder builder = keep_node_fetchable_ && !outputs_.empty()
? NodeBuilder(name_, "IdentityN").Input(outputs_)
: NodeBuilder(name_, "NoOp");
return builder.Device(if_op_->requested_device())
.ControlInput(branch_executed_node_)
.Finalize(graph_, &lowered_if_output_);
}
} // namespace
Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable) {
VLOG(2) << "Lower If node (keep_node_fetchable=" << keep_node_fetchable
<< "): " << SummarizeNode(*n);
const AttrValue* then_attr = n->attrs().Find("then_branch");
if (then_attr == nullptr) {
return errors::InvalidArgument("Then branch function missing");
}
const AttrValue* else_attr = n->attrs().Find("else_branch");
if (else_attr == nullptr) {
return errors::InvalidArgument("Else branch function missing");
}
CondBuilder cb(n, then_attr->func(), else_attr->func(), keep_node_fetchable,
g);
TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
TF_RETURN_IF_ERROR(cb.AddInputs());
TF_RETURN_IF_ERROR(cb.AddOutputs());
g->RemoveNode(n);
return Status::OK();
}
} // namespace tensorflow