blob: 6b43d7dc68d2120b5bc06cb6eaa12ef460188bda [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kTestOpList[] = R"(op {
name: "testop"
input_arg {
name: "arg_a"
}
input_arg {
name: "arg_b"
}
output_arg {
name: "arg_c"
}
attr {
name: "attr_a"
}
deprecation {
version: 123
explanation: "foo"
}
})";
constexpr char kTestApiDef[] = R"(op {
graph_op_name: "testop"
visibility: VISIBLE
endpoint {
name: "testop1"
}
in_arg {
name: "arg_a"
}
in_arg {
name: "arg_b"
}
out_arg {
name: "arg_c"
}
attr {
name: "attr_a"
}
summary: "Mock op for testing."
description: <<END
Description for the
testop.
END
arg_order: "arg_a"
arg_order: "arg_b"
}
)";
TEST(OpGenLibTest, MultilinePBTxt) {
// Non-multiline pbtxt
const string pbtxt = R"(foo: "abc"
foo: ""
foo: "\n\n"
foo: "abc\nEND"
foo: "ghi\njkl\n"
bar: "quotes:\""
)";
// Field "foo" converted to multiline but not "bar".
const string ml_foo = R"(foo: <<END
abc
END
foo: <<END
END
foo: <<END
END
foo: <<END0
abc
END
END0
foo: <<END
ghi
jkl
END
bar: "quotes:\""
)";
// Both fields "foo" and "bar" converted to multiline.
const string ml_foo_bar = R"(foo: <<END
abc
END
foo: <<END
END
foo: <<END
END
foo: <<END0
abc
END
END0
foo: <<END
ghi
jkl
END
bar: <<END
quotes:"
END
)";
// ToMultiline
EXPECT_EQ(ml_foo, PBTxtToMultiline(pbtxt, {"foo"}));
EXPECT_EQ(pbtxt, PBTxtToMultiline(pbtxt, {"baz"}));
EXPECT_EQ(ml_foo_bar, PBTxtToMultiline(pbtxt, {"foo", "bar"}));
// FromMultiline
EXPECT_EQ(pbtxt, PBTxtFromMultiline(pbtxt));
EXPECT_EQ(pbtxt, PBTxtFromMultiline(ml_foo));
EXPECT_EQ(pbtxt, PBTxtFromMultiline(ml_foo_bar));
}
TEST(OpGenLibTest, PBTxtToMultilineErrorCases) {
// Everything correct.
EXPECT_EQ("f: <<END\n7\nEND\n", PBTxtToMultiline("f: \"7\"\n", {"f"}));
// In general, if there is a problem parsing in PBTxtToMultiline, it leaves
// the line alone.
// No colon
EXPECT_EQ("f \"7\"\n", PBTxtToMultiline("f \"7\"\n", {"f"}));
// Only converts strings.
EXPECT_EQ("f: 7\n", PBTxtToMultiline("f: 7\n", {"f"}));
// No quote after colon.
EXPECT_EQ("f: 7\"\n", PBTxtToMultiline("f: 7\"\n", {"f"}));
// Only one quote
EXPECT_EQ("f: \"7\n", PBTxtToMultiline("f: \"7\n", {"f"}));
// Illegal escaping
EXPECT_EQ("f: \"7\\\"\n", PBTxtToMultiline("f: \"7\\\"\n", {"f"}));
}
TEST(OpGenLibTest, PBTxtToMultilineComments) {
const string pbtxt = R"(f: "bar" # Comment 1
f: "\n" # Comment 2
)";
const string ml = R"(f: <<END
bar
END # Comment 1
f: <<END
END # Comment 2
)";
EXPECT_EQ(ml, PBTxtToMultiline(pbtxt, {"f"}));
EXPECT_EQ(pbtxt, PBTxtFromMultiline(ml));
}
TEST(OpGenLibTest, ApiDefAccessInvalidName) {
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
ApiDefMap api_map(op_list);
ASSERT_EQ(nullptr, api_map.GetApiDef("testop5"));
}
TEST(OpGenLibTest, ApiDefInitializedFromOpDef) {
const string expected_api_def = R"(graph_op_name: "testop"
visibility: VISIBLE
endpoint {
name: "testop"
}
in_arg {
name: "arg_a"
rename_to: "arg_a"
}
in_arg {
name: "arg_b"
rename_to: "arg_b"
}
out_arg {
name: "arg_c"
rename_to: "arg_c"
}
attr {
name: "attr_a"
rename_to: "attr_a"
}
arg_order: "arg_a"
arg_order: "arg_b"
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
ApiDefMap api_map(op_list);
const auto* api_def = api_map.GetApiDef("testop");
ASSERT_EQ(expected_api_def, api_def->DebugString());
}
TEST(OpGenLibTest, ApiDefLoadSingleApiDef) {
const string expected_api_def = R"(op {
graph_op_name: "testop"
visibility: VISIBLE
endpoint {
name: "testop1"
}
in_arg {
name: "arg_a"
rename_to: "arg_a"
}
in_arg {
name: "arg_b"
rename_to: "arg_b"
}
out_arg {
name: "arg_c"
rename_to: "arg_c"
}
attr {
name: "attr_a"
rename_to: "attr_a"
}
summary: "Mock op for testing."
description: "Description for the\ntestop."
arg_order: "arg_a"
arg_order: "arg_b"
}
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef));
const auto* api_def = api_map.GetApiDef("testop");
EXPECT_EQ(1, api_def->endpoint_size());
EXPECT_EQ("testop1", api_def->endpoint(0).name());
ApiDefs api_defs;
*api_defs.add_op() = *api_def;
EXPECT_EQ(expected_api_def, api_defs.DebugString());
}
TEST(OpGenLibTest, ApiDefOverrideVisibility) {
const string api_def1 = R"(
op {
graph_op_name: "testop"
endpoint {
name: "testop2"
}
}
)";
const string api_def2 = R"(
op {
graph_op_name: "testop"
visibility: HIDDEN
endpoint {
name: "testop2"
}
}
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef));
auto* api_def = api_map.GetApiDef("testop");
EXPECT_EQ(ApiDef::VISIBLE, api_def->visibility());
// Loading ApiDef with default visibility should
// keep current visibility.
TF_CHECK_OK(api_map.LoadApiDef(api_def1));
EXPECT_EQ(ApiDef::VISIBLE, api_def->visibility());
// Loading ApiDef with non-default visibility,
// should update visibility.
TF_CHECK_OK(api_map.LoadApiDef(api_def2));
EXPECT_EQ(ApiDef::HIDDEN, api_def->visibility());
}
TEST(OpGenLibTest, ApiDefOverrideEndpoints) {
const string api_def1 = R"(
op {
graph_op_name: "testop"
endpoint {
name: "testop2"
}
}
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef));
auto* api_def = api_map.GetApiDef("testop");
ASSERT_EQ(1, api_def->endpoint_size());
EXPECT_EQ("testop1", api_def->endpoint(0).name());
TF_CHECK_OK(api_map.LoadApiDef(api_def1));
ASSERT_EQ(1, api_def->endpoint_size());
EXPECT_EQ("testop2", api_def->endpoint(0).name());
}
TEST(OpGenLibTest, ApiDefOverrideArgs) {
const string api_def1 = R"(
op {
graph_op_name: "testop"
in_arg {
name: "arg_a"
rename_to: "arg_aa"
}
out_arg {
name: "arg_c"
rename_to: "arg_cc"
}
arg_order: "arg_b"
arg_order: "arg_a"
}
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef));
TF_CHECK_OK(api_map.LoadApiDef(api_def1));
const auto* api_def = api_map.GetApiDef("testop");
ASSERT_EQ(2, api_def->in_arg_size());
EXPECT_EQ("arg_aa", api_def->in_arg(0).rename_to());
// 2nd in_arg is not renamed
EXPECT_EQ("arg_b", api_def->in_arg(1).rename_to());
ASSERT_EQ(1, api_def->out_arg_size());
EXPECT_EQ("arg_cc", api_def->out_arg(0).rename_to());
ASSERT_EQ(2, api_def->arg_order_size());
EXPECT_EQ("arg_b", api_def->arg_order(0));
EXPECT_EQ("arg_a", api_def->arg_order(1));
}
TEST(OpGenLibTest, ApiDefOverrideDescriptions) {
const string api_def1 = R"(
op {
graph_op_name: "testop"
summary: "New summary"
description: <<END
New description
END
description_prefix: "A"
description_suffix: "Z"
}
)";
const string api_def2 = R"(
op {
graph_op_name: "testop"
description_prefix: "B"
description_suffix: "Y"
}
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef));
TF_CHECK_OK(api_map.LoadApiDef(api_def1));
const auto* api_def = api_map.GetApiDef("testop");
EXPECT_EQ("New summary", api_def->summary());
EXPECT_EQ("A\nNew description\nZ", api_def->description());
EXPECT_EQ("", api_def->description_prefix());
EXPECT_EQ("", api_def->description_suffix());
TF_CHECK_OK(api_map.LoadApiDef(api_def2));
EXPECT_EQ("B\nA\nNew description\nZ\nY", api_def->description());
EXPECT_EQ("", api_def->description_prefix());
EXPECT_EQ("", api_def->description_suffix());
}
TEST(OpGenLibTest, ApiDefInvalidOpInOverride) {
const string api_def1 = R"(
op {
graph_op_name: "different_testop"
endpoint {
name: "testop2"
}
}
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef));
TF_CHECK_OK(api_map.LoadApiDef(api_def1));
ASSERT_EQ(nullptr, api_map.GetApiDef("different_testop"));
}
TEST(OpGenLibTest, ApiDefInvalidArgOrder) {
const string api_def1 = R"(
op {
graph_op_name: "testop"
arg_order: "arg_a"
arg_order: "unexpected_arg"
}
)";
const string api_def2 = R"(
op {
graph_op_name: "testop"
arg_order: "arg_a"
}
)";
const string api_def3 = R"(
op {
graph_op_name: "testop"
arg_order: "arg_a"
arg_order: "arg_a"
}
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef));
// Loading with incorrect arg name in arg_order should fail.
auto status = api_map.LoadApiDef(api_def1);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
// Loading with incorrect number of args in arg_order should fail.
status = api_map.LoadApiDef(api_def2);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
// Loading with the same argument twice in arg_order should fail.
status = api_map.LoadApiDef(api_def3);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
}
TEST(OpGenLibTest, ApiDefInvalidSyntax) {
const string api_def = R"pb(
op { bad_op_name: "testop" }
)pb";
OpList op_list;
ApiDefMap api_map(op_list);
// Loading with invalid syntax (e.g. unrecognized field name) should fail.
auto status = api_map.LoadApiDef(api_def);
ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
}
TEST(OpGenLibTest, ApiDefUpdateDocs) {
const string op_list1 = R"(op {
name: "testop"
input_arg {
name: "arg_a"
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
}
output_arg {
name: "arg_c"
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
}
attr {
name: "attr_a"
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
}
description: "`arg_a`, `arg_c`, `attr_a`, `testop`"
}
)";
const string api_def1 = R"(
op {
graph_op_name: "testop"
endpoint {
name: "testop2"
}
in_arg {
name: "arg_a"
rename_to: "arg_aa"
}
out_arg {
name: "arg_c"
rename_to: "arg_cc"
description: "New description: `arg_a`, `arg_c`, `attr_a`, `testop`"
}
attr {
name: "attr_a"
rename_to: "attr_aa"
}
}
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(op_list1, &op_list); // NOLINT
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(api_def1));
api_map.UpdateDocs();
const string expected_description =
"`arg_aa`, `arg_cc`, `attr_aa`, `testop2`";
EXPECT_EQ(expected_description, api_map.GetApiDef("testop")->description());
EXPECT_EQ(expected_description,
api_map.GetApiDef("testop")->in_arg(0).description());
EXPECT_EQ("New description: " + expected_description,
api_map.GetApiDef("testop")->out_arg(0).description());
EXPECT_EQ(expected_description,
api_map.GetApiDef("testop")->attr(0).description());
}
} // namespace
} // namespace tensorflow