| #include <ATen/Config.h> |
| #if AT_PARALLEL_NATIVE |
| #include <ATen/Parallel.h> |
| #include <ATen/PTThreadPool.h> |
| |
| #ifndef C10_MOBILE |
| #include <c10/core/thread_pool.h> |
| #else |
| #include <caffe2/utils/threadpool/pthreadpool-cpp.h> |
| #endif // C10_MOBILE |
| |
| #include <atomic> |
| |
| #ifdef _OPENMP |
| #include <omp.h> |
| #endif |
| |
| #ifdef TH_BLAS_MKL |
| #include <mkl.h> |
| #endif |
| |
| namespace at { |
| namespace { |
| // used with _set_in_parallel_region to mark master thread |
| // as in parallel region while executing parallel primitives |
| thread_local bool in_parallel_region_ = false; |
| |
| // thread number (task_id) set by parallel primitive |
| thread_local size_t thread_num_ = 0; |
| |
| void _set_in_parallel_region(bool in_region) { |
| in_parallel_region_ = in_region; |
| } |
| |
| void _set_thread_num(size_t thread_num) { |
| thread_num_ = thread_num; |
| } |
| |
| void _unset_thread_num() { |
| thread_num_ = 0; |
| } |
| |
| #ifndef C10_MOBILE |
| |
| const int NOT_SET = -1; |
| const int CONSUMED = -2; |
| |
| // Number of threads set by the user |
| // NOT_SET -> positive value -> CONSUMED |
| // or |
| // NOT_SET -> CONSUMED |
| // Meaning: |
| // - NOT_SET - pool not initialized, user value is not set |
| // - positive value - pool not initialized, user value set |
| // - CONSUMED - pool is initialized |
| std::atomic<int> num_intraop_threads{NOT_SET}; |
| |
| int _num_pool_threads(int nthreads) { |
| if (nthreads == NOT_SET) { |
| nthreads = intraop_default_num_threads(); |
| } else { |
| TORCH_INTERNAL_ASSERT(nthreads > 0); |
| } |
| // minus one because of the master thread |
| return nthreads - 1; |
| } |
| |
| TaskThreadPoolBase& _get_intraop_pool() { |
| static std::shared_ptr<TaskThreadPoolBase> pool = |
| ThreadPoolRegistry()->Create( |
| "C10", |
| /* device_id */ 0, |
| /* pool_size */ _num_pool_threads(num_intraop_threads.exchange(CONSUMED)), |
| /* create_new */ true); // create a separate thread pool for intra-op |
| return *pool; |
| } |
| |
| #endif // C10_MOBILE |
| |
| // Run lambda function `fn` over `task_id` in [0, `range`) with threadpool. |
| // `fn` will be called with params: (thread_pool_task_id, task_id). |
| void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) { |
| #ifndef C10_MOBILE |
| for (size_t i = 1; i < range; ++i) { |
| _get_intraop_pool().run([fn, i]() { fn((int)i, i); }); |
| } |
| // Run the first task on the current thread directly. |
| fn(0, 0); |
| #else |
| caffe2::PThreadPool* const pool = caffe2::pthreadpool(); |
| TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); |
| |
| pool->run( |
| // PThreadPool::run() is blocking. A std::function [const] reference to |
| // this lambda cannot go out of scope before PThreadPool::run() returns. |
| [&fn](const size_t task_id) { |
| fn(0 /* unused */, task_id); |
| }, range); |
| #endif // C10_MOBILE |
| } |
| |
| // RAII guard helps to support in_parallel_region() and get_thread_num() API. |
| struct ParallelRegionGuard { |
| ParallelRegionGuard(int64_t task_id) { |
| _set_thread_num(task_id); |
| _set_in_parallel_region(true); |
| } |
| |
| ~ParallelRegionGuard() { |
| _set_in_parallel_region(false); |
| _unset_thread_num(); |
| } |
| }; |
| |
| } // namespace |
| |
| namespace internal { |
| |
| void _parallel_run( |
| const int64_t begin, |
| const int64_t end, |
| const int64_t grain_size, |
| const std::function<void(int64_t, int64_t, size_t)>& f) { |
| at::internal::lazy_init_num_threads(); |
| |
| size_t num_tasks, chunk_size; |
| std::tie(num_tasks, chunk_size) = |
| internal::calc_num_tasks_and_chunk_size(begin, end, grain_size); |
| |
| struct { |
| std::atomic_flag err_flag = ATOMIC_FLAG_INIT; |
| std::exception_ptr eptr; |
| std::mutex mutex; |
| volatile size_t remaining; |
| std::condition_variable cv; |
| } state; |
| |
| auto task = [f, &state, begin, end, chunk_size] |
| (int /* unused */, size_t task_id) { |
| int64_t local_start = begin + task_id * chunk_size; |
| if (local_start < end) { |
| int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start)); |
| try { |
| ParallelRegionGuard guard(task_id); |
| f(local_start, local_end, task_id); |
| } catch (...) { |
| if (!state.err_flag.test_and_set()) { |
| state.eptr = std::current_exception(); |
| } |
| } |
| } |
| { |
| std::unique_lock<std::mutex> lk(state.mutex); |
| if (--state.remaining == 0) { |
| state.cv.notify_one(); |
| } |
| } |
| }; |
| state.remaining = num_tasks; |
| _run_with_pool(task, num_tasks); |
| |
| // Wait for all tasks to finish. |
| { |
| std::unique_lock<std::mutex> lk(state.mutex); |
| if (state.remaining != 0) { |
| state.cv.wait(lk); |
| } |
| } |
| if (state.eptr) { |
| std::rethrow_exception(state.eptr); |
| } |
| } |
| |
| } // namespace internal |
| |
| void init_num_threads() { |
| #ifdef _OPENMP |
| omp_set_num_threads(1); |
| #endif |
| |
| #ifdef TH_BLAS_MKL |
| mkl_set_num_threads(1); |
| #endif |
| |
| #ifdef C10_MOBILE |
| caffe2::pthreadpool(); |
| #endif |
| } |
| |
| void set_num_threads(int nthreads) { |
| #ifndef C10_MOBILE |
| TORCH_CHECK(nthreads > 0, "Expected positive number of threads"); |
| int no_value = NOT_SET; |
| if (!num_intraop_threads.compare_exchange_strong(no_value, nthreads)) { |
| // num_intraop_threads either stores a positive integer or CONSUMED, |
| // check that requested size is the same as the current one |
| int stored_nthreads = num_intraop_threads.load(); |
| if (stored_nthreads <= 0) { |
| // plus one because of master thread |
| stored_nthreads = _get_intraop_pool().size() + 1; |
| } |
| if (stored_nthreads != nthreads) { |
| TORCH_WARN( |
| "Cannot set number of intraop threads " |
| "after parallel work has started or after set_num_threads call " |
| "when using native parallel backend"); |
| } |
| } |
| #else |
| caffe2::PThreadPool* const pool = caffe2::pthreadpool(); |
| TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!"); |
| pool->set_thread_count(nthreads); |
| #endif // C10_MOBILE |
| } |
| |
| int get_num_threads() { |
| #ifndef C10_MOBILE |
| // not initializing pool unnecessarily, |
| // because pool cannot be resized after initialization |
| int nthreads = num_intraop_threads.load(); |
| if (nthreads > 0) { |
| return nthreads; |
| } else if (nthreads == NOT_SET) { |
| return intraop_default_num_threads(); |
| } else { |
| TORCH_INTERNAL_ASSERT(nthreads == CONSUMED); |
| return _get_intraop_pool().size() + 1; |
| } |
| #else |
| caffe2::PThreadPool* const pool = caffe2::pthreadpool(); |
| TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!") |
| return in_parallel_region() ? 1 /* current thread */ : pool->get_thread_count(); |
| #endif // C10_MOBILE |
| } |
| |
| int get_thread_num() { |
| return thread_num_; |
| } |
| |
| bool in_parallel_region() { |
| #ifndef C10_MOBILE |
| return in_parallel_region_ || ( |
| num_intraop_threads.load() == CONSUMED && |
| // Needed as intraop_launch() doesn't set in_parallel_region(). |
| _get_intraop_pool().inThreadPool() |
| ); |
| #else |
| return in_parallel_region_; |
| #endif // C10_MOBILE |
| } |
| |
| void intraop_launch(std::function<void()> func) { |
| #ifndef C10_MOBILE |
| if (!in_parallel_region() && get_num_threads() > 1) { |
| _get_intraop_pool().run(func); |
| } else { |
| // execute inline if we're in parallel region |
| func(); |
| } |
| #else |
| // TODO: caffe2::PThreadPool only provides a data-parallel API. |
| // Task parallelism is not currently supported. |
| func(); |
| #endif // C10_MOBILE |
| } |
| |
| std::shared_ptr<c10::ivalue::Future> intraop_launch_future( |
| std::function<void()> func) { |
| #ifndef C10_MOBILE |
| auto future = std::make_shared<c10::ivalue::Future>(c10::NoneType::get()); |
| if (!in_parallel_region() && get_num_threads() > 1) { |
| _get_intraop_pool().run( |
| [func, future]() { |
| func(); |
| future->markCompleted(); |
| } |
| ); |
| } else { |
| func(); |
| future->markCompleted(); |
| } |
| return future; |
| #else |
| // TODO: caffe2::PThreadPool only provides a data-parallel API. |
| // Task parallelism is not currently supported. |
| auto future = std::make_shared<c10::ivalue::Future>(NoneType::get()); |
| func(); |
| future->markCompleted(); |
| return future; |
| #endif // C10_MOBILE |
| } |
| |
| } // namespace at |
| #endif |