Move OMP/MKL thread initialization into ATen/Parallel (#19011)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19011
ghimport-source-id: 432e31eccfd0e59fa21a790f861e6b2ff4fdbac6
Differential Revision: D14846034
Pulled By: ilia-cher
fbshipit-source-id: d9d03c761d34bac80e09ce776e41c20fd3b04389
diff --git a/aten/src/ATen/ATen.h b/aten/src/ATen/ATen.h
index 16adfc3..3d55be6 100644
--- a/aten/src/ATen/ATen.h
+++ b/aten/src/ATen/ATen.h
@@ -1,7 +1,7 @@
#pragma once
#include <c10/core/Allocator.h>
-#include <ATen/CPUGeneral.h>
+#include <ATen/core/ATenGeneral.h>
#include <ATen/Context.h>
#include <ATen/Device.h>
#include <ATen/DeviceGuard.h>
diff --git a/aten/src/ATen/CPUGeneral.cpp b/aten/src/ATen/CPUGeneral.cpp
deleted file mode 100644
index 910e3ae..0000000
--- a/aten/src/ATen/CPUGeneral.cpp
+++ /dev/null
@@ -1,16 +0,0 @@
-#include <ATen/CPUGeneral.h>
-#include <atomic>
-#include <memory>
-#include <thread>
-
-namespace at {
-// Lock free atomic type
-std::atomic<int> num_threads(-1);
-
-void set_num_threads(int num_threads_) {
- if (num_threads_ >= 0)
- num_threads.store(num_threads_);
-}
-
-int get_num_threads() { return num_threads.load(); }
-}
diff --git a/aten/src/ATen/CPUGeneral.h b/aten/src/ATen/CPUGeneral.h
deleted file mode 100644
index 246af37..0000000
--- a/aten/src/ATen/CPUGeneral.h
+++ /dev/null
@@ -1,12 +0,0 @@
-#pragma once
-
-// Using CAFFE2_API is crucial as otherwise you'll see
-// linking errors using MSVC
-// See https://msdn.microsoft.com/en-us/library/a90k134d.aspx
-// This header adds this if using CAFFE2_API
-#include <ATen/core/ATenGeneral.h>
-
-namespace at {
-CAFFE2_API void set_num_threads(int);
-CAFFE2_API int get_num_threads();
-}
diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h
index 9033494..925b327 100644
--- a/aten/src/ATen/Context.h
+++ b/aten/src/ATen/Context.h
@@ -1,6 +1,6 @@
#pragma once
-#include <ATen/CPUGeneral.h>
+#include <ATen/core/ATenGeneral.h>
#include <ATen/Type.h>
#include <ATen/TypeExtendedInterface.h>
#include <ATen/Utils.h>
@@ -164,12 +164,6 @@
static inline void init() {
globalContext();
- if (const char *env_p = std::getenv("OMP_NUM_THREADS")) {
- at::set_num_threads(std::stoi(env_p));
- }
- if (const char *env_p = std::getenv("MKL_NUM_THREADS")) {
- at::set_num_threads(std::stoi(env_p));
- }
}
static inline TypeExtendedInterface& getNonVariableType(Backend p, ScalarType s) {
diff --git a/aten/src/ATen/Parallel.cpp b/aten/src/ATen/Parallel.cpp
new file mode 100644
index 0000000..3345e09
--- /dev/null
+++ b/aten/src/ATen/Parallel.cpp
@@ -0,0 +1,63 @@
+#include <ATen/Parallel.h>
+
+#include <atomic>
+
+#ifdef TH_BLAS_MKL
+#include <mkl.h>
+#endif
+
+namespace at {
+
+namespace {
+// Number of threads set by the user
+std::atomic<int> num_threads(-1);
+}
+
+void init_num_threads() {
+ auto nthreads = num_threads.load();
+ if (nthreads > 0) {
+ set_num_threads(nthreads);
+ } else {
+#if defined(_OPENMP) && defined(TH_BLAS_MKL)
+ // If we are using MKL an OpenMP make sure the number of threads match.
+ // Otherwise, MKL and our OpenMP-enabled functions will keep changing the
+ // size of the OpenMP thread pool, resulting in worse performance (and memory
+ // leaks in GCC 5.4)
+ omp_set_num_threads(mkl_get_max_threads());
+#endif
+ }
+}
+
+void set_num_threads(size_t nthreads) {
+ if (nthreads == 0) {
+ return;
+ }
+ num_threads.store(nthreads);
+#ifdef _OPENMP
+ omp_set_num_threads(nthreads);
+#endif
+#ifdef TH_BLAS_MKL
+ mkl_set_num_threads(nthreads);
+
+ // because PyTorch uses OpenMP outside of MKL invocations
+ // as well, we want this flag to be false, so that
+ // threads aren't destroyed and recreated across every
+ // MKL / non-MKL boundary of OpenMP usage
+ // See https://github.com/pytorch/pytorch/issues/13757
+ mkl_set_dynamic(false);
+#endif
+}
+
+// Explicitly calling omp_get_max_threads() as the size of the parallel
+// region might be different in the new thread;
+// Use init_num_threads() during thread initialization to ensure
+// consistent size of parallel region in different threads
+size_t get_num_threads() {
+#ifdef _OPENMP
+ return omp_get_max_threads();
+#else
+ return 1;
+#endif
+}
+
+}
diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h
index fe073f1..0d00a9e 100644
--- a/aten/src/ATen/Parallel.h
+++ b/aten/src/ATen/Parallel.h
@@ -22,14 +22,17 @@
return (x + y - 1) / y;
}
-inline int get_max_threads() {
-#ifdef _OPENMP
- return omp_get_max_threads();
-#else
- return 1;
-#endif
-}
+// Called during new thread initialization
+C10_API void init_num_threads();
+// Sets the number of threads to be used in parallel region
+C10_API void set_num_threads(size_t);
+
+// Returns the number of threads used in parallel region
+C10_API size_t get_num_threads();
+
+// Returns the current thread number (starting from 0)
+// in the current parallel region, or 0 in the sequential region
inline int get_thread_num() {
#ifdef _OPENMP
return omp_get_thread_num();
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index 3da24a1..d9ba8fa 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -73,7 +73,7 @@
template<class scalar_t>
void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
AT_ERROR("triangular_solve only takes float or double Tensors");
-}
+}
#ifdef USE_LAPACK
template<> void lapackSolve<double>(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp
index e86444a..ae823a3 100644
--- a/aten/src/ATen/native/TensorIterator.cpp
+++ b/aten/src/ATen/native/TensorIterator.cpp
@@ -357,7 +357,7 @@
int64_t numel = this->numel();
if (numel == 0) {
return;
- } else if (numel < internal::GRAIN_SIZE || at::get_max_threads() == 1) {
+ } else if (numel < internal::GRAIN_SIZE || at::get_num_threads() == 1) {
return serial_for_each(loop, {0, numel});
} else {
at::parallel_for(0, numel, internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
diff --git a/aten/src/ATen/native/TensorIteratorReduce.cpp b/aten/src/ATen/native/TensorIteratorReduce.cpp
index f6d2028..228bb13 100644
--- a/aten/src/ATen/native/TensorIteratorReduce.cpp
+++ b/aten/src/ATen/native/TensorIteratorReduce.cpp
@@ -16,7 +16,8 @@
void TensorIterator::parallel_reduce(const loop2d_t& loop) {
AT_CHECK(ntensors() == 2, "parallel_reduce only supports one input and one output");
int64_t numel = this->numel();
- if (numel < at::internal::GRAIN_SIZE || at::get_max_threads() == 1 || at::in_parallel_region()) {
+ if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
+ at::in_parallel_region()) {
serial_for_each(loop, {0, numel});
} else if (use_two_pass_reduction(*this)) {
two_pass_reduction(*this, loop);
@@ -30,7 +31,7 @@
}
static void two_pass_reduction(TensorIterator& iter, const loop2d_t& loop) {
- int max_threads = at::get_max_threads();
+ int max_threads = at::get_num_threads();
auto& dst = iter.tensor(0);
auto buffer_shape = DimVector(dst.sizes());
@@ -65,7 +66,7 @@
/// Chooses a dimension over which to parallelize. Prefers the outer-most
/// dimension thats larger than the number of available threads.
static int find_split_dim(TensorIterator& iter) {
- int num_threads = at::get_max_threads();
+ int num_threads = at::get_num_threads();
auto shape = iter.shape();
// start with the outer-most dimension
@@ -125,7 +126,8 @@
if (tensor(0).numel() == 1) {
loop(*this);
}
- else if (numel() < at::internal::GRAIN_SIZE || at::get_max_threads() == 1 || at::in_parallel_region() || !parallelize) {
+ else if (numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
+ at::in_parallel_region() || !parallelize) {
auto reduce_dims = num_reduce_dims();
auto non_reduced_shape = shape.slice(reduce_dims, shape.size() - reduce_dims);
@@ -154,7 +156,7 @@
sub_iter.narrow(dim, begin, end - begin);
// On some broken setups, `#ifdef _OPENMP` is true,
- // and `get_max_threads` returns > 1, but
+ // and `get_num_threads` returns > 1, but
// `#pragma omp parallel` is ignored.
// There is no API to check for this, so we need to explicitly
// stop trying to parallelize if we've already gotten here.
diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h
index 7500781..dbc469a 100644
--- a/aten/src/ATen/native/cpu/Reduce.h
+++ b/aten/src/ATen/native/cpu/Reduce.h
@@ -97,10 +97,11 @@
};
acc_t total_acc = init;
auto numel = sub_iter.numel();
- if (numel < at::internal::GRAIN_SIZE || at::get_max_threads() == 1 || at::in_parallel_region()) {
+ if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
+ at::in_parallel_region()) {
total_acc = reduction_body(total_acc, 0, numel);
} else {
- int max_threads = at::get_max_threads();
+ int max_threads = at::get_num_threads();
AT_ASSERT(max_threads > 0);
static_assert(
!std::is_same<acc_t, bool>::value,
diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt
index bcee7ef..20b643c 100644
--- a/aten/src/ATen/test/CMakeLists.txt
+++ b/aten/src/ATen/test/CMakeLists.txt
@@ -19,7 +19,7 @@
${CMAKE_CURRENT_SOURCE_DIR}/test_parallel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/thread_init_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp
diff --git a/aten/src/ATen/test/tbb_init_test.cpp b/aten/src/ATen/test/thread_init_test.cpp
similarity index 73%
rename from aten/src/ATen/test/tbb_init_test.cpp
rename to aten/src/ATen/test/thread_init_test.cpp
index 5e17fae..adbb324 100644
--- a/aten/src/ATen/test/tbb_init_test.cpp
+++ b/aten/src/ATen/test/thread_init_test.cpp
@@ -9,24 +9,24 @@
// will throw an exception when multiple threads call
// their first parallel construct.
void test(int given_num_threads) {
+ at::init_num_threads();
auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat));
- if (given_num_threads >= 0) {
- ASSERT(at::get_num_threads() == given_num_threads);
- } else {
- ASSERT(at::get_num_threads() == -1);
- }
+ ASSERT(given_num_threads >= 0);
+ ASSERT(at::get_num_threads() == given_num_threads);
auto t_sum = t.sum();
- for (int i = 0; i < 1000; i ++) {
+ for (int i = 0; i < 1000; ++i) {
t_sum = t_sum + t.sum();
}
}
int main() {
+ at::init_num_threads();
at::manual_seed(123);
- test(-1);
- std::thread t1(test, -1);
+ test(at::get_num_threads());
+ std::thread t1(test, at::get_num_threads());
t1.join();
+
at::set_num_threads(4);
std::thread t2(test, 4);
std::thread t3(test, 4);
@@ -34,6 +34,7 @@
t4.join();
t3.join();
t2.join();
+
at::set_num_threads(5);
test(5);
diff --git a/aten/src/TH/THGeneral.cpp b/aten/src/TH/THGeneral.cpp
index 8df8ff5..f8f8a3b 100644
--- a/aten/src/TH/THGeneral.cpp
+++ b/aten/src/TH/THGeneral.cpp
@@ -4,10 +4,6 @@
#include <c10/core/CPUAllocator.h>
#endif
-#ifdef _OPENMP
-#include <omp.h>
-#endif
-
#ifndef TH_HAVE_THREAD
#define __thread
#elif _MSC_VER
@@ -24,10 +20,6 @@
#include <malloc/malloc.h>
#endif
-#ifdef TH_BLAS_MKL
-#include <mkl.h>
-#endif
-
/* Torch Error Handling */
static void defaultErrorHandlerFunction(const char *msg, void *data)
{
@@ -224,53 +216,6 @@
return expm1(x);
}
-void THSetNumThreads(int num_threads)
-{
-#ifdef _OPENMP
- omp_set_num_threads(num_threads);
-#endif
-#ifdef TH_BLAS_MKL
- mkl_set_num_threads(num_threads);
-
- // because PyTorch uses OpenMP outside of MKL invocations
- // as well, we want this flag to be false, so that
- // threads aren't destroyed and recreated across every
- // MKL / non-MKL boundary of OpenMP usage
- // See https://github.com/pytorch/pytorch/issues/13757
- mkl_set_dynamic(false);
-#endif
-
-}
-
-int THGetNumThreads(void)
-{
-#ifdef _OPENMP
- return omp_get_max_threads();
-#else
- return 1;
-#endif
-}
-
-int THGetNumCores(void)
-{
-#ifdef _OPENMP
- return omp_get_num_procs();
-#else
- return 1;
-#endif
-}
-
-TH_API void THInferNumThreads(void)
-{
-#if defined(_OPENMP) && defined(TH_BLAS_MKL)
- // If we are using MKL an OpenMP make sure the number of threads match.
- // Otherwise, MKL and our OpenMP-enabled functions will keep changing the
- // size of the OpenMP thread pool, resulting in worse performance (and memory
- // leaks in GCC 5.4)
- omp_set_num_threads(mkl_get_max_threads());
-#endif
-}
-
THDescBuff _THSizeDesc(const int64_t *size, const int64_t ndim) {
const int L = TH_DESC_BUFF_LEN;
THDescBuff buf;
diff --git a/aten/src/TH/THGeneral.h.in b/aten/src/TH/THGeneral.h.in
index 895ebf4..f5ea797 100644
--- a/aten/src/TH/THGeneral.h.in
+++ b/aten/src/TH/THGeneral.h.in
@@ -100,10 +100,6 @@
TH_API void THSetGCHandler( void (*torchGCHandlerFunction)(void *data), void *data );
// this hook should only be called by custom allocator functions
TH_API void THHeapUpdate(ptrdiff_t size);
-TH_API void THSetNumThreads(int num_threads);
-TH_API int THGetNumThreads(void);
-TH_API int THGetNumCores(void);
-TH_API void THInferNumThreads(void);
#define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__)
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 522fcca..98512af 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -2094,7 +2094,7 @@
r"""
get_num_threads() -> int
-Gets the number of OpenMP threads used for parallelizing CPU operations
+Gets the number of threads used for parallelizing CPU operations
""")
add_docstr(torch.gt,
@@ -4282,7 +4282,10 @@
r"""
set_num_threads(int)
-Sets the number of OpenMP threads used for parallelizing CPU operations
+Sets the number of threads used for parallelizing CPU operations.
+WARNING:
+To ensure that the correct number of threads is used, set_num_threads
+must be called before running eager, JIT or autograd code.
""")
add_docstr(torch.sigmoid,
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index df44a3e..9d47ae7 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -13,6 +13,7 @@
#include <ATen/ExpandUtils.h>
#include <ATen/dlpack.h>
#include <ATen/DLConvertor.h>
+#include <ATen/Parallel.h>
#include <ATen/Utils.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
@@ -147,14 +148,13 @@
static PyObject * THPModule_getNumThreads(PyObject *module)
{
- return PyLong_FromLong(THGetNumThreads());
+ return PyLong_FromLong(at::get_num_threads());
}
static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg)
{
THPUtils_assert(THPUtils_checkLong(arg), "set_num_threads expects an int, "
"but got %s", THPUtils_typename(arg));
- THSetNumThreads((int)THPUtils_unpackLong(arg));
at::set_num_threads((int)THPUtils_unpackLong(arg));
Py_RETURN_NONE;
}
@@ -541,7 +541,7 @@
#endif
PyObject* initModule() {
HANDLE_TH_ERRORS
- THInferNumThreads();
+ at::init_num_threads();
#define ASSERT_TRUE(cmd) if (!(cmd)) return nullptr
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index 8882edc..b4716dd 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -9,6 +9,7 @@
#include <ATen/DeviceGuard.h>
#include <ATen/ExpandUtils.h>
+#include <ATen/Parallel.h>
#include <c10/util/Exception.h>
#include <atomic>
@@ -204,7 +205,7 @@
Engine::~Engine() = default;
auto Engine::thread_init(int device) -> void {
- THInferNumThreads();
+ at::init_num_threads();
// Note [Allocating GPUs to autograd threads]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What's our strategy here? Originally, the autograd engine was written