Respect num_workers parameter in async net executor
Making sure we honor num_workers parameter in async executor
diff --git a/caffe2/core/net_async_base.cc b/caffe2/core/net_async_base.cc
index 5f83914..273dd91 100644
--- a/caffe2/core/net_async_base.cc
+++ b/caffe2/core/net_async_base.cc
@@ -28,7 +28,7 @@
CAFFE2_DEFINE_int(
caffe2_net_async_cpu_pool_size,
0,
- "Number of threads in CPU pool (default - number of cores)");
+ "Number of threads in CPU pool by default");
CAFFE2_DEFINE_bool(
caffe2_net_async_check_stream_status,
@@ -62,24 +62,20 @@
events_.push_back(&op->event());
}
- gpu_pools_.resize(FLAGS_caffe2_net_async_max_gpus);
- cpu_pools_.resize(FLAGS_caffe2_net_async_max_numa_nodes);
- DeviceOption cpu_option;
- cpu_option.set_device_type(CPU);
- cpu_pool_ = ThreadPoolRegistry()->Create(
- DeviceTypeName(cpu_option.device_type()), cpu_option);
+ num_workers_ = net_def->has_num_workers() ? net_def->num_workers() : -1;
}
std::shared_ptr<TaskThreadPool> AsyncNetBase::pool_getter(
- std::vector<std::shared_ptr<TaskThreadPool>>& pools,
- int pool_idx,
- const DeviceOption& device_option) {
+ PoolsMap& pools,
+ int device_type,
+ int device_id,
+ int pool_size) {
std::unique_lock<std::mutex> pools_lock(pools_mutex_);
- auto pool = pools[pool_idx];
+ auto pool = pools[device_id][pool_size];
if (!pool) {
pool = ThreadPoolRegistry()->Create(
- DeviceTypeName(device_option.device_type()), device_option);
- pools[pool_idx] = pool;
+ DeviceTypeName(device_type), device_id, pool_size);
+ pools[device_id][pool_size] = pool;
}
return pool;
}
@@ -88,21 +84,17 @@
const DeviceOption& device_option) {
if (device_option.device_type() == CPU) {
auto numa_node_id = device_option.numa_node_id();
- if (numa_node_id == -1) {
- return cpu_pool_;
- } else {
- CAFFE_ENFORCE(
- numa_node_id >= 0 &&
- numa_node_id < FLAGS_caffe2_net_async_max_numa_nodes,
- "Invalid NUMA node id: " + caffe2::to_string(numa_node_id));
- return pool_getter(cpu_pools_, numa_node_id, device_option);
- }
+ CAFFE_ENFORCE(
+ numa_node_id >= -1 &&
+ numa_node_id < FLAGS_caffe2_net_async_max_numa_nodes,
+ "Invalid NUMA node id: " + caffe2::to_string(numa_node_id));
+ return pool_getter(cpu_pools_, CPU, numa_node_id, num_workers_);
} else if (device_option.device_type() == CUDA) {
auto gpu_id = device_option.cuda_gpu_id();
CAFFE_ENFORCE(
gpu_id >= 0 && gpu_id < FLAGS_caffe2_net_async_max_gpus,
"Invalid GPU id: " + caffe2::to_string(gpu_id));
- return pool_getter(gpu_pools_, gpu_id, device_option);
+ return pool_getter(gpu_pools_, CUDA, gpu_id, num_workers_);
} else {
CAFFE_THROW(
"Unsupported device type " +
@@ -238,46 +230,45 @@
AsyncNetBase::~AsyncNetBase() {}
-CAFFE_DEFINE_SHARED_REGISTRY(
- ThreadPoolRegistry,
- TaskThreadPool,
- const DeviceOption&);
+CAFFE_DEFINE_SHARED_REGISTRY(ThreadPoolRegistry, TaskThreadPool, int, int);
-namespace {
-std::shared_ptr<TaskThreadPool> AsyncNetCPUThreadPoolCreator(
- const DeviceOption& device_option) {
- CAFFE_ENFORCE_EQ(
- device_option.device_type(),
- CPU,
- "Unexpected device type for CPU thread pool");
- return GetAsyncNetCPUThreadPool(device_option.numa_node_id());
-}
-} // namespace
-
-CAFFE_REGISTER_CREATOR(ThreadPoolRegistry, CPU, AsyncNetCPUThreadPoolCreator);
+CAFFE_REGISTER_CREATOR(ThreadPoolRegistry, CPU, GetAsyncNetCPUThreadPool);
/* static */
-std::shared_ptr<TaskThreadPool> GetAsyncNetCPUThreadPool(int numa_node_id) {
+std::shared_ptr<TaskThreadPool> GetAsyncNetCPUThreadPool(
+ int numa_node_id,
+ int pool_size) {
// Note: numa_node_id = -1 (DeviceOption's default value) corresponds to
// no NUMA used
- static std::unordered_map<int, std::weak_ptr<TaskThreadPool>> pools;
+ static std::
+ unordered_map<int, std::unordered_map<int, std::weak_ptr<TaskThreadPool>>>
+ pools;
static std::mutex pool_mutex;
std::lock_guard<std::mutex> lock(pool_mutex);
- std::shared_ptr<TaskThreadPool> shared_pool = nullptr;
- if (pools.count(numa_node_id)) {
- shared_pool = pools.at(numa_node_id).lock();
- }
- if (!shared_pool) {
- auto pool_size = FLAGS_caffe2_net_async_cpu_pool_size;
- if (pool_size <= 0) {
+ if (pool_size <= 0) {
+ if (FLAGS_caffe2_net_async_cpu_pool_size > 0) {
+ pool_size = FLAGS_caffe2_net_async_cpu_pool_size;
+ LOG(INFO) << "Using default CPU pool size: " << pool_size
+ << "; NUMA node id: " << numa_node_id;
+ } else {
auto num_cores = std::thread::hardware_concurrency();
CAFFE_ENFORCE(num_cores > 0, "Failed to get number of CPU cores");
+ LOG(INFO) << "Using estimated CPU pool size: " << num_cores
+ << "; NUMA node id: " << numa_node_id;
pool_size = num_cores;
}
- LOG(INFO) << "Using cpu pool size: " << pool_size;
+ } else {
+ LOG(INFO) << "Using specified CPU pool size: " << pool_size
+ << "; NUMA node id: " << numa_node_id;
+ }
+
+ auto shared_pool = pools[numa_node_id][pool_size].lock();
+ if (!shared_pool) {
+ LOG(INFO) << "Created CPU pool, size: " << pool_size
+ << "; NUMA node id: " << numa_node_id;
shared_pool = std::make_shared<TaskThreadPool>(pool_size, numa_node_id);
- pools[numa_node_id] = shared_pool;
+ pools[numa_node_id][pool_size] = shared_pool;
}
return shared_pool;
}
diff --git a/caffe2/core/net_async_base.h b/caffe2/core/net_async_base.h
index 0feb1ae..a69f8e3 100644
--- a/caffe2/core/net_async_base.h
+++ b/caffe2/core/net_async_base.h
@@ -58,26 +58,28 @@
// Pools and streams
std::mutex pools_mutex_;
- std::shared_ptr<TaskThreadPool> cpu_pool_;
- std::vector<std::shared_ptr<TaskThreadPool>> cpu_pools_;
- std::vector<std::shared_ptr<TaskThreadPool>> gpu_pools_;
+ // first int key - device id, second - pool size, one pool per (device, size)
+ typedef std::unordered_map<
+ int,
+ std::unordered_map<int, std::shared_ptr<TaskThreadPool>>>
+ PoolsMap;
+ PoolsMap cpu_pools_;
+ PoolsMap gpu_pools_;
static thread_local std::vector<int> stream_counters_;
+ int num_workers_;
DISABLE_COPY_AND_ASSIGN(AsyncNetBase);
private:
- std::shared_ptr<TaskThreadPool> pool_getter(
- std::vector<std::shared_ptr<TaskThreadPool>>& pools,
- int pool_idx,
- const DeviceOption& device_option);
+ std::shared_ptr<TaskThreadPool>
+ pool_getter(PoolsMap& pools, int device_type, int device_id, int pool_size);
};
-CAFFE_DECLARE_SHARED_REGISTRY(
- ThreadPoolRegistry,
- TaskThreadPool,
- const DeviceOption&);
+CAFFE_DECLARE_SHARED_REGISTRY(ThreadPoolRegistry, TaskThreadPool, int, int);
-std::shared_ptr<TaskThreadPool> GetAsyncNetCPUThreadPool(int numa_node_id);
+std::shared_ptr<TaskThreadPool> GetAsyncNetCPUThreadPool(
+ int numa_node_id,
+ int pool_size);
} // namespace caffe2
diff --git a/caffe2/core/net_async_gpu_thread_pool.h b/caffe2/core/net_async_gpu_thread_pool.h
index 4fa7caf..faf7082 100644
--- a/caffe2/core/net_async_gpu_thread_pool.h
+++ b/caffe2/core/net_async_gpu_thread_pool.h
@@ -5,7 +5,9 @@
namespace caffe2 {
-std::shared_ptr<TaskThreadPool> GetAsyncNetGPUThreadPool(int gpu_id);
+std::shared_ptr<TaskThreadPool> GetAsyncNetGPUThreadPool(
+ int gpu_id,
+ int pool_size);
} // namespace caffe2
diff --git a/caffe2/core/net_async_gpu_thread_pool_gpu.cc b/caffe2/core/net_async_gpu_thread_pool_gpu.cc
index 93726fb..d653879 100644
--- a/caffe2/core/net_async_gpu_thread_pool_gpu.cc
+++ b/caffe2/core/net_async_gpu_thread_pool_gpu.cc
@@ -6,20 +6,16 @@
namespace caffe2 {
-namespace {
-std::shared_ptr<TaskThreadPool> AsyncNetGPUThreadPoolCreator(
- const DeviceOption& device_option) {
- CAFFE_ENFORCE_EQ(
- device_option.device_type(),
- CUDA,
- "Unexpected device type for CUDA thread pool");
- return GetAsyncNetGPUThreadPool(device_option.cuda_gpu_id());
-}
-} // namespace
+CAFFE_REGISTER_CREATOR(ThreadPoolRegistry, CUDA, GetAsyncNetGPUThreadPool);
-CAFFE_REGISTER_CREATOR(ThreadPoolRegistry, CUDA, AsyncNetGPUThreadPoolCreator);
-
-std::shared_ptr<TaskThreadPool> GetAsyncNetGPUThreadPool(int gpu_id) {
+std::shared_ptr<TaskThreadPool> GetAsyncNetGPUThreadPool(
+ int gpu_id,
+ int pool_size) {
+ // For GPU, use per device thread pools of predefined constant size
+ if (pool_size != FLAGS_caffe2_threads_per_gpu) {
+ LOG(INFO) << "Overriding GPU pool size: using "
+ << FLAGS_caffe2_threads_per_gpu << " threads per GPU";
+ }
static std::unordered_map<int, std::weak_ptr<TaskThreadPool>> pools;
static std::mutex pool_mutex;
std::lock_guard<std::mutex> lock(pool_mutex);