blob: 853e6bcfc8c09951d3491ac2821460b2bda656ba [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include <memory>
#include <utility>
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_util.h"
#include "tensorflow/compiler/jit/extract_outside_compilation_pass.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
namespace {
const char* const kXlaHostTransferSequencerAttr =
"_xla_host_transfer_sequencer";
Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder,
const string& name_suffix,
FunctionDefLibrary* library) {
GraphDef graphdef;
TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef));
std::unique_ptr<Graph> graph =
std::unique_ptr<Graph>(new Graph(OpRegistry::Global()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graphdef, graph.get()));
FunctionDef* fdef = library->add_function();
TF_RETURN_IF_ERROR(GraphToFunctionDef(
*graph,
absl::StrCat("_outside_compilation_shape_inference_", name_suffix),
fdef));
return Status::OK();
}
template <class Tkey, class Tvalue>
bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b,
const std::function<string(const Tkey&)>& key_to_string,
const std::function<string(const Tvalue&)>& value_to_string,
const std::function<bool(const Tkey&, const Tvalue&,
const Tvalue&)>& compare,
const string& map_name, string* diff) {
for (const auto& elt_a : a) {
const auto iter = b.find(elt_a.first);
if (iter == b.end()) {
if (diff) {
*diff = absl::StrCat(map_name, " expected: contains element with key '",
key_to_string(elt_a.first),
"' got: map has no such element");
}
return false;
}
if (!compare(elt_a.first, elt_a.second, iter->second)) {
if (diff) {
*diff = absl::StrCat(map_name, " expected: element with key '",
key_to_string(elt_a.first), "' has value '",
value_to_string(elt_a.second), "' got: '",
value_to_string(iter->second), "'");
}
return false;
}
}
for (const auto& elt_b : b) {
const auto iter = a.find(elt_b.first);
if (iter == a.end()) {
if (diff) {
*diff = absl::StrCat(map_name, " got: contains element with key '",
key_to_string(elt_b.first),
"' expected: map has no such element");
}
return false;
}
}
return true;
}
bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
const string& diff_preamble, string* diff) {
if (a.op() != b.op()) {
if (diff) {
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
", expected op '", a.op(), "' got '", b.op());
}
return false;
}
if (a.device() != b.device()) {
if (diff) {
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
", expected device '", a.device(), "' got '",
b.device());
}
return false;
}
if (a.input_size() != b.input_size()) {
if (diff) {
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
", expected ", a.input_size(), " inputs got ",
b.input_size(), " expected:\n", a.DebugString(),
"\ngot:\n", b.DebugString());
}
return false;
}
std::unordered_set<string> control_input_a;
std::unordered_set<string> control_input_b;
for (int i = 0; i < a.input_size(); ++i) {
if (absl::StartsWith(a.input(i), "^")) {
if (!absl::StartsWith(b.input(i), "^")) {
if (diff) {
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
" input ", i, ", expected control input ",
a.input(i), " got ", b.input(i), " expected:\n",
a.DebugString(), "\ngot:\n", b.DebugString());
}
return false;
}
control_input_a.insert(a.input(i));
control_input_b.insert(b.input(i));
} else if (a.input(i) != b.input(i)) {
if (diff) {
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
" input ", i, ", expected ", a.input(i), " got ",
b.input(i), " expected:\n", a.DebugString(),
"\ngot:\n", b.DebugString());
}
return false;
}
}
if (control_input_a != control_input_b) {
if (diff) {
*diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(),
" control inputs differ expected:\n",
a.DebugString(), "\ngot:\n", b.DebugString());
}
return false;
}
return EqualProtoMap<string, AttrValue>(
a.attr(), b.attr(), [](const string& s) { return s; },
[](const AttrValue& v) { return v.DebugString(); },
[](const string& key, const AttrValue& av, const AttrValue& bv) {
if (key == "ancestors") {
// The ancestors are added from a set so the order is unpredictable;
// just compare set equality not list equality.
std::unordered_set<string> a_set(av.list().s().begin(),
av.list().s().end());
std::unordered_set<string> b_set(bv.list().s().begin(),
bv.list().s().end());
return a_set == b_set;
} else {
return av.DebugString() == bv.DebugString();
}
},
absl::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff);
}
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
string* diff) {
if (a.signature().DebugString() != b.signature().DebugString()) {
if (diff) {
*diff =
absl::StrCat("Signature mismatch for function ", a.signature().name(),
", expected:\n", a.signature().DebugString(), "\ngot:\n",
b.signature().DebugString());
}
return false;
}
if (!EqualProtoMap<string, AttrValue>(
a.attr(), b.attr(), [](const string& s) { return s; },
[](const AttrValue& v) { return v.DebugString(); },
[](const string& key, const AttrValue& av, const AttrValue& bv) {
return av.DebugString() == bv.DebugString();
},
absl::StrCat("attr mismatch for function ", a.signature().name()),
diff)) {
return false;
}
if (!EqualProtoMap<string, string>(
a.ret(), b.ret(), [](const string& s) { return s; },
[](const string& s) { return s; },
[](const string& key, const string& av, const string& bv) {
return av == bv;
},
absl::StrCat("ret mismatch for function ", a.signature().name()),
diff)) {
return false;
}
for (int i = 0; i < a.node_def_size(); ++i) {
bool found = false;
for (int j = 0; j < b.node_def_size(); ++j) {
if (a.node_def(i).name() == b.node_def(j).name()) {
if (!EqualFunctionNodeDef(
a.node_def(i), b.node_def(j),
absl::StrCat("Function ", a.signature().name()), diff)) {
return false;
}
found = true;
break;
}
}
if (!found) {
if (diff) {
*diff = absl::StrCat("Function ", a.signature().name(),
", expected: has node '", a.node_def(i).name(),
"' got: no node of that name");
}
return false;
}
}
for (int i = 0; i < b.node_def_size(); ++i) {
bool found = false;
for (int j = 0; j < a.node_def_size(); ++j) {
if (b.node_def(i).name() == a.node_def(j).name()) {
found = true;
break;
}
}
if (!found) {
if (diff) {
*diff = absl::StrCat("Function ", a.signature().name(),
", got: has node '", b.node_def(i).name(),
"' expected: no node of that name");
}
return false;
}
}
return true;
}
bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
const FunctionDefLibrary& actual, string* diff) {
std::unordered_map<string, const FunctionDef*> actual_index;
for (const FunctionDef& function : actual.function()) {
actual_index[function.signature().name()] = &function;
}
for (const FunctionDef& expected_function : expected.function()) {
auto it = actual_index.find(expected_function.signature().name());
if (it == actual_index.end()) {
if (diff) {
*diff = absl::StrCat("Did not find expected function '",
expected_function.signature().name(), "'");
}
return false;
}
if (!EqualFunctionDef(expected_function, *it->second, diff)) return false;
actual_index.erase(it);
}
if (!actual_index.empty()) {
if (diff != nullptr) {
*diff =
absl::StrCat("Found unexpected function '",
actual_index.begin()->second->signature().name(), "'");
}
return false;
}
return true;
}
#define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual) \
do { \
string diff; \
EXPECT_TRUE(EqualFunctionDefLibrary(expected, actual, &diff)) \
<< diff << "\nActual: " << actual.DebugString(); \
} while (false)
// These dummy Op registrations are here because the real Op registrations live
// in contrib and there can't be a dependence from this test to contrib.
REGISTER_OP("XlaHostCompute")
.Input("inputs: Tinputs")
.Output("outputs: Toutputs")
.Attr("Tinputs: list(type) >= 0")
.Attr("Toutputs: list(type) >= 0")
.Attr("ancestors: list(string) >= 0")
.Attr("key: string")
.Attr("shape_inference_graph: func")
.Attr("shapes: list(shape) >= 0")
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("InputTest")
.Output("o: float")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->UnknownShape());
return Status::OK();
});
REGISTER_OP("InputTestShaped")
.Output("o: float")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->Vector(2));
return Status::OK();
});
REGISTER_OP("UnaryTest")
.Input("a: float")
.Output("o: float")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
::tensorflow::shape_inference::ShapeHandle o;
TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
c->set_output(0, o);
return Status::OK();
});
REGISTER_OP("BinaryTest")
.Input("a: float")
.Input("b: float")
.Output("o: float")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
::tensorflow::shape_inference::ShapeHandle o;
TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
c->set_output(0, o);
return Status::OK();
});
REGISTER_OP("BinaryTest2")
.Input("a: float")
.Input("b: float")
.Output("o: float")
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("AddNLikeTest")
.Input("inputs: N * T")
.Output("sum: T")
.Attr("N: int >= 1")
.Attr("T: numbertype")
.SetIsCommutative()
.SetIsAggregate();
Node* Sequencer(const GraphDefBuilder::Options& opts,
const string& call_node_name) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("NoOp"), "NoOp",
opts.op_registry());
return opts.WithAttr(kXlaHostTransferSequencerAttr, call_node_name)
.FinalizeBuilder(&node_builder);
}
Node* Input(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("InputTest", opts);
}
Node* InputShaped(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("InputTestShaped", opts);
}
Node* KnownShapeBase(DataType dtype, absl::Span<const int> shape,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
opts.op_registry());
TensorProto value;
value.set_dtype(dtype);
for (int dim : shape) {
value.mutable_tensor_shape()->add_dim()->set_size(dim);
}
return opts.WithAttr("value", value)
.WithAttr("dtype", dtype)
.FinalizeBuilder(&node_builder);
}
Node* KnownShape(absl::Span<const int> shape,
const GraphDefBuilder::Options& opts) {
return KnownShapeBase(DT_FLOAT, shape, opts);
}
Node* KeyPlaceholderShape(const GraphDefBuilder::Options& opts) {
return KnownShapeBase(DT_STRING, {2}, opts);
}
Node* KeyPlaceholder(const string& call_node,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(absl::StrCat(call_node, "_key_placeholder"),
"Placeholder", opts.op_registry());
TensorShapeProto shape;
shape.add_dim()->set_size(2);
return opts.WithAttr("shape", shape)
.WithAttr("dtype", DT_STRING)
.WithAttr("_host_compute_call_node", call_node)
.FinalizeBuilder(&node_builder);
}
Node* RecvAtHost(ops::NodeOut key_input, const string& cluster,
const string& new_func_name, const string& oc_cluster,
absl::Span<const DataType> dtypes,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
string key = absl::StrCat("host_compute_channel_", cluster, "_",
new_func_name, "_", oc_cluster);
string name = absl::StrCat("outside_compilation_", cluster, "_",
new_func_name, "_", oc_cluster, "_recv");
NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"),
"_XlaRecvAtHost", opts.op_registry());
node_builder.Input(std::move(key_input));
return opts.WithAttr("Toutputs", dtypes)
.WithAttr("key", key)
.WithAttr("device_ordinal", 0)
.WithAttr("_encapsulate", cluster)
.WithAttr("_outside", oc_cluster)
.FinalizeBuilder(&node_builder);
}
Node* SendFromHost(ops::NodeOut key_input, const string& cluster,
const string& new_func_name, const string& oc_cluster,
const std::vector<ops::NodeOut>& inputs,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
string key = absl::StrCat("host_compute_channel_", cluster, "_",
new_func_name, "_", oc_cluster);
string name = absl::StrCat("outside_compilation_", cluster, "_",
new_func_name, "_", oc_cluster, "_send");
NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"),
"_XlaSendFromHost", opts.op_registry());
node_builder.Input(inputs);
node_builder.Input(std::move(key_input));
std::vector<DataType> dtypes;
for (const auto& node : inputs) {
dtypes.push_back(node.dt);
}
return opts.WithAttr("Tinputs", dtypes)
.WithAttr("key", key)
.WithAttr("device_ordinal", 0)
.WithAttr("_encapsulate", cluster)
.WithAttr("_outside", oc_cluster)
.FinalizeBuilder(&node_builder);
}
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
return ops::UnaryOp("UnaryTest", std::move(a), opts);
}
Node* Binary(ops::NodeOut a, ops::NodeOut b,
const GraphDefBuilder::Options& opts) {
return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
}
Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b,
const GraphDefBuilder::Options& opts) {
return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts);
}
Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest",
opts.op_registry());
node_builder.Input(inputs);
return opts.FinalizeBuilder(&node_builder);
}
Node* ArgOp(int index, DataType type, const GraphDefBuilder::Options& opts) {
return ops::SourceOp("_Arg",
opts.WithAttr("T", type).WithAttr("index", index));
}
Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
opts.op_registry());
node_builder.Input(std::move(a)).Attr("index", index);
return opts.FinalizeBuilder(&node_builder);
}
Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library,
const std::vector<string>& encapsulated_functions) {
Status s;
// Convert the GraphDef to a Graph
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), *library));
GraphConstructorOptions options;
options.allow_internal_ops = true;
std::unique_ptr<Graph> graph(new Graph(lib_def.get()));
s = ConvertGraphDefToGraph(options, *graphdef, graph.get());
if (!s.ok()) return s;
s = PerformStaticShapeInferenceBeforeEncapsulation(graph.get());
if (!s.ok()) return s;
// Create FunctionLibraryRuntime.
SessionOptions session_options;
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
session_options, "/job:localhost/replica:0/task:0", &devices));
OptimizerOptions opts;
auto device_mgr = absl::make_unique<StaticDeviceMgr>(std::move(devices));
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, lib_def.get(), opts,
/*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
auto flr = pflr->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
std::unique_ptr<Graph> graph_out;
s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
/*rewrite_subgraph_fn=*/{},
/*reuse_existing_functions=*/false,
&graph_out, lib_def.get());
if (!s.ok()) return s;
std::unordered_map<string, XlaClusterInfo> clusters;
for (const auto& func : encapsulated_functions) {
Node* xla_computation_node;
for (Node* n : graph_out->nodes()) {
if (n->name() == func) {
xla_computation_node = n;
}
}
if (!xla_computation_node) {
return errors::Internal("Cannot find node ", func);
}
NameAttrList func_name_attrs;
func_name_attrs.set_name(func);
clusters.emplace(func,
XlaClusterInfo{func, func_name_attrs, xla_computation_node,
std::map<string, int>{}});
}
bool modified;
s = ExtractOutsideCompilation("_encapsulate", "_outside", clusters,
graph_out.get(), flr, lib_def.get(), &modified);
if (!s.ok()) return s;
GraphDef graphdef_out;
graph_out->ToGraphDef(&graphdef_out);
graphdef->Swap(&graphdef_out);
*library = lib_def->ToProto();
// Remove "_xla_inferred_shapes" attr. They are added by
// `PerformStaticShapeInferenceBeforeEncapsulation`.
for (FunctionDef& fdef : *library->mutable_function()) {
for (NodeDef& node_def : *fdef.mutable_node_def()) {
node_def.mutable_attr()->erase("_xla_inferred_shapes");
}
}
return s;
}
Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
std::vector<string> encapsulated_functions;
return Encapsulate(graphdef, library, encapsulated_functions);
}
// If there are no marked nodes, funcification should be a no-op.
TEST(EncapsulateSubgraphsTest, NoFunctions) {
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = Input(builder.opts().WithName("A"));
Node* b = Input(builder.opts().WithName("B"));
Node* c = Unary(a, builder.opts().WithName("C"));
Binary(b, c, builder.opts().WithName("D"));
GraphDef graphdef_in;
FunctionDefLibrary library_in;
TF_EXPECT_OK(builder.ToGraphDef(&graphdef_in));
*library_in.add_function() = test::function::XTimesTwo();
GraphDef graphdef_out = graphdef_in;
FunctionDefLibrary library_out = library_in;
TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out));
// If there are no marked nodes, funcification should be a no-op.
TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out);
}
// Test with one function to transform.
TEST(EncapsulateSubgraphsTest, OneFunction) {
FunctionDefLibrary library;
GraphDef graphdef;
{
*library.add_function() = test::function::XTimesTwo();
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
// Give nodes 'c' and 'd' names that collide after lowercasing.
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d = Binary(b, c,
b1.opts().WithName("c").WithControlInput(c).WithAttr(
"_encapsulate", "F1"));
Binary(a, d, b1.opts().WithName("E"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"c_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
},
{{"c_0_retval", "c:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
Node* call = b2.opts().FinalizeBuilder(&node_builder);
Binary(a, call, b2.opts().WithName("E"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two functions to transform.
TEST(EncapsulateSubgraphsTest, TwoFunctions) {
FunctionDefLibrary library;
GraphDef graphdef;
{
*library.add_function() = test::function::XTimesTwo();
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* control = Input(b1.opts().WithName("Control"));
Node* c =
Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr(
"_encapsulate", "F1"));
Node* d = Binary(b, c,
b1.opts().WithName("D").WithControlInput(control).WithAttr(
"_encapsulate", "F2"));
Binary(a, d, b1.opts().WithName("E"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float"}, {"c_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
},
{{"c_0_retval", "C:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
"F2", {"b_0_arg:float", "c_0_arg:float"}, {"d_0_retval:float"}, {},
{
{{"D"}, "BinaryTest", {"b_0_arg", "c_0_arg"}},
},
{{"d_0_retval", "D:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* control = Input(b2.opts().WithName("Control"));
NodeBuilder nb("F1", "F1", lib_def.get());
nb.Input(a).ControlInput(control);
Node* call1 = b2.opts().FinalizeBuilder(&nb);
NodeBuilder nb2("F2", "F2", lib_def.get());
nb2.Input(b).Input(call1).ControlInput(control);
Node* call2 = b2.opts().FinalizeBuilder(&nb2);
Binary(a, call2, b2.opts().WithName("E"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
// If there are no marked nodes, funcification should be a no-op.
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Returns a vector of node names in 'graph', sorted by name.
std::vector<string> GraphNodes(const Graph& graph) {
std::vector<string> nodes;
for (const auto& node : graph.nodes()) {
if (!node->IsSource() && !node->IsSink()) {
nodes.push_back(node->name());
}
}
std::sort(nodes.begin(), nodes.end());
return nodes;
}
// Returns a sorted vector of (src, dst) edges in 'graph'.
std::vector<std::pair<string, string>> GraphEdges(const Graph& graph) {
std::vector<std::pair<string, string>> edges;
for (const Edge* edge : graph.edges()) {
if (edge->src()->IsSource() || edge->dst()->IsSink()) continue;
edges.emplace_back(
absl::StrCat(edge->src()->name(), ":", edge->src_output()),
absl::StrCat(edge->dst()->name(), ":", edge->dst_input()));
}
std::sort(edges.begin(), edges.end());
return edges;
}
TEST(EncapsulateSubgraphsTest, InputDeduplication) {
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
"/job:localhost/replica:0/task:0/cpu:0");
auto x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT);
auto add1 = ops::Add(root.WithOpName("add1"), x, x);
add1.node()->AddAttr("_cluster", "cluster1");
auto add2 = ops::Add(root.WithOpName("add2"), add1, add1);
add2.node()->AddAttr("_cluster", "cluster2");
auto out = ops::Mul(root.WithOpName("mul"), add1, add2);
Graph graph_before_encapsulation(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
FunctionLibraryDefinition library(OpRegistry::Global(), {});
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", graph_before_encapsulation,
/*rewrite_subgraph_fn=*/{},
/*reuse_existing_functions=*/false, &graph, &library));
std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
EXPECT_EQ(expected_nodes, GraphNodes(*graph));
std::vector<std::pair<string, string>> expected_edges = {
{"cluster1:0", "cluster2:0"},
{"cluster1:0", "mul:0"},
{"cluster2:0", "mul:1"},
{"x:0", "cluster1:0"}};
EXPECT_EQ(expected_edges, GraphEdges(*graph));
}
const Node* FindNodeByName(const Graph& graph, const string& name) {
for (const Node* node : graph.nodes()) {
if (node->name() == name) return node;
}
return nullptr;
}
bool HasGuaranteeConstAttr(const Node& n) {
bool is_guaranteed_constant = false;
if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant",
&is_guaranteed_constant)
.ok()) {
return false;
}
return is_guaranteed_constant;
}
TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
"/job:localhost/replica:0/task:0/cpu:0");
auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
auto const_guarantee_x2 =
ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2);
auto const_guarantee_x1 =
ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
auto add1 =
ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_guarantee_x2);
add1.node()->AddAttr("_encapsulate", "encapsulate1");
Graph graph_before(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&graph_before));
std::unique_ptr<Graph> graph_after;
FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation,
NodeDef* call_def) {
Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" &&
absl::StartsWith(n->name(), "const")) {
++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else {
EXPECT_FALSE(HasGuaranteeConstAttr(*n));
}
}
return Status::OK();
},
/*reuse_existing_functions=*/false, &graph_after, &library));
EXPECT_EQ(2, guaranteed_consts);
}
TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
"/job:localhost/replica:0/task:0/cpu:0");
auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
auto const_guarantee_x1 =
ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
auto const_guarantee_x2 =
ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2);
auto const_guarantee_add1 = ops::Add(root.WithOpName("const_guarantee_add1"),
const_guarantee_x1, const_guarantee_x2);
auto add2 = ops::Add(root.WithOpName("add2"), const_guarantee_x1, x2);
auto mul1 = ops::Mul(root.WithOpName("mul1"), const_guarantee_add1, add2);
mul1.node()->AddAttr("_encapsulate", "encapsulate1");
Graph graph_before(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&graph_before));
std::unique_ptr<Graph> graph_after;
FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation,
NodeDef* call_def) {
Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" &&
absl::StartsWith(n->name(), "const")) {
++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else {
EXPECT_FALSE(HasGuaranteeConstAttr(*n));
}
}
return Status::OK();
},
/*reuse_existing_functions=*/false, &graph_after, &library));
// Only 1 runtime const, which is const_guarantee_add1. Add2 has one const
// and another non-const, so overall non-const.
EXPECT_EQ(1, guaranteed_consts);
}
// Test with one function to transform and one outside_compilation cluster.
TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
FunctionDefLibrary library;
GraphDef graphdef;
{
*library.add_function() = test::function::XTimesTwo();
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
// Give nodes 'c' and 'd' names that collide after lowercasing.
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d = Binary(b, c,
b1.opts().WithName("c").WithControlInput(c).WithAttr(
"_encapsulate", "F1"));
Node* e = Binary(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape.opts());
Node* recv = RecvAtHost(
ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
shape.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape, "F1_F1_O1", &library_expected));
}
NameAttrList shape_inference_graph;
shape_inference_graph.set_name(
"_outside_compilation_shape_inference_F1_F1_O1");
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
{},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "c:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
{"c"}},
},
{{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv = RecvAtHost(
ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
b2.opts()
.WithName("E")
.WithControlInputs({recv})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
b2.opts().WithControlInput(e).WithAttr(
kXlaHasHostTransferAttrName, true));
Node* s = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
"F1");
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
Node* call =
b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder);
Binary(a, call, b2.opts().WithName("G").WithControlInputs({call}));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one function to transform and two outside_compilation clusters.
TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Binary(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Node* g = Binary(e, f,
b1.opts()
.WithName("G")
.WithControlInputs({e, f})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* h = Binary(d, e,
b1.opts()
.WithName("H")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F1"));
Binary(g, i, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
Node* recv = RecvAtHost(
ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
}
{
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape2.opts());
Node* recv1 = RecvAtHost(
ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
shape2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* recv2 = RecvAtHost(
ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT, DT_FLOAT},
shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* g = Binary(e, ops::NodeOut(recv2, 0),
shape2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* h = Binary(ops::NodeOut(recv2, 1), e,
shape2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g, h},
shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape2, "F1_F1_O2", &library_expected));
}
NameAttrList shape_inference_graph1, shape_inference_graph2;
shape_inference_graph1.set_name(
"_outside_compilation_shape_inference_F1_F1_O1");
shape_inference_graph2.set_name(
"_outside_compilation_shape_inference_F1_F1_O2");
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"g_0_retval_retval:float", "i_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
{{"I"},
"UnaryTest",
{"outside_compilation_O2_host_compute:outputs:1"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
{},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"F:o:0", "D:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O2"},
{"shape_inference_graph", shape_inference_graph2},
{"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}},
{"F", "outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", shape_inference_graph1},
{"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
{"D"}},
},
{{"g_0_retval_retval", "outside_compilation_O2_host_compute:outputs:0"},
{"i_0_retval_retval", "I:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 = RecvAtHost(
ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts()
.WithName("E")
.WithControlInputs({recv1})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
b2.opts().WithControlInput(e).WithAttr(
kXlaHasHostTransferAttrName, true));
Node* recv2 = RecvAtHost(
ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT, DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* g = Binary(e, ops::NodeOut(recv2, 0),
b2.opts()
.WithName("G")
.WithControlInputs({recv2, e})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* h = Binary(ops::NodeOut(recv2, 1), e,
b2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
Node* send2 =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g, h},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* s = Sequencer(b2.opts()
.WithName("F1_sequencer")
.WithControlInputs({recv1, send1, recv2, send2}),
"F1");
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(a).Input(b);
Node* call =
b2.opts().WithControlInputs({s, b}).FinalizeBuilder(&node_builder);
Binary(ops::NodeOut(call, 0), ops::NodeOut(call, 1),
b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two functions to transform, each with one outside_compilation
// cluster.
TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = InputShaped(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Binary(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Node* g = Binary(e, f,
b1.opts().WithName("G").WithControlInputs({e, f}).WithAttr(
"_encapsulate", "F2"));
Node* h = Binary(d, g,
b1.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
.WithAttr("_outside", "O1"));
Node* i =
Binary(f, h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
Binary(g, i, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1", "F2"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
TensorShapeProto shape_proto_expected;
shape_proto_expected.add_dim()->set_size(2);
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"e_0_retval_retval:float", "f_0_retval_retval:float",
"d_0_retval_retval:float"},
{},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
{},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
{"D"}},
},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"d_0_retval_retval", "D:o:0"},
{"f_0_retval_retval", "F:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
"F2", {"e_0_arg:float", "f_0_arg:float", "d_0_arg:float"},
{"g_0_retval_retval:float", "i_0_retval_retval:float"}, {},
{
{{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
{{"I"},
"BinaryTest",
{"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"d_0_arg", "G:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F2_F2_O1"},
{"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
},
{{"g_0_retval_retval", "G:o:0"}, {"i_0_retval_retval", "I:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = InputShaped(b2.opts().WithName("B"));
Node* key_constant1 =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 = RecvAtHost(
ops::NodeOut(key_constant1, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts()
.WithName("E")
.WithControlInputs({recv1})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 =
SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "F1", "O1", {e},
b2.opts().WithControlInput(e).WithAttr(
kXlaHasHostTransferAttrName, true));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1);
Node* key_constant2 =
KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
Node* recv2 = RecvAtHost(
ops::NodeOut(key_constant2, 0), "F2", "F2", "O1", {DT_FLOAT, DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* h = Binary(recv2, ops::NodeOut(recv2, 1),
b2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
.WithAttr("_outside", "O1"));
Node* send2 =
SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "F2", "O1", {h},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* s2 = Sequencer(
b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
"F2");
NodeBuilder node_builder2("F2", "F2", lib_def.get());
node_builder2.Input(call1)
.Input(ops::NodeOut(call1, 1))
.Input(ops::NodeOut(call1, 2));
Node* call2 = b2.opts()
.WithControlInputs({s2, call1})
.FinalizeBuilder(&node_builder2);
Binary(call2, ops::NodeOut(call2, 1), b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two functions to transform, each with one outside_compilation
// cluster, with the dependency between them purely from an outside_compilation
// edge.
TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = InputShaped(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Binary(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Node* g =
Binary(a, b, b1.opts().WithName("G").WithAttr("_encapsulate", "F2"));
Node* h = Unary(g, b1.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
.WithAttr("_outside", "O1"));
Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2"));
Binary(f, i, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1", "F2"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
TensorShapeProto shape_proto_expected;
shape_proto_expected.add_dim()->set_size(2);
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
{},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"C:o:0", "D:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
{"D"}},
},
{{"f_0_retval_retval", "F:o:0"}});
*library_expected.add_function() = FunctionDefHelper::Create(
"F2", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval_retval:float"}, {},
{
{{"G"}, "BinaryTest", {"a_0_arg", "b_0_arg"}},
{{"I"},
"UnaryTest",
{"outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"G:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F2_F2_O1"},
{"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
},
{{"i_0_retval_retval", "I:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = InputShaped(b2.opts().WithName("B"));
Node* key_constant1 =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "F1", "O1",
{DT_FLOAT, DT_FLOAT}, b2.opts());
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts()
.WithName("E")
.WithControlInputs({recv1})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "F1", "O1",
{e}, b2.opts().WithControlInput(e));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
b2.opts().WithControlInputs({s1, b}).FinalizeBuilder(&node_builder1);
Node* key_constant2 =
KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder"));
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "F2", "O1",
{DT_FLOAT}, b2.opts());
Node* h = Unary(recv2, b2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F2")
.WithAttr("_outside", "O1"));
Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "F2", "O1",
{h}, b2.opts());
Node* s2 = Sequencer(
b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}),
"F2");
NodeBuilder node_builder2("F2", "F2", lib_def.get());
node_builder2.Input(a).Input(b);
Node* call2 =
b2.opts().WithControlInputs({s2}).FinalizeBuilder(&node_builder2);
Binary(call1, call2, b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no inputs from the
// compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(a, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f =
Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Unary(f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
TensorShapeProto shape_proto_expected;
shape_proto_expected.add_dim()->set_size(2);
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"a_0_arg"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
},
{{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
{DT_FLOAT}, b2.opts());
Node* e = Unary(recv1, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
{e}, b2.opts());
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({send1, recv1}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
Unary(call1, b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no data inputs but has a
// control input from the compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(a, b1.opts()
.WithName("E")
.WithControlInput(d)
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f =
Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Unary(f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
TensorShapeProto shape_proto_expected;
shape_proto_expected.add_dim()->set_size(2);
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"a_0_arg"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", NameAttrList()},
{"shapes",
absl::Span<const TensorShapeProto>({shape_proto_expected})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
{"D"}},
},
{{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
{DT_FLOAT}, b2.opts());
Node* e = Unary(recv1, b2.opts()
.WithName("E")
.WithControlInput(recv1)
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
{e}, b2.opts());
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
Unary(call1, b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no outputs from the
// compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(d, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Binary(e, f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
Node* recv1 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
}
NameAttrList shape_inference_graph;
shape_inference_graph.set_name(
"_outside_compilation_shape_inference_F1_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
{DT_FLOAT}, b2.opts());
Node* e = Unary(recv1, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
{e}, b2.opts());
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no data outputs but has a
// control output to the compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(d, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(d, b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Binary(e, f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
Node* recv1 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
}
NameAttrList shape_inference_graph;
shape_inference_graph.set_name(
"_outside_compilation_shape_inference_F1_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"UnaryTest",
{"D:o:0"},
{},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
{DT_FLOAT}, b2.opts());
Node* e = Unary(recv1, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 = SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1",
{e}, b2.opts().WithControlInput(e));
Node* s1 = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 =
b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1);
Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two outside_compilation clusters that interact outside the compiled
// subgraph, where the ancestor has no HostCompute Op.
TEST(EncapsulateSubgraphsTest,
OutsideCompilationClusterDependencyNoSrcCluster) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(a, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Node* g = Unary(f, b1.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* h = Unary(g, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
Binary(e, h, b1.opts().WithName("I"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
Node* recv1 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(ops::NodeOut(recv1, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
}
{
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape2.opts());
Node* recv2 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g},
shape2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape2, "F1_F1_O2", &library_expected));
}
NameAttrList shape_inference_graph1;
shape_inference_graph1.set_name(
"_outside_compilation_shape_inference_F1_F1_O1");
NameAttrList shape_inference_graph2;
shape_inference_graph2.set_name(
"_outside_compilation_shape_inference_F1_F1_O2");
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
{{"H"},
"UnaryTest",
{"outside_compilation_O2_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"a_0_arg"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", shape_inference_graph1},
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"F:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O2"},
{"shape_inference_graph", shape_inference_graph2},
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}},
{"outside_compilation_O1_host_compute"}},
},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(recv1, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send1 =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* recv2 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* g = Unary(recv2, b2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* send2 =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {g},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* s1 = Sequencer(b2.opts()
.WithName("F1_sequencer")
.WithControlInputs({recv1, send1, recv2, send2}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two outside_compilation clusters that interact outside the compiled
// subgraph, where the successor has no HostCompute Op.
TEST(EncapsulateSubgraphsTest,
OutsideCompilationClusterDependencyNoDstCluster) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(d, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
/*Node* g =*/Unary(a, b1.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
Binary(e, h, b1.opts().WithName("I"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
Node* recv2 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
}
NameAttrList shape_inference_graph;
shape_inference_graph.set_name(
"_outside_compilation_shape_inference_F1_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"UnaryTest",
{"outside_compilation_O1_host_compute:outputs:0"}},
{{"H"}, "UnaryTest", {"F:o:0"}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"a_0_arg"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O2"},
{"shape_inference_graph", NameAttrList()},
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>(
{"_xla_token_arg_node",
"outside_compilation_O1_host_compute"})}},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(recv1, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* recv2 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
/*Node* g =*/Unary(recv2, b2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* s1 = Sequencer(b2.opts()
.WithName("F1_sequencer")
.WithControlInputs({recv1, recv2, send}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("I"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with two outside_compilation clusters that interact outside the compiled
// subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(d, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Node* g = Unary(d, b1.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1"));
/*Node* i =*/Binary(d, e,
b1.opts()
.WithName("I")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O3")
.WithControlInput(g));
Binary(e, h, b1.opts().WithName("J"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
Node* recv2 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
}
NameAttrList shape_inference_graph;
shape_inference_graph.set_name(
"_outside_compilation_shape_inference_F1_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"e_0_retval_retval:float", "h_0_retval_retval:float"}, {},
{{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}},
{{"H"}, "UnaryTest", {"F:o:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O2"},
{"shape_inference_graph", NameAttrList()},
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O2"},
{"_xla_token_input_nodes",
absl::Span<const string>(
{"_xla_token_arg_node", "outside_compilation_O1_host_compute"})}},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O3_host_compute"},
"XlaHostCompute",
{"D:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O3"},
{"shape_inference_graph", NameAttrList()},
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O3"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node",
"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"})}},
{"outside_compilation_O1_host_compute",
"outside_compilation_O2_host_compute"}}},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"h_0_retval_retval", "H:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv1 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(recv1, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* recv2 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O2", {DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* g = Unary(recv2, b2.opts()
.WithName("G")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2")
.WithControlInput(e));
Node* recv3 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O3", {DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
/*Node* i =*/Binary(recv3, e,
b2.opts()
.WithName("I")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O3")
.WithControlInput(g));
Node* s1 = Sequencer(b2.opts()
.WithName("F1_sequencer")
.WithControlInputs({recv1, send, recv2, recv3}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b).ControlInput(s1);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("J"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test with one outside_compilation cluster that has no outputs from the
// compiled subgraph.
TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
FunctionDefLibrary library;
GraphDef graphdef;
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
Node* e = Unary(a, b1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1"));
Binary(e, f, b1.opts().WithName("G"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape1.opts());
Node* recv2 =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
shape1.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape1, "F1_F1_O1", &library_expected));
}
NameAttrList shape_inference_graph;
shape_inference_graph.set_name(
"_outside_compilation_shape_inference_F1_F1_O1");
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"e_0_retval_retval:float", "f_0_retval_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"a_0_arg"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span<const TensorShapeProto>({})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}}},
},
{{"e_0_retval_retval", "outside_compilation_O1_host_compute:outputs:0"},
{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv =
RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = Unary(recv, b2.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* s = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
"F1");
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b).ControlInput(s);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
Binary(call1, ops::NodeOut(call1, 1), b2.opts().WithName("G"));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test for shape inference of outside compilation.
TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
FunctionDefLibrary library;
GraphDef graphdef;
{
*library.add_function() = test::function::XTimesTwo();
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
// Give nodes 'c' and 'd' names that collide after lowercasing.
Node* c = Unary(a, b1.opts().WithName("C"));
Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr(
"_encapsulate", "F1"));
Node* e = BinaryUnknownShape(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
std::vector<string> encapsulated_functions{"F1"};
TF_EXPECT_OK(Encapsulate(&graphdef, &library, encapsulated_functions));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* key_constant = KeyPlaceholder("F1", shape.opts());
Node* recv = RecvAtHost(
ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1),
shape.opts()
.WithName("E")
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
shape.opts().WithAttr(kXlaHasHostTransferAttrName, true));
TF_EXPECT_OK(
AddGraphDefToFunctionLibrary(shape, "F1_F1_O1", &library_expected));
}
NameAttrList shape_inference_graph;
shape_inference_graph.set_name(
"_outside_compilation_shape_inference_F1_F1_O1");
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval_retval:float"}, {},
{
{{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}},
{{"F"},
"BinaryTest",
{"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"},
{},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"XlaHostCompute",
{"c_0_arg", "c:o:0"},
{{"Tinputs", absl::Span<const DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", absl::Span<const DataType>({DT_FLOAT})},
{"ancestors", absl::Span<const string>({})},
{"key", "host_compute_channel_F1_F1_O1"},
{"shape_inference_graph", shape_inference_graph},
{"shapes", absl::Span<const DataType>({})},
{"_outside_compilation_subgraph", "O1"},
{"_xla_token_input_nodes",
absl::Span<const string>({"_xla_token_arg_node"})}},
{"c"}},
},
{{"f_0_retval_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* c = Unary(a, b2.opts().WithName("C"));
Node* key_constant =
KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder"));
Node* recv = RecvAtHost(
ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {DT_FLOAT, DT_FLOAT},
b2.opts().WithAttr(kXlaHasHostTransferAttrName, true));
Node* e = BinaryUnknownShape(recv, ops::NodeOut(recv, 1),
b2.opts()
.WithName("E")
.WithControlInputs({recv})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* send =
SendFromHost(ops::NodeOut(key_constant, 0), "F1", "F1", "O1", {e},
b2.opts().WithControlInput(e).WithAttr(
kXlaHasHostTransferAttrName, true));
Node* s = Sequencer(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}),
"F1");
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(b).Input(c);
Node* call =
b2.opts().WithControlInputs({s, b, c}).FinalizeBuilder(&node_builder);
Binary(a, call, b2.opts().WithName("G").WithControlInputs({call}));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
void CreateSubgraphTouchingRefVar(const Scope& s) {
Output variable =
ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT);
Output read = ops::Identity(s.WithOpName("read_ref_var"), variable);
Output neg = ops::Negate(s.WithOpName("negate_ref"), read);
Output add = ops::Add(s.WithOpName("add_ref"), neg, neg);
Output constant =
ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0));
s.graph()->AddControlEdge(constant.node(), variable.node());
}
TEST(EncapsulateSubgraphsTest, RefVariablesMarked) {
Scope root = Scope::NewRootScope().ExitOnError();
CreateSubgraphTouchingRefVar(root);
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get()));
GraphOptimizationPassWrapper wrapper;
GraphOptimizationPassOptions options =
wrapper.CreateGraphOptimizationPassOptions(&graph);
EncapsulateSubgraphsPass pass;
TF_ASSERT_OK(pass.Run(options));
for (const Node* node : graph->nodes()) {
bool has_ref_var;
TF_ASSERT_OK(
GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var));
EXPECT_TRUE(node->IsSink() || node->IsSource() || has_ref_var)
<< "All nodes apart from source and sink can access reference variable";
}
}
void CreateSubgraphNotTouchingRefVar(const Scope& s) {
Output constant =
ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0));
Output neg = ops::Negate(s.WithOpName("negate_normal"), constant);
Output add = ops::Add(s.WithOpName("add_normal"), neg, neg);
}
TEST(EncapsulateSubgraphsTest, NoRefVarsNoAttr) {
Scope root = Scope::NewRootScope().ExitOnError();
CreateSubgraphNotTouchingRefVar(root);
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(graph.get()));
GraphOptimizationPassWrapper wrapper;
GraphOptimizationPassOptions options =
wrapper.CreateGraphOptimizationPassOptions(&graph);
EncapsulateSubgraphsPass pass;
TF_ASSERT_OK(pass.Run(options));
for (const Node* node : graph->nodes()) {
bool has_ref_var;
TF_ASSERT_OK(
GetNodeAttr(node->attrs(), kXlaHasReferenceVarsAttr, &has_ref_var));
EXPECT_FALSE(has_ref_var) << "The graph does not have reference variables";
}
}
} // namespace
} // namespace tensorflow