blob: 9a53b00275ec4fc648de2f47c1804e6c97be0a9b [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace grappler {
namespace function_utils {
namespace {
TEST(FunctionDefTensorDesc, Parsing) {
FunctionDefTensorDesc f("Cast:y:0");
EXPECT_EQ(f.full_str, "Cast:y:0");
EXPECT_EQ(f.node_name, "Cast");
EXPECT_EQ(f.node_output, "y");
EXPECT_EQ(f.position, 0);
FunctionDefTensorDesc f2("Arg0");
EXPECT_EQ(f2.full_str, "Arg0");
EXPECT_EQ(f2.node_name, "Arg0");
EXPECT_EQ(f2.node_output, "");
EXPECT_EQ(f2.position, -1);
}
TEST(ReplaceReferencesTest, ReplaceReferencesTest) {
FunctionDef outer = FunctionDefHelper::Create(
"outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {},
{{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}});
NodeDef* derive_node =
AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer);
// Check that both the input to "X" and retval of "outer" are replaced.
ReplaceReferences("MapDefun:output:0", "arg0", &outer);
EXPECT_EQ(outer.ret().at("out"), "arg0");
EXPECT_EQ(derive_node->input(0), "arg0");
}
TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) {
FunctionDef function = test::function::XTimesTwo();
AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64);
EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function));
EXPECT_EQ(function.ret().at("y/_1"), "two");
}
TEST(FunctionUtilsTest, AddFunctionInput) {
FunctionDef fdef;
auto arg0 = AddFunctionInput("arg0", &fdef, DT_INT32);
auto arg1 = AddFunctionInput("arg1", &fdef, DT_BOOL);
EXPECT_EQ(fdef.signature().input_arg().data()[0], arg0);
EXPECT_EQ(arg0->name(), "arg0");
EXPECT_EQ(arg0->type(), DT_INT32);
EXPECT_EQ(fdef.signature().input_arg().data()[1], arg1);
EXPECT_EQ(arg1->name(), "arg1");
EXPECT_EQ(arg1->type(), DT_BOOL);
}
TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
FunctionDef function = test::function::XTimesTwo();
EXPECT_FALSE(ContainsFunctionNodeWithName(
"weird_name_that_should_not_be_there", function));
EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
}
TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) {
FunctionDef function = test::function::XTimesTwo();
EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
function));
EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
}
TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) {
FunctionDef function = test::function::XTimesTwo();
EXPECT_TRUE(ContainsFunctionOutputWithName("y", function));
EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function));
}
TEST(FunctionUtilsTest, FindFunctionNodeWithName) {
FunctionDef function = test::function::XTimesTwo();
EXPECT_EQ(
FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
-1);
EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
}
TEST(FunctionUtilsTest, FindFunctionNodeWithOp) {
FunctionDef function = test::function::XTimesTwo();
EXPECT_EQ(
FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
-1);
EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
}
TEST(FunctionUtilsTest, FindFunctionInputWithName) {
FunctionDef function = test::function::XTimesTwo();
EXPECT_EQ(FindFunctionInputWithName("x", function), 0);
EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1);
}
TEST(FunctionUtilsTest, FindFunctionOutputWithName) {
FunctionDef function = test::function::XTimesTwo();
EXPECT_EQ(FindFunctionOutputWithName("y", function), 0);
EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1);
}
TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) {
FunctionDef function = test::function::XTimesTwo();
NodeDef node;
SetUniqueFunctionNodeName("abc", &function, &node);
for (const NodeDef& function_node : function.node_def()) {
EXPECT_NE(node.name(), function_node.name());
}
auto* new_node = function.add_node_def();
*new_node = node;
NodeDef other;
SetUniqueFunctionNodeName("abc", &function, &other);
EXPECT_NE(other.name(), new_node->name());
}
TEST(FunctionUtilsTest, AddNodeToFunctionDef) {
FunctionDef func;
const char* op_name = "xxx";
AddNode(op_name, op_name, {}, {}, &func);
const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
EXPECT_EQ(node1.op(), op_name);
EXPECT_EQ(node1.input_size(), 0);
EXPECT_EQ(node1.attr_size(), 0);
const std::vector<string> inputs({"input1", "input2"});
AddNode("", op_name, inputs, {}, &func);
const NodeDef& node2 =
func.node_def(FindFunctionNodeWithName("xxx/_2", func));
EXPECT_EQ(node2.op(), op_name);
EXPECT_EQ(node2.attr_size(), 0);
EXPECT_EQ(node2.input_size(), inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
EXPECT_EQ(node2.input(i), inputs[i]);
}
AttrValue a1, a2;
a1.set_type(DT_INT32);
a2.set_type(DT_INT64);
const std::vector<std::pair<string, AttrValue>> attrs(
{{"attr1", a1}, {"attr2", a2}});
AddNode("", op_name, {}, attrs, &func);
const NodeDef& node3 =
func.node_def(FindFunctionNodeWithName("xxx/_3", func));
EXPECT_EQ(node3.op(), op_name);
EXPECT_EQ(node3.input_size(), 0);
EXPECT_EQ(node3.attr_size(), attrs.size());
for (size_t i = 0; i < attrs.size(); ++i) {
EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
}
}
// Graph containing function with "If" and "Assert" Op.
/*
@eager_function.defun
def test_function():
pred = constant_op.constant(True)
def fn1():
return control_flow_ops.no_op()
def fn2():
return control_flow_ops.Assert(False, ["Wrong branch!!!"])
return control_flow_ops.cond(pred, fn1, fn2)
r = test_function()
*/
// Following proto is generated in python using the above code block, to
// regenerate get the graph_def from the default graph/specified graph for the
// code block (e.g ops.get_default_graph.as_graph_def()).
constexpr char kCondGraphProto[] = R"proto(
node {
name: "StatefulPartitionedCall"
op: "StatefulPartitionedCall"
attr {
key: "Tin"
value { list {} }
}
attr {
key: "Tout"
value { list { type: DT_BOOL } }
}
attr {
key: "_gradient_op_type"
value { s: "PartitionedCall-20" }
}
attr {
key: "config"
value { s: "" }
}
attr {
key: "config_proto"
value { s: "" }
}
attr {
key: "executor_type"
value { s: "" }
}
attr {
key: "f"
value { func { name: "__inference_test_function_19" } }
}
}
library {
function {
signature {
name: "cond_true_3"
input_arg { name: "identity_const" type: DT_BOOL }
output_arg { name: "identity_1" type: DT_BOOL }
}
node_def { name: "NoOp" op: "NoOp" }
node_def {
name: "Identity"
op: "Identity"
input: "identity_const"
input: "^NoOp"
attr {
key: "T"
value { type: DT_BOOL }
}
}
node_def {
name: "Identity_1"
op: "Identity"
input: "Identity:output:0"
attr {
key: "T"
value { type: DT_BOOL }
}
}
ret { key: "identity_1" value: "Identity_1:output:0" }
}
function {
signature {
name: "cond_false_4"
input_arg { name: "identity_const" type: DT_BOOL }
output_arg { name: "identity_1" type: DT_BOOL }
is_stateful: true
}
node_def {
name: "Assert/Const"
op: "Const"
attr {
key: "dtype"
value { type: DT_STRING }
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {}
string_val: "Wrong branch!!!"
}
}
}
}
node_def {
name: "Assert/Assert/condition"
op: "Const"
attr {
key: "dtype"
value { type: DT_BOOL }
}
attr {
key: "value"
value {
tensor {
dtype: DT_BOOL
tensor_shape {}
bool_val: false
}
}
}
}
node_def {
name: "Assert/Assert/data_0"
op: "Const"
attr {
key: "dtype"
value { type: DT_STRING }
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {}
string_val: "Wrong branch!!!"
}
}
}
}
node_def {
name: "Assert/Assert"
op: "Assert"
input: "Assert/Assert/condition:output:0"
input: "Assert/Assert/data_0:output:0"
attr {
key: "T"
value { list { type: DT_STRING } }
}
attr {
key: "summarize"
value { i: 3 }
}
}
node_def {
name: "Identity"
op: "Identity"
input: "identity_const"
input: "^Assert/Assert"
attr {
key: "T"
value { type: DT_BOOL }
}
}
node_def {
name: "Identity_1"
op: "Identity"
input: "Identity:output:0"
input: "^Assert/Assert"
attr {
key: "T"
value { type: DT_BOOL }
}
}
ret { key: "identity_1" value: "Identity_1:output:0" }
}
function {
signature {
name: "__inference_test_function_19"
output_arg { name: "identity" type: DT_BOOL }
is_stateful: true
}
node_def {
name: "Const"
op: "Const"
attr {
key: "dtype"
value { type: DT_BOOL }
}
attr {
key: "value"
value {
tensor {
dtype: DT_BOOL
tensor_shape {}
bool_val: true
}
}
}
}
node_def {
name: "cond"
op: "If"
input: "Const:output:0"
input: "Const:output:0"
attr {
key: "Tcond"
value { type: DT_BOOL }
}
attr {
key: "Tin"
value { list { type: DT_BOOL } }
}
attr {
key: "Tout"
value { list { type: DT_BOOL } }
}
attr {
key: "_lower_using_switch_merge"
value { b: true }
}
attr {
key: "else_branch"
value { func { name: "cond_false_4" } }
}
attr {
key: "output_shapes"
value { list { shape {} } }
}
attr {
key: "then_branch"
value { func { name: "cond_true_3" } }
}
}
node_def {
name: "cond/Identity"
op: "Identity"
input: "cond:output:0"
attr {
key: "T"
value { type: DT_BOOL }
}
}
node_def {
name: "Identity"
op: "Identity"
input: "cond/Identity:output:0"
input: "^cond"
attr {
key: "T"
value { type: DT_BOOL }
}
}
ret { key: "identity" value: "Identity:output:0" }
}
}
versions { producer: 27 min_consumer: 12 })proto";
// Graph containing function with "While" Op in python.
/*
@eager_function.defun
def test_function():
return control_flow_ops.while_loop(
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
r = test_function()
*/
// Following proto is generated in python using the above code block, to
// regenerate get the graph_def from the default graph/specified graph for the
// code block (e.g ops.get_default_graph.as_graph_def()).
constexpr char kWhileGraphProto[] = R"proto(
node {
name: "StatefulPartitionedCall"
op: "StatefulPartitionedCall"
attr {
key: "Tin"
value { list {} }
}
attr {
key: "Tout"
value { list { type: DT_INT32 } }
}
attr {
key: "_gradient_op_type"
value { s: "PartitionedCall-35" }
}
attr {
key: "config"
value { s: "" }
}
attr {
key: "config_proto"
value { s: "" }
}
attr {
key: "executor_type"
value { s: "" }
}
attr {
key: "f"
value { func { name: "__inference_test_function_34" } }
}
}
library {
function {
signature {
name: "while_body_5"
input_arg { name: "while_loop_counter" type: DT_INT32 }
input_arg { name: "const" type: DT_INT32 }
input_arg { name: "maximum_iterations" type: DT_INT32 }
output_arg { name: "identity" type: DT_INT32 }
output_arg { name: "identity_1" type: DT_INT32 }
output_arg { name: "identity_2" type: DT_INT32 }
}
node_def {
name: "add/y"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT32 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {}
int_val: 1
}
}
}
}
node_def {
name: "add"
op: "Add"
input: "const"
input: "add/y:output:0"
attr {
key: "T"
value { type: DT_INT32 }
}
}
node_def {
name: "add_1/y"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT32 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {}
int_val: 1
}
}
}
}
node_def {
name: "add_1"
op: "Add"
input: "while_loop_counter"
input: "add_1/y:output:0"
attr {
key: "T"
value { type: DT_INT32 }
}
}
node_def {
name: "Identity"
op: "Identity"
input: "add_1:z:0"
attr {
key: "T"
value { type: DT_INT32 }
}
}
node_def {
name: "Identity_1"
op: "Identity"
input: "add:z:0"
attr {
key: "T"
value { type: DT_INT32 }
}
}
node_def {
name: "Identity_2"
op: "Identity"
input: "maximum_iterations"
attr {
key: "T"
value { type: DT_INT32 }
}
}
ret { key: "identity" value: "Identity:output:0" }
ret { key: "identity_1" value: "Identity_1:output:0" }
ret { key: "identity_2" value: "Identity_2:output:0" }
}
function {
signature {
name: "__inference_test_function_34"
output_arg { name: "identity" type: DT_INT32 }
is_stateful: true
}
node_def {
name: "maximum_iterations"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT32 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {}
int_val: 1
}
}
}
}
node_def {
name: "Const"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT32 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {}
int_val: 0
}
}
}
}
node_def {
name: "while/loop_counter"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT32 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {}
int_val: 0
}
}
}
}
node_def {
name: "while"
op: "While"
input: "while/loop_counter:output:0"
input: "Const:output:0"
input: "maximum_iterations:output:0"
attr {
key: "T"
value { list { type: DT_INT32 type: DT_INT32 type: DT_INT32 } }
}
attr {
key: "_lower_using_switch_merge"
value { b: true }
}
attr {
key: "body"
value { func { name: "while_body_5" } }
}
attr {
key: "cond"
value { func { name: "while_cond_4" } }
}
attr {
key: "output_shapes"
value {
list {
shape {}
shape {}
shape {}
}
}
}
}
node_def {
name: "while/Identity"
op: "Identity"
input: "while:output:0"
attr {
key: "T"
value { type: DT_INT32 }
}
}
node_def {
name: "while/Identity_1"
op: "Identity"
input: "while:output:1"
attr {
key: "T"
value { type: DT_INT32 }
}
}
node_def {
name: "while/Identity_2"
op: "Identity"
input: "while:output:2"
attr {
key: "T"
value { type: DT_INT32 }
}
}
node_def {
name: "Identity"
op: "Identity"
input: "while/Identity_1:output:0"
input: "^while"
attr {
key: "T"
value { type: DT_INT32 }
}
}
ret { key: "identity" value: "Identity:output:0" }
}
function {
signature {
name: "while_cond_4"
input_arg { name: "while_loop_counter" type: DT_INT32 }
input_arg { name: "const" type: DT_INT32 }
input_arg { name: "less_maximum_iterations" type: DT_INT32 }
output_arg { name: "identity" type: DT_BOOL }
}
node_def {
name: "Less"
op: "Less"
input: "while_loop_counter"
input: "less_maximum_iterations"
attr {
key: "T"
value { type: DT_INT32 }
}
}
node_def {
name: "Less_1/y"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT32 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {}
int_val: 3
}
}
}
}
node_def {
name: "Less_1"
op: "Less"
input: "const"
input: "Less_1/y:output:0"
attr {
key: "T"
value { type: DT_INT32 }
}
}
node_def {
name: "LogicalAnd"
op: "LogicalAnd"
input: "Less:z:0"
input: "Less_1:z:0"
}
node_def {
name: "Identity"
op: "Identity"
input: "LogicalAnd:z:0"
attr {
key: "T"
value { type: DT_BOOL }
}
}
ret { key: "identity" value: "Identity:output:0" }
}
}
versions { producer: 27 min_consumer: 12 })proto";
// TODO(shivaniagrawal): split the test into multiple tests for better
// readability and add full coverage i.e. add/separate out the tests for all
// branches of IsNodeStateful and IsFunctionStateful:
// - test for IsNodeStateful for Cond that has a stateful branch
// - test for IsNodeStateful for Cond that does not have a stateful branches
// - test for IsNodeStateful for While that has a stateful branch
// - test for IsNodeStateful for While that does not have a stateful branches
// - test for IsNodeStateful for Assert
// - test for IsNodeStateful for a stateful op
// - test for IsNodeStateful for a stateless op
//
// - test for IsFunctionStateful for a function that contains a Cond
// - test for IsFunctionStateful for a function that contains a While
// - test for IsFunctionStateful for a function that contains an Assert (and no
// other stateful op)
// - test for IsFunctionStateful for a function that contains a stateful op
// other than Assert
// - test for IsFunctionStateful for a function that does not contain a stateful
// op
TEST(FunctionUtilsTest, IsFunctionStateful) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
NodeDef* nodeA = graph_utils::AddNode("", "A", {}, {}, &graph);
FunctionDef* function = graph_def.mutable_library()->add_function();
*function = test::function::XTimesTwo();
FunctionLibraryDefinition lib_def(OpRegistry::Global(),
*graph_def.mutable_library());
EXPECT_FALSE(IsFunctionStateful(lib_def, *function));
// Op "A" is not a registered Op.
EXPECT_TRUE(IsNodeStateful(lib_def, *nodeA));
// Get graph_def for the graph `kCondGraphProto`, graph with function
// containing "If" and "Assert" Op.
GraphDef graph_def_cond;
protobuf::TextFormat::ParseFromString(kCondGraphProto, &graph_def_cond);
FunctionLibraryDefinition cond_lib(OpRegistry::Global(),
graph_def_cond.library());
const FunctionDef* no_op_fnc = cond_lib.Find("cond_true_3");
EXPECT_FALSE(IsFunctionStateful(cond_lib, *no_op_fnc));
EXPECT_FALSE(IsFunctionStateful(cond_lib, *no_op_fnc, true));
const FunctionDef* assert_func = cond_lib.Find("cond_false_4");
EXPECT_TRUE(IsFunctionStateful(cond_lib, *assert_func));
EXPECT_FALSE(IsFunctionStateful(cond_lib, *assert_func, true));
EXPECT_TRUE(ContainsFunctionNodeWithOp("Const", *assert_func));
EXPECT_TRUE(ContainsFunctionNodeWithOp("Assert", *assert_func));
for (auto node : assert_func->node_def()) {
if (node.op() == "Const") {
EXPECT_FALSE(IsNodeStateful(lib_def, node));
}
if (node.op() == "Assert") {
EXPECT_TRUE(IsNodeStateful(lib_def, node));
EXPECT_FALSE(IsNodeStateful(lib_def, node, true));
}
}
const FunctionDef* cond_func = cond_lib.Find("__inference_test_function_19");
EXPECT_TRUE(IsFunctionStateful(cond_lib, *cond_func));
EXPECT_FALSE(IsFunctionStateful(cond_lib, *cond_func, true));
// Get graph def for the graph `kWhileGraphProto`, graph with function
// containing "While" Op.
GraphDef graph_def_while;
protobuf::TextFormat::ParseFromString(kWhileGraphProto, &graph_def_while);
FunctionLibraryDefinition while_lib(OpRegistry::Global(),
graph_def_while.library());
const FunctionDef* while_function =
while_lib.Find("__inference_test_function_34");
EXPECT_FALSE(IsFunctionStateful(while_lib, *while_function));
EXPECT_FALSE(IsFunctionStateful(while_lib, *while_function, true));
}
} // namespace
} // namespace function_utils
} // namespace grappler
} // namespace tensorflow