blob: acac2f7d05525a95175ff1f9b90e7f5fdf90186c [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) {
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
Output enter =
ops::internal::Enter(root.WithOpName("enter"), a, "only_frame");
Output exit = ops::internal::Exit(root.WithOpName("exit"), enter);
Output b = ops::Add(root.WithOpName("b"), a, exit);
FixupSourceAndSinkEdges(root.graph());
GraphCycles cycles;
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
}
TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
Output enter_0 =
ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0");
Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0);
Output enter_1 =
ops::internal::Enter(root.WithOpName("enter_1"), a, "frame_1");
Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
Output b = ops::Add(root.WithOpName("b"), a, exit_1);
FixupSourceAndSinkEdges(root.graph());
GraphCycles cycles;
TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles).status());
EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
}
TEST(CreateCycleDetectionGraph, ReachingEnterExit) {
// TODO(b/127521408): We can lift this limitation with some work.
Scope root = Scope::NewRootScope().ExitOnError();
Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
Output enter_0 =
ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0");
Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0);
Output add = ops::Add(root.WithOpName("add"), exit_0, exit_0);
Output enter_1 =
ops::internal::Enter(root.WithOpName("enter_1"), add, "frame_0");
Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
FixupSourceAndSinkEdges(root.graph());
GraphCycles cycles;
TF_ASSERT_OK_AND_ASSIGN(bool ok,
CreateCycleDetectionGraph(root.graph(), &cycles));
EXPECT_FALSE(ok);
}
const char* kCPU0 = "/job:localhost/replica:0/task:0/device:CPU:0";
const char* kGPU0 = "/job:localhost/replica:0/task:0/device:GPU:0";
const char* kGPU1 = "/job:localhost/replica:0/task:0/device:GPU:1";
TEST(IsSingleGpuGraph, ReturnsTrue) {
Scope root = Scope::NewRootScope().WithAssignedDevice(kGPU0).ExitOnError();
Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
Output b = ops::Add(root.WithOpName("b"), a, a);
Output c = ops::Add(root.WithOpName("c"), b, b);
FixupSourceAndSinkEdges(root.graph());
EXPECT_TRUE(IsSingleGpuGraph(*root.graph()));
}
TEST(IsSingleGpuGraph, ReturnsFalseForCpuGraph) {
Scope root = Scope::NewRootScope().WithAssignedDevice(kCPU0).ExitOnError();
Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
Output b = ops::Add(root.WithOpName("b"), a, a);
Output c = ops::Add(root.WithOpName("c"), b, b);
FixupSourceAndSinkEdges(root.graph());
EXPECT_FALSE(IsSingleGpuGraph(*root.graph()));
}
TEST(IsSingleGpuGraph, ReturnsFalseForMultiGpuGraph) {
Scope root = Scope::NewRootScope().WithAssignedDevice(kGPU0).ExitOnError();
Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
Output b = ops::Add(root.WithOpName("b").WithAssignedDevice(kGPU1), a, a);
Output c = ops::Add(root.WithOpName("c"), b, b);
FixupSourceAndSinkEdges(root.graph());
EXPECT_FALSE(IsSingleGpuGraph(*root.graph()));
}
xla::StatusOr<std::vector<string>> GetNodesRelatedToRefVarsSorted(
const Scope& scope, FunctionLibraryDefinition* flib_def = nullptr) {
FunctionDefLibrary flib;
FunctionLibraryDefinition flib_def_local(OpRegistry::Global(), flib);
if (flib_def == nullptr) {
flib_def = &flib_def_local;
}
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(
nullptr, Env::Default(), /*config=*/nullptr, TF_GRAPH_DEF_VERSION,
flib_def, OptimizerOptions{}));
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> nodes_related_to_ref_vars,
GetNodesRelatedToRefVariables(*graph, lib_runtime));
std::vector<string> names;
absl::c_transform(nodes_related_to_ref_vars, std::back_inserter(names),
[](Node* n) { return n->name(); });
absl::c_sort(names);
return names;
}
void CreateSubgraphTouchingRefVar(const Scope& s) {
Output variable =
ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT);
Output read = ops::Identity(s.WithOpName("read_ref_var"), variable);
Output neg = ops::Negate(s.WithOpName("negate_ref"), read);
Output add = ops::Add(s.WithOpName("add_ref"), neg, neg);
Output constant =
ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0));
s.graph()->AddControlEdge(constant.node(), variable.node());
}
void CreateSubgraphNotTouchingRefVar(const Scope& s) {
Output constant =
ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0));
Output neg = ops::Negate(s.WithOpName("negate_normal"), constant);
Output add = ops::Add(s.WithOpName("add_normal"), neg, neg);
}
void CreateSubgraphCallingFunctionWithRefVar(const Scope& s) {
NameAttrList ref_float_function;
ref_float_function.set_name("RefFloatFn");
ops::PartitionedCall call(s.WithOpName("RefFloat"), {absl::Span<Input>{}},
{DT_FLOAT}, ref_float_function);
Output constant =
ops::Const(s.WithOpName("constant_ref_pco"), Input::Initializer(0.0));
s.graph()->AddControlEdge(call.operation.node(), constant.node());
}
void CreateSubgraphCallingFunctionWithoutRefVar(const Scope& s) {
NameAttrList regular_float_function;
regular_float_function.set_name("RegularFloatFn");
ops::PartitionedCall call(s.WithOpName("RegularFloat"), {absl::Span<Input>{}},
{DT_FLOAT}, regular_float_function);
Output constant =
ops::Const(s.WithOpName("constant_normal_pco"), Input::Initializer(0.0));
s.graph()->AddControlEdge(call.operation.node(), constant.node());
}
void AddRefFunctionFunctionDef(FunctionDefLibrary* fdef_lib) {
FunctionDef make_ref_float = FunctionDefHelper::Define(
"RefFloatFn", {}, {"r:float"}, {},
{{{"var"},
"VariableV2",
{},
{{"dtype", DT_FLOAT}, {"shape", TensorShape({})}}},
{{"r"}, "Identity", {"var"}, {{"T", DT_FLOAT}}}});
*fdef_lib->add_function() = make_ref_float;
}
void AddRegularFunctionFunctionDef(FunctionDefLibrary* fdef_lib) {
Tensor seven(DT_FLOAT, {});
seven.scalar<float>()() = 7;
FunctionDef make_regular_float = FunctionDefHelper::Define(
"RegularFloatFn", {}, {"r:float"}, {},
{{{"r"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", seven}}}});
*fdef_lib->add_function() = make_regular_float;
}
TEST(NodesRelatedToRefVariables, Basic) {
Scope root = Scope::NewRootScope().ExitOnError();
FunctionDefLibrary fdef_lib;
CreateSubgraphTouchingRefVar(root);
CreateSubgraphNotTouchingRefVar(root);
AddRefFunctionFunctionDef(&fdef_lib);
CreateSubgraphCallingFunctionWithRefVar(root);
AddRegularFunctionFunctionDef(&fdef_lib);
CreateSubgraphCallingFunctionWithoutRefVar(root);
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
TF_ASSERT_OK_AND_ASSIGN(std::vector<string> names,
GetNodesRelatedToRefVarsSorted(root, &flib_def));
std::vector<string> expected({
"RefFloat",
"add_ref",
"constant_ref",
"constant_ref_pco",
"negate_ref",
"read_ref_var",
"variable",
});
EXPECT_EQ(names, expected);
}
Status MakeLoop(Scope s, Output init_value, absl::string_view loop_name) {
s = s.NewSubScope(std::string(loop_name));
ops::internal::Enter enter(s.WithOpName("init_value"), init_value, loop_name);
ops::Merge merge(s.WithOpName("merge"), {init_value, init_value});
Output next_iteration =
ops::NextIteration(s.WithOpName("next_itr"), merge.output);
return s.graph()->UpdateEdge(next_iteration.node(), 0, merge.output.node(),
1);
}
TEST(NodesRelatedToRefVariables, Cycles) {
Scope root = Scope::NewRootScope().ExitOnError();
Output variable = ops::Variable(root.WithOpName("variable"),
PartialTensorShape{}, DT_FLOAT);
TF_ASSERT_OK(
MakeLoop(root, ops::Identity(root.WithOpName("read_ref_var"), variable),
"ref_loop"));
TF_ASSERT_OK(MakeLoop(
root, ops::Const(root.WithOpName("constant"), Input::Initializer(0.0)),
"normal_loop"));
TF_ASSERT_OK_AND_ASSIGN(std::vector<string> names,
GetNodesRelatedToRefVarsSorted(root));
std::vector<string> expected({"read_ref_var", "ref_loop/init_value",
"ref_loop/merge", "ref_loop/next_itr",
"variable"});
EXPECT_EQ(names, expected);
}
} // namespace
} // namespace tensorflow