| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" |
| |
| #include "absl/memory/memory.h" |
| #include "absl/strings/match.h" |
| #include "tensorflow/cc/framework/ops.h" |
| #include "tensorflow/cc/ops/array_ops.h" |
| #include "tensorflow/cc/ops/control_flow_ops_internal.h" |
| #include "tensorflow/cc/ops/function_ops.h" |
| #include "tensorflow/cc/ops/resource_variable_ops.h" |
| #include "tensorflow/cc/ops/sendrecv_ops.h" |
| #include "tensorflow/cc/ops/standard_ops.h" |
| #include "tensorflow/compiler/jit/defs.h" |
| #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" |
| #include "tensorflow/compiler/tf2xla/xla_op_registry.h" |
| #include "tensorflow/core/framework/node_def_util.h" |
| #include "tensorflow/core/framework/op.h" |
| #include "tensorflow/core/graph/algorithm.h" |
| #include "tensorflow/core/graph/graph_constructor.h" |
| #include "tensorflow/core/graph/graph_def_builder.h" |
| #include "tensorflow/core/graph/graph_def_builder_util.h" |
| #include "tensorflow/core/lib/core/status_test_util.h" |
| #include "tensorflow/core/platform/test.h" |
| |
| namespace tensorflow { |
| namespace { |
| |
| REGISTER_OP("UncompilableNullary").Output("o: float"); |
| REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); |
| |
| std::unordered_map<string, string> GetClusters(const Graph& graph) { |
| std::unordered_map<string, string> ids; |
| for (Node* node : graph.nodes()) { |
| string cluster; |
| if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) { |
| CHECK(!cluster.empty()); |
| ids[node->name()] = cluster; |
| } |
| } |
| |
| if (VLOG_IS_ON(2)) { |
| VLOG(2) << "Clusters:"; |
| for (const auto& p : ids) { |
| VLOG(2) << " " << p.first << " -> " << p.second; |
| } |
| } |
| return ids; |
| } |
| |
| gtl::FlatMap<string, std::vector<string>> GetClusterSets( |
| const Graph& g, std::vector<string>* cluster_names = nullptr) { |
| CHECK(cluster_names == nullptr || cluster_names->empty()); |
| gtl::FlatMap<string, std::vector<string>> cluster_sets; |
| for (const auto& p : GetClusters(g)) { |
| cluster_sets[p.second].push_back(p.first); |
| } |
| for (auto& p : cluster_sets) { |
| if (cluster_names != nullptr) { |
| cluster_names->push_back(p.first); |
| } |
| std::sort(p.second.begin(), p.second.end()); |
| } |
| if (cluster_names != nullptr) { |
| std::sort(cluster_names->begin(), cluster_names->end()); |
| } |
| return cluster_sets; |
| } |
| |
| TEST(XlaCompilationTest, Chains) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* a = |
| ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); |
| Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); |
| Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); |
| Node* d = |
| ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); |
| Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); |
| ops::UnaryOp("Relu", e, builder.opts().WithName("F")); |
| TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| EXPECT_EQ(4, clusters.size()); |
| EXPECT_EQ(clusters["B"], clusters["C"]); |
| EXPECT_EQ(clusters["E"], clusters["F"]); |
| EXPECT_NE(clusters["B"], clusters["E"]); |
| EXPECT_TRUE(clusters.find("A") == clusters.cend()); |
| EXPECT_TRUE(clusters.find("D") == clusters.cend()); |
| } |
| |
| TEST(XlaCompilationTest, UncompilableCycles) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* a = ops::SourceOp("Const", builder.opts() |
| .WithName("A") |
| .WithAttr("dtype", DT_FLOAT) |
| .WithAttr("value", Tensor())); |
| Node* b = |
| ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); |
| ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); |
| TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| |
| EXPECT_TRUE(clusters.empty()); |
| } |
| |
| TEST(XlaCompilationTest, CompilableCycles) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* a = ops::SourceOp("Const", builder.opts() |
| .WithName("A") |
| .WithAttr("dtype", DT_FLOAT) |
| .WithAttr("value", Tensor())); |
| Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); |
| ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); |
| TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| |
| EXPECT_EQ(3, clusters.size()); |
| EXPECT_EQ(clusters["A"], clusters["B"]); |
| EXPECT_EQ(clusters["A"], clusters["C"]); |
| } |
| |
| TEST(XlaCompilationTest, Complex128Unsupported) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* a = ops::SourceOp( |
| "Const", builder.opts() |
| .WithName("A") |
| .WithAttr("dtype", DT_COMPLEX128) |
| .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape()))); |
| Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); |
| ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); |
| TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| EXPECT_TRUE(clusters.empty()); |
| } |
| |
| TEST(XlaCompilationTest, HalfSupported) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Tensor t(DT_HALF, TensorShape()); |
| t.scalar<Eigen::half>()() = static_cast<Eigen::half>(0.0f); |
| Node* a = ops::SourceOp("Const", builder.opts() |
| .WithName("A") |
| .WithAttr("dtype", DT_HALF) |
| .WithAttr("value", t)); |
| Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); |
| ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); |
| TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| EXPECT_FALSE(clusters.empty()); |
| } |
| |
| TEST(XlaCompilationTest, ConcatWithConstArg) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| Tensor t(DT_INT32, TensorShape()); |
| t.scalar<int32>()() = 0; |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* dim = ops::SourceOp("Const", builder.opts() |
| .WithName("Dim") |
| .WithAttr("dtype", DT_INT32) |
| .WithAttr("value", t)); |
| Node* a = ops::SourceOp("Const", builder.opts() |
| .WithName("A") |
| .WithAttr("dtype", DT_FLOAT) |
| .WithAttr("value", t)); |
| |
| NodeBuilder concat_builder("Concat", "Concat", |
| builder.opts().op_registry()); |
| concat_builder.Input(dim).Input({a, a}).Attr("N", 2); |
| builder.opts().FinalizeBuilder(&concat_builder); |
| |
| TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| EXPECT_EQ(3, clusters.size()); // Everything should be compiled. |
| } |
| |
| TEST(XlaCompilationTest, FunctionCalls) { |
| FunctionDef compilable = FunctionDefHelper::Define( |
| "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {}, |
| {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}}); |
| FunctionDef uncompilable = |
| FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"}, |
| {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}}); |
| FunctionDef noinline = compilable; |
| noinline.mutable_signature()->set_name("NoInlineFn"); |
| AddAttr("_noinline", static_cast<bool>(true), noinline.mutable_attr()); |
| |
| FunctionDefLibrary flib; |
| *flib.add_function() = compilable; |
| *flib.add_function() = uncompilable; |
| *flib.add_function() = noinline; |
| FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); |
| |
| std::unique_ptr<Graph> graph(new Graph(&flib_def)); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def); |
| Node* a = |
| ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); |
| Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B")); |
| Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); |
| ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D")); |
| ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E")); |
| TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK( |
| MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def)); |
| auto clusters = GetClusters(*graph); |
| |
| EXPECT_EQ(2, clusters.size()); |
| EXPECT_FALSE(clusters["B"].empty()); |
| EXPECT_EQ(clusters["B"], clusters["C"]); |
| EXPECT_TRUE(clusters.find("A") == clusters.cend()); |
| EXPECT_TRUE(clusters.find("D") == clusters.cend()); |
| EXPECT_TRUE(clusters.find("E") == clusters.cend()); |
| } |
| |
| // Metadata-only operators such as Shape/Rank/Size may not be the root of a |
| // cluster. This is partially to work around b/26800664, and partially because |
| // we should probably prefer to compile metadata operators with their producers |
| // wherever possible, rather than their consumers. |
| TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* a = |
| ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); |
| // While all of the following ops are notionally compilable, none is |
| // permitted |
| // to start a cluster. So nothing should be compiled. |
| Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B")); |
| Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C")); |
| Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D")); |
| ops::UnaryOp("Shape", d, builder.opts().WithName("E")); |
| TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. |
| } |
| |
| static Status GradForUnaryCwise(FunctionDef* g, |
| std::vector<FunctionDefHelper::Node> nodes) { |
| for (auto& n : nodes) { |
| if (n.attr.empty()) { |
| n.attr = {{"T", DT_FLOAT}}; |
| } |
| } |
| *g = FunctionDefHelper::Define( |
| // Arg defs |
| {"x: float", "dy: float"}, |
| // Ret val defs |
| {"dx: float"}, |
| // Attr defs |
| {}, |
| // Nodes |
| nodes); |
| return Status::OK(); |
| } |
| |
| // A gradient containing only supported operators |
| Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) { |
| // clang-format off |
| return GradForUnaryCwise(g, { |
| {{"y"}, "Tanh", {"x"}}, |
| {{"y2"}, "Square", {"y"}, {}, {"dy"}}, |
| FunctionDefHelper::Const("one", 1.0f), |
| {{"a"}, "Sub", {"one", "y2"}}, |
| {{"dx"}, "Mul", {"dy", "a"}}, |
| }); |
| // clang-format on |
| } |
| REGISTER_OP_GRADIENT("Supported", SupportedGrad); |
| |
| // A gradient containing an unsupported operator. |
| Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) { |
| // clang-format off |
| return GradForUnaryCwise(g, { |
| {{"y"}, "Tanh", {"x"}}, |
| {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}}, |
| FunctionDefHelper::Const("one", 1.0f), |
| {{"a"}, "Sub", {"one", "y2"}}, |
| {{"dx"}, "Mul", {"dy", "a"}}, |
| }); |
| // clang-format on |
| } |
| REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad); |
| |
| TEST(XlaCompilationTest, SymbolicGradients) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* a = |
| ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); |
| |
| // Builds a Symbolic gradient for Supported |
| NodeBuilder b_builder("B", "SymbolicGradient", |
| builder.opts().op_registry()); |
| NameAttrList b_name_attr; |
| b_name_attr.set_name("Supported"); |
| b_builder.Attr("f", b_name_attr); |
| b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT}); |
| b_builder.Attr("Tout", {DT_FLOAT}); |
| b_builder.Input({a, a}); |
| Node* b = builder.opts().FinalizeBuilder(&b_builder); |
| |
| Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); |
| |
| // Builds a Symbolic gradient for Unsupported |
| NodeBuilder d_builder("D", "SymbolicGradient", |
| builder.opts().op_registry()); |
| NameAttrList d_name_attr; |
| d_name_attr.set_name("Unsupported"); |
| d_builder.Attr("f", d_name_attr); |
| d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT}); |
| d_builder.Attr("Tout", {DT_FLOAT}); |
| d_builder.Input({c, c}); |
| builder.opts().FinalizeBuilder(&d_builder); |
| |
| TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| |
| EXPECT_EQ(2, clusters.size()); |
| EXPECT_FALSE(clusters["B"].empty()); |
| EXPECT_EQ(clusters["B"], clusters["C"]); |
| EXPECT_TRUE(clusters.find("A") == clusters.cend()); |
| EXPECT_TRUE(clusters.find("D") == clusters.cend()); |
| } |
| |
| TEST(XlaCompilationTest, Loops) { |
| // Regression test for b/32350199, where the autoclustering code introduced a |
| // deadlock in a graph containing a while loop. |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT); |
| auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT); |
| auto c = ops::Add(root.WithOpName("C"), a, b); |
| auto enter = ops::internal::Enter(root, c, "aframe"); |
| auto next_iter = ops::NextIteration(root, enter); |
| auto exit = ops::internal::Exit(root, next_iter); |
| auto d = ops::Add(root.WithOpName("D"), c, exit); |
| |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| TF_EXPECT_OK(root.ToGraph(graph.get())); |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| |
| // Nothing should be compiled. In particular, 'd' and 'c' must not be |
| // compiled. |
| EXPECT_EQ(0, clusters.size()); |
| } |
| |
| TEST(XlaCompilationTest, CyclesWithAllDifferentScopesGlobalJitOverridden) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* a = ops::SourceOp("Const", builder.opts() |
| .WithName("A") |
| .WithAttr("dtype", DT_FLOAT) |
| .WithAttr("value", Tensor()) |
| .WithAttr(kXlaScopeAttr, "ScopeA")); |
| Node* b = ops::UnaryOp( |
| "Relu", a, |
| builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); |
| ops::BinaryOp( |
| "MatMul", a, b, |
| builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC")); |
| TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| FunctionDefLibrary flib; |
| FunctionLibraryDefinition flib_def(graph->op_registry(), flib); |
| SessionOptions session_options; |
| session_options.config.mutable_graph_options() |
| ->mutable_optimizer_options() |
| ->set_global_jit_level(OptimizerOptions::ON_2); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( |
| &graph, &flib_def, &session_options)); |
| auto clusters = GetClusters(*graph); |
| |
| // The computation is: C = A + relu(A) |
| // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. |
| // In this case, the GlobalJitLevel overrides the scopes to cluster while |
| // ignoring scopes. |
| EXPECT_EQ(3, clusters.size()); |
| EXPECT_EQ(clusters["A"], clusters["B"]); |
| EXPECT_EQ(clusters["A"], clusters["C"]); |
| } |
| |
| TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* a = ops::SourceOp("Const", builder.opts() |
| .WithName("A") |
| .WithAttr("dtype", DT_FLOAT) |
| .WithAttr("value", Tensor()) |
| .WithAttr(kXlaScopeAttr, "ScopeA")); |
| Node* b = ops::UnaryOp( |
| "Relu", a, |
| builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); |
| ops::BinaryOp( |
| "MatMul", a, b, |
| builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC")); |
| TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| |
| // The computation is: C = A + relu(A) |
| // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. |
| // In this case, we cannot fuse anything, and there are no clusters. |
| EXPECT_EQ(0, clusters.size()); |
| } |
| |
| TEST(XlaCompilationTest, CyclesWithSplittingScopes) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* a = ops::SourceOp("Const", builder.opts() |
| .WithName("A") |
| .WithAttr("dtype", DT_FLOAT) |
| .WithAttr("value", Tensor()) |
| .WithAttr(kXlaScopeAttr, "Scope1")); |
| Node* b = ops::UnaryOp( |
| "Relu", a, |
| builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "Scope1")); |
| Node* c = ops::BinaryOp( |
| "MatMul", a, b, |
| builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "Scope2")); |
| ops::BinaryOp( |
| "Add", b, c, |
| builder.opts().WithName("D").WithAttr(kXlaScopeAttr, "Scope2")); |
| TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| |
| // The computation is: D = relu(A) + (A @ relu(A)) |
| // where A and relu(A) are in Scope1, and the @, + ops are in Scope2. |
| // In this case, we can fuse the A and relu(A), and we can fuse the |
| // second half of the operations; there are two clusters. |
| EXPECT_EQ(4, clusters.size()); |
| EXPECT_EQ(clusters["A"], clusters["B"]); |
| EXPECT_NE(clusters["A"], clusters["C"]); |
| EXPECT_EQ(clusters["C"], clusters["D"]); |
| } |
| |
| TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* a = ops::SourceOp("Const", builder.opts() |
| .WithName("A") |
| .WithAttr("dtype", DT_FLOAT) |
| .WithAttr("value", Tensor()) |
| .WithAttr(kXlaScopeAttr, "ScopeA")); |
| Node* b = ops::UnaryOp( |
| "Relu", a, |
| builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); |
| ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); |
| TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| |
| // The computation is: C = A @ relu(A) |
| // where A sits in ScopeA, relu(A) sits in ScopeB, and C sits in ScopeC. |
| // In this case, we cannot fuse anything. |
| EXPECT_EQ(2, clusters.size()); |
| EXPECT_NE(clusters["A"], clusters["B"]); |
| EXPECT_EQ(clusters["B"], clusters["C"]); |
| } |
| |
| namespace { |
| Node* MakeRead(const Scope& scope, const string& id) { |
| Output var_handle = |
| ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); |
| Output read = |
| ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT); |
| return read.node(); |
| } |
| |
| Node* MakeWrite(const Scope& scope, const string& id) { |
| Output var_handle = |
| ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); |
| Output value_to_write = |
| ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); |
| ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id), |
| var_handle, value_to_write); |
| return assign_op.operation.node(); |
| } |
| |
| Node* MakeNeutral(const Scope& scope, const string& id) { |
| return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); |
| } |
| } // namespace |
| |
| TEST(XlaCompilationTest, ResourcesClusteringAllowed) { |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| |
| Node* read = MakeRead(root, "R"); |
| Node* write = MakeWrite(root, "W"); |
| |
| root.graph()->AddControlEdge(read, write); |
| |
| FixupSourceAndSinkEdges(root.graph()); |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| TF_EXPECT_OK(root.ToGraph(graph.get())); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| gtl::FlatMap<string, std::vector<string>> cluster_sets = |
| GetClusterSets(*graph); |
| ASSERT_EQ(cluster_sets.size(), 1); |
| std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR", |
| "ValueToAssignW"}; |
| ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); |
| } |
| |
| TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| |
| Node* read = MakeRead(root, "R"); |
| Node* write = MakeWrite(root, "W"); |
| |
| root.graph()->AddControlEdge(write, read); |
| |
| FixupSourceAndSinkEdges(root.graph()); |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| TF_EXPECT_OK(root.ToGraph(graph.get())); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| gtl::FlatMap<string, std::vector<string>> cluster_sets = |
| GetClusterSets(*graph); |
| ASSERT_EQ(cluster_sets.size(), 1); |
| std::vector<string> expected_clustered_nodes = {"AssignmentW", |
| "ValueToAssignW"}; |
| ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); |
| } |
| |
| TEST(XlaCompilationTest, ChainOfOps) { |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| |
| Node* write_0 = MakeWrite(root, "W0"); |
| Node* neutral_0 = MakeNeutral(root, "N0"); |
| Node* read_0 = MakeRead(root, "R0"); |
| Node* write_1 = MakeWrite(root, "W1"); |
| Node* neutral_1 = MakeNeutral(root, "N1"); |
| Node* read_1 = MakeRead(root, "R1"); |
| |
| root.graph()->AddControlEdge(write_0, neutral_0); |
| root.graph()->AddControlEdge(neutral_0, read_0); |
| root.graph()->AddControlEdge(read_0, write_1); |
| root.graph()->AddControlEdge(write_1, neutral_1); |
| root.graph()->AddControlEdge(neutral_1, read_1); |
| |
| FixupSourceAndSinkEdges(root.graph()); |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| TF_EXPECT_OK(root.ToGraph(graph.get())); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| |
| std::vector<string> cluster_names; |
| gtl::FlatMap<string, std::vector<string>> cluster_sets = |
| GetClusterSets(*graph, &cluster_names); |
| |
| ASSERT_EQ(cluster_sets.size(), 2); |
| |
| std::vector<string> expected_clustered_nodes_a = {"AssignmentW0", "ConstN0", |
| "ValueToAssignW0"}; |
| ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); |
| |
| std::vector<string> expected_clustered_nodes_b = { |
| "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"}; |
| ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b); |
| } |
| |
| TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| { |
| auto BuildNoopNode = [](absl::string_view name, Graph* graph) { |
| NodeDefBuilder builder(name, "NoOp"); |
| NodeDef def; |
| TF_CHECK_OK(builder.Finalize(&def)); |
| |
| Status status; |
| Node* node = graph->AddNode(def, &status); |
| TF_CHECK_OK(status); |
| return node; |
| }; |
| |
| Node* a = BuildNoopNode("a", graph.get()); |
| Node* b = BuildNoopNode("b", graph.get()); |
| Node* c = BuildNoopNode("c", graph.get()); |
| graph->AddControlEdge(a, b); |
| graph->AddControlEdge(b, c); |
| graph->AddControlEdge(c, a); |
| } |
| |
| TF_EXPECT_OK(root.ToGraph(graph.get())); |
| |
| Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); |
| EXPECT_FALSE(status.ok()); |
| EXPECT_TRUE(absl::StrContains(status.ToString(), |
| "Edge from c to a would create a cycle.\n" |
| "+-> a\n" |
| "| b\n" |
| "+-- c\n")); |
| } |
| |
| TEST(XlaCompilationTest, Retval) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| GraphDef graphdef; |
| { |
| GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); |
| Node* a = ops::SourceOp("Const", builder.opts() |
| .WithName("A") |
| .WithAttr("dtype", DT_FLOAT) |
| .WithAttr("value", Tensor())); |
| Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); |
| ops::UnaryOp("_Retval", b, |
| builder.opts() |
| .WithName("R") |
| .WithAttr("T", DT_FLOAT) |
| .WithAttr("index", 0)); |
| |
| TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); |
| } |
| |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| |
| EXPECT_EQ(2, clusters.size()); |
| EXPECT_TRUE(clusters.find("R") == clusters.cend()); |
| EXPECT_EQ(clusters["A"], clusters["B"]); |
| } |
| |
| TEST(XlaCompilationTest, DontCountIdentityOps) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| { |
| auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0); |
| auto b = ops::Identity(root.WithOpName("B"), a); |
| auto c = ops::Identity(root.WithOpName("C"), b); |
| auto r = ops::_Retval(root.WithOpName("R"), c, 0); |
| } |
| TF_ASSERT_OK(root.ToGraph(graph.get())); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| |
| EXPECT_TRUE(clusters.empty()); |
| } |
| |
| TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| { |
| auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0); |
| auto b = ops::Identity(root.WithOpName("B"), a); |
| b.node()->AddAttr(kXlaCompileAttr, true); |
| auto r = ops::_Retval(root.WithOpName("R"), b, 0); |
| } |
| TF_ASSERT_OK(root.ToGraph(graph.get())); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| auto clusters = GetClusters(*graph); |
| |
| EXPECT_TRUE(clusters.empty()); |
| } |
| |
| TEST(XlaCompilationTest, ConstOp) { |
| // valid data type |
| { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| auto c = ops::Const(root.WithOpName("const"), 0.5f); |
| c.node()->AddAttr(kXlaCompileAttr, true); |
| TF_ASSERT_OK(root.ToGraph(graph.get())); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| EXPECT_EQ(1, GetClusters(*graph).size()); |
| } |
| |
| // invalid data type |
| { |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| auto c = ops::Const(root.WithOpName("const"), string("string")); |
| c.node()->AddAttr(kXlaCompileAttr, true); |
| TF_ASSERT_OK(root.ToGraph(graph.get())); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| EXPECT_TRUE(GetClusters(*graph).empty()); |
| } |
| } |
| |
| TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| Output variable = ops::Variable(root.WithOpName("variable"), |
| PartialTensorShape{}, DT_FLOAT); |
| Output read = ops::Identity(root.WithOpName("read"), variable); |
| Output neg = ops::Negate(root.WithOpName("negate"), read); |
| Output add = ops::Add(root.WithOpName("add"), neg, neg); |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| |
| TF_ASSERT_OK(root.ToGraph(graph.get())); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| |
| std::unordered_map<string, string> clusters = GetClusters(*graph); |
| |
| ASSERT_FALSE(clusters.empty()); |
| string cluster_name = clusters.begin()->second; |
| |
| std::unordered_map<string, string> expected_clusters( |
| {{"negate", cluster_name}, {"add", cluster_name}}); |
| EXPECT_EQ(clusters, expected_clusters); |
| } |
| |
| TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| Output variable = ops::Variable(root.WithOpName("variable"), |
| PartialTensorShape{}, DT_FLOAT); |
| Output read = ops::Identity(root.WithOpName("read"), variable); |
| Output neg = ops::Negate(root.WithOpName("negate"), read); |
| Output identity = ops::Negate(root.WithOpName("identity"), neg); |
| Output add = ops::Add(root.WithOpName("add"), identity, neg); |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| |
| TF_ASSERT_OK(root.ToGraph(graph.get())); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| |
| std::unordered_map<string, string> clusters = GetClusters(*graph); |
| |
| ASSERT_FALSE(clusters.empty()); |
| string cluster_name = clusters.begin()->second; |
| |
| std::unordered_map<string, string> expected_clusters( |
| {{"negate", cluster_name}, |
| {"identity", cluster_name}, |
| {"add", cluster_name}}); |
| EXPECT_EQ(clusters, expected_clusters); |
| } |
| |
| TEST(XlaCompilationTest, ClusterControlTrigger) { |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| |
| Output recv_a = ops::_Recv(root.WithOpName("recv_a"), DT_BOOL, "tensor_a", |
| "sender", 0, "receiver"); |
| Output recv_b = ops::_Recv(root.WithOpName("recv_b"), DT_BOOL, "tensor_b", |
| "sender", 0, "receiver"); |
| Output const_a = ops::Const(root.WithOpName("const_a"), 42); |
| |
| ops::ControlTrigger ctrl_trigger_a(root.WithOpName("ctrl_trigger_a")); |
| ops::ControlTrigger ctrl_trigger_b(root.WithOpName("ctrl_trigger_b")); |
| root.graph()->AddControlEdge(recv_a.node(), ctrl_trigger_a.operation.node()); |
| root.graph()->AddControlEdge(recv_b.node(), ctrl_trigger_a.operation.node()); |
| root.graph()->AddControlEdge(ctrl_trigger_b.operation.node(), const_a.node()); |
| |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| |
| TF_ASSERT_OK(root.ToGraph(graph.get())); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| |
| std::unordered_map<string, string> clusters = GetClusters(*graph); |
| |
| ASSERT_FALSE(clusters.empty()); |
| string cluster_name = clusters.begin()->second; |
| |
| // ctrl_trigger_a has inputs with mismatching deadness so it won't be |
| // clustered. ctrl_trigger_b is okay to cluster. |
| std::unordered_map<string, string> expected_clusters( |
| {{"const_a", cluster_name}, {"ctrl_trigger_b", cluster_name}}); |
| EXPECT_EQ(clusters, expected_clusters); |
| } |
| |
| TEST(XlaCompilationTest, RandomShape) { |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1}); |
| Output shape = |
| ops::RandomUniformInt(root.WithOpName("shape"), shape_shape, |
| ops::Const(root.WithOpName("minval"), 1), |
| ops::Const(root.WithOpName("maxval"), 20)); |
| Output reshape_input = |
| ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, |
| ops::Placeholder::Shape(TensorShape({500, 500}))); |
| Output reshape = |
| ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); |
| |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| |
| TF_ASSERT_OK(root.ToGraph(graph.get())); |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| |
| std::unordered_map<string, string> clusters = GetClusters(*graph); |
| EXPECT_EQ(clusters["shape"], ""); |
| } |
| |
| TEST(XlaCompilationTest, RandomShapeWithFunc) { |
| Scope root = Scope::DisabledShapeInferenceScope().ExitOnError(); |
| |
| FunctionDefLibrary flib_def; |
| FunctionDef func = FunctionDefHelper::Create( |
| /*function_name=*/"Stateful_func", /*in_def=*/{}, |
| /*out_def=*/{"out: int32"}, |
| /*attr_def*/ |
| {}, /*node_def=*/ |
| {FunctionDefHelper::Const("shape_shape", 2), |
| FunctionDefHelper::Const("minval", 1), |
| FunctionDefHelper::Const("maxval", 20), |
| {{"shape"}, |
| "RandomUniformInt", |
| {"shape_shape:output:0", "minval:output:0", "maxval:output:0"}, |
| {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}}, |
| /*ret_def=*/{{"out", "shape:output:0"}}); |
| |
| func.mutable_signature()->set_is_stateful(true); |
| *flib_def.add_function() = std::move(func); |
| TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); |
| NodeDef call_node; |
| call_node.set_name("fn_call"); |
| call_node.set_op("Stateful_func"); |
| Status status; |
| Node* call = root.graph()->AddNode(call_node, &status); |
| TF_ASSERT_OK(status); |
| |
| Output shape = Output(call, 0); |
| Output reshape_input = |
| ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, |
| ops::Placeholder::Shape(TensorShape({500, 500}))); |
| Output reshape = |
| ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); |
| |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| TF_ASSERT_OK(root.ToGraph(graph.get())); |
| auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), |
| flib_def); |
| TF_ASSERT_OK( |
| MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get())); |
| |
| std::unordered_map<string, string> clusters = GetClusters(*graph); |
| EXPECT_EQ(clusters["fn_call"], ""); |
| } |
| |
| TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { |
| absl::string_view xla_gpu_device = |
| "/job:worker/replica:0/task:0/device:XLA_GPU:0"; |
| |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| Output shape_shape = |
| ops::Const(root.WithOpName("test/shape_shape"), {2}, {1}); |
| Output shape = |
| ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape, |
| ops::Const(root.WithOpName("test/minval"), 1), |
| ops::Const(root.WithOpName("test/maxval"), 20)); |
| Output reshape_input = |
| ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT, |
| ops::Placeholder::Shape(TensorShape({500, 500}))); |
| Output reshape = |
| ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape); |
| |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| TF_ASSERT_OK(root.ToGraph(graph.get())); |
| |
| for (Node* n : graph->nodes()) { |
| if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { |
| n->set_assigned_device_name(string(xla_gpu_device)); |
| } |
| } |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| |
| std::unordered_map<string, string> clusters = GetClusters(*graph); |
| EXPECT_NE(clusters["test/shape_rng"], ""); |
| EXPECT_NE(clusters["test/reshape"], ""); |
| EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]); |
| } |
| |
| TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { |
| absl::string_view xla_gpu_device = |
| "/job:worker/replica:0/task:0/device:XLA_GPU:0"; |
| Scope root = Scope::NewRootScope().ExitOnError(); |
| ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1, |
| DT_INT32); |
| Output zero = ops::Const(root.WithOpName("test/zero"), 0); |
| ops::TensorArrayWrite tensor_array_write( |
| root.WithOpName("test/write"), tensor_array.handle, zero, |
| ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow); |
| Output tensor_array_read = |
| ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle, |
| zero, tensor_array_write.flow_out, DT_INT32); |
| Output reshape = |
| ops::Reshape(root.WithOpName("test/reshape"), |
| ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT), |
| tensor_array_read); |
| |
| std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); |
| TF_ASSERT_OK(root.ToGraph(graph.get())); |
| |
| for (Node* n : graph->nodes()) { |
| if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { |
| n->set_assigned_device_name(string(xla_gpu_device)); |
| } |
| } |
| TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); |
| |
| std::unordered_map<string, string> clusters = GetClusters(*graph); |
| EXPECT_NE(clusters["test/read"], ""); |
| EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]); |
| } |
| |
| } // namespace |
| } // namespace tensorflow |