blob: e6e49ae7957f1fc748f08922f603f47514fff8c9 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "absl/memory/memory.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kFunctionalIfNodeName[] = "If";
constexpr char kFunctionalWhileNodeName[] = "While";
constexpr char kCompilableFunctionName[] = "CompilableFn";
constexpr char kCompilableFunctionNodeName[] = "n_c";
constexpr char kUncompilableFunctionName[] = "UncompilableFn";
constexpr char kUncompilableFunctionNodeName[] = "n_c_uncompilable";
constexpr char kUncompilableFunctionTwoName[] = "UncompilableFnTwo";
constexpr char kUncompilableFunctionNodeTwoName[] = "n_d_uncompilable";
// A dummy OpKernel for testing.
class DummyCompilableOp : public XlaOpKernel {
public:
explicit DummyCompilableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ctx->SetOutput(0, ctx->Input(0));
}
};
// Register the DummyCompilableOp kernel for CPU.
REGISTER_OP("InputFloatOp").Output("o: float");
REGISTER_OP("CompilableOp").Input("i: float").Output("o: float");
REGISTER_XLA_OP(Name("CompilableOp").Device(DEVICE_CPU_XLA_JIT),
DummyCompilableOp);
// Dummy op that is uncompilable in CPU.
REGISTER_OP("MissingKernel").Input("i: float").Output("o: float");
class CompilabilityCheckUtilTest : public ::testing::Test {
protected:
void SetUp() override {
XlaOpRegistry::RegisterCompilationKernels();
op_filter_.allow_resource_ops_in_called_functions = false;
op_filter_.allow_stack_ops = false;
op_filter_.allow_tensor_array_ops = false;
op_filter_.allow_stateful_rng_ops = false;
op_filter_.allow_control_trigger = false;
op_filter_.allow_eliding_assert_and_checknumerics_ops = false;
op_filter_.allow_ops_producing_or_consuming_variant = false;
op_filter_.allow_inaccurate_ops = false;
op_filter_.allow_slow_ops = false;
checker_ = absl::make_unique<RecursiveCompilabilityChecker>(&op_filter_,
&device_type_);
}
FunctionLibraryRuntime* GetFunctionLibraryRuntime() {
OptimizerOptions opts;
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
flib_def_.get(), opts);
return pflr_->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
}
RecursiveCompilabilityChecker::OperationFilter op_filter_;
DeviceType device_type_ = DeviceType(DEVICE_CPU_XLA_JIT);
std::unique_ptr<FunctionDefLibrary> func_library_ =
absl::make_unique<FunctionDefLibrary>();
std::unique_ptr<FunctionLibraryDefinition> flib_def_ =
absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
*func_library_);
std::unique_ptr<RecursiveCompilabilityChecker> checker_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
};
TEST_F(CompilabilityCheckUtilTest, CheckNonFunctionalNodes) {
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
auto opts = builder.opts();
Node* const0 = ops::SourceOp("InputFloatOp", opts);
Node* compilable_op = ops::UnaryOp("CompilableOp", const0, opts);
Node* uncompilable_op = ops::UnaryOp("MissingKernel", compilable_op, opts);
GraphDef graph_def;
TF_EXPECT_OK(builder.ToGraphDef(&graph_def));
auto* flib_runtime = GetFunctionLibraryRuntime();
// Source node is not compilable.
EXPECT_FALSE(checker_->IsCompilableNode(*const0, flib_runtime));
EXPECT_TRUE(checker_->IsCompilableNode(*compilable_op, flib_runtime));
// Uncompilable as we are only checking compilability in CPU device type.
EXPECT_FALSE(checker_->IsCompilableNode(*uncompilable_op, flib_runtime));
const auto uncompilable_nodes =
checker_->FindUncompilableNodes(*uncompilable_op, flib_runtime);
ASSERT_EQ(1, uncompilable_nodes.size());
auto node_info_it =
uncompilable_nodes.find(NameAttrList().ShortDebugString());
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
const auto& uncompilable_nodes_inside_function = node_info_it->second.second;
ASSERT_EQ(1, uncompilable_nodes_inside_function.size());
const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0);
EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason,
"unsupported op"));
ASSERT_EQ(1, uncompilable_node_info.stack_trace.size());
ASSERT_EQ("", uncompilable_node_info.stack_trace.at(0).function_name);
}
TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
FunctionDefLibrary flib;
*flib.add_function() = FunctionDefHelper::Define(
/*Function*/ kUncompilableFunctionName,
/*Inputs*/ {"n_a:float"},
/*Outputs*/ {"n_c_uncompilable:float"},
/*Attributes*/ {},
// Node info
{{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, flib_def_.get());
std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
Node* const0 = ops::SourceOp("InputFloatOp", builder.opts());
Node* functional_node = ops::UnaryOp(kUncompilableFunctionName, const0,
builder.opts().WithName("D"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
auto* flib_runtime = GetFunctionLibraryRuntime();
EXPECT_FALSE(checker_->IsCompilableNode(*functional_node, flib_runtime));
const auto uncompilable_nodes =
checker_->FindUncompilableNodes(*functional_node, flib_runtime);
EXPECT_EQ(1, uncompilable_nodes.size());
NameAttrList function;
function.set_name(kUncompilableFunctionName);
const auto node_info_it =
uncompilable_nodes.find(function.ShortDebugString());
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
const auto& uncompilable_node_list = node_info_it->second.second;
ASSERT_EQ(1, uncompilable_node_list.size());
const auto& node_info = uncompilable_node_list.at(0);
const auto& node_stack = node_info.stack_trace;
ASSERT_EQ(2, node_stack.size());
EXPECT_EQ("D", node_stack.at(0).name);
EXPECT_EQ(kUncompilableFunctionNodeName, node_stack.at(1).name);
EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
EXPECT_TRUE(
absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
}
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) {
FunctionDefLibrary flib;
*flib.add_function() = FunctionDefHelper::Define(
/*Function*/ kCompilableFunctionName,
/*Inputs*/ {"n_a:float", "n_b:float"},
/*Outputs*/ {"n_c:float"},
/*Attribute*/ {},
// Node info
{{{kCompilableFunctionNodeName},
"Add",
{"n_a", "n_b"},
{{"T", DT_FLOAT}}}});
*flib.add_function() = FunctionDefHelper::Define(
/*Function*/ kUncompilableFunctionName,
/*Inputs*/ {"n_a:float"},
/*Outputs*/ {"n_c_uncompilable:float"},
/*Attributes*/ {},
// Node info
{{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, flib_def_.get());
Node* const0 = ops::SourceOp("InputFloatOp", builder.opts());
Node* input_node = ops::UnaryOp("CompilableOp", const0, builder.opts());
NameAttrList compilable;
compilable.set_name(kCompilableFunctionName);
NameAttrList uncompilable;
uncompilable.set_name(kUncompilableFunctionName);
NodeBuilder while_builder(kFunctionalWhileNodeName, "While",
builder.opts().op_registry());
while_builder.Input({input_node, input_node})
.Attr("cond", compilable)
.Attr("body", uncompilable);
builder.opts().FinalizeBuilder(&while_builder);
GraphDef graph_def;
TF_EXPECT_OK(builder.ToGraphDef(&graph_def));
std::unique_ptr<Graph> graph(new Graph(flib_def_.get()));
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
auto while_node_it = std::find_if(
graph->nodes().begin(), graph->nodes().end(),
[&](const Node* n) { return n->name() == kFunctionalWhileNodeName; });
EXPECT_NE(while_node_it, graph->nodes().end());
auto* flib_runtime = GetFunctionLibraryRuntime();
EXPECT_FALSE(checker_->IsCompilableNode(**while_node_it, flib_runtime));
const auto uncompilable_nodes =
checker_->FindUncompilableNodes(**while_node_it, flib_runtime);
ASSERT_EQ(1, uncompilable_nodes.size());
NameAttrList function;
function.set_name(kUncompilableFunctionName);
const auto node_info_it =
uncompilable_nodes.find(function.ShortDebugString());
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
const auto& uncompilable_node_list = node_info_it->second.second;
ASSERT_EQ(1, uncompilable_node_list.size());
const auto& node_info = uncompilable_node_list.at(0);
const auto& node_stack = node_info.stack_trace;
ASSERT_EQ(2, node_stack.size());
const auto& stacktrace_first_node_info = node_stack.at(0);
EXPECT_EQ(kFunctionalWhileNodeName, stacktrace_first_node_info.name);
EXPECT_EQ("", stacktrace_first_node_info.function_name);
const auto& stacktrace_second_node_info = node_stack.at(1);
EXPECT_EQ(kUncompilableFunctionNodeName, stacktrace_second_node_info.name);
EXPECT_EQ(kUncompilableFunctionName,
stacktrace_second_node_info.function_name);
EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name);
EXPECT_TRUE(
absl::StrContains(node_info.uncompilable_reason, "unsupported op"));
}
TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) {
FunctionDefLibrary flib;
*flib.add_function() = FunctionDefHelper::Define(
/*Function*/ kUncompilableFunctionName,
/*Inputs*/ {"n_a:float"},
/*Outputs*/ {"n_c_uncompilable:float"},
/*Attributes*/ {},
// Node info
{{{kUncompilableFunctionNodeName}, "MissingKernel", {"n_a"}}});
*flib.add_function() = FunctionDefHelper::Define(
/*Function*/ kUncompilableFunctionTwoName,
/*Inputs*/ {"n_a:float"},
/*Outputs*/ {"n_d_uncompilable:float"},
/*Attribute*/ {},
// Node info
{{{kUncompilableFunctionNodeTwoName}, "MissingKernel", {"n_a"}}});
NameAttrList uncompilable_fn1_attr;
uncompilable_fn1_attr.set_name(kUncompilableFunctionName);
NameAttrList uncompilable_fn2_attr;
uncompilable_fn2_attr.set_name(kUncompilableFunctionTwoName);
Scope root = Scope::NewRootScope().ExitOnError();
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib));
auto predicate = ops::Placeholder(root.WithOpName("pred"), DT_BOOL);
auto placeholder = ops::Placeholder(root.WithOpName("A"), DT_INT32);
std::vector<NodeBuilder::NodeOut> if_inputs(
{NodeBuilder::NodeOut(placeholder.node())});
Node* if_node;
TF_ASSERT_OK(
NodeBuilder(kFunctionalIfNodeName, "If", &root.graph()->flib_def())
.Input(predicate.node())
.Input(if_inputs)
.Attr("then_branch", uncompilable_fn1_attr)
.Attr("else_branch", uncompilable_fn2_attr)
.Attr("Tout", {DT_INT32})
.Finalize(root.graph(), &if_node));
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
auto if_node_it = std::find_if(
graph->nodes().begin(), graph->nodes().end(),
[&](const Node* n) { return n->name() == kFunctionalIfNodeName; });
EXPECT_NE(if_node_it, graph->nodes().end());
auto* flib_runtime = GetFunctionLibraryRuntime();
EXPECT_FALSE(checker_->IsCompilableNode(**if_node_it, flib_runtime));
const auto uncompilable_nodes =
checker_->FindUncompilableNodes(**if_node_it, flib_runtime);
ASSERT_EQ(2, uncompilable_nodes.size());
NameAttrList function_one;
function_one.set_name(kUncompilableFunctionName);
auto it = uncompilable_nodes.find(function_one.ShortDebugString());
ASSERT_NE(uncompilable_nodes.end(), it);
const auto& uncompilable_node_list = it->second.second;
ASSERT_EQ(1, uncompilable_node_list.size());
const auto& uncompilable_node_one = uncompilable_node_list.at(0);
const auto& node_one_stack = uncompilable_node_one.stack_trace;
ASSERT_EQ(2, node_one_stack.size());
const auto& node_one_stacktrace_first_node = node_one_stack.at(0);
EXPECT_EQ(kFunctionalIfNodeName, node_one_stacktrace_first_node.name);
EXPECT_EQ("", node_one_stacktrace_first_node.function_name);
const auto& stacktrace_second_node_info = node_one_stack.at(1);
EXPECT_EQ(kUncompilableFunctionNodeName, stacktrace_second_node_info.name);
EXPECT_EQ(kUncompilableFunctionName,
stacktrace_second_node_info.function_name);
EXPECT_EQ(kUncompilableFunctionNodeName, uncompilable_node_one.name);
EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
"unsupported op"));
NameAttrList function_two;
function_two.set_name(kUncompilableFunctionTwoName);
it = uncompilable_nodes.find(function_two.ShortDebugString());
ASSERT_NE(uncompilable_nodes.end(), it);
const auto& uncompilable_node_two_list = it->second.second;
ASSERT_EQ(1, uncompilable_node_two_list.size());
const auto& uncompilable_node_two = uncompilable_node_two_list.at(0);
const auto& node_two_stack = uncompilable_node_two.stack_trace;
ASSERT_EQ(2, node_two_stack.size());
const auto& node_two_stacktrace_first_node = node_two_stack.at(0);
EXPECT_EQ(kFunctionalIfNodeName, node_two_stacktrace_first_node.name);
EXPECT_EQ("", node_two_stacktrace_first_node.function_name);
const auto& node_two_stacktrace_second_node = node_two_stack.at(1);
EXPECT_EQ(kUncompilableFunctionNodeTwoName,
node_two_stacktrace_second_node.name);
EXPECT_EQ(kUncompilableFunctionTwoName,
node_two_stacktrace_second_node.function_name);
EXPECT_EQ(kUncompilableFunctionNodeTwoName, uncompilable_node_two.name);
EXPECT_TRUE(absl::StrContains(uncompilable_node_one.uncompilable_reason,
"unsupported op"));
}
} // namespace
} // namespace tensorflow