blob: 8f31cda9310e174ecbf800e79a734fa1d256f31f [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
#include <atomic>
#include <utility>
#include "tensorflow/cc/ops/array_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/equal_graph_def.h"
namespace tensorflow {
namespace {
class FunctionLibraryRuntimeTest : public ::testing::Test {
protected:
void Init(const std::vector<FunctionDef>& flib,
thread::ThreadPool* default_thread_pool) {
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 3});
std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts;
device_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, default_thread_pool));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime::Options opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
bool add_runner = true) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
[&call_count](std::function<void()> fn) {
++call_count;
test::function::FunctionTestSchedClosure(fn);
};
if (add_runner) {
opts.runner = &runner;
} else {
opts.runner = nullptr;
}
Notification done;
std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
status = s;
done.Notify();
});
done.WaitForNotification();
if (!status.ok()) {
return status;
}
CHECK_EQ(rets.size(), out.size());
for (size_t i = 0; i < rets.size(); ++i) {
*rets[i] = out[i];
}
if (add_runner) {
EXPECT_GE(call_count, 1); // Test runner is used.
}
return Status::OK();
}
Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
FunctionLibraryRuntime::Handle* handle) {
return flr->Instantiate(name, attrs, handle);
}
Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle) {
return flr->Instantiate(name, attrs, options, handle);
}
Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const std::vector<Tensor>& args,
std::vector<Tensor*> rets, bool add_runner = true) {
return InstantiateAndRun(flr, name, attrs,
FunctionLibraryRuntime::InstantiateOptions(), args,
std::move(rets), add_runner);
}
Status InstantiateAndRun(
FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
const FunctionLibraryRuntime::InstantiateOptions& options,
const std::vector<Tensor>& args, std::vector<Tensor*> rets,
bool add_runner = true) {
FunctionLibraryRuntime::Handle handle;
Status status = flr->Instantiate(name, attrs, options, &handle);
if (!status.ok()) {
return status;
}
FunctionLibraryRuntime::Options opts;
status = Run(flr, handle, opts, args, rets, add_runner);
if (!status.ok()) return status;
// Release the handle and try running again. It should not succeed.
status = flr->ReleaseHandle(handle);
if (!status.ok()) return status;
Status status2 = Run(flr, handle, opts, args, std::move(rets));
EXPECT_TRUE(errors::IsNotFound(status2));
EXPECT_TRUE(absl::StrContains(status2.error_message(), "Handle"));
EXPECT_TRUE(absl::StrContains(status2.error_message(), "not found"));
return status;
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime::Options opts, CallFrameInterface* frame,
bool add_runner = true) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
[&call_count](std::function<void()> fn) {
++call_count;
test::function::FunctionTestSchedClosure(fn);
};
if (add_runner) {
opts.runner = &runner;
} else {
opts.runner = nullptr;
}
Notification done;
std::vector<Tensor> out;
Status status;
flr->Run(opts, handle, frame, [&status, &done](const Status& s) {
status = s;
done.Notify();
});
done.WaitForNotification();
if (!status.ok()) {
return status;
}
if (add_runner) {
EXPECT_GE(call_count, 1); // Test runner is used.
}
return Status::OK();
}
FunctionLibraryRuntime* flr0_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
};
TEST_F(FunctionLibraryRuntimeTest, DefaultThreadpool) {
using test::function::blocking_op_state;
using test::function::BlockingOpState;
thread::ThreadPool* tp = new thread::ThreadPool(Env::Default(), "FLRTest", 1);
Init({test::function::BlockingOpFn(), test::function::XTimesTwo()}, tp);
auto x = test::AsScalar<float>(1.3);
Tensor y;
blocking_op_state = new BlockingOpState();
thread::ThreadPool* tp1 = new thread::ThreadPool(Env::Default(), "tp1", 5);
bool finished_running = false;
tp1->Schedule([&x, &y, &finished_running, this]() {
TF_CHECK_OK(InstantiateAndRun(flr0_, "BlockingOpFn", {}, {x}, {&y},
false /* add_runner */));
finished_running = true;
});
// InstantiateAndRun shouldn't finish because BlockingOpFn should be blocked.
EXPECT_FALSE(finished_running);
FunctionLibraryRuntime::Handle h;
TF_CHECK_OK(Instantiate(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, &h));
auto x1 = test::AsTensor<float>({1, 2, 3, 4});
std::atomic<int32> num_done(0);
FunctionLibraryRuntime::Options opts;
for (int i = 0; i < 4; ++i) {
tp1->Schedule([&h, &x1, &opts, &num_done, this]() {
Tensor y1;
TF_CHECK_OK(Run(flr0_, h, opts, {x1}, {&y1}, false /* add_runner */));
num_done.fetch_add(1);
});
}
// All the 4 Run() calls should be blocked because the runner is occupied.
EXPECT_EQ(0, num_done.load());
blocking_op_state->AwaitState(1);
blocking_op_state->MoveToState(1, 2);
// Now the runner should be unblocked and all the other Run() calls should
// proceed.
blocking_op_state->AwaitState(3);
blocking_op_state->MoveToState(3, 0);
delete tp1;
EXPECT_TRUE(finished_running);
EXPECT_EQ(4, num_done.load());
delete blocking_op_state;
blocking_op_state = nullptr;
delete tp;
}
} // namespace
} // namespace tensorflow