blob: c3d6e948f1eb759c74c4e17c0a59c609f022b95e [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
#include <atomic>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/cc/ops/array_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
namespace {
using FDH = ::tensorflow::FunctionDefHelper;
using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
Status GetOpSig(const string& op, const OpDef** sig) {
return OpRegistry::Global()->LookUpOpDef(op, sig);
}
void HasError(const Status& s, StringPiece substr) {
EXPECT_TRUE(absl::StrContains(s.ToString(), substr))
<< s << ", expected substring " << substr;
}
class FunctionTest : public ::testing::Test {
protected:
FunctionTest()
: device_(DeviceFactory::NewDevice("CPU", {},
"/job:localhost/replica:0/task:0")) {}
void Create(const FunctionDef& fdef, test::function::Attrs attrs) {
exec_ = nullptr;
InstantiationResult result;
TF_CHECK_OK(InstantiateFunction(fdef, attrs, GetOpSig, &result));
arg_types_ = result.arg_types;
ret_types_ = result.ret_types;
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
const int version = g->versions().producer();
LocalExecutorParams params;
params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef,
OpKernel** kernel) {
return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
kernel);
};
params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel);
};
params.rendezvous_factory = [](const int64, const DeviceMgr* device_mgr,
Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
};
Executor* exec;
TF_CHECK_OK(NewLocalExecutor(params, *g, &exec));
exec_.reset(exec);
}
void Run(const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
FunctionCallFrame frame(arg_types_, ret_types_);
TF_CHECK_OK(frame.SetArgs(args));
Executor::Args exec_args;
exec_args.call_frame = &frame;
exec_args.runner = test::function::FunctionTestSchedClosure;
TF_CHECK_OK(exec_->Run(exec_args));
std::vector<Tensor> computed;
TF_CHECK_OK(frame.GetRetvals(&computed));
CHECK_EQ(computed.size(), rets.size());
for (int i = 0; i < rets.size(); ++i) {
*(rets[i]) = computed[i];
}
}
std::unique_ptr<Device> device_;
std::unique_ptr<Executor> exec_;
DataTypeVector arg_types_;
DataTypeVector ret_types_;
};
TEST_F(FunctionTest, XTimesTwo) {
Create(test::function::XTimesTwo(), {{"T", DT_FLOAT}});
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
Run({x}, {&y});
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
}
TEST_F(FunctionTest, WXPlusB) {
Create(test::function::WXPlusB(), {{"T", DT_FLOAT}});
auto w = test::AsTensor<float>({1., 2., 3., 4.}, {2, 2});
auto x = test::AsTensor<float>({1., 3., 2., 4.}, {2, 2});
auto b = test::AsTensor<float>({0.5, 2.5}, {2});
Tensor y;
Run({w, x, b}, {&y});
test::ExpectTensorEqual<float>(
y, test::AsTensor<float>({5.5, 13.5, 11.5, 27.5}, {2, 2}));
}
class FunctionLibraryRuntimeTest : public ::testing::Test {
protected:
void Init(const std::vector<FunctionDef>& flib) {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 3});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts;
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), &options.config,
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");
fdef_lib_ = lib_def_->ToProto();
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime::Options opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
bool add_runner = true) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
[&call_count](std::function<void()> fn) {
++call_count;
test::function::FunctionTestSchedClosure(fn);
};
if (add_runner) {
opts.runner = &runner;
} else {
opts.runner = nullptr;
}
Notification done;
std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
status = s;
done.Notify();
});
done.WaitForNotification();
if (!status.ok()) {
return status;
}
CHECK_EQ(rets.size(), out.size());
for (size_t i = 0; i < rets.size(); ++i) {
*rets[i] = out[i];
}
if (add_runner) {
EXPECT_GE(call_count, 1); // Test runner is used.
}
return Status::OK();
}
Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
FunctionLibraryRuntime::Handle* handle) {
return flr->Instantiate(name, attrs, handle);
}
Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle) {
return flr->Instantiate(name, attrs, options, handle);
}
Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const std::vector<Tensor>& args,
std::vector<Tensor*> rets, bool add_runner = true) {
return InstantiateAndRun(flr, name, attrs,
FunctionLibraryRuntime::InstantiateOptions(), args,
std::move(rets), add_runner);
}
Status InstantiateAndRun(
FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
bool add_runner = true) {
FunctionLibraryRuntime::Handle handle;
Status status = flr->Instantiate(name, attrs, options, &handle);
if (!status.ok()) {
return status;
}
FunctionLibraryRuntime::Options opts;
status = Run(flr, handle, opts, args, rets, add_runner);
if (!status.ok()) return status;
// Release the handle and try running again. It should not succeed.
status = flr->ReleaseHandle(handle);
if (!status.ok()) return status;
Status status2 = Run(flr, handle, opts, args, std::move(rets));
EXPECT_TRUE(errors::IsNotFound(status2))
<< "Actual status: " << status2.ToString();
EXPECT_TRUE(absl::StrContains(status2.error_message(), "Handle"));
EXPECT_TRUE(absl::StrContains(status2.error_message(), "not found"));
return status;
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime::Options opts, CallFrameInterface* frame,
bool add_runner = true) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
[&call_count](std::function<void()> fn) {
++call_count;
test::function::FunctionTestSchedClosure(fn);
};
if (add_runner) {
opts.runner = &runner;
} else {
opts.runner = nullptr;
}
Notification done;
std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, frame, [&status, &done](const Status& s) {
status = s;
done.Notify();
});
done.WaitForNotification();
if (!status.ok()) {
return status;
}
if (add_runner) {
EXPECT_GE(call_count, 1); // Test runner is used.
}
return Status::OK();
}
Status InstantiateAndRunViaCallFrameInterface(FunctionLibraryRuntime* flr,
const string& name,
test::function::Attrs attrs,
const std::vector<Tensor>& args,
std::vector<Tensor*> rets) {
FunctionLibraryRuntime::Handle handle;
Status status = flr->Instantiate(name, attrs, &handle);
if (!status.ok()) {
return status;
}
const FunctionBody* fbody = flr->GetFunctionBody(handle);
FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
TF_RETURN_IF_ERROR(frame.SetArgs(args));
FunctionLibraryRuntime::Options opts;
status = Run(flr, handle, opts, &frame);
if (!status.ok()) return status;
std::vector<Tensor> retvals;
TF_RETURN_IF_ERROR(frame.GetRetvals(&retvals));
CHECK_EQ(rets.size(), retvals.size());
for (size_t i = 0; i < rets.size(); ++i) {
*rets[i] = retvals[i];
}
// Release the handle and try running again. It should not succeed.
status = flr->ReleaseHandle(handle);
if (!status.ok()) return status;
Status status2 = Run(flr, handle, opts, args, std::move(rets));
EXPECT_TRUE(errors::IsNotFound(status2));
EXPECT_TRUE(absl::StrContains(status2.error_message(), "Handle"));
EXPECT_TRUE(absl::StrContains(status2.error_message(), "not found"));
return status;
}
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
const string& name,
test::function::Attrs attrs) {
FunctionLibraryRuntime::Handle handle;
Status status = flr->Instantiate(name, attrs, &handle);
if (!status.ok()) {
LOG(ERROR) << status;
return nullptr;
}
const FunctionBody* fbody = flr->GetFunctionBody(handle);
CHECK_NOTNULL(fbody);
std::unique_ptr<Graph> ret(new Graph(lib_def_.get()));
CopyGraph(*fbody->graph, ret.get());
return ret;
}
std::unique_ptr<Graph> GetGradBody(FunctionLibraryRuntime* flr,
const string& func,
test::function::Attrs attrs) {
FunctionLibraryRuntime::Handle handle;
Status status = flr->Instantiate(func, attrs, &handle);
if (!status.ok()) {
LOG(ERROR) << status;
return nullptr;
}
const FunctionBody* fbody = flr->GetFunctionBody(handle);
CHECK_NOTNULL(fbody);
std::unique_ptr<FunctionBody> gbody(SymbolicGradient(*fbody));
CHECK_NOTNULL(gbody);
std::unique_ptr<Graph> ret(new Graph(lib_def_.get()));
CopyGraph(*gbody->graph, ret.get());
return ret;
}
FunctionLibraryRuntime* flr0_;
FunctionLibraryRuntime* flr1_;
FunctionLibraryRuntime* flr2_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
FunctionDefLibrary fdef_lib_;
};
TEST_F(FunctionLibraryRuntimeTest, IsStateful) {
Init({});
EXPECT_TRUE(flr0_->IsStateful("Variable"));
EXPECT_TRUE(flr0_->IsStateful("VariableV2"));
EXPECT_FALSE(flr0_->IsStateful("Matmul"));
}
TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) {
Init({test::function::XTimesTwo()});
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
TF_CHECK_OK(
InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
TF_CHECK_OK(InstantiateAndRunViaCallFrameInterface(
flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
}
TEST_F(FunctionLibraryRuntimeTest, XTimesN) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
TF_CHECK_OK(
InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
TF_CHECK_OK(
InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
TF_CHECK_OK(
InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64}));
}
TEST_F(FunctionLibraryRuntimeTest, XTimesNInLibDef) {
Init({});
FunctionDefLibrary proto;
*proto.add_function() = test::function::XTimesTwo();
*proto.add_function() = test::function::XTimesFour();
*proto.add_function() = test::function::XTimes16();
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), proto));
FunctionLibraryRuntime::InstantiateOptions options;
options.lib_def = lib_def.get();
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
// Ensure that the function is not installed in the base library.
HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
{} /* options */, {x}, {&y}),
"Not found: Function XTimesTwo is not defined.");
TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
{x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesFour", {{"T", DT_FLOAT}}, options,
{x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, options,
{x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({16, 32, 48, 64}));
// Ensure that the function is still not installed in the base library.
HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
{} /* options */, {x}, {&y}),
"Not found: Function XTimesTwo is not defined.");
}
TEST_F(FunctionLibraryRuntimeTest, XTimesNInLibDefAndDelayedInstantiation) {
using FDH = ::tensorflow::FunctionDefHelper;
Init({});
// Call XTimesFour via PartitionedCall which delays functions instantiation
// to the first call to Compute/ComputeAsync.
FunctionDef my_xt4 = FunctionDefHelper::Create(
"MyXTimesFour", {"x:float"}, {"z:float"}, {},
{{{"x_times_four"},
"PartitionedCall",
{"x"},
{{"Tin", DataTypeSlice({DT_FLOAT})},
{"Tout", DataTypeSlice({DT_FLOAT})},
{"f", FDH::FunctionRef("XTimesFour", {{"T", DT_FLOAT}})}}}},
/* Mapping between function returns and function node outputs. */
{{"z", "x_times_four:output:0"}});
FunctionDefLibrary lib;
*lib.add_function() = test::function::XTimesTwo();
*lib.add_function() = test::function::XTimesFour();
*lib.add_function() = my_xt4;
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), lib));
FunctionLibraryRuntime::InstantiateOptions options;
options.lib_def = lib_def.get();
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
// When we instantiate with `options` we should get x*4.
TF_CHECK_OK(InstantiateAndRun(flr0_, "MyXTimesFour", {}, options, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
// Create options that override XTimesFour body with XTimesTwo body.
FunctionDef xt4_override = test::function::XTimesTwo();
xt4_override.mutable_signature()->set_name("XTimesFour");
FunctionDefLibrary lib_override;
*lib_override.add_function() = xt4_override;
*lib_override.add_function() = my_xt4;
std::unique_ptr<FunctionLibraryDefinition> lib_def_override(
new FunctionLibraryDefinition(OpRegistry::Global(), lib_override));
options.lib_def = lib_def_override.get();
// When we instantiate with `options` we should get x*2.
TF_CHECK_OK(InstantiateAndRun(flr0_, "MyXTimesFour", {}, options, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
}
TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
auto T = DT_INT32;
// The expected sequence of outputs from this function is [6, 4, 0, 1, ...].
FunctionDef stateful_func = FDH::Define(
// Name
"RandomUniformWrapper",
// Args
{},
// Return values
{"y: int32"},
// Attrs
{},
// Nodes
{FDH::Const<int32>("shape", gtl::ArraySlice<int32>({1})),
FDH::Const<int32>("minval", 0),
FDH::Const<int32>("maxval", 10),
// A stateful node.
{{"y"},
"RandomUniformInt",
{"shape", "minval", "maxval"},
{{"seed", 37}, {"seed2", 48}, {"Tout", T}, {"T", T}}}});
Init({stateful_func});
FunctionLibraryRuntime::Handle handle;
TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, &handle));
FunctionLibraryRuntime::Options opts;
Tensor y;
{
// Simple case: instantiating with no state_handle.
for (int32 expected : {6, 4}) {
TF_CHECK_OK(Run(flr0_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
{
// Instantiating again with no state_handle should yield the same handle and
// the continuation of the same sequence.
FunctionLibraryRuntime::Handle handle_non_isolated;
TF_CHECK_OK(
Instantiate(flr0_, "RandomUniformWrapper", {}, &handle_non_isolated));
EXPECT_EQ(handle, handle_non_isolated);
for (int32 expected : {0, 1}) {
TF_CHECK_OK(Run(flr0_, handle_non_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
{
// Instantiating with a given state handle will create new state and yield
// the original sequence.
FunctionLibraryRuntime::InstantiateOptions options;
FunctionLibraryRuntime::Handle handle_isolated;
options.state_handle = "handle_1";
TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, options,
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
{
// Instantiating with a different given state handle will create new state
// and yield the original sequence.
FunctionLibraryRuntime::InstantiateOptions options;
FunctionLibraryRuntime::Handle handle_isolated;
options.state_handle = "handle_2";
TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, options,
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
}
{
// Reinstantiating after releasing a handle will yield the original sequence
// multiple times.
FunctionLibraryRuntime::InstantiateOptions options;
FunctionLibraryRuntime::Handle handle_isolated;
options.state_handle = "handle_3";
for (int i = 0; i < 2; ++i) {
TF_CHECK_OK(Instantiate(flr0_, "RandomUniformWrapper", {}, options,
&handle_isolated));
EXPECT_NE(handle, handle_isolated);
for (int32 expected : {6, 4, 0, 1}) {
TF_CHECK_OK(Run(flr0_, handle_isolated, opts, {}, {&y}, true));
test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({expected}));
}
TF_CHECK_OK(flr0_->ReleaseHandle(handle_isolated));
}
}
}
namespace {
class DummyExecutorRegistrar {
public:
DummyExecutorRegistrar() {
ExecutorFactory::Register("DUMMY", new Factory());
}
private:
class Factory : public ExecutorFactory {
Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
std::unique_ptr<Executor>* out_executor) override {
return errors::Internal("This is a dummy.");
}
};
};
static DummyExecutorRegistrar registrar;
} // namespace
TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
Init({test::function::XTimesTwo()});
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
// Test that the default executor works.
{
FunctionLibraryRuntime::InstantiateOptions options;
options.executor_type = "";
TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
options, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
}
// Test the explicit registration for the default executor.
{
FunctionLibraryRuntime::InstantiateOptions options;
options.executor_type = "DEFAULT";
TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
options, {x}, {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
}
// Test that a non-default executor factory can be invoked.
{
FunctionLibraryRuntime::InstantiateOptions options;
options.executor_type = "DUMMY";
HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
{x}, {&y}),
"Internal: This is a dummy.");
}
// Test that a non-default executor factory can be invoked via an attr.
{
FunctionLibraryRuntime::InstantiateOptions options;
HasError(InstantiateAndRun(flr0_, "XTimesTwo",
{{"T", DT_FLOAT}, {"_executor", "DUMMY"}},
options, {x}, {&y}),
"Internal: This is a dummy.");
}
// Test that a non-default executor factory specified via an
// `InstantiateOptions` supersedes the attr when both are present.
{
FunctionLibraryRuntime::InstantiateOptions options;
options.executor_type = "DUMMY";
HasError(
InstantiateAndRun(flr0_, "XTimesTwo",
{{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
options, {x}, {&y}),
"Internal: This is a dummy.");
}
// Test that non-existent executor types trigger an error.
{
FunctionLibraryRuntime::InstantiateOptions options;
options.executor_type = "UNKNOWN_EXECUTOR";
HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
{x}, {&y}),
"Not found: No executor factory registered for the given executor "
"type: UNKNOWN_EXECUTOR");
}
{
FunctionLibraryRuntime::InstantiateOptions options;
HasError(
InstantiateAndRun(flr0_, "XTimesTwo",
{{"T", DT_FLOAT}, {"_executor", "UNKNOWN_EXECUTOR"}},
options, {x}, {&y}),
"Not found: No executor factory registered for the given executor "
"type: UNKNOWN_EXECUTOR");
}
}
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}});
ASSERT_TRUE(g != nullptr);
{
Scope s = Scope::NewRootScope();
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
auto arg = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto a = test::function::Call(&s, "x4", "XTimesFour", {arg});
auto b = test::function::Call(&s, "y", "XTimesFour", {a});
auto ret = ops::_Retval(s.WithOpName("y_RetVal"), b, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
ExpandInlineFunctions(flr0_, g.get());
{
Scope s = Scope::NewRootScope();
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::Identity(s.WithOpName("Func/x4/input/_0"), x);
auto x4_x2 = test::function::Call(&s, "x4/x2", "XTimesTwo", {func0});
auto x4_y = test::function::Call(&s, "x4/y", "XTimesTwo", {x4_x2});
auto func1 = ops::Identity(s.WithOpName("Func/x4/output/_1"), x4_y);
auto func2 = ops::Identity(s.WithOpName("Func/y/input/_2"), func1);
auto y_x2 = test::function::Call(&s, "y/x2", "XTimesTwo", {func2});
auto y_y = test::function::Call(&s, "y/y", "XTimesTwo", {y_x2});
auto func3 = ops::Identity(s.WithOpName("Func/y/output/_3"), y_y);
auto ret = ops::_Retval(s.WithOpName("y_RetVal"), func3, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
ExpandInlineFunctions(flr0_, g.get());
GraphDef e2;
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_two = ops::Const<int64>(s.WithOpName("x4/x2/two"), 2LL);
auto x4_y_two = ops::Const<int64>(s.WithOpName("x4/y/two"), 2LL);
auto y_x2_two = ops::Const<int64>(s.WithOpName("y/x2/two"), 2LL);
auto y_y_two = ops::Const<int64>(s.WithOpName("y/y/two"), 2LL);
auto x4_x2_scale =
ops::Cast(s.WithOpName("x4/x2/scale"), x4_x2_two, DT_FLOAT);
auto x4_y_scale = ops::Cast(s.WithOpName("x4/y/scale"), x4_y_two, DT_FLOAT);
auto y_x2_scale = ops::Cast(s.WithOpName("y/x2/scale"), y_x2_two, DT_FLOAT);
auto y_y_scale = ops::Cast(s.WithOpName("y/y/scale"), y_y_two, DT_FLOAT);
auto func0 = ops::Identity(s.WithOpName("Func/x4/input/_0"), x);
auto func4 = ops::Identity(s.WithOpName("Func/x4/x2/input/_4"), func0);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), func4, x4_x2_scale);
auto func5 = ops::Identity(s.WithOpName("Func/x4/x2/output/_5"), x4_x2_y);
auto func6 = ops::Identity(s.WithOpName("Func/x4/y/input/_6"), func5);
auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), func6, x4_y_scale);
auto func7 = ops::Identity(s.WithOpName("Func/x4/y/output/_7"), x4_y_y);
auto func1 = ops::Identity(s.WithOpName("Func/x4/output/_1"), func7);
auto func2 = ops::Identity(s.WithOpName("Func/y/input/_2"), func1);
auto func8 = ops::Identity(s.WithOpName("Func/y/x2/input/_8"), func2);
auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), func8, y_x2_scale);
auto func9 = ops::Identity(s.WithOpName("Func/y/x2/output/_9"), y_x2_y);
auto func10 = ops::Identity(s.WithOpName("Func/y/y/input/_10"), func9);
auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), func10, y_y_scale);
auto func11 = ops::Identity(s.WithOpName("Func/y/y/output/_11"), y_y_y);
auto func3 = ops::Identity(s.WithOpName("Func/y/output/_3"), func11);
auto ret = ops::_Retval(s.WithOpName("y_RetVal"), func3, 0);
TF_ASSERT_OK(s.ToGraphDef(&e2));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(e2, actual);
}
// No further inlining.
ExpandInlineFunctions(flr0_, g.get());
{
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(e2, actual);
}
// Get rid of redundant Identity nodes.
RemoveIdentityNodes(g.get());
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_two = ops::Const<int64>(s.WithOpName("x4/x2/two"), 2LL);
auto x4_y_two = ops::Const<int64>(s.WithOpName("x4/y/two"), 2LL);
auto y_x2_two = ops::Const<int64>(s.WithOpName("y/x2/two"), 2LL);
auto y_y_two = ops::Const<int64>(s.WithOpName("y/y/two"), 2LL);
auto x4_x2_scale =
ops::Cast(s.WithOpName("x4/x2/scale"), x4_x2_two, DT_FLOAT);
auto x4_y_scale = ops::Cast(s.WithOpName("x4/y/scale"), x4_y_two, DT_FLOAT);
auto y_x2_scale = ops::Cast(s.WithOpName("y/x2/scale"), y_x2_two, DT_FLOAT);
auto y_y_scale = ops::Cast(s.WithOpName("y/y/scale"), y_y_two, DT_FLOAT);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), x4_x2_y, x4_y_scale);
auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), x4_y_y, y_x2_scale);
auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), y_x2_y, y_y_scale);
auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y_y_y, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
// Verifies that control dependencies on the caller are added as control
// dependencies on any function calls created by inlining.
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithInputControlEdges) {
Init({test::function::XTimesTwo(), test::function::XTimesFour()});
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
{
Scope s = Scope::NewRootScope();
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
auto c = ops::NoOp(s.WithOpName("c"));
auto b = test::function::Call(&s, "b", "XTimesFour", {a});
s.graph()->AddControlEdge(c.operation.node(), b.node());
auto ret = ops::_Retval(s.WithOpName("b_RetVal"), b, 0);
TF_ASSERT_OK(s.ToGraph(g.get()));
}
ExpandInlineFunctions(flr0_, g.get());
{
Scope s = Scope::NewRootScope();
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
auto c = ops::NoOp(s.WithOpName("c"));
auto func0 = ops::NoOp(s.WithOpName("Func/b/input_control_node/_0")
.WithControlDependencies({c}));
auto func1 = ops::Identity(
s.WithOpName("Func/b/input/_1").WithControlDependencies({func0}), a);
auto b_x2 = test::function::Call(&s, "b/x2", "XTimesTwo", {func1});
s.graph()->AddControlEdge(func0.operation.node(), b_x2.node());
auto b_y = test::function::Call(&s, "b/y", "XTimesTwo", {b_x2});
s.graph()->AddControlEdge(func0.operation.node(), b_y.node());
auto func2 = ops::Identity(s.WithOpName("Func/b/output/_2"), b_y);
auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
ExpandInlineFunctions(flr0_, g.get());
{
Scope s = Scope::NewRootScope();
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
auto c = ops::NoOp(s.WithOpName("c"));
auto func0 = ops::NoOp(s.WithOpName("Func/b/input_control_node/_0")
.WithControlDependencies({c}));
auto func1 = ops::Identity(
s.WithOpName("Func/b/input/_1").WithControlDependencies({func0}), a);
auto func3 = ops::NoOp(s.WithOpName("Func/b/x2/input_control_node/_3")
.WithControlDependencies({func0}));
auto func4 = ops::Identity(
s.WithOpName("Func/b/x2/input/_4").WithControlDependencies({func3}),
func1);
auto b_x2_two = ops::Const(
s.WithOpName("b/x2/two").WithControlDependencies({func3}), 2LL);
auto b_x2_scale = ops::Cast(s.WithOpName("b/x2/scale"), b_x2_two, DT_FLOAT);
auto b_x2_y = ops::Mul(s.WithOpName("b/x2/y"), func4, b_x2_scale);
auto func5 = ops::Identity(s.WithOpName("Func/b/x2/output/_5"), b_x2_y);
auto func6 = ops::NoOp(s.WithOpName("Func/b/y/input_control_node/_6")
.WithControlDependencies({func0}));
auto func7 = ops::Identity(
s.WithOpName("Func/b/y/input/_7").WithControlDependencies({func6}),
func5);
auto b_y_two = ops::Const(
s.WithOpName("b/y/two").WithControlDependencies({func6}), 2LL);
auto b_y_scale = ops::Cast(s.WithOpName("b/y/scale"), b_y_two, DT_FLOAT);
auto b_y_y = ops::Mul(s.WithOpName("b/y/y"), func7, b_y_scale);
auto func8 = ops::Identity(s.WithOpName("Func/b/y/output/_8"), b_y_y);
auto func2 = ops::Identity(s.WithOpName("Func/b/output/_2"), func8);
auto ret = ops::_Retval(s.WithOpName("b_RetVal"), func2, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
TEST_F(FunctionLibraryRuntimeTest,
ExpandInlineFunctionsWithOutputControlEdges) {
using test::function::NDef;
// `add` node is not required to compute regular output `o`, but it must
// execute because it is in `control_ret`.
const FunctionDef func =
FDH::Create("AddAndMul", {"i: float"}, {"o: float"}, {},
{{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
{{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/{{"o", "ret:z:0"}},
/*control_ret_def=*/{{"must_execute", "add"}});
Init({func});
// Construct a graph for the function call:
//
// a = Arg[dtype=DT_FLOAT]
// b = AddAndMul(a)
// c = NoOp(^b)
// ret = RetVal(b, ^c)
const auto init_graph = [this](std::unique_ptr<Graph>* g) -> void {
*g = absl::make_unique<Graph>(OpRegistry::Global());
Scope s = Scope::NewRootScope();
TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
auto b = test::function::Call(&s, "b", "AddAndMul", {a});
auto c = ops::NoOp(s.WithOpName("c"));
auto ret = ops::_Retval(s.WithOpName("ret"), b, 0);
s.graph()->AddControlEdge(b.node(), c.operation.node());
s.graph()->AddControlEdge(c.operation.node(), ret.operation.node());
TF_ASSERT_OK(s.ToGraph(g->get()));
};
std::unique_ptr<Graph> g;
ExpandInlineFunctionsOptions opts;
const string input_node = "Func/b/input/_0";
const string output_node = "Func/b/output/_1";
const string output_control_node = "Func/b/output_control_node/_2";
// Use data outputs as output control source.
opts.native_options.output_control_src = OutputControlSrc::kDataOutputs;
init_graph(&g);
ExpandInlineFunctions(flr0_, g.get(), opts);
{
GraphDef expected = test::function::GDef(
{NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
NDef(output_control_node, "NoOp", {"^Func/b/output/_1"}, {}),
NDef("c", "NoOp", {"^" + output_control_node}, {}),
NDef("ret", "_Retval", {output_node, "^c"},
{{"T", DT_FLOAT}, {"index", 0}})},
{func});
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
// Use control outputs as output control source.
opts.native_options.output_control_src = OutputControlSrc::kControlOutputs;
init_graph(&g);
ExpandInlineFunctions(flr0_, g.get(), opts);
{
GraphDef expected = test::function::GDef(
{NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
NDef(output_control_node, "NoOp", {"^b/add"}, {}),
NDef("c", "NoOp", {"^" + output_control_node}, {}),
NDef("ret", "_Retval", {output_node, "^c"},
{{"T", DT_FLOAT}, {"index", 0}})},
{func});
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndKeepCallerNode) {
using test::function::NDef;
using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
const FunctionDef func =
FDH::Create("AddAndMul", {"i: float"}, {"o: float"}, {},
{{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
{{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/{{"o", "ret:z:0"}},
/*control_ret_def=*/{{"must_execute", "add"}});
Init({func});
// Construct a graph:
// a = Arg[dtype=DT_FLOAT]
// b = FunctionWithControlOutputs(a)
auto construct_graph = [this](std::unique_ptr<Graph>* g) -> Status {
Scope s = Scope::NewRootScope();
TF_RETURN_IF_ERROR(s.graph()->AddFunctionLibrary(fdef_lib_));
auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
auto b = test::function::Call(&s, "b", "AddAndMul", {a});
TF_RETURN_IF_ERROR(s.ToGraph(g->get()));
return Status::OK();
};
const string input_node = "Func/b/input/_0";
const string output_node = "Func/b/output/_1";
const string output_control_node = "Func/b/output_control_node/_2";
// Construct expected graph after function inlining.
auto expected_graph = [&](const NodeDef& caller) -> GraphDef {
return test::function::GDef(
{
NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
NDef(output_control_node, "NoOp", {"^b/add"}, {}),
caller, // Keep node in a graph with the same name as caller node.
},
{func});
};
ExpandInlineFunctionsOptions opts;
opts.native_options.output_control_src = OutputControlSrc::kControlOutputs;
// Keep inlined function call node fetchable.
{
opts.native_options.keep_caller_node = KeepCallerNode::kFetchable;
std::unique_ptr<Graph> g = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(construct_graph(&g));
ExpandInlineFunctions(flr0_, g.get(), opts);
GraphDef expected =
expected_graph(/*caller=*/
NDef("b", "IdentityN",
{output_node, "^" + output_control_node},
{{"T", DataTypeSlice{DT_FLOAT}}}));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
// Keep inlined function call node targetable.
{
opts.native_options.keep_caller_node = KeepCallerNode::kTargetable;
std::unique_ptr<Graph> g = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(construct_graph(&g));
ExpandInlineFunctions(flr0_, g.get(), opts);
GraphDef expected =
expected_graph(/*caller=*/
NDef("b", "NoOp", {"^" + output_control_node}, {}));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsAndPlaceInlinedNodes) {
using test::function::NDef;
using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
const string arg_device = "/job:arg/replica:0/task:0/device:GPU";
const string call_device = "/job:call/replica:0/task:1/device:GPU";
const string body_device = "/job:body/replica:0/task:1/device:CPU";
const FunctionDef func = FDH::Create(
"AddFunc", {"i: float"}, {"o: float"}, {},
{{{"ret"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}, {}, body_device}},
/*ret_def=*/{{"o", "ret:z:0"}});
Init({func});
// Construct a graph:
// a = Arg[dtype=DT_FLOAT, _device=arg_device]
// b = AddFunc[_device=call_device](a)
auto construct_graph = [&](std::unique_ptr<Graph>* g) -> Status {
Scope s = Scope::NewRootScope();
TF_RETURN_IF_ERROR(s.graph()->AddFunctionLibrary(fdef_lib_));
auto a = ops::_Arg(s.WithOpName("a").WithDevice(arg_device), DT_FLOAT, 0);
auto b = test::function::Call(&s, "b", "AddFunc", {a});
TF_RETURN_IF_ERROR(s.ToGraph(g->get()));
for (Node* node : (*g)->op_nodes()) {
if (node->name() == "b") node->set_requested_device(call_device);
}
return Status::OK();
};
const string input_node = "Func/b/input/_0";
const string output_node = "Func/b/output/_1";
const string output_control_node = "Func/b/output_control_node/_2";
// Construct expected graph after function inlining.
auto expected_graph = [&](const std::vector<string>& placed) -> GraphDef {
return test::function::GDef(
{
NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}, placed[0]),
NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}, placed[1]),
NDef("b/ret", "Add", {input_node, input_node}, {{"T", DT_FLOAT}},
placed[2]),
NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}},
placed[3]),
NDef(output_control_node, "NoOp", {"^" + output_node}, {},
placed[4]),
},
{func});
};
ExpandInlineFunctionsOptions opts;
opts.native_options.keep_caller_node = KeepCallerNode::kDoNotKeep;
// Place only input nodes to match input device.
{
opts.native_options.inlined_function_body_placer =
InlinedFunctionBodyPlacer::Default();
auto g = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(construct_graph(&g));
ExpandInlineFunctions(flr0_, g.get(), opts);
GraphDef expected = expected_graph({/*a*/ arg_device, //
/*input*/ arg_device, //
/*body*/ body_device, //
/*output*/ "", //
/*control_output*/ ""} //
);
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
// Place all nodes on the call node device.
{
opts.native_options.inlined_function_body_placer =
InlinedFunctionBodyPlacer::SingleDevice();
auto g = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(construct_graph(&g));
ExpandInlineFunctions(flr0_, g.get(), opts);
GraphDef expected = expected_graph({/*a*/ arg_device, //
/*input*/ call_device, //
/*body*/ call_device, //
/*output*/ call_device, //
/*control_output*/ call_device} //
);
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
// Multi device function placement.
{
opts.native_options.inlined_function_body_placer =
InlinedFunctionBodyPlacer::MultiDevice();
auto g = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(construct_graph(&g));
const string merged_device = "/job:call/replica:0/task:1/device:CPU:*";
ExpandInlineFunctions(flr0_, g.get(), opts);
GraphDef expected = expected_graph({/*a*/ arg_device, //
/*input*/ arg_device, //
/*body*/ merged_device, //
/*output*/ "", //
/*control_output*/ call_device} //
);
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
auto T = DT_INT32;
FunctionDef stateful_func = FDH::Define(
// Name
"SquareAndAddOneWithStatefulNodes",
// Args
{"x: int32", "y: float32"},
// Return values
{"z: int32"},
// Attrs
{},
// Nodes
{// a = Square<T>(x)
{{"a"}, "Square", {"x"}, {{"T", T}}},
// 1
FDH::Const("o", 1),
// A bunch of extra arithmetic that y doesn't depend on
{{"x1"}, "Add", {"o", "o"}, {{"T", T}}},
{{"x2"}, "Mul", {"a", "x1"}, {{"T", T}}},
{{"x3"}, "Mul", {"x1", "x2"}, {{"T", T}}},
FDH::Const<int32>("shape", {1, 2}),
// A stateful node.
{{"keep_me"},
"RandomUniform",
{"shape"},
{{"T", T}, {"dtype", DT_FLOAT}}},
// z = Add<T>(a, o)
{{"z"}, "Add", {"a", "o"}, {{"T", T}}}});
Init({stateful_func});
auto x = test::AsTensor<int32>({1, 2, 3, 4});
auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0});
Tensor z;
FunctionLibraryRuntime::Handle handle;
TF_CHECK_OK(
Instantiate(flr0_, "SquareAndAddOneWithStatefulNodes", {}, &handle));
StepStats stats;
StepStatsCollector stats_collector(&stats);
FunctionLibraryRuntime::Options opts;
opts.stats_collector = &stats_collector;
TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z}));
TF_CHECK_OK(flr0_->ReleaseHandle(handle));
TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {},
{x, y}, {&z}));
test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17}));
stats_collector.FinalizeAndSwap(&stats);
// Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to
// execute.
std::set<string> expected_node_names(
{"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"});
std::set<string> executed_node_names;
for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
executed_node_names.insert(node_stats.node_name());
}
EXPECT_EQ(expected_node_names, executed_node_names);
}
TEST_F(FunctionLibraryRuntimeTest, DoNotPruneControlOutputsFromBody) {
// `add` node is not required to compute regular output `o`, but it must
// execute because it is in `control_ret`.
const FunctionDef func =
FDH::Create("FunctionWithControlOutputs", {"i: float"}, {"o: float"}, {},
{
{{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
{{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}},
},
/*ret_def=*/{{"o", "ret:z:0"}},
/*control_ret_def=*/{{"must_execute", "add"}});
Init({func});
auto x = test::AsTensor<float>({1.25});
Tensor z;
FunctionLibraryRuntime::Handle handle;
TF_CHECK_OK(Instantiate(flr1_, "FunctionWithControlOutputs", {}, &handle));
StepStats stats;
StepStatsCollector stats_collector(&stats);
FunctionLibraryRuntime::Options opts;
opts.stats_collector = &stats_collector;
TF_CHECK_OK(Run(flr1_, handle, opts, {x}, {&z}));
TF_CHECK_OK(flr1_->ReleaseHandle(handle));
TF_CHECK_OK(
InstantiateAndRun(flr1_, "FunctionWithControlOutputs", {}, {x}, {&z}));
test::ExpectTensorEqual<float>(z, test::AsTensor<float>({1.25 * 1.25}));
stats_collector.FinalizeAndSwap(&stats);
std::set<string> expected_node_names(
{"_SOURCE", "i", "add", "ret", "o_RetVal"});
std::set<string> executed_node_names;
for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
executed_node_names.insert(node_stats.node_name());
}
EXPECT_EQ(expected_node_names, executed_node_names);
}
// Constant folding generates names using a global counter.
// This function invokes constant folding and parses the counter
// from the generated node name.
int GetConstantFoldingCounter() {
Graph g(OpRegistry::Global());
Scope s = Scope::NewRootScope();
auto a = ops::Const<float>(s, {1.0}, {});
auto b = ops::Const<float>(s, {2.0}, {});
auto add = ops::Add(s.WithOpName("add"), a, b);
auto send =
ops::_Send(s.WithOpName("s1"), add, "add", "sender", 0, "receiver");
TF_CHECK_OK(s.ToGraph(&g));
bool was_mutated;
ConstantFoldingOptions opt{};
TF_CHECK_OK(
ConstantFold(opt, nullptr, Env::Default(), nullptr, &g, &was_mutated));
GraphDef def;
g.ToGraphDef(&def);
for (const NodeDef& node : def.node()) {
if (absl::StartsWith(node.name(), "add/")) {
std::vector<std::string> v = absl::StrSplit(node.name(), "__cf__");
CHECK_GT(v.size(), 1);
int counter;
CHECK(absl::SimpleAtoi(v[v.size() - 1], &counter));
return counter;
}
}
LOG(FATAL) << "Should have found a node that replcaed add";
}
TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
std::unique_ptr<Graph> g = GetFuncBody(flr0_, "XTimes16", {{"T", DT_FLOAT}});
ASSERT_TRUE(g != nullptr);
ExpandInlineFunctions(flr0_, g.get());
int cf_counter = GetConstantFoldingCounter();
OptimizeGraph(flr0_, &g);
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
s.WithOpName("x4/x2/scale/_12__cf__" + std::to_string(cf_counter + 1))
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
auto x4_y_y = ops::Mul(s.WithOpName("x4/y/y"), x4_x2_y, x4_x2_scale);
auto y_x2_y = ops::Mul(s.WithOpName("y/x2/y"), x4_y_y, x4_x2_scale);
auto y_y_y = ops::Mul(s.WithOpName("y/y/y"), y_x2_y, x4_x2_scale);
auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y_y_y, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
auto func = FDH::Create( // Creates a FunctionDef using NodeDefs
// Name
"ManySwapsNodeDef",
// Input
{"x: float", "y: float"},
// Output
{"o: float"},
// Attr
{},
// Nodes
{{{"a"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}},
{{"b"}, "Swap", {"a:o0", "a:o1"}, {{"T", DT_FLOAT}}},
{{"c"}, "Swap", {"b:o0", "b:o1"}, {{"T", DT_FLOAT}}},
{{"d"}, "Swap", {"c:o0", "c:o1"}, {{"T", DT_FLOAT}}},
{{"e"}, "Swap", {"d:o0", "d:o1"}, {{"T", DT_FLOAT}}},
{{"f"}, "Swap", {"e:o0", "e:o1"}, {{"T", DT_FLOAT}}},
{{"g"}, "Identity", {"f:o0"}, {{"T", DT_FLOAT}}}},
// Return
{{"o", "g:output"}});
Init({test::function::Swap(), func});
std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsNodeDef", {});
ASSERT_TRUE(g != nullptr);
OptimizeGraph(flr0_, &g);
const char* e0 = R"P(
(n2:float, n3:float) -> (n2:float) {
}
)P";
EXPECT_EQ(e0, DebugString(g.get()));
}
TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
auto func = FDH::Create(
// Name
"ManySwapsFirst",
// Args
{"x: float", "y: float"},
// Return values
{"o: float"},
// attr def
{},
// Nodes
//
// o = x*x + y*y. Furthermore, The 1st swap depends on x2, and
// y2 depends on the 2nd swap. The 2nd swap has data dependency
// on the 1st swap. The optimization should maintain the control
// dependencies.
{{{"a0"}, "Swap", {"x", "y"}, {{"T", DT_FLOAT}}, {"x2"}},
{{"a1"}, "Swap", {"a0:o0:0", "a0:o1:0"}, {{"T", DT_FLOAT}}},
{{"x2"}, "Mul", {"x", "x"}, {{"T", DT_FLOAT}}},
{{"y2"}, "Mul", {"y", "y"}, {{"T", DT_FLOAT}}, {"a1"}},
{{"o"}, "Add", {"x2:z:0", "y2:z:0"}, {{"T", DT_FLOAT}}}},
{{"o", "o:z:0"}});
Init({test::function::Swap(), func});
std::unique_ptr<Graph> g = GetFuncBody(flr0_, "ManySwapsFirst", {});
ASSERT_TRUE(g != nullptr);
OptimizeGraph(flr0_, &g);
// NOTE: We can remove func0, func1, func2, func9 with a control edge
// n8->n5. But we don't have a pass doing that.
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
auto x2 = ops::Mul(s.WithOpName("x2"), x, x);
auto func0 = ops::NoOp(s.WithOpName("Func/a0/input_control_node/_0")
.WithControlDependencies(x2));
auto func1 = ops::Identity(
s.WithOpName("Func/a0/input/_1").WithControlDependencies({func0}), x);
auto func2 = ops::Identity(
s.WithOpName("Func/a0/input/_2").WithControlDependencies({func0}), y);
auto func9 = ops::NoOp(
s.WithOpName("Func/a1/output_control_node/_9")
.WithControlDependencies({func1.output.op(), func2.output.op()}));
auto y2 =
ops::Mul(s.WithOpName("y2").WithControlDependencies({func9}), y, y);
auto o = ops::Add(s.WithOpName("o"), x2, y2);
auto ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) {
Init({test::function::XTimesTwo(), test::function::XTimesFour()});
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
HasError(InstantiateAndRun(flr0_, "Foo", {{"T", DT_FLOAT}}, {x}, {&y}),
"Not found: Function Foo is not defined.");
}
TEST_F(FunctionLibraryRuntimeTest, Error_InstantiationError) {
auto bad_x_times_two = FDH::Define(
// Name
"XTimesTwo",
// Args
{"x: T"},
// Return values
{"y: T"},
// Attr def
{"T: {float, double, int32, int64}"},
// Nodes
{
{{"y"}, "Add", {"x", "x"}, {{"no_T", "$T"}}},
});
Init({bad_x_times_two, test::function::XTimesFour(),
test::function::XTimes16()});
// Instantiating "XTimesTwo" should fail.
FunctionLibraryRuntime::Handle handle;
HasError(flr0_->Instantiate(
"XTimesTwo", test::function::Attrs({{"T", DT_FLOAT}}), &handle),
"Not found: type attr not found");
// But XTimesFour and XTimes16 instantiation should succeed. Only
// when they run, they fail because XTimesTwo is bad.
TF_CHECK_OK(flr0_->Instantiate(
"XTimesFour", test::function::Attrs({{"T", DT_FLOAT}}), &handle));
TF_CHECK_OK(flr0_->Instantiate(
"XTimes16", test::function::Attrs({{"T", DT_FLOAT}}), &handle));
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
HasError(InstantiateAndRun(flr0_, "XTimes16", {{"T", DT_FLOAT}}, {x}, {&y}),
"type attr not found");
}
TEST_F(FunctionLibraryRuntimeTest, Error_BadControlFlow) {
Init({test::function::InvalidControlFlow()});
auto x = test::AsTensor<int32>({0});
DCHECK_EQ(x.dtype(), DT_INT32);
Tensor y;
HasError(InstantiateAndRun(flr0_, "InvalidControlFlow", {}, {x}, {&y}),
"{{node add}} has inputs from different frames. The input"
" {{node enter}} is in frame 'while'. The input {{node i}} is in"
" frame ''.");
}
TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
std::unique_ptr<Graph> f = GetFuncBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}});
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto two = ops::Const(s.WithOpName("two"), 2LL);
auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT);
auto y = ops::Mul(s.WithOpName("y"), x, scale);
auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
f->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
std::unique_ptr<Graph> g = GetGradBody(flr0_, "XTimesTwo", {{"T", DT_FLOAT}});
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto two = ops::Const(s.WithOpName("two"), 2LL);
auto scale = ops::Cast(s.WithOpName("scale"), two, DT_FLOAT);
auto y = ops::Mul(s.WithOpName("y"), x, scale);
NameAttrList fn0;
fn0.set_name("Mul");
(*fn0.mutable_attr())["T"].set_type(DT_FLOAT);
auto func1 = ops::SymbolicGradient(
s.WithOpName("Func/_1"), std::initializer_list<Input>{x, scale, func0},
{DT_FLOAT, DT_FLOAT}, fn0);
NameAttrList fn1;
fn1.set_name("Cast");
(*fn1.mutable_attr())["SrcT"].set_type(DT_INT64);
(*fn1.mutable_attr())["DstT"].set_type(DT_FLOAT);
(*fn1.mutable_attr())["Truncate"].set_b(false);
auto func2 = ops::SymbolicGradient(
s.WithOpName("Func/_2"),
std::initializer_list<Input>{two, func1.output[1]}, {DT_INT64}, fn1);
auto func3 = ops::_Retval(s.WithOpName("Func/_3"), func1[0], 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
int cf_counter = GetConstantFoldingCounter();
OptimizeGraph(flr0_, &g);
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
s.WithOpName("scale/_6__cf__" + std::to_string(cf_counter + 2))
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
s.WithOpName("Func/_1/sy/_5__cf__" + std::to_string(cf_counter + 1))
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(
s.WithOpName("Func/_1/rx"), func1_sx, const0);
auto func1_sum_gx =
ops::Sum(s.WithOpName("Func/_1/sum_gx"), func1_gx, func1_rx.r0);
auto func1_dx =
ops::Reshape(s.WithOpName("Func/_1/dx"), func1_sum_gx, func1_sx);
auto func2 = ops::_Retval(s.WithOpName("Func/_3"), func1_dx, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
TEST_F(FunctionLibraryRuntimeTest, Gradient_Select) {
FunctionDef my_select = FunctionDefHelper::Create(
"MySelect",
// Args
{"condition: bool", "t: float32", "e: float32"},
// Return values
{"z: float32"},
// Attrs
{},
// Nodes
{
{{"select0"}, "Select", {"condition", "t", "e"}, {{"T", DT_FLOAT}}},
{{"select1"}, "Select", {"condition", "t", "e"}, {{"T", DT_FLOAT}}},
{{"add"},
"Add",
{"select0:output", "select1:output"},
{{"T", DT_FLOAT}}},
},
// Output mapping
{{"z", "add:z"}});
FunctionDef select_grad = FunctionDefHelper::Create(
"MySelectGrad",
// Args
{"condition: bool", "t:float32", "e: float32", "dz: float32"},
// Return values
{"dt: float32"},
// Attrs
{},
// Nodes
{{
{"grad"},
"SymbolicGradient",
{"condition", "t", "e", "dz"},
{
{"f", FunctionDefHelper::FunctionRef("MySelect")},
{"Tin", DataTypeSlice({DT_BOOL, DT_FLOAT, DT_FLOAT, DT_FLOAT})},
{"Tout", DataTypeSlice({DT_BOOL, DT_FLOAT, DT_FLOAT})},
},
}},
// Output mapping
{{"dt", "grad:output:1"}});
Init({my_select, select_grad});
auto condition = test::AsTensor<bool>({false});
auto t = test::AsTensor<float>({13.0});
auto e = test::AsTensor<float>({15.0});
auto dz = test::AsTensor<float>({1.0});
Tensor y;
TF_EXPECT_OK(InstantiateAndRun(flr0_, "MySelectGrad", {},
{condition, t, e, dz}, {&y}));
}
TEST_F(FunctionLibraryRuntimeTest, Gradient_Add) {
Init({});
auto T = DT_FLOAT;
std::unique_ptr<Graph> g = GetFuncBody(
flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Add", {{"T", T}})}});
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
auto dz = ops::_Arg(s.WithOpName("dz"), DT_FLOAT, 2);
auto gx = ops::Identity(s.WithOpName("gx"), dz);
auto gy = ops::Identity(s.WithOpName("gy"), dz);
auto sx = ops::Shape(s.WithOpName("sx"), x);
auto sy = ops::Shape(s.WithOpName("sy"), y);
auto rx = ops::internal::BroadcastGradientArgs(s.WithOpName("rx"), sx, sy);
auto sum_gx = ops::Sum(s.WithOpName("sum_gx"), gx, rx.r0);
auto sum_gy = ops::Sum(s.WithOpName("sum_gy"), gy, rx.r1);
auto dx = ops::Reshape(s.WithOpName("dx"), sum_gx, sx);
auto dy = ops::Reshape(s.WithOpName("dy"), sum_gy, sy);
auto dx_ret = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
auto dy_ret = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
TEST_F(FunctionLibraryRuntimeTest, Gradient_Mul) {
Init({});
auto T = DT_FLOAT;
std::unique_ptr<Graph> g = GetFuncBody(
flr0_, "SymbolicGradient", {{"f", FDH::FunctionRef("Mul", {{"T", T}})}});
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
auto dz = ops::_Arg(s.WithOpName("dz"), DT_FLOAT, 2);
auto gx = ops::Mul(s.WithOpName("gx"), dz, y);
auto sx = ops::Shape(s.WithOpName("sx"), x);
auto gy = ops::Mul(s.WithOpName("gy"), x, dz);
auto sy = ops::Shape(s.WithOpName("sy"), y);
auto rx = ops::internal::BroadcastGradientArgs(s.WithOpName("rx"), sx, sy);
auto sum_gx = ops::Sum(s.WithOpName("sum_gx"), gx, rx.r0);
auto sum_gy = ops::Sum(s.WithOpName("sum_gy"), gy, rx.r1);
auto dx = ops::Reshape(s.WithOpName("dx"), sum_gx, sx);
auto dy = ops::Reshape(s.WithOpName("dy"), sum_gy, sy);
auto dx_ret = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
auto dy_ret = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
}
TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
// Sum(Add(x, y))
auto T = DT_FLOAT;
auto test = FDH::Define("Test", {"x:float", "y:float"}, {"l:float"}, {},
{
{{"z"}, "Add", {"x", "y"}, {{"T", T}}},
FDH::Const("zero", 0),
FDH::Const("one", 1),
{{"r"}, "Rank", {"z"}, {{"T", T}}},
{{"indices"}, "Range", {"zero", "r", "one"}},
{{"l"}, "Sum", {"z", "indices"}, {{"T", T}}},
});
// TestGrad = Test'(x, y)
auto grad = FDH::Define("TestGrad", {"x:float", "y:float"},
{"dx:float", "dy:float"}, {},
{FDH::Const<float>("dz", 1),
{{"grad0", "grad1"},
"SymbolicGradient",
{"x", "y", "dz"},
{
{"f", FDH::FunctionRef("Test")},
{"Tin", DataTypeSlice{T, T, T}},
{"Tout", DataTypeSlice{T, T}},
}},
{{"dx"}, "Identity", {"grad0"}, {{"T", DT_FLOAT}}},
{{"dy"}, "Identity", {"grad1"}, {{"T", DT_FLOAT}}}});
Init({test, grad});
std::unique_ptr<Graph> g = GetFuncBody(flr0_, "TestGrad", {});
ASSERT_TRUE(g != nullptr);
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
NameAttrList fn;
fn.set_name("Test");
auto grad0 = ops::SymbolicGradient(s.WithOpName("grad0"),
std::initializer_list<Input>{x, y, dz},
{DT_FLOAT, DT_FLOAT}, fn);
auto dx = ops::Identity(s.WithOpName("dx"), grad0[0]);
auto dy = ops::Identity(s.WithOpName("dy"), grad0[1]);
auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
ExpandInlineFunctions(flr0_, g.get());
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
auto grad0_zero = ops::Const(s.WithOpName("grad0/zero"), 0);
auto grad0_one = ops::Const(s.WithOpName("grad0/one"), 1);
auto func0 = ops::Identity(s.WithOpName("Func/grad0/input/_0"), x);
auto func1 = ops::Identity(s.WithOpName("Func/grad0/input/_1"), y);
auto func2 = ops::Identity(s.WithOpName("Func/grad0/input/_2"), dz);
auto grad0_z = ops::Add(s.WithOpName("grad0/z"), func0, func1);
auto grad0_r = ops::Rank(s.WithOpName("grad0/r"), grad0_z);
auto grad0_indices = ops::Range(s.WithOpName("grad0/indices"), grad0_zero,
grad0_r, grad0_one);
auto grad0_l = ops::Sum(s.WithOpName("grad0/l"), grad0_z, grad0_indices);
NameAttrList sum;
sum.set_name("Sum");
(*sum.mutable_attr())["T"].set_type(DT_FLOAT);
(*sum.mutable_attr())["Tidx"].set_type(DT_INT32);
(*sum.mutable_attr())["keep_dims"].set_b(false);
auto grad0_func1 = ops::SymbolicGradient(
s.WithOpName("grad0/Func/_1"),
std::initializer_list<Input>{grad0_z, grad0_indices, func2},
{DT_FLOAT, DT_INT32}, sum);
auto grad0_func2 =
ops::ZerosLike(s.WithOpName("grad0/Func/_2"), grad0_zero);
auto grad0_func3 = ops::ZerosLike(s.WithOpName("grad0/Func/_3"), grad0_r);
auto grad0_func4 = ops::ZerosLike(s.WithOpName("grad0/Func/_4"), grad0_one);
NameAttrList add;
add.set_name("Add");
(*add.mutable_attr())["T"].set_type(DT_FLOAT);
auto grad0_func5 = ops::SymbolicGradient(
s.WithOpName("grad0/Func/_5"),
std::initializer_list<Input>{func0, func1, grad0_func1[0]},
{DT_FLOAT, DT_FLOAT}, add);
auto func3 =
ops::Identity(s.WithOpName("Func/grad0/output/_3"), grad0_func5[0]);
auto func4 =
ops::Identity(s.WithOpName("Func/grad0/output/_4"), grad0_func5[1]);
auto dx = ops::Identity(s.WithOpName("dx"), func3);
auto dy = ops::Identity(s.WithOpName("dy"), func4);
auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
TF_EXPECT_GRAPH_EQ(expected, actual);
}
OptimizeGraph(flr0_, &g);
{
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto y = ops::_Arg(s.WithOpName("y"), DT_FLOAT, 1);
auto dz = ops::Const(s.WithOpName("dz"), 1.0f);
auto grad0_zero = ops::Const(s.WithOpName("grad0/zero"), 0);
auto grad0_one = ops::Const(s.WithOpName("grad0/one"), 1);
auto grad0_z = ops::Add(s.WithOpName("grad0/z"), x, y);
auto grad0_r = ops::Rank(s.WithOpName("grad0/r"), grad0_z);
auto grad0_indices = ops::Range(s.WithOpName("grad0/indices"), grad0_zero,
grad0_r, grad0_one);
auto i_shape =
ops::Shape(s.WithOpName("grad0/Func/_1/i_shape"), grad0_indices);
auto stitch_val = ops::Fill(s.WithOpName("grad0/Func/_1/stitch_val1"),
i_shape, grad0_one);
auto x_shape = ops::Shape(s.WithOpName("grad0/Func/_1/x_shape"), grad0_z);
auto y_shape = ops::DynamicStitch(
s.WithOpName("grad0/Func/_1/y_shape"),
std::initializer_list<Input>{grad0_indices, grad0_indices},
std::initializer_list<Input>{x_shape, stitch_val});
auto dy_reshaped =
ops::Reshape(s.WithOpName("grad0/Func/_1/dy_reshaped"), dz, y_shape);
auto tile_scaling =
ops::Div(s.WithOpName("grad0/Func/_1/tile_scaling"), x_shape, y_shape);
auto func1_dx =
ops::Tile(s.WithOpName("grad0/Func/_1/dx"), dy_reshaped, tile_scaling);
auto sx = ops::Shape(s.WithOpName("grad0/Func/_3/sx"), x);
auto sy = ops::Shape(s.WithOpName("grad0/Func/_3/sy"), y);
auto rx = ops::internal::BroadcastGradientArgs(
s.WithOpName("grad0/Func/_3/rx"), sx, sy);
auto sum_gx =
ops::Sum(s.WithOpName("grad0/Func/_3/sum_gx"), func1_dx, rx.r0);
auto sum_gy =
ops::Sum(s.WithOpName("grad0/Func/_3/sum_gy"), func1_dx, rx.r1);
auto dx = ops::Reshape(s.WithOpName("grad0/Func/_3/dx"), sum_gx, sx);
auto dy = ops::Reshape(s.WithOpName("grad0/Func/_3/dy"), sum_gy, sy);
auto dx_retval = ops::_Retval(s.WithOpName("dx_RetVal"), dx, 0);
auto dy_retval = ops::_Retval(s.WithOpName("dy_RetVal"), dy, 1);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
GraphDef actual;
g->ToGraphDef(&actual);
// The optimizer is non-deterministic, so we only check that the number of
// nodes is not greater than expected.
EXPECT_LE(actual.node_size(), expected.node_size());
}
}
TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
Init({test::function::FindDevice()});
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
instantiate_opts.target = "/device:CPU:1";
FunctionLibraryRuntime::Handle handle;
TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, instantiate_opts, &handle));
Tensor y;
FunctionLibraryRuntime::Options opts;
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
opts.source_device = "/device:CPU:1";
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<tstring>(
y,
test::AsTensor<tstring>({"/job:localhost/replica:0/task:0/device:CPU:1"},
TensorShape({})));
opts.remote_execution = true;
opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}, true));
test::ExpectTensorEqual<tstring>(
y,
test::AsTensor<tstring>({"/job:localhost/replica:0/task:0/device:CPU:1"},
TensorShape({})));
opts.rendezvous->Unref();
}
namespace {
bool DoNothing(Graph* g) { return false; }
GraphDef Optimize(const std::function<bool(Graph* g)>& pass,
const FunctionDef& fdef) {
InstantiationResult result;
TF_CHECK_OK(InstantiateFunction(fdef, AttrSlice(), GetOpSig, &result));
std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = false;
TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get()));
pass(g.get());
std::unique_ptr<Graph> g1(new Graph(OpRegistry::Global()));
CopyGraph(*g, g1.get());
g = nullptr;
GraphDef gdef;
g1->ToGraphDef(&gdef);
return gdef;
}
} // end namespace
TEST(OptimizationTest, RemoveDeadNodes) {
auto T = DT_INT32;
auto func = FDH::Define(
// Name
"F",
// Args
{"x: int32"},
// Return values
{"y: int32"},
// Attrs
{},
// Nodes
{// a = Square<T>(x)
{{"a"}, "Square", {"x"}, {{"T", T}}},
// 1
FDH::Const("o", 1),
// A bunch of extra arithmetic that y doesn't depend on
{{"x1"}, "Add", {"o", "o"}, {{"T", T}}},
{{"x2"}, "Mul", {"a", "x1"}, {{"T", T}}},
{{"x3"}, "Mul", {"x1", "x2"}, {{"T", T}}},
// A stateful node.
{{"keep_me"}, "RandomUniform", {"o"}, {{"T", T}, {"dtype", DT_FLOAT}}},
// y = Add<T>(a, o)
{{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
GraphDef expected;
{
Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto keep_me = ops::RandomUniform(s.WithOpName("keep_me"), {o}, DT_FLOAT);
auto x1 = ops::Add(s.WithOpName("x1"), o, o);
auto a = ops::Square(s.WithOpName("a"), x);
auto y = ops::Add(s.WithOpName("y"), a, o);
auto x2 = ops::Mul(s.WithOpName("x2"), a, x1);
auto x3 = ops::Mul(s.WithOpName("x3"), x1, x2);
auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
TF_ASSERT_OK(s.ToGraphDef(&expected));
}
TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
// TODO(zhifengc): Comes up another test case.
TF_EXPECT_GRAPH_EQ(expected, Optimize(::tensorflow::RemoveDeadNodes, func));
}
TEST(OptimizationTest, RemoveIdentityNodes_Ref) {
auto T = DT_FLOAT;
auto func = FDH::Define(
// Name
"F",
// Args
{},
// Return values
{"ret: float"},
// Attrs
{},
// Nodes
{// variable
{{"v"}, "VariableV2", {}, {{"dtype", T}, {"shape", TensorShape({})}}},
// read the variable. Shouldn't be removed.
{{"v_read"}, "Identity", {"v"}, {{"T", T}}},
// returns v + v
{{"ret"}, "Add", {"v_read", "v_read"}, {{"T", T}}}});
GraphDef expected;
{
Scope s = Scope::NewRootScope();
auto v = ops::Variable(s.WithOpName("v"), PartialTensorShape({}), DT_FLOAT);
auto v_read = ops::Identity(s.WithOpName("v_read"), v);
auto ret = ops::Add(s.WithOpName("ret"), v_read, v_read);
auto ret_retval = ops::_Retval(s.WithOpName("ret_RetVal"), ret, 0);
TF_ASSERT_OK(s.ToGraphDef(&expected));
}
TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
TF_EXPECT_GRAPH_EQ(expected,
Optimize(::tensorflow::RemoveIdentityNodes, func));
}
TEST(OptimizationTest, RemoveIdentityNodes) {
auto T = DT_INT32;
auto func = FDH::Define(
// Name
"F",
// Args
{"x: int32"},
// Return values
{"y: int32"},
// Attrs
{},
// Nodes
{// a = Square<T>(x)
{{"a"}, "Square", {"x"}, {{"T", T}}},
// 1
FDH::Const("o", 1),
// A bunch of extra arithmetic that y doesn't depend on
{{"x1"}, "Identity", {"a"}, {{"T", T}}},
{{"x2"}, "Identity", {"x1"}, {{"T", T}}},
{{"x3"}, "Identity", {"x2"}, {{"T", T}}},
// A stateful node.
{{"keep_me"},
"RandomUniform",
{"o"},
{{"T", T}, {"dtype", DT_FLOAT}},
{"x3"}},
// y = Add<T>(a, o)
{{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
{
Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto a = ops::Square(s.WithOpName("a"), x);
auto y = ops::Add(s.WithOpName("y"), a, o);
auto x1 = ops::Identity(s.WithOpName("x1"), a);
auto x2 = ops::Identity(s.WithOpName("x2"), x1);
auto x3 = ops::Identity(s.WithOpName("x3"), x2);
auto keep_me = ops::RandomUniform(
s.WithOpName("keep_me").WithControlDependencies(x3), {o}, DT_FLOAT);
auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
}
{
Scope s = Scope::DisabledShapeInferenceScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_INT32, 0);
auto o = ops::Const(s.WithOpName("o"), 1);
auto a = ops::Square(s.WithOpName("a"), x);
auto y = ops::Add(s.WithOpName("y"), a, o);
auto keep_me = ops::RandomUniform(
s.WithOpName("keep_me").WithControlDependencies(a), {o}, DT_FLOAT);
auto ret = ops::_Retval(s.WithOpName("y_RetVal"), y, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected,
Optimize(::tensorflow::RemoveIdentityNodes, func));
}
}
TEST(OptimizationTest, RemoveListArrayConverter) {
auto func = FDH::Create(
// Name
"Test",
// Args
{"i: float"},
// Return signature
{"o: float"},
// Attrs
{},
// Nodes
{FDH::Const("zero", 0),
{{"s"},
"Split",
{"zero:output:0", "i"},
{{"num_split", 4}, {"T", DT_FLOAT}}},
{{"a"},
"_ArrayToList",
{"s:output"},
{{"N", 4},
{"T", DT_FLOAT},
{"out_types", DataTypeSlice{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}}}},
{{"l"}, "Mul", {"a:output:0", "a:output:1"}, {{"T", DT_FLOAT}}},
{{"r"}, "Mul", {"a:output:2", "a:output:3"}, {{"T", DT_FLOAT}}},
{{"x"},
"_ListToArray",
{"l:z", "r:z"},
{{"N", 2},
{"T", DT_FLOAT},
{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}},
{{"o"}, "AddN", {"x:output"}, {{"N", 2}, {"T", DT_FLOAT}}}},
// Return values
{{"o", "o:sum"}});
{
Scope scope = Scope::DisabledShapeInferenceScope();
auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
auto zero = ops::Const(scope.WithOpName("zero"), 0);
auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
auto a = ops::_ArrayToList(scope.WithOpName("a"), s.output,
{DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT});
auto r = ops::Mul(scope.WithOpName("r"), a[2], a[3]);
auto l = ops::Mul(scope.WithOpName("l"), a[0], a[1]);
auto x = ops::_ListToArray(scope.WithOpName("x"),
std::initializer_list<Input>{l, r}, DT_FLOAT, 2);
auto o = ops::AddN(scope.WithOpName("o"), x.output);
auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
GraphDef expected;
TF_ASSERT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
}
{
Scope scope = Scope::NewRootScope();
auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
auto zero = ops::Const(scope.WithOpName("zero"), 0);
auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
auto func_0 = ops::Identity(scope.WithOpName("Func/a/input/_0"), s[0]);
auto func_1 = ops::Identity(scope.WithOpName("Func/a/input/_1"), s[1]);
auto func_2 = ops::Identity(scope.WithOpName("Func/a/input/_2"), s[2]);
auto func_3 = ops::Identity(scope.WithOpName("Func/a/input/_3"), s[3]);
auto r = ops::Mul(scope.WithOpName("r"), func_2, func_3);
auto l = ops::Mul(scope.WithOpName("l"), func_0, func_1);
auto func_4 = ops::Identity(scope.WithOpName("Func/x/input/_4"), l);
auto func_5 = ops::Identity(scope.WithOpName("Func/x/input/_5"), r);
auto o = ops::AddN(scope.WithOpName("o"),
std::initializer_list<Input>{func_4, func_5});
auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
GraphDef expected;
TF_ASSERT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, Optimize(RemoveListArrayConverter, func));
}
{
Scope scope = Scope::NewRootScope();
auto i = ops::_Arg(scope.WithOpName("i"), DT_FLOAT, 0);
auto zero = ops::Const(scope.WithOpName("zero"), 0);
auto s = ops::Split(scope.WithOpName("s"), zero, i, 4);
auto r = ops::Mul(scope.WithOpName("r"), s[2], s[3]);
auto l = ops::Mul(scope.WithOpName("l"), s[0], s[1]);
auto o =
ops::AddN(scope.WithOpName("o"), std::initializer_list<Input>{l, r});
auto o_ret = ops::_Retval(scope.WithOpName("o_RetVal"), o, 0);
GraphDef expected;
TF_ASSERT_OK(scope.ToGraphDef(&expected));
auto remove_listarray_and_identity = [](Graph* g) {
return RemoveListArrayConverter(g) && RemoveIdentityNodes(g);
};
TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
}
}
TEST(OptimizationTest, RemoveListArrayConverter_WithContolDeps) {
auto func = FDH::Create(
// Name
"Test",
// Args
{"i: float"},
// Return values
{"o: float"},
// Attrs
{},
// Nodes
{FDH::Const("dummy", 0),
{{"x"},
"_ListToArray",
{"i", "i"},
{{"N", 2}, {"T", DT_FLOAT}, {"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
// Control dep
{"dummy"}},
{{"o"},
"AddN",
{"x:output"},
{{"N", 2}, {"T", DT_FLOAT}},
// Control dep
{"x"}}},
{{"o", "o:sum"}});
{
Scope s = Scope::DisabledShapeInferenceScope();
auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
auto dummy = ops::Const(s.WithOpName("dummy"), 0);
auto x = ops::_ListToArray(s.WithOpName("x").WithControlDependencies(dummy),
std::initializer_list<Input>{i, i}, DT_FLOAT, 2);
auto o =
ops::AddN(s.WithOpName("o").WithControlDependencies({x.output[0].op()}),
x.output);
auto o_ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
GraphDef expected;
TF_ASSERT_OK(s.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, Optimize(DoNothing, func));
}
GraphDef expected;
{
Scope s = Scope::NewRootScope();
auto i = ops::_Arg(s.WithOpName("i"), DT_FLOAT, 0);
auto dummy = ops::Const(s.WithOpName("dummy"), 0);
auto func_2 = ops::NoOp(s.WithOpName("Func/x/input_control_node/_2")
.WithControlDependencies(dummy));
auto func_0 = ops::Identity(
s.WithOpName("Func/x/input/_0").WithControlDependencies({func_2}), i);
auto func_1 = ops::Identity(
s.WithOpName("Func/x/input/_1").WithControlDependencies({func_2}), i);
auto func_3 = ops::NoOp(
s.WithOpName("Func/x/output_control_node/_3")
.WithControlDependencies({func_0.output.op(), func_1.output.op()}));
auto o = ops::AddN(s.WithOpName("o").WithControlDependencies({func_3}),
std::initializer_list<Input>{func_0, func_1});
auto o_ret = ops::_Retval(s.WithOpName("o_RetVal"), o, 0);
TF_ASSERT_OK(s.ToGraphDef(&expected));
}
TF_EXPECT_GRAPH_EQ(expected, Optimize(RemoveListArrayConverter, func));
auto remove_listarray_and_identity = [](Graph* g) {
return RemoveListArrayConverter(g) && RemoveIdentityNodes(g);
};
// NOTE: We are not removing Identity nodes with any control
// dependencies yet.
TF_EXPECT_GRAPH_EQ(expected, Optimize(remove_listarray_and_identity, func));
}
} // namespace
} // namespace tensorflow