blob: f643fb0cfe136caba42272d72f3972ec63a94bf3 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
#include "tensorflow/compiler/tf2xla/test_util.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
static std::unique_ptr<Graph> MakeOuterGraph(
const FunctionLibraryDefinition& flib_def, const string& function) {
Scope scope = Scope::NewRootScope().ExitOnError();
TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
NodeDef def;
TF_CHECK_OK(
NodeDefBuilder("launch0", function, &flib_def)
.Input(a.node()->name(), 0, DT_INT32)
.Input(b.node()->name(), 0, DT_FLOAT)
.Input(c.node()->name(), 0, DT_INT32)
.Input(d.node()->name(), 0, DT_FLOAT)
.Input(u.node()->name(), 0, DT_RESOURCE)
.Input(v.node()->name(), 0, DT_RESOURCE)
.Input(w.node()->name(), 0, DT_RESOURCE)
.Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
.Attr("_variable_start_index", 4)
.Finalize(&def));
Status status;
Node* launch = scope.graph()->AddNode(def, &status);
TF_CHECK_OK(status);
TF_CHECK_OK(scope.DoShapeInference(launch));
scope.graph()->AddEdge(a.node(), 0, launch, 0);
scope.graph()->AddEdge(b.node(), 0, launch, 1);
scope.graph()->AddEdge(c.node(), 0, launch, 2);
scope.graph()->AddEdge(d.node(), 0, launch, 3);
scope.graph()->AddEdge(u.node(), 0, launch, 4);
scope.graph()->AddEdge(v.node(), 0, launch, 5);
scope.graph()->AddEdge(w.node(), 0, launch, 6);
auto out0 =
ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0));
auto out1 =
ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1));
auto out2 =
ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2));
auto out3 =
ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3));
auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_CHECK_OK(scope.ToGraph(graph.get()));
return graph;
}
// Makes an encapsulate body graph for use in tests.
static std::unique_ptr<Graph> MakeBodyGraph() {
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0);
auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1);
auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2);
auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3);
auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4);
auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5);
auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
};
auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
add_attrs(read_u.node());
auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
add_attrs(read_v.node());
auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT);
add_attrs(read_w.node());
auto e = ops::Add(scope.WithOpName("E"), arg0, arg2);
add_attrs(e.node());
auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
add_attrs(f.node());
auto g = ops::Add(scope.WithOpName("G"), f, arg3);
add_attrs(g.node());
auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"),
b_identity, 0);
auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1);
auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2);
auto out3 =
ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_CHECK_OK(scope.ToGraph(graph.get()));
return graph;
}
TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
// Test that control edge insertion order doesn't affect the cache key
// (cluster name) generated by TPU encapsulate pass.
auto get_serialized_graph = [](bool control_input_reversed,
bool operand_reversed) -> string {
FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph(new Graph(&flib_def));
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1)
: ops::Add(scope.WithOpName("E"), a1, a0);
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
"launch0");
};
add_attrs(e.node());
TF_CHECK_OK(scope.ToGraph(graph.get()));
auto get_node_in_graph = [&graph](Node* node) {
return graph->FindNodeId(node->id());
};
// Insert control edge in different order. The order should not affect
// the encapsulated or serialized graph.
if (!control_input_reversed) {
graph->AddControlEdge(get_node_in_graph(a0.node()),
get_node_in_graph(e.node()), true);
graph->AddControlEdge(get_node_in_graph(a1.node()),
get_node_in_graph(e.node()), true);
} else {
graph->AddControlEdge(get_node_in_graph(a1.node()),
get_node_in_graph(e.node()), true);
graph->AddControlEdge(get_node_in_graph(a0.node()),
get_node_in_graph(e.node()), true);
}
}
TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
GraphDef gdef;
graph->ToGraphDef(&gdef);
// Before serialization, sort control inputs first to remove
// nondeterminism.
SortControlInputs(&gdef);
string serialized;
SerializeToStringDeterministic(gdef, &serialized);
return serialized;
};
// Changing the order of control input shouldn't affect the graph generated.
EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true,
/*operand_reversed=*/false),
get_serialized_graph(/*control_input_reversed=*/false,
/*operand_reversed=*/false));
// Changing the order of data input should affect the graph generated.
EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false,
/*operand_reversed=*/true),
get_serialized_graph(/*control_input_reversed=*/false,
/*operand_reversed=*/false));
}
TEST(EncapsulateXlaComputations, Encapsulate) {
FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph(new Graph(&flib_def));
{
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
};
auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
add_attrs(b_identity.node());
auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT);
add_attrs(read_u.node());
auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT);
add_attrs(read_v.node());
auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT);
add_attrs(read_w.node());
auto e = ops::Add(scope.WithOpName("E"), a, c);
add_attrs(e.node());
auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
add_attrs(f.node());
auto g = ops::Add(scope.WithOpName("G"), f, d);
add_attrs(g.node());
auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity);
auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e);
auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g);
auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u);
auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
}
std::unique_ptr<Graph> graph_copy(new Graph(&flib_def));
CopyGraph(*graph, graph_copy.get());
TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
std::unordered_map<string, Node*> index = BuildNodeIndex(*graph);
string function = index.at("launch0")->type_string();
// Tests the outer graph is as expected.
{
std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function);
GraphDef expected_def;
outer->ToGraphDef(&expected_def);
GraphDef actual_def;
graph->ToGraphDef(&actual_def);
TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def);
}
// Tests the encapsulated body graph is as expected.
{
std::unique_ptr<Graph> body = MakeBodyGraph();
GraphDef expected_body_def;
body->ToGraphDef(&expected_body_def);
InstantiationResultForTest result;
TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result));
EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT,
DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}),
result.arg_types);
EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}),
result.ret_types);
TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef);
}
// Encapsulates the same computation again, verifies we reuse the same
// function. Encapsulation should be deterministic to avoid recompilation.
TF_ASSERT_OK(
EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def));
std::unordered_map<string, Node*> index_copy = BuildNodeIndex(*graph_copy);
string function_copy = index_copy.at("launch0")->type_string();
EXPECT_EQ(function, function_copy);
}
TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
std::unique_ptr<Graph> body_graph = MakeBodyGraph();
FunctionDefLibrary flib;
TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function()));
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "launch0");
TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get()));
Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError();
TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
NameAttrList function;
function.set_name("launch0");
auto launch = ops::XlaLaunch(
scope.WithOpName("launch0"), std::initializer_list<Input>{},
std::initializer_list<Input>{a, b, c, d},
std::initializer_list<Input>{u, v, w},
DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
auto consumer0_a =
ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]);
auto consumer0_b =
ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]);
auto consumer0_c =
ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]);
auto consumer1 =
ops::Identity(scope.WithOpName("consumer1"), launch.results[1]);
auto consumer2 =
ops::Identity(scope.WithOpName("consumer2"), launch.results[2]);
auto consumer3 =
ops::Identity(scope.WithOpName("consumer3"), launch.results[3]);
GraphDef expected_def;
TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
GraphDef actual_def;
graph->ToGraphDef(&actual_def);
TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
}
} // namespace tensorflow