| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" |
| |
| #include "tensorflow/cc/framework/ops.h" |
| #include "tensorflow/cc/ops/control_flow_ops_internal.h" |
| #include "tensorflow/cc/ops/function_ops.h" |
| #include "tensorflow/cc/ops/functional_ops.h" |
| #include "tensorflow/cc/ops/resource_variable_ops.h" |
| #include "tensorflow/cc/ops/standard_ops.h" |
| #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" |
| #include "tensorflow/compiler/tf2xla/test_util.h" |
| #include "tensorflow/compiler/xla/status_macros.h" |
| #include "tensorflow/core/common_runtime/function.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/graph/graph_constructor.h" |
| #include "tensorflow/core/graph/graph_def_builder.h" |
| #include "tensorflow/core/graph/validate.h" |
| #include "tensorflow/core/lib/core/status_test_util.h" |
| #include "tensorflow/core/platform/test.h" |
| #include "tensorflow/core/util/equal_graph_def.h" |
| |
| namespace tensorflow { |
| namespace { |
| |
| // Returns the names of the "then" and "else" functions for the If node in a |
| // graph. |
| Status FindIfThenAndElse(const GraphDef& graph, string* op_name, |
| NameAttrList* then_fn, NameAttrList* else_fn) { |
| for (const NodeDef& node : graph.node()) { |
| if (node.op() == "If") { |
| *op_name = node.name(); |
| const NameAttrList* result; |
| TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); |
| *then_fn = *result; |
| TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result)); |
| *else_fn = *result; |
| return Status::OK(); |
| } |
| } |
| return errors::NotFound("No If node found in graph"); |
| } |
| |
| // Graph: |
| // x = array_ops.placeholder(dtypes.int32) |
| // y = array_ops.placeholder(dtypes.int32) |
| // z = control_flow_ops.cond( |
| // math_ops.less(y, x), lambda: math_ops.multiply(y, 17), |
| // lambda: math_ops.add(x, 23)) |
| TEST(FunctionalizeControlFlow, Conditional) { |
| Graph graph(OpRegistry::Global()); |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| |
| auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); |
| auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); |
| auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); |
| auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less); |
| |
| auto identity_t = |
| ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true); |
| auto seventeen = ops::Const<int32>( |
| scope.WithOpName("cond").WithControlDependencies(identity_t), 17); |
| auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less); |
| auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true, |
| seventeen); |
| |
| auto identity_f = |
| ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false); |
| auto twenty_three = ops::Const<int32>( |
| scope.WithOpName("cond").WithControlDependencies(identity_f), 23); |
| auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); |
| auto add = ops::Add(scope.WithOpName("cond/false/add"), |
| switch_3.output_false, twenty_three); |
| |
| auto merge = ops::Merge(scope.WithOpName("cond/Merge"), |
| std::initializer_list<Input>{add, mul}); |
| |
| TF_EXPECT_OK(scope.ToGraph(&graph)); |
| } |
| |
| FunctionLibraryDefinition library(OpRegistry::Global(), {}); |
| GraphDef optimized_graph_def; |
| graph.ToGraphDef(&optimized_graph_def); |
| TF_ASSERT_OK( |
| FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); |
| TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); |
| GraphDef converted_graph_def; |
| graph.ToGraphDef(&converted_graph_def); |
| |
| for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { |
| string op_name; |
| NameAttrList then_fn; |
| NameAttrList else_fn; |
| TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); |
| InstantiationResultForTest else_result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(else_fn.name(), library, &else_result)); |
| |
| // Outer graph |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); |
| auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); |
| auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); |
| auto if_op = ops::If(scope.WithOpName(op_name), less, |
| std::initializer_list<Input>{less, y, x}, {DT_INT32}, |
| then_fn, else_fn); |
| auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| TF_EXPECT_GRAPH_EQ(expected, graph_def); |
| } |
| |
| // then body. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); |
| auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); |
| auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); |
| auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); |
| auto cond = ops::Const( |
| scope.WithOpName("cond").WithControlDependencies(identity), 17); |
| auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); |
| auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(then_fn.name(), library, &result)); |
| |
| EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); |
| EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), |
| result.arg_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| |
| // else body. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); |
| auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); |
| auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); |
| auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); |
| auto cond_1 = ops::Const( |
| scope.WithOpName("cond_1").WithControlDependencies(identity), 23); |
| auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); |
| auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(else_fn.name(), library, &result)); |
| |
| EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); |
| EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), |
| result.arg_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| } |
| } |
| |
| // Returns the names of the "cond" and "body" functions for the While node |
| // in a graph. |
| Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, |
| NameAttrList* body) { |
| for (const NodeDef& node : graph.node()) { |
| if (node.op() == "While") { |
| const NameAttrList* result; |
| TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result)); |
| *cond = *result; |
| TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result)); |
| *body = *result; |
| return Status::OK(); |
| } |
| } |
| return errors::NotFound("No While node found in graph"); |
| } |
| |
| // Graph: |
| // x = array_ops.placeholder(dtypes.int32) |
| // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) |
| TEST(FunctionalizeControlFlow, OneLoopVar) { |
| Graph graph(OpRegistry::Global()); |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| |
| auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); |
| |
| auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); |
| auto enter = |
| ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); |
| // Add an unused Enter node. These should be ignored. |
| auto enter2 = |
| ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop"); |
| auto merge = ops::Merge(scope.WithOpName("while/Merge"), |
| std::initializer_list<Input>{enter, dummy}); |
| auto ten = ops::Const<int32>( |
| scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), |
| 10); |
| auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); |
| auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); |
| auto switch_ = |
| ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); |
| auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), |
| switch_.output_false); |
| auto identity = |
| ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); |
| auto one = ops::Const<int32>( |
| scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); |
| auto add = ops::Add(scope.WithOpName("while/add"), identity, one); |
| auto next_iteration = |
| ops::NextIteration(scope.WithOpName("while/NextIteration"), add); |
| |
| auto sink = ops::Identity(scope.WithOpName("sink"), exit); |
| |
| // Remove the dummy node and add the loop backedge. |
| scope.graph()->RemoveNode(dummy.node()); |
| scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); |
| |
| TF_EXPECT_OK(scope.ToGraph(&graph)); |
| } |
| |
| // Regression test: control edges from an Enter node to the graph sink should |
| // be ignored. |
| for (Node* n : graph.nodes()) { |
| if (n->name() == "while/Enter") { |
| graph.AddControlEdge(n, graph.sink_node()); |
| } |
| } |
| |
| FunctionLibraryDefinition library(OpRegistry::Global(), {}); |
| GraphDef optimized_graph_def; |
| graph.ToGraphDef(&optimized_graph_def); |
| TF_ASSERT_OK( |
| FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); |
| TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); |
| GraphDef converted_graph_def; |
| graph.ToGraphDef(&converted_graph_def); |
| |
| for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { |
| NameAttrList cond_fn, body_fn; |
| TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); |
| |
| // Outer graph |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); |
| auto while_op = |
| ops::While(scope.WithOpName("while/LoopCond"), |
| std::initializer_list<Input>{source}, cond_fn, body_fn); |
| auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| TF_EXPECT_GRAPH_EQ(expected, graph_def); |
| } |
| |
| // Condition graph |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); |
| auto ten = ops::Const<int32>( |
| scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); |
| auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); |
| auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(cond_fn.name(), library, &result)); |
| |
| EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); |
| EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| |
| // Body graph. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); |
| auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); |
| auto one = ops::Const<int32>( |
| scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); |
| auto add = ops::Add(scope.WithOpName("while/add"), identity, one); |
| auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(body_fn.name(), library, &result)); |
| |
| EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); |
| EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| } |
| } |
| |
| FunctionDef GetNoinlineFunctionDef() { |
| FunctionDef fdef = FunctionDefHelper::Create( |
| "increment_fn", {"x:int32"}, {"add:int32"}, {}, |
| { |
| {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}}, |
| {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}}, |
| }, |
| {{"add", "add_0:z:0"}}); |
| (*fdef.mutable_attr())["_noinline"].set_b(true); |
| return fdef; |
| } |
| |
| // @function.Defun(noinline=True) |
| // def increment_fn(x): |
| // return [x + 1] |
| // Define the above function, and add it to the given graph. It's used as the |
| // while loop body in NoinlineLoopBody test. |
| Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { |
| FunctionDefLibrary fdef_lib; |
| *(fdef_lib.add_function()) = GetNoinlineFunctionDef(); |
| TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); |
| NodeDef increment_fn; |
| increment_fn.set_name(node_name); |
| increment_fn.set_op("increment_fn"); |
| *increment_fn.add_input() = "while/Identity"; |
| *increment_fn.add_input() = "^while/Identity"; |
| Status status; |
| graph->AddNode(increment_fn, &status); |
| return status; |
| } |
| |
| // Graph: |
| // x = array_ops.placeholder(dtypes.int32) |
| // y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x]) |
| TEST(FunctionalizeControlFlow, NoinlineLoopBody) { |
| const string& noinline_node_name = "while/increment_fn"; |
| Graph graph(OpRegistry::Global()); |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); |
| auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); |
| auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source, |
| "while/while_context"); |
| auto merge = ops::Merge(scope.WithOpName("while/Merge"), |
| std::initializer_list<Input>{enter, dummy}); |
| auto ten = ops::Const<int32>( |
| scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), |
| 10); |
| auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); |
| auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); |
| auto switch_ = |
| ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); |
| auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), |
| switch_.output_false); |
| auto identity = |
| ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); |
| |
| TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); |
| |
| NodeDef next_iter; |
| next_iter.set_name("while/NextIteration"); |
| next_iter.set_op("NextIteration"); |
| *next_iter.add_input() = noinline_node_name; |
| (*next_iter.mutable_attr())["T"].set_type(DT_INT32); |
| |
| Status status; |
| Node* n = scope.graph()->AddNode(next_iter, &status); |
| TF_ASSERT_OK(status); |
| |
| // Remove the dummy node and add the loop backedge. |
| scope.graph()->RemoveNode(dummy.node()); |
| scope.graph()->AddEdge(n, 0, merge.output.node(), 1); |
| TF_ASSERT_OK(scope.ToGraph(&graph)); |
| } |
| |
| FunctionLibraryDefinition library(graph.flib_def()); |
| GraphDef optimized_graph_def; |
| graph.ToGraphDef(&optimized_graph_def); |
| |
| *(optimized_graph_def.mutable_library()->add_function()) = |
| GetNoinlineFunctionDef(); |
| |
| TF_ASSERT_OK( |
| FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); |
| TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); |
| GraphDef converted_graph_def; |
| graph.ToGraphDef(&converted_graph_def); |
| |
| for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { |
| NameAttrList cond_fn, body_fn; |
| TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); |
| |
| // Outer graph |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); |
| auto while_op = |
| ops::While(scope.WithOpName("while/LoopCond"), |
| std::initializer_list<Input>{source}, cond_fn, body_fn); |
| GraphDef expected; |
| TF_ASSERT_OK(scope.ToGraphDef(&expected)); |
| TF_EXPECT_GRAPH_EQ(expected, graph_def); |
| } |
| |
| // Body graph. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); |
| TF_ASSERT_OK( |
| AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); |
| auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); |
| NodeDef retval; |
| retval.set_name("_retval0_RetVal"); |
| retval.set_op(FunctionLibraryDefinition::kRetOp); |
| *retval.add_input() = noinline_node_name; |
| (*retval.mutable_attr())["T"].set_type(DT_INT32); |
| (*retval.mutable_attr())["index"].set_i(0); |
| Status status; |
| scope.graph()->AddNode(retval, &status); |
| TF_ASSERT_OK(status); |
| |
| GraphDef expected; |
| TF_ASSERT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| // Verify that increment_fn has been copied to library. |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(body_fn.name(), library, &result)); |
| |
| EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); |
| EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); |
| // Ignore the function library when comparing the graphs. |
| expected.clear_library(); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| } |
| } |
| |
| TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) { |
| const string& noinline_node_name = "while/increment_fn"; |
| Graph graph(OpRegistry::Global()); |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); |
| auto identity = ops::Identity(scope.WithOpName("while/Identity"), source); |
| TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); |
| TF_ASSERT_OK(scope.ToGraph(&graph)); |
| } |
| |
| FunctionLibraryDefinition library(graph.flib_def()); |
| GraphDef graph_def; |
| graph.ToGraphDef(&graph_def); |
| graph_def.clear_library(); |
| |
| Status status = FunctionalizeControlFlowForGraphDef(&graph_def, &library); |
| EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code()); |
| } |
| |
| // Tests functionalizing OneLoopVar where the loop value is not used post the |
| // loop. |
| // Graph: |
| // x = array_ops.placeholder(dtypes.int32) |
| // control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) |
| TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { |
| Graph graph(OpRegistry::Global()); |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| |
| auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); |
| |
| auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); |
| auto enter = |
| ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); |
| auto merge = ops::Merge(scope.WithOpName("while/Merge"), |
| std::initializer_list<Input>{enter, dummy}); |
| auto ten = ops::Const<int32>( |
| scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), |
| 10); |
| auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); |
| auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); |
| auto switch_ = |
| ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); |
| auto identity = |
| ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); |
| auto one = ops::Const<int32>( |
| scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); |
| auto add = ops::Add(scope.WithOpName("while/add"), identity, one); |
| auto next_iteration = |
| ops::NextIteration(scope.WithOpName("while/NextIteration"), add); |
| |
| // Remove the dummy node and add the loop backedge. |
| scope.graph()->RemoveNode(dummy.node()); |
| scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); |
| |
| TF_EXPECT_OK(scope.ToGraph(&graph)); |
| } |
| |
| FunctionLibraryDefinition library(OpRegistry::Global(), {}); |
| GraphDef optimized_graph_def; |
| graph.ToGraphDef(&optimized_graph_def); |
| TF_ASSERT_OK( |
| FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); |
| TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); |
| GraphDef converted_graph_def; |
| graph.ToGraphDef(&converted_graph_def); |
| |
| for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { |
| NameAttrList cond_fn, body_fn; |
| TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); |
| |
| // Outer graph |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); |
| auto while_op = |
| ops::While(scope.WithOpName("while/LoopCond"), |
| std::initializer_list<Input>{source}, cond_fn, body_fn); |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| TF_EXPECT_GRAPH_EQ(expected, graph_def); |
| } |
| |
| // Condition graph |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); |
| auto ten = ops::Const<int32>( |
| scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); |
| auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); |
| auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(cond_fn.name(), library, &result)); |
| |
| EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); |
| EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| |
| // Body graph. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); |
| auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); |
| auto one = ops::Const<int32>( |
| scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); |
| auto add = ops::Add(scope.WithOpName("while/add"), identity, one); |
| auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(body_fn.name(), library, &result)); |
| |
| EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); |
| EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| } |
| } |
| |
| // Graph: |
| // x = array_ops.placeholder(dtypes.int32) |
| // y = array_ops.placeholder(dtypes.int32) |
| // cond = lambda (i, j): i + 3 < 10 |
| // body = lambda (i, j): (i < 10, j * 2) |
| // z = control_flow_ops.while_loop(cond, body, [x, y]) |
| TEST(FunctionalizeControlFlow, TwoLoopVars) { |
| Graph graph(OpRegistry::Global()); |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| |
| auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); |
| |
| auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); |
| auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); |
| auto enter_x = |
| ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop"); |
| auto enter_y = |
| ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop"); |
| auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"), |
| std::initializer_list<Input>{enter_x, dummy}); |
| auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"), |
| std::initializer_list<Input>{enter_y, dummy}); |
| |
| // Loop condition |
| auto three = ops::Const<int32>(scope.WithOpName("while/cond/three") |
| .WithControlDependencies(merge_x.output), |
| 3); |
| auto cond_add = |
| ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three); |
| auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten") |
| .WithControlDependencies(merge_x.output), |
| 10); |
| auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); |
| auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); |
| |
| auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"), |
| merge_x.output, loop_cond); |
| auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"), |
| merge_y.output, loop_cond); |
| |
| auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"), |
| switch_x.output_false); |
| auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"), |
| switch_y.output_false); |
| |
| auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), |
| switch_x.output_true); |
| auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), |
| switch_y.output_true); |
| |
| auto one = ops::Const<int32>( |
| scope.WithOpName("while/add/one").WithControlDependencies(identity_x), |
| 1); |
| auto two = ops::Const<int32>( |
| scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), |
| 2); |
| |
| auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); |
| auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); |
| auto next_iteration_x = |
| ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add); |
| auto next_iteration_y = |
| ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul); |
| |
| auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x); |
| auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y); |
| |
| // Remove the dummy node and add the loop backedges. |
| scope.graph()->RemoveNode(dummy.node()); |
| scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(), |
| 1); |
| scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(), |
| 1); |
| |
| TF_EXPECT_OK(scope.ToGraph(&graph)); |
| } |
| |
| FunctionLibraryDefinition library(OpRegistry::Global(), {}); |
| GraphDef optimized_graph_def; |
| graph.ToGraphDef(&optimized_graph_def); |
| TF_ASSERT_OK( |
| FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); |
| TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); |
| GraphDef converted_graph_def; |
| graph.ToGraphDef(&converted_graph_def); |
| |
| for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { |
| NameAttrList cond_fn, body_fn; |
| TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); |
| |
| // Outer graph. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); |
| auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); |
| auto while_op = |
| ops::While(scope.WithOpName("while/LoopCond"), |
| std::initializer_list<Input>{x, y}, cond_fn, body_fn); |
| auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); |
| auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| TF_EXPECT_GRAPH_EQ(expected, graph_def); |
| } |
| |
| // Condition graph. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); |
| auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); |
| auto three = ops::Const<int32>(scope.WithOpName("while/cond/three") |
| .WithControlDependencies(arg0.output), |
| 3); |
| auto cond_add = |
| ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); |
| auto ten = ops::Const<int32>(scope.WithOpName("while/cond/ten") |
| .WithControlDependencies(arg0.output), |
| 10); |
| auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); |
| auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(cond_fn.name(), library, &result)); |
| |
| EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); |
| EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| |
| // Body graph. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); |
| auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); |
| |
| auto identity_x = |
| ops::Identity(scope.WithOpName("while/Identity/x"), arg0); |
| auto identity_y = |
| ops::Identity(scope.WithOpName("while/Identity/y"), arg1); |
| |
| auto one = ops::Const<int32>( |
| scope.WithOpName("while/add/one").WithControlDependencies(identity_x), |
| 1); |
| auto two = ops::Const<int32>( |
| scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), |
| 2); |
| |
| auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); |
| auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); |
| auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); |
| auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(body_fn.name(), library, &result)); |
| |
| EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); |
| EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| } |
| } |
| |
| // Example with nesting, loop-invariant arguments, and resource variables. |
| // |
| // accum = resource_variable_ops.ResourceVariable(1) |
| // x = array_ops.placeholder(2, dtype=dtypes.int32) |
| // y = 3 + x |
| // |
| // def inner_body(j, k): |
| // add = state_ops.assign_add(accum, k * j + x) |
| // with ops.control_dependencies([add]): |
| // return [j + 1, k] |
| // |
| // def body(i): |
| // m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body, |
| // [1, y], name="inner") |
| // with ops.control_dependencies(m): |
| // return [i + 1] |
| // |
| // z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer") |
| TEST(FunctionalizeControlFlow, Complex) { |
| Graph graph(OpRegistry::Global()); |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| |
| auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); |
| |
| auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); |
| auto three = ops::Const<int32>(scope.WithOpName("three"), 3); |
| auto y = ops::Add(scope.WithOpName("y"), x, three); |
| |
| auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, |
| TensorShape({})); |
| |
| // Outer loop |
| auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0); |
| auto enter_i = |
| ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer"); |
| auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"), |
| std::initializer_list<Input>{enter_i, dummy}); |
| auto ten = ops::Const<int32>(scope.WithOpName("outer/Less/y") |
| .WithControlDependencies(merge_i.output), |
| 10); |
| auto less_i = |
| ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten); |
| auto outer_loop_cond = |
| ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i); |
| auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"), |
| merge_i.output, outer_loop_cond); |
| auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"), |
| switch_i.output_false); |
| auto identity_i = |
| ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true); |
| |
| auto enter_x_outer = |
| ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer", |
| ops::internal::Enter::Attrs().IsConstant(true)); |
| auto enter_k_outer = |
| ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer", |
| ops::internal::Enter::Attrs().IsConstant(true)); |
| auto enter_var_outer = |
| ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer", |
| ops::internal::Enter::Attrs().IsConstant(true)); |
| |
| // Inner loop |
| auto one_j = ops::Const<int32>( |
| scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); |
| auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"), |
| one_j, "inner"); |
| auto enter_k = |
| ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k") |
| .WithControlDependencies(identity_i), |
| enter_k_outer, "inner"); |
| auto enter_x = ops::internal::Enter( |
| scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner", |
| ops::internal::Enter::Attrs().IsConstant(true)); |
| auto enter_var = ops::internal::Enter( |
| scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner", |
| ops::internal::Enter::Attrs().IsConstant(true)); |
| |
| auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"), |
| std::initializer_list<Input>{enter_j, dummy}); |
| auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"), |
| std::initializer_list<Input>{enter_k, dummy}); |
| |
| auto five = ops::Const<int32>(scope.WithOpName("outer/inner/Five") |
| .WithControlDependencies(merge_j.output), |
| 5); |
| auto less_j = |
| ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five); |
| auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j); |
| |
| auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"), |
| merge_j.output, loop_cond); |
| auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"), |
| merge_k.output, loop_cond); |
| auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"), |
| switch_j.output_false); |
| auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"), |
| switch_k.output_false); |
| auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"), |
| switch_j.output_true); |
| auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"), |
| switch_k.output_true); |
| |
| // Variable update |
| auto mul_jk = |
| ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); |
| auto add_jkx = |
| ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x); |
| auto assign = ops::AssignAddVariableOp( |
| scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx); |
| |
| auto one = ops::Const<int32>( |
| scope.WithOpName("outer/inner/One") |
| .WithControlDependencies( |
| absl::Span<const Operation>{assign.operation}), |
| 1); |
| auto add_j = |
| ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); |
| |
| auto next_iteration_j = ops::NextIteration( |
| scope.WithOpName("outer/inner/NextIteration_j"), add_j); |
| auto next_iteration_k = ops::NextIteration( |
| scope.WithOpName("outer/inner/NextIteration_k"), identity_k); |
| |
| // Body and backedge for outer loop. |
| auto one_outer = ops::Const<int32>( |
| scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); |
| auto add_i = |
| ops::Add(scope.WithOpName("outer/add") |
| .WithControlDependencies(absl::Span<const Operation>{ |
| exit_j.output.op(), exit_k.output.op()}), |
| identity_i, one_outer); |
| auto next_iteration_i = |
| ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i); |
| |
| auto sink = ops::Identity(scope.WithOpName("sink"), exit_i); |
| |
| // Remove the dummy node and add the loop backedge. |
| scope.graph()->RemoveNode(dummy.node()); |
| scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(), |
| 1); |
| scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(), |
| 1); |
| scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(), |
| 1); |
| |
| TF_EXPECT_OK(scope.ToGraph(&graph)); |
| } |
| |
| FunctionLibraryDefinition library(OpRegistry::Global(), {}); |
| GraphDef optimized_graph_def; |
| graph.ToGraphDef(&optimized_graph_def); |
| TF_ASSERT_OK( |
| FunctionalizeControlFlowForGraphDef(&optimized_graph_def, &library)); |
| TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); |
| GraphDef converted_graph_def; |
| graph.ToGraphDef(&converted_graph_def); |
| |
| for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { |
| NameAttrList outer_cond_fn, outer_body_fn; |
| TF_EXPECT_OK( |
| FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); |
| |
| // Outer graph. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); |
| auto three = ops::Const<int32>(scope.WithOpName("three"), 3); |
| auto y = ops::Add(scope.WithOpName("y"), x, three); |
| |
| auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, |
| TensorShape({})); |
| |
| auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0); |
| |
| auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), |
| std::initializer_list<Input>{zero, y, x, var}, |
| outer_cond_fn, outer_body_fn); |
| auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| TF_EXPECT_GRAPH_EQ(expected, graph_def); |
| } |
| |
| // Outer condition graph. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); |
| auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); |
| auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); |
| auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); |
| |
| auto ten = ops::Const<int32>( |
| scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), |
| 10); |
| auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); |
| auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); |
| |
| EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), |
| result.arg_types); |
| EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| |
| // Outer body graph. |
| NameAttrList inner_cond_fn, inner_body_fn; |
| { |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); |
| |
| // Find the inner condition and body names. |
| TF_EXPECT_OK( |
| FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); |
| |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); |
| auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); |
| auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); |
| auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); |
| |
| auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); |
| auto one_j = ops::Const<int32>( |
| scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); |
| auto while_op = |
| ops::While(scope.WithOpName("outer/LoopCond_1"), |
| std::initializer_list<Input>{one_j, arg1, arg2, arg3}, |
| inner_cond_fn, inner_body_fn); |
| |
| auto one_outer = ops::Const<int32>( |
| scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), |
| 1); |
| auto add_i = |
| ops::Add(scope.WithOpName("outer/add") |
| .WithControlDependencies(absl::Span<const Operation>{ |
| while_op[0].op(), while_op[1].op()}), |
| identity_i, one_outer); |
| |
| auto retval0 = |
| ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); |
| auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); |
| auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); |
| auto retval3 = ops::_Retval(scope.WithOpName("_retval3_RetVal"), arg3, 3); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), |
| result.arg_types); |
| EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), |
| result.ret_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| |
| // Inner condition graph. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); |
| auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); |
| auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); |
| auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); |
| |
| auto five = ops::Const<int32>( |
| scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), |
| 5); |
| auto less_j = |
| ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); |
| auto retval = |
| ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); |
| |
| EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), |
| result.arg_types); |
| EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| |
| // Inner body graph. |
| { |
| Scope scope = Scope::NewRootScope().ExitOnError(); |
| auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); |
| auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); |
| auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); |
| auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); |
| |
| auto identity_j = |
| ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); |
| auto identity_k = |
| ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); |
| |
| auto mul_jk = |
| ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); |
| auto add_jkx = |
| ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); |
| auto assign = ops::AssignAddVariableOp( |
| scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); |
| |
| auto one = ops::Const<int32>( |
| scope.WithOpName("outer/inner/One") |
| .WithControlDependencies( |
| absl::Span<const Operation>{assign.operation}), |
| 1); |
| auto add_j = |
| ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); |
| |
| auto retval0 = |
| ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); |
| auto retval1 = |
| ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); |
| auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); |
| auto retval3 = ops::_Retval(scope.WithOpName("_retval3_RetVal"), arg3, 3); |
| |
| GraphDef expected; |
| TF_EXPECT_OK(scope.ToGraphDef(&expected)); |
| |
| InstantiationResultForTest result; |
| TF_EXPECT_OK( |
| InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); |
| |
| EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), |
| result.arg_types); |
| EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), |
| result.ret_types); |
| TF_EXPECT_GRAPH_EQ(expected, result.gdef); |
| } |
| } |
| } |
| |
| } // namespace |
| } // namespace tensorflow |