Introducing winograd transformed fp16 nnpack to PT for unet 106 (#47925)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47925
ghstack-source-id: 117004847
Test Plan:
buck run caffe2/fb/custom_ops/unet_106_pt:unet_106_rewrite
buck run caffe2/fb/custom_ops/unet_106_pt:tests
Reviewed By: dreiss
Differential Revision: D24822418
fbshipit-source-id: 0c0bc0772e4c878e979ee3d2078105377e220c43
diff --git a/caffe2/utils/threadpool/pthreadpool.h b/caffe2/utils/threadpool/pthreadpool.h
index 27935fe..0c6cc36 100644
--- a/caffe2/utils/threadpool/pthreadpool.h
+++ b/caffe2/utils/threadpool/pthreadpool.h
@@ -8,6 +8,25 @@
#include <stddef.h> // for size_t
#include <stdint.h> // for uint32_t
+#ifdef USE_PTHREADPOOL
+// This is a hack.
+// Mainly introduced here because
+// 1. NNPACK can be compiled to use internal legacy threadpool implementation because much of C2 depends on that.
+// 2. Then if we want to use NNPACK in PyTorch, which uses new pthreadpool, then we will supply new pthreadpool pointer
+// to NNPACK. This will not work if NNPACK is compiled with internal legacy threadpool. Thus this guard
+// along with changes in pthreadpool_impl.cc allows us to override that behavior.
+// It enables us to use NNPACK from pytorch using `caffe2::pthreadpool_()`
+namespace caffe2 {
+class WithCastToNewThreadPool {
+ public:
+ explicit WithCastToNewThreadPool(bool use_new_threadpool);
+ ~WithCastToNewThreadPool();
+ private:
+ bool use_new_threadpool_;
+};
+}
+#endif
+
typedef struct pthreadpool* legacy_pthreadpool_t;
typedef void (*legacy_pthreadpool_function_1d_t)(void*, size_t);
diff --git a/caffe2/utils/threadpool/pthreadpool_impl.cc b/caffe2/utils/threadpool/pthreadpool_impl.cc
index 66326ee..8165ae3 100644
--- a/caffe2/utils/threadpool/pthreadpool_impl.cc
+++ b/caffe2/utils/threadpool/pthreadpool_impl.cc
@@ -1,6 +1,21 @@
#include "caffe2/utils/threadpool/pthreadpool.h"
+#include "caffe2/utils/threadpool/pthreadpool-cpp.h"
#include "caffe2/utils/threadpool/ThreadPool.h"
+#ifdef USE_PTHREADPOOL
+namespace caffe2 {
+namespace {
+static thread_local bool using_new_threadpool{false};
+}
+WithCastToNewThreadPool::WithCastToNewThreadPool(bool use_new_threadpool) {
+ use_new_threadpool_ = using_new_threadpool;
+ using_new_threadpool = use_new_threadpool;
+}
+WithCastToNewThreadPool::~WithCastToNewThreadPool() {
+ using_new_threadpool = use_new_threadpool_;
+}
+}
+#endif
//
// External API
@@ -19,12 +34,25 @@
}
return;
}
+#ifdef USE_PTHREADPOOL
+ if (caffe2::using_new_threadpool) {
+ pthreadpool_parallelize_1d(threadpool, function, argument, range, 0u);
+ } else {
+ reinterpret_cast<caffe2::ThreadPool*>(threadpool)
+ ->run(
+ [function, argument](int threadId, size_t workId) {
+ function(argument, workId);
+ },
+ range);
+ }
+#else
reinterpret_cast<caffe2::ThreadPool*>(threadpool)
->run(
[function, argument](int threadId, size_t workId) {
function(argument, workId);
},
range);
+#endif
}
void legacy_pthreadpool_parallelize_1d(