blob: 4134bb8669de7b9199700898f233bea22c719077 [file] [log] [blame] [edit]
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/extension/threadpool/threadpool.h>
#include <algorithm>
#include <atomic>
#include <memory>
#include <executorch/extension/threadpool/threadpool_guard.h>
#include <executorch/runtime/platform/assert.h>
#include <cpuinfo.h>
namespace executorch::extension::threadpool {
#if !(defined(WIN32))
namespace {
// After fork, the child process inherits the data-structures of the parent
// process' thread-pool, but since those threads don't exist, the thread-pool
// is corrupt. It's leaked in order to prevent segfaults.
// Ref: https://github.com/pytorch/pytorch/issues/54752#issuecomment-810315302
bool leak_corrupted_threadpool = false;
void child_atfork() {
leak_corrupted_threadpool = true;
}
} // namespace
#endif
ThreadPool::ThreadPool(size_t thread_count)
: threadpool_(pthreadpool_create(thread_count), pthreadpool_destroy) {}
size_t ThreadPool::get_thread_count() const {
std::lock_guard<std::mutex> lock{mutex_};
ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");
return pthreadpool_get_threads_count(threadpool_.get());
}
bool ThreadPool::_unsafe_reset_threadpool(uint32_t new_thread_count) {
// No need to do anything if the count is same or 0
if (new_thread_count == get_thread_count() || new_thread_count == 0) {
return true;
}
std::lock_guard<std::mutex> lock{mutex_};
threadpool_.reset(pthreadpool_create(new_thread_count));
return true;
}
void ThreadPool::run(
const std::function<void(size_t)>& fn,
const size_t range) {
// Run on same thread if NoThreadPoolGuard guard is enabled
if (NoThreadPoolGuard::is_enabled()) {
for (size_t i = 0; i < range; ++i) {
fn(i);
}
return;
}
std::lock_guard<std::mutex> lock{mutex_};
ET_CHECK_MSG(!NoThreadPoolGuard::is_enabled(), "Inside a threadpool guard!");
ET_CHECK_MSG(threadpool_.get(), "Invalid threadpool!");
struct Context final {
const std::function<void(size_t)>& fn;
} context{
fn,
};
pthreadpool_parallelize_1d(
threadpool_.get(),
// Note: pthreadpool_parallelize_1d() is a blocking function. The
// function pointer to this lambda passed on to
// pthreadpool_parallelize_1d() cannot go out of scope until
// pthreadpool_parallelize_1d() returns.
[](void* const context, const size_t item) {
NoThreadPoolGuard guard;
reinterpret_cast<Context*>(context)->fn(item);
},
&context,
range,
0u);
}
// get_threadpool is not thread safe due to leak_corrupted_threadpool
// Make this part threadsafe: TODO(kimishpatel)
ThreadPool* get_threadpool() {
ET_CHECK_MSG(cpuinfo_initialize(), "cpuinfo initialization failed");
int num_threads = cpuinfo_get_processors_count();
/*
* For llvm-tsan, holding limit for the number of locks for a single thread
* is 63 (because of comparison < 64 instead of <=). pthreadpool's worst
* case is the number of threads in a pool. So we want to limit the threadpool
* size to 64 when running with tsan. However, sometimes it is tricky to
* detect if we are running under tsan, for now capping the default
* threadcount to the tsan limit unconditionally.
*/
constexpr int tsan_thread_limit = 63;
num_threads = std::min(num_threads, tsan_thread_limit);
static auto threadpool = std::make_unique<ThreadPool>(num_threads);
// Inheriting from old threadpool to get around segfault issue
// commented above at child_atfork
#if !(defined(WIN32))
// @lint-ignore CLANGTIDY facebook-hte-std::once_flag
static std::once_flag flag;
// @lint-ignore CLANGTIDY facebook-hte-std::call_once
std::call_once(
flag, []() { pthread_atfork(nullptr, nullptr, child_atfork); });
if ET_UNLIKELY (leak_corrupted_threadpool) {
leak_corrupted_threadpool = false;
if (auto leaked = threadpool.release()) {
auto t = leaked->get_thread_count();
threadpool = std::make_unique<ThreadPool>(t);
}
}
#endif
return threadpool.get();
}
pthreadpool_t get_pthreadpool() {
if (NoThreadPoolGuard::is_enabled()) {
return nullptr;
}
ThreadPool* const threadpool = get_threadpool();
ET_CHECK_MSG(threadpool, "Failed to acquire an instance of ThreadPool!");
return threadpool->threadpool_.get();
}
} // namespace executorch::extension::threadpool