blob: bf3b14a46fb7643258274463ab08eac70db76f37 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/run_handler.h"
#include <algorithm>
#include <cmath>
#include <list>
#include <memory>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/run_handler_util.h"
#include "tensorflow/core/lib/core/threadpool_interface.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/numa.h"
#include "tensorflow/core/platform/setround.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
// LINT.IfChange
static constexpr int32 kMaxConcurrentHandlers = 128;
// LINT.ThenChange(//tensorflow/core/framework/run_handler_test.cc)
typedef typename internal::RunHandlerEnvironment::Task Task;
typedef Eigen::RunQueue<Task, 1024> Queue;
} // namespace
namespace internal {
RunHandlerEnvironment::RunHandlerEnvironment(
Env* env, const ThreadOptions& thread_options, const string& name)
: env_(env), thread_options_(thread_options), name_(name) {}
RunHandlerEnvironment::EnvThread* RunHandlerEnvironment::CreateThread(
std::function<void()> f) {
return env_->StartThread(thread_options_, name_, [=]() {
// Set the processor flag to flush denormals to zero.
port::ScopedFlushDenormal flush;
// Set the processor rounding mode to ROUND TO NEAREST.
port::ScopedSetRound round(FE_TONEAREST);
if (thread_options_.numa_node != port::kNUMANoAffinity) {
port::NUMASetThreadNodeAffinity(thread_options_.numa_node);
}
f();
});
}
RunHandlerEnvironment::Task RunHandlerEnvironment::CreateTask(
std::function<void()> f) {
uint64 id = 0;
if (tracing::EventCollector::IsEnabled()) {
id = tracing::GetUniqueArg();
tracing::RecordEvent(tracing::EventCategory::kScheduleClosure, id);
}
return Task{
std::unique_ptr<TaskImpl>(new TaskImpl{
std::move(f),
Context(ContextKind::kThread),
id,
}),
};
}
void RunHandlerEnvironment::ExecuteTask(const Task& t) {
WithContext wc(t.f->context);
tracing::ScopedRegion region(tracing::EventCategory::kRunClosure,
t.f->trace_id);
t.f->f();
}
void WaitOnWaiter(Waiter* waiter, Waiter* queue_head, mutex* mutex,
int max_sleep_micros) {
{
mutex_lock l(*mutex);
CHECK_EQ(waiter->next, waiter); // Crash OK.
CHECK_EQ(waiter->prev, waiter); // Crash OK.
// Add waiter to the LIFO queue
waiter->prev = queue_head;
waiter->next = queue_head->next;
waiter->next->prev = waiter;
waiter->prev->next = waiter;
}
{
mutex_lock l(waiter->mu);
// Wait on the condition variable
waiter->cv.wait_for(l, std::chrono::microseconds(max_sleep_micros));
}
mutex_lock l(*mutex);
// Remove waiter from the LIFO queue. Note even when a waiter wakes up due
// to a notification we cannot conclude the waiter is not in the queue.
// This is due to the fact that a thread preempted right before notifying
// may resume after a waiter got re-added.
if (waiter->next != waiter) {
CHECK(waiter->prev != waiter); // Crash OK.
waiter->next->prev = waiter->prev;
waiter->prev->next = waiter->next;
waiter->next = waiter;
waiter->prev = waiter;
} else {
CHECK_EQ(waiter->prev, waiter); // Crash OK.
}
}
ThreadWorkSource::ThreadWorkSource()
: non_blocking_work_sharding_factor_(
static_cast<int32>(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_NUM_OF_NON_BLOCKING_QUEUES", 1))),
non_blocking_work_queues_(non_blocking_work_sharding_factor_),
blocking_inflight_(0),
non_blocking_inflight_(0),
traceme_id_(0),
version_(0),
sub_thread_pool_waiter_(nullptr) {
queue_waiters_.next = &queue_waiters_;
queue_waiters_.prev = &queue_waiters_;
for (int i = 0; i < NonBlockingWorkShardingFactor(); ++i) {
non_blocking_work_queues_.emplace_back(new NonBlockingQueue());
}
}
ThreadWorkSource::~ThreadWorkSource() {
for (int i = 0; i < non_blocking_work_queues_.size(); ++i) {
delete non_blocking_work_queues_[i];
}
}
Task ThreadWorkSource::EnqueueTask(Task t, bool is_blocking) {
mutex* mu = nullptr;
Queue* task_queue = nullptr;
thread_local int64 closure_counter = 0;
if (!is_blocking) {
int queue_index = ++closure_counter % non_blocking_work_sharding_factor_;
task_queue = &(non_blocking_work_queues_[queue_index]->queue);
mu = &non_blocking_work_queues_[queue_index]->queue_op_mu;
} else {
task_queue = &blocking_work_queue_;
mu = &blocking_queue_op_mu_;
}
{
mutex_lock l(*mu);
// For a given queue, only one thread can call PushFront.
t = task_queue->PushFront(std::move(t));
}
Waiter* w = nullptr;
static const bool use_sub_thread_pool =
ParamFromEnvBoolWithDefault("TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false);
Waiter* waiter_queue;
mutex* waiter_queue_mu;
if (use_sub_thread_pool) {
// When we use multiple sub thread pools, free threads wait on sub
// thread pool waiting queues. Wake up threads from sub thread waiting
// queues.
// The waiting queues are defined at RunHandlerPool.
// Get the waiter_queue and corresponding mutex. Note, the thread work
// source may change afterwards if a new request comes or an old request
// finishes.
tf_shared_lock lock(run_handler_waiter_mu_);
waiter_queue = sub_thread_pool_waiter_;
waiter_queue_mu = sub_thread_pool_waiter_mu_;
} else {
waiter_queue = &queue_waiters_;
waiter_queue_mu = &waiters_mu_;
}
{
mutex_lock l(*waiter_queue_mu);
if (waiter_queue->next != waiter_queue) {
// Remove waiter from the LIFO queue
w = waiter_queue->next;
CHECK(w->prev != w); // Crash OK.
CHECK(w->next != w); // Crash OK.
w->next->prev = w->prev;
w->prev->next = w->next;
// Use `w->next == &w` to indicate that the waiter has been removed
// from the queue.
w->next = w;
w->prev = w;
}
}
if (w != nullptr) {
// We call notify_one() without any locks, so we can miss notifications.
// The wake up logic is best effort and a thread will wake in short
// period of time in case a notification is missed.
w->cv.notify_one();
}
VLOG(3) << "Added " << (is_blocking ? "inter" : "intra") << " work from "
<< traceme_id_.load(std::memory_order_relaxed);
return t;
}
Task ThreadWorkSource::PopBlockingTask() {
return blocking_work_queue_.PopBack();
}
Task ThreadWorkSource::PopNonBlockingTask(int start_index,
bool search_from_all_queue) {
Task t;
unsigned sharding_factor = NonBlockingWorkShardingFactor();
for (unsigned j = 0; j < sharding_factor; ++j) {
t = non_blocking_work_queues_[(start_index + j) % sharding_factor]
->queue.PopBack();
if (t.f) {
return t;
}
if (!search_from_all_queue) {
break;
}
}
return t;
}
void ThreadWorkSource::WaitForWork(int max_sleep_micros) {
thread_local Waiter waiter;
WaitOnWaiter(&waiter, &queue_waiters_, &waiters_mu_, max_sleep_micros);
}
int ThreadWorkSource::TaskQueueSize(bool is_blocking) {
if (is_blocking) {
return blocking_work_queue_.Size();
} else {
unsigned total_size = 0;
for (int i = 0; i < non_blocking_work_sharding_factor_; ++i) {
total_size += non_blocking_work_queues_[i]->queue.Size();
}
return total_size;
}
}
int64 ThreadWorkSource::GetTracemeId() {
return traceme_id_.load(std::memory_order_relaxed);
}
void ThreadWorkSource::SetTracemeId(int64 value) { traceme_id_ = value; }
void ThreadWorkSource::SetWaiter(uint64 version, Waiter* waiter, mutex* mutex) {
{
tf_shared_lock lock(run_handler_waiter_mu_);
// Most of the request won't change sub pool for recomputation.
// Optimization for avoiding holding exclusive lock to reduce contention.
if (sub_thread_pool_waiter_ == waiter) {
return;
}
// If the current version is a newer version, no need to update.
if (version_ > version) {
return;
}
}
mutex_lock l(run_handler_waiter_mu_);
sub_thread_pool_waiter_ = waiter;
sub_thread_pool_waiter_mu_ = mutex;
version_ = version;
}
int64 ThreadWorkSource::GetInflightTaskCount(bool is_blocking) {
std::atomic<int64>* counter =
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
return counter->load(std::memory_order_relaxed);
}
void ThreadWorkSource::IncrementInflightTaskCount(bool is_blocking) {
std::atomic<int64>* counter =
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
counter->fetch_add(1, std::memory_order_relaxed);
}
void ThreadWorkSource::DecrementInflightTaskCount(bool is_blocking) {
std::atomic<int64>* counter =
is_blocking ? &blocking_inflight_ : &non_blocking_inflight_;
counter->fetch_sub(1, std::memory_order_relaxed);
}
unsigned ThreadWorkSource::NonBlockingWorkShardingFactor() {
return non_blocking_work_sharding_factor_;
}
std::string ThreadWorkSource::ToString() {
return strings::StrCat("traceme_id = ", GetTracemeId(),
", inter queue size = ", TaskQueueSize(true),
", inter inflight = ", GetInflightTaskCount(true),
", intra queue size = ", TaskQueueSize(false),
", intra inflight = ", GetInflightTaskCount(false));
}
RunHandlerThreadPool::RunHandlerThreadPool(
int num_blocking_threads, int num_non_blocking_threads, Env* env,
const ThreadOptions& thread_options, const string& name,
Eigen::MaxSizeVector<mutex>* waiters_mu,
Eigen::MaxSizeVector<Waiter>* queue_waiters)
: num_threads_(num_blocking_threads + num_non_blocking_threads),
num_blocking_threads_(num_blocking_threads),
num_non_blocking_threads_(num_non_blocking_threads),
thread_data_(num_threads_),
env_(env, thread_options, name),
name_(name),
waiters_mu_(waiters_mu),
queue_waiters_(queue_waiters),
use_sub_thread_pool_(ParamFromEnvBoolWithDefault(
"TF_RUN_HANDLER_USE_SUB_THREAD_POOL", false)),
num_threads_in_sub_thread_pool_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_NUM_THREADS_IN_SUB_THREAD_POOL",
std::vector<int>({num_blocking_threads / 2,
num_blocking_threads - num_blocking_threads / 2}))),
sub_thread_pool_start_request_percentage_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_SUB_THREAD_POOL_START_REQUEST_PERCENTAGE",
std::vector<double>({0, 0.4}))),
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
std::vector<double>({0.4, 1}))) {
thread_data_.resize(num_threads_);
VLOG(1) << "Creating RunHandlerThreadPool " << name << " with "
<< num_blocking_threads_ << " blocking threads and "
<< num_non_blocking_threads_ << " non-blocking threads.";
}
RunHandlerThreadPool::~RunHandlerThreadPool() {
VLOG(1) << "Exiting RunHandlerThreadPool " << name_;
cancelled_ = true;
for (size_t i = 0; i < thread_data_.size(); ++i) {
{
mutex_lock l(thread_data_[i].mu);
thread_data_[i].sources_not_empty.notify_all();
}
thread_data_[i].thread.reset();
}
}
void RunHandlerThreadPool::Start() {
cancelled_ = false;
int num_blocking_threads = num_blocking_threads_;
for (int i = 0; i < num_threads_; i++) {
int sub_thread_pool_id = num_threads_in_sub_thread_pool_.size() - 1;
for (int j = 0; j < num_threads_in_sub_thread_pool_.size(); ++j) {
if (i < num_threads_in_sub_thread_pool_[j]) {
sub_thread_pool_id = j;
break;
}
}
thread_data_[i].sub_thread_pool_id = sub_thread_pool_id;
thread_data_[i].thread.reset(
env_.CreateThread([this, i, num_blocking_threads]() {
WorkerLoop(i, i < num_blocking_threads);
}));
}
}
void RunHandlerThreadPool::StartOneThreadForTesting() {
cancelled_ = false;
thread_data_[0].sub_thread_pool_id = 0;
thread_data_[0].thread.reset(
env_.CreateThread([this]() { WorkerLoop(0, true); }));
}
void RunHandlerThreadPool::AddWorkToQueue(ThreadWorkSource* tws,
bool is_blocking,
std::function<void()> fn) {
Task t = env_.CreateTask(std::move(fn));
t = tws->EnqueueTask(std::move(t), is_blocking);
if (t.f) {
VLOG(3) << "Running " << (is_blocking ? "inter" : "intra") << " work for "
<< tws->GetTracemeId();
env_.ExecuteTask(t);
}
}
// TODO(donglin) Change the task steal order to be round-robin such that if
// an attempt to steal task from request i failed, then attempt to steal task
// from the next request in terms of the arrival time. This approach may
// provide better performance due to less lock retention. The drawback is that
// the profiler will be a bit harder to read.
void RunHandlerThreadPool::SetThreadWorkSources(
int tid, int start_request_idx, uint64 version,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources) {
mutex_lock l(thread_data_[tid].mu);
if (version > thread_data_[tid].new_version) {
thread_data_[tid].new_version = version;
} else {
// A newer version is already updated. No need to update.
return;
}
thread_data_[tid].new_thread_work_sources->resize(0);
if (use_sub_thread_pool_) {
for (int i = 0; i < thread_work_sources.size(); ++i) {
thread_data_[tid].new_thread_work_sources->emplace_back(
thread_work_sources[i]);
}
} else {
thread_data_[tid].new_thread_work_sources->emplace_back(
thread_work_sources[start_request_idx]);
// The number of shards for the queue. Threads in each shard will
// prioritize different thread_work_sources. Increase the number of shards
// could decrease the contention in the queue. For example, when
// num_shards == 1: thread_work_sources are ordered as start_request_idx,
// 0, 1, 2, 3, 4 ... for all threads. When num_shards == 2:
// thread_work_sources are order as start_request_idx, 0, 2, 4 ... 1, 3,
// 5... for half of the threads and start_request_idx, 1, 3, 5 ... 0, 2,
// 4... for the other half of the threads.
static const int num_shards =
ParamFromEnvWithDefault("TF_RUN_HANDLER_QUEUE_SHARDS", 1);
int token = tid % num_shards;
for (int i = 0; i < num_shards; ++i) {
for (int j = token; j < thread_work_sources.size(); j += num_shards) {
if (j != start_request_idx) {
thread_data_[tid].new_thread_work_sources->emplace_back(
thread_work_sources[j]);
}
}
token = (token + 1) % num_shards;
}
thread_data_[tid].sources_not_empty.notify_all();
}
}
RunHandlerThreadPool::PerThread* RunHandlerThreadPool::GetPerThread() {
thread_local RunHandlerThreadPool::PerThread per_thread_;
RunHandlerThreadPool::PerThread* pt = &per_thread_;
return pt;
}
int RunHandlerThreadPool::CurrentThreadId() const {
const PerThread* pt = const_cast<RunHandlerThreadPool*>(this)->GetPerThread();
if (pt->pool == this) {
return pt->thread_id;
} else {
return -1;
}
}
int RunHandlerThreadPool::NumThreads() const { return num_threads_; }
int RunHandlerThreadPool::NumBlockingThreads() const {
return num_blocking_threads_;
}
int RunHandlerThreadPool::NumNonBlockingThreads() const {
return num_non_blocking_threads_;
}
RunHandlerThreadPool::ThreadData::ThreadData()
: new_version(0),
current_index(0),
new_thread_work_sources(
new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
kMaxConcurrentHandlers)))),
current_version(0),
current_thread_work_sources(
new Eigen::MaxSizeVector<ThreadWorkSource*>(static_cast<int32>(
ParamFromEnvWithDefault("TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
kMaxConcurrentHandlers)))) {}
Task RunHandlerThreadPool::FindTask(
int searching_range_start, int searching_range_end, int thread_id,
int sub_thread_pool_id, int max_blocking_inflight,
bool may_steal_blocking_work,
const Eigen::MaxSizeVector<ThreadWorkSource*>& thread_work_sources,
bool* task_from_blocking_queue, ThreadWorkSource** tws) {
Task t;
int current_index = thread_data_[thread_id].current_index;
*task_from_blocking_queue = false;
for (int i = 0; i < searching_range_end - searching_range_start; ++i) {
if (current_index >= searching_range_end ||
current_index < searching_range_start) {
current_index = searching_range_start;
}
*tws = thread_work_sources[current_index];
++current_index;
// For blocking thread, search for blocking tasks first.
if (may_steal_blocking_work &&
(*tws)->GetInflightTaskCount(true) < max_blocking_inflight) {
t = (*tws)->PopBlockingTask();
if (t.f) {
*task_from_blocking_queue = true;
break;
}
}
// Search for non-blocking tasks.
t = (*tws)->PopNonBlockingTask(thread_id, true);
if (t.f) {
break;
}
}
thread_data_[thread_id].current_index = current_index;
return t;
}
// Main worker thread loop.
void RunHandlerThreadPool::WorkerLoop(int thread_id,
bool may_steal_blocking_work) {
PerThread* pt = GetPerThread();
pt->pool = this;
pt->thread_id = thread_id;
static constexpr int32 kMaxBlockingInflight = 10;
while (!cancelled_) {
Task t;
ThreadWorkSource* tws = nullptr;
bool task_from_blocking_queue = true;
int sub_thread_pool_id;
// Get the current thread work sources.
{
mutex_lock l(thread_data_[thread_id].mu);
if (thread_data_[thread_id].current_version <
thread_data_[thread_id].new_version) {
thread_data_[thread_id].current_version =
thread_data_[thread_id].new_version;
thread_data_[thread_id].current_thread_work_sources.swap(
thread_data_[thread_id].new_thread_work_sources);
}
}
Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
thread_data_[thread_id].current_thread_work_sources.get();
if (use_sub_thread_pool_) {
sub_thread_pool_id = thread_data_[thread_id].sub_thread_pool_id;
int active_requests = thread_work_sources->size();
if (may_steal_blocking_work) {
// Each thread will first look for tasks from requests that belongs to
// its sub thread pool.
int search_range_start =
active_requests *
sub_thread_pool_start_request_percentage_[sub_thread_pool_id];
int search_range_end =
active_requests *
sub_thread_pool_end_request_percentage_[sub_thread_pool_id];
search_range_end =
std::min(active_requests,
std::max(search_range_end, search_range_start + 1));
t = FindTask(search_range_start, search_range_end, thread_id,
sub_thread_pool_id, kMaxBlockingInflight,
/*may_steal_blocking_work=*/true, *thread_work_sources,
&task_from_blocking_queue, &tws);
if (!t.f) {
// Search from all requests if the thread cannot find tasks from
// requests that belong to its own sub thread pool.
t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
kMaxBlockingInflight,
/*may_steal_blocking_work=*/true, *thread_work_sources,
&task_from_blocking_queue, &tws);
}
} else {
// For non-blocking threads, it will always search from all pending
// requests.
t = FindTask(0, active_requests, thread_id, sub_thread_pool_id,
kMaxBlockingInflight,
/*may_steal_blocking_work=*/false, *thread_work_sources,
&task_from_blocking_queue, &tws);
}
} else {
// TODO(chaox): Refactor the following code to share the logic with
// FindTask.
for (int i = 0; i < thread_work_sources->size(); ++i) {
tws = (*thread_work_sources)[i];
// We want a smallish numbers of inter threads since
// otherwise there will be contention in PropagateOutputs.
// This is best effort policy.
if (may_steal_blocking_work &&
tws->GetInflightTaskCount(true) < kMaxBlockingInflight) {
t = tws->PopBlockingTask();
if (t.f) {
break;
}
}
if (i == 0) {
// Always look for any work from the "primary" work source.
// This way when we wake up a thread for a new closure we are
// guaranteed it can be worked on.
t = tws->PopNonBlockingTask(thread_id, true);
if (t.f) {
task_from_blocking_queue = false;
break;
}
if (t.f) {
break;
}
} else {
t = tws->PopNonBlockingTask(thread_id, false);
if (t.f) {
task_from_blocking_queue = false;
break;
}
}
}
}
if (t.f) {
profiler::TraceMe activity(
[=] {
return strings::StrCat(task_from_blocking_queue ? "inter" : "intra",
" #id = ", tws->GetTracemeId(), " ",
thread_id, "#");
},
profiler::TraceMeLevel::kInfo);
VLOG(2) << "Running " << (task_from_blocking_queue ? "inter" : "intra")
<< " work from " << tws->GetTracemeId();
tws->IncrementInflightTaskCount(task_from_blocking_queue);
env_.ExecuteTask(t);
tws->DecrementInflightTaskCount(task_from_blocking_queue);
} else {
profiler::TraceMe activity(
[=] {
return strings::StrCat("Sleeping#thread_id=", thread_id, "#");
},
profiler::TraceMeLevel::kInfo);
if (VLOG_IS_ON(4)) {
for (int i = 0; i < thread_work_sources->size(); ++i) {
VLOG(4) << "source id " << i << " "
<< (*thread_work_sources)[i]->ToString();
}
}
if (use_sub_thread_pool_) {
WaitForWorkInSubThreadPool(may_steal_blocking_work, sub_thread_pool_id);
} else {
WaitForWork(may_steal_blocking_work, thread_id, kMaxBlockingInflight);
}
}
}
}
void RunHandlerThreadPool::WaitForWorkInSubThreadPool(bool is_blocking,
int sub_thread_pool_id) {
const int kMaxSleepMicros = 250;
// The non-blocking thread will just sleep.
if (!is_blocking) {
Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
return;
}
thread_local Waiter waiter;
WaitOnWaiter(&waiter, &(*queue_waiters_)[sub_thread_pool_id],
&(*waiters_mu_)[sub_thread_pool_id], kMaxSleepMicros);
}
void RunHandlerThreadPool::WaitForWork(bool is_blocking, int thread_id,
int32 max_blocking_inflight) {
const int kMaxSleepMicros = 250;
// The non-blocking thread will just sleep.
if (!is_blocking) {
Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
return;
}
ThreadWorkSource* tws = nullptr;
{
mutex_lock l(thread_data_[thread_id].mu);
if (thread_data_[thread_id].new_version >
thread_data_[thread_id].current_version) {
thread_data_[thread_id].current_thread_work_sources.swap(
thread_data_[thread_id].new_thread_work_sources);
thread_data_[thread_id].current_version =
thread_data_[thread_id].new_version;
}
Eigen::MaxSizeVector<ThreadWorkSource*>* thread_work_sources =
thread_data_[thread_id].current_thread_work_sources.get();
while (!cancelled_ && thread_work_sources->empty()) {
// Wait until there is new request
thread_data_[thread_id].sources_not_empty.wait(l);
if (thread_data_[thread_id].new_version >
thread_data_[thread_id].current_version) {
thread_data_[thread_id].current_thread_work_sources.swap(
thread_data_[thread_id].new_thread_work_sources);
thread_data_[thread_id].current_version =
thread_data_[thread_id].new_version;
thread_work_sources =
thread_data_[thread_id].current_thread_work_sources.get();
}
}
if (cancelled_) {
return;
}
tws = (*thread_work_sources)[0];
}
if (tws->GetInflightTaskCount(true) >= max_blocking_inflight) {
// Sleep to reduce contention in PropagateOutputs
Env::Default()->SleepForMicroseconds(kMaxSleepMicros);
}
tws->WaitForWork(kMaxSleepMicros);
}
} // namespace internal
// Contains the concrete implementation of the RunHandler.
// Externally visible RunHandler class simply forwards the work to this one.
class RunHandler::Impl {
public:
explicit Impl(RunHandlerPool::Impl* pool_impl);
~Impl() {}
thread::ThreadPoolInterface* thread_pool_interface() {
return thread_pool_interface_.get();
}
// Stores now time (in microseconds) since unix epoch when the handler is
// requested via RunHandlerPool::Get().
uint64 start_time_us() const { return start_time_us_; }
int64 step_id() const { return step_id_; }
void ScheduleInterOpClosure(std::function<void()> fn);
void ScheduleIntraOpClosure(std::function<void()> fn);
void Reset(int64 step_id,
const RunOptions::Experimental::RunHandlerPoolOptions& options);
RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
internal::ThreadWorkSource* tws() { return &tws_; }
int64 priority() { return options_.priority(); }
private:
class ThreadPoolInterfaceWrapper : public thread::ThreadPoolInterface {
public:
explicit ThreadPoolInterfaceWrapper(Impl* run_handler_impl)
: run_handler_impl_(run_handler_impl) {}
~ThreadPoolInterfaceWrapper() override {}
void Schedule(std::function<void()> fn) override;
int NumThreads() const override;
int CurrentThreadId() const override;
private:
RunHandler::Impl* run_handler_impl_ = nullptr;
};
RunHandlerPool::Impl* pool_impl_; // NOT OWNED.
uint64 start_time_us_;
int64 step_id_;
std::unique_ptr<thread::ThreadPoolInterface> thread_pool_interface_;
internal::ThreadWorkSource tws_;
RunOptions::Experimental::RunHandlerPoolOptions options_;
};
// Contains shared state across all run handlers present in the pool. Also
// responsible for pool management decisions.
// This class is thread safe.
class RunHandlerPool::Impl {
public:
explicit Impl(int num_inter_op_threads, int num_intra_op_threads)
: max_handlers_(static_cast<int32>(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", kMaxConcurrentHandlers))),
waiters_mu_(
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
queue_waiters_(
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2)),
run_handler_thread_pool_(new internal::RunHandlerThreadPool(
num_inter_op_threads, num_intra_op_threads, Env::Default(),
ThreadOptions(), "tf_run_handler_pool", &waiters_mu_,
&queue_waiters_)),
iterations_(0),
version_(0),
sub_thread_pool_end_request_percentage_(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_SUB_THREAD_POOL_END_REQUEST_PERCENTAGE",
std::vector<double>({1}))) {
VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
free_handlers_.reserve(max_handlers_);
handlers_.reserve(max_handlers_);
for (int i = 0; i < max_handlers_; ++i) {
handlers_.emplace_back(new RunHandler::Impl(this));
free_handlers_.push_back(handlers_.back().get());
}
queue_waiters_.resize(
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
waiters_mu_.resize(
ParamFromEnvWithDefault("TF_RUN_HANDLER_NUM_SUB_THREAD_POOL", 2));
for (auto& queue_waiter : queue_waiters_) {
queue_waiter.next = &queue_waiter;
queue_waiter.prev = &queue_waiter;
}
run_handler_thread_pool_->Start();
}
~Impl() {
// Sanity check that all handlers have been returned back to the pool before
// destruction.
DCHECK_EQ(handlers_.size(), max_handlers_);
DCHECK_EQ(free_handlers_.size(), handlers_.size());
DCHECK_EQ(sorted_active_handlers_.size(), 0);
// Stop the threads in run_handler_thread_pool_ before freeing other
// pointers. Otherwise a thread may try to access a pointer after the
// pointer has been freed.
run_handler_thread_pool_.reset();
}
internal::RunHandlerThreadPool* run_handler_thread_pool() {
return run_handler_thread_pool_.get();
}
bool has_free_handler() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return !free_handlers_.empty();
}
std::unique_ptr<RunHandler> Get(
int64 step_id, int64 timeout_in_ms,
const RunOptions::Experimental::RunHandlerPoolOptions& options)
TF_LOCKS_EXCLUDED(mu_) {
thread_local std::unique_ptr<
Eigen::MaxSizeVector<internal::ThreadWorkSource*>>
thread_work_sources =
std::unique_ptr<Eigen::MaxSizeVector<internal::ThreadWorkSource*>>(
new Eigen::MaxSizeVector<internal::ThreadWorkSource*>(
static_cast<int32>(ParamFromEnvWithDefault(
"TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS",
kMaxConcurrentHandlers))));
uint64 version;
int num_active_requests;
RunHandler::Impl* handler_impl;
{
mutex_lock l(mu_);
if (!has_free_handler()) {
profiler::TraceMe activity(
[&] {
return strings::StrCat("WaitingForHandler#step_id=", step_id,
"#");
},
profiler::TraceMeLevel::kInfo);
if (timeout_in_ms == 0) {
mu_.Await(Condition(this, &Impl::has_free_handler));
} else if (!mu_.AwaitWithDeadline(
Condition(this, &Impl::has_free_handler),
EnvTime::NowNanos() + timeout_in_ms * 1000 * 1000)) {
return nullptr;
}
}
// Remove the last entry from free_handlers_ and add to the end of
// sorted_active_handlers_.
handler_impl = free_handlers_.back();
handler_impl->Reset(step_id, options);
free_handlers_.pop_back();
num_active_requests = sorted_active_handlers_.size() + 1;
thread_work_sources->resize(num_active_requests);
int priority = options.priority();
auto it = sorted_active_handlers_.cbegin();
bool new_handler_inserted = false;
for (int i = 0; i < num_active_requests; ++i) {
if (!new_handler_inserted && (it == sorted_active_handlers_.cend() ||
priority > (*it)->priority())) {
sorted_active_handlers_.insert(it, handler_impl);
new_handler_inserted = true;
// Point to the newly added handler.
--it;
}
(*thread_work_sources)[i] = (*it)->tws();
++it;
}
version = ++version_;
}
RecomputePoolStats(num_active_requests, version, *thread_work_sources);
return WrapUnique<RunHandler>(new RunHandler(handler_impl));
}
void ReleaseHandler(RunHandler::Impl* handler) TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
DCHECK_GT(sorted_active_handlers_.size(), 0);
CHECK_EQ(handler->tws()->TaskQueueSize(true), 0); // Crash OK.
CHECK_EQ(handler->tws()->TaskQueueSize(false), 0); // Crash OK.
uint64 now = tensorflow::EnvTime::NowMicros();
double elapsed = (now - handler->start_time_us()) / 1000.0;
time_hist_.Add(elapsed);
// Erase from and update sorted_active_handlers_. Add it to the end of
// free_handlers_.
auto iter = std::find(sorted_active_handlers_.begin(),
sorted_active_handlers_.end(), handler);
DCHECK(iter != sorted_active_handlers_.end())
<< "Unexpected handler: " << handler
<< " is being requested for release";
// Remove this handler from this list and add it to the list of free
// handlers.
sorted_active_handlers_.erase(iter);
free_handlers_.push_back(handler);
DCHECK_LE(free_handlers_.size(), max_handlers_);
LogInfo();
// We do not recompute pool stats during release. The side effect is that
// there may be empty thread work sources in the queue. However, any new
// requests will trigger recomputation.
}
std::vector<int64> GetActiveHandlerPrioritiesForTesting()
TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
std::vector<int64> ret;
for (const auto& handler_impl : sorted_active_handlers_) {
ret.push_back(handler_impl->priority());
}
return ret;
}
private:
void RecomputePoolStats(
int num_active_requests, uint64 version,
const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
thread_work_sources);
void LogInfo() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Maximum number of handlers pre-created during pool construction time. The
// number has been chosen expecting each handler might at least want 1
// inter-op thread for execution (during compute intensive workloads like
// inference).
const int max_handlers_;
Eigen::MaxSizeVector<mutex> waiters_mu_;
Eigen::MaxSizeVector<internal::Waiter> queue_waiters_;
std::unique_ptr<internal::RunHandlerThreadPool> run_handler_thread_pool_;
// Thread compatible part used only by lock under RunHandlerPool.
// Handlers are sorted by start time.
// TODO(azaks): sort by the remaining latency budget.
// TODO(chaox): Consider other data structure for maintaining the sorted
// active handlers if the searching overhead(currently O(n)) becomes the
// bottleneck.
std::list<RunHandler::Impl*> sorted_active_handlers_ TF_GUARDED_BY(mu_);
std::vector<RunHandler::Impl*> free_handlers_ TF_GUARDED_BY(mu_);
std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ TF_GUARDED_BY(mu_);
// Histogram of elapsed runtime of every handler (in ms).
histogram::Histogram time_hist_ TF_GUARDED_BY(mu_);
int64 iterations_ TF_GUARDED_BY(mu_);
mutex mu_;
int64 version_ TF_GUARDED_BY(mu_);
const std::vector<double> sub_thread_pool_end_request_percentage_;
};
void RunHandlerPool::Impl::RecomputePoolStats(
int num_active_requests, uint64 version,
const Eigen::MaxSizeVector<internal::ThreadWorkSource*>&
thread_work_sources) {
if (num_active_requests == 0) return;
int sub_thread_pool_id = 0;
for (int i = 0; i < num_active_requests; ++i) {
while (
sub_thread_pool_id <
sub_thread_pool_end_request_percentage_.size() - 1 &&
i >= num_active_requests *
sub_thread_pool_end_request_percentage_[sub_thread_pool_id]) {
sub_thread_pool_id++;
}
thread_work_sources[i]->SetWaiter(version,
&queue_waiters_[sub_thread_pool_id],
&waiters_mu_[sub_thread_pool_id]);
}
int num_threads = run_handler_thread_pool()->NumThreads();
int num_blocking_threads = run_handler_thread_pool()->NumBlockingThreads();
int num_non_blocking_threads = num_threads - num_blocking_threads;
std::vector<int> request_idx_list = ChooseRequestsWithExponentialDistribution(
num_active_requests, num_blocking_threads);
for (int i = 0; i < num_blocking_threads; ++i) {
VLOG(2) << "Set work for tid=" << i
<< " with start_request_idx=" << request_idx_list[i];
run_handler_thread_pool()->SetThreadWorkSources(
i, request_idx_list[i], version, thread_work_sources);
}
request_idx_list = ChooseRequestsWithExponentialDistribution(
num_active_requests, num_non_blocking_threads);
for (int i = 0; i < num_non_blocking_threads; ++i) {
VLOG(2) << "Set work for tid=" << (i + num_blocking_threads)
<< " with start_request_idx=" << request_idx_list[i];
run_handler_thread_pool()->SetThreadWorkSources(
i + num_blocking_threads, request_idx_list[i], version,
thread_work_sources);
}
}
void RunHandlerPool::Impl::LogInfo() {
if (iterations_++ % 50000 == 10 && VLOG_IS_ON(1)) {
int num_active_requests = sorted_active_handlers_.size();
VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
VLOG(1) << "Active session runs: " << num_active_requests;
uint64 now = tensorflow::Env::Default()->NowMicros();
string times_str = "";
string ids_str = "";
auto it = sorted_active_handlers_.cbegin();
for (int i = 0; i < num_active_requests; ++i) {
if (i > 0) {
times_str += " ";
ids_str += " ";
}
times_str +=
strings::StrCat((now - (*it)->start_time_us()) / 1000.0, " ms.");
ids_str += strings::StrCat((*it)->tws()->GetTracemeId());
++it;
}
VLOG(1) << "Elapsed times are: " << times_str;
VLOG(1) << "Step ids are: " << ids_str;
}
}
// It is important to return a value such as:
// CurrentThreadId() in [0, NumThreads)
int RunHandler::Impl::ThreadPoolInterfaceWrapper::NumThreads() const {
return run_handler_impl_->pool_impl_->run_handler_thread_pool()->NumThreads();
}
int RunHandler::Impl::ThreadPoolInterfaceWrapper::CurrentThreadId() const {
return run_handler_impl_->pool_impl_->run_handler_thread_pool()
->CurrentThreadId();
}
void RunHandler::Impl::ThreadPoolInterfaceWrapper::Schedule(
std::function<void()> fn) {
return run_handler_impl_->ScheduleIntraOpClosure(std::move(fn));
}
RunHandler::Impl::Impl(RunHandlerPool::Impl* pool_impl)
: pool_impl_(pool_impl) {
thread_pool_interface_.reset(new ThreadPoolInterfaceWrapper(this));
Reset(0, RunOptions::Experimental::RunHandlerPoolOptions());
}
void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
VLOG(3) << "Scheduling inter work for " << tws()->GetTracemeId();
pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), true,
std::move(fn));
}
void RunHandler::Impl::ScheduleIntraOpClosure(std::function<void()> fn) {
VLOG(3) << "Scheduling intra work for " << tws()->GetTracemeId();
pool_impl_->run_handler_thread_pool()->AddWorkToQueue(tws(), false,
std::move(fn));
}
void RunHandler::Impl::Reset(
int64 step_id,
const RunOptions::Experimental::RunHandlerPoolOptions& options) {
start_time_us_ = tensorflow::Env::Default()->NowMicros();
step_id_ = step_id;
options_ = options;
tws_.SetTracemeId(step_id);
}
RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
: impl_(new Impl(num_inter_op_threads, 0)) {}
RunHandlerPool::RunHandlerPool(int num_inter_op_threads,
int num_intra_op_threads)
: impl_(new Impl(num_inter_op_threads, num_intra_op_threads)) {}
RunHandlerPool::~RunHandlerPool() {}
std::unique_ptr<RunHandler> RunHandlerPool::Get(
int64 step_id, int64 timeout_in_ms,
const RunOptions::Experimental::RunHandlerPoolOptions& options) {
return impl_->Get(step_id, timeout_in_ms, options);
}
std::vector<int64> RunHandlerPool::GetActiveHandlerPrioritiesForTesting()
const {
return impl_->GetActiveHandlerPrioritiesForTesting();
}
RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
impl_->ScheduleInterOpClosure(std::move(fn));
}
thread::ThreadPoolInterface* RunHandler::AsIntraThreadPoolInterface() {
return impl_->thread_pool_interface();
}
RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
} // namespace tensorflow