[Pytorch] Add python binding to use mobile cpu allocator. (#52323)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52323
Using default cpu allocator for ops executed on qnnpack backend will result in
asan failures with heap overflow since qnnpack (and xnnpack) can access input
beyond their and/beginning.
Here we are enabling this feature specifically to enable dynamic sparse linear op test
using qnnpack engine. In dynamic linear op, the fp32 bias is not packed and
hence can result in out-of-bound access.
Test Plan: test_set_default_mobile_cpu_allocator.py
Reviewed By: z-a-f
Differential Revision: D26263481
fbshipit-source-id: a49227cac7e6781b0db4a156ca734d7671972d9f
diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp
index a927934..709c668 100644
--- a/aten/src/ATen/Context.cpp
+++ b/aten/src/ATen/Context.cpp
@@ -267,4 +267,21 @@
display_vmap_fallback_warnings_ = enabled;
}
+void Context::setDefaultMobileCPUAllocator() {
+ TORCH_CHECK(prev_allocator_ptr_ == nullptr,
+ "Already within the scope of another non-default cpu allocator."
+ "Cannot set another allocator.");
+ // Setting the priority high to make sure no other allocator gets used instead of this.
+ prev_allocator_ptr_ = c10::GetCPUAllocator();
+ c10::SetCPUAllocator(c10::GetDefaultMobileCPUAllocator(), /*priority*/ 100);
+}
+
+void Context::unsetDefaultMobileCPUAllocator() {
+ TORCH_CHECK(prev_allocator_ptr_ != nullptr,
+ "setDefaultMobileCPUAllocator must have been called "
+ "before unsetDefaultMobileCPUAllocator.");
+ // Setting the priority high to make sure no other allocator gets used instead of this.
+ c10::SetCPUAllocator(prev_allocator_ptr_ , /*priority*/ 100);
+ prev_allocator_ptr_ = nullptr;
+}
} // namespace at
diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h
index 2041aa6..4dfb316 100644
--- a/aten/src/ATen/Context.h
+++ b/aten/src/ATen/Context.h
@@ -199,6 +199,9 @@
void setDisplayVmapFallbackWarnings(bool enabled);
bool areVmapFallbackWarningsEnabled() const;
+ void setDefaultMobileCPUAllocator();
+ void unsetDefaultMobileCPUAllocator();
+
private:
void initCUDAIfNeeded(DeviceType p) {
if (p == DeviceType::CUDA) {
@@ -229,6 +232,8 @@
c10::optional<at::QEngine> quantized_engine = c10::nullopt;
std::unique_ptr<THCState, void(*)(THCState*)> thc_state;
std::unique_ptr<THHState, void(*)(THHState*)> thh_state;
+
+ Allocator* prev_allocator_ptr_{nullptr};
};
TORCH_API Context& globalContext();
diff --git a/test/run_test.py b/test/run_test.py
index 273abfe..19e3688 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -60,6 +60,7 @@
'test_optim',
'test_pytree',
'test_mobile_optimizer',
+ 'test_set_default_mobile_cpu_allocator',
'test_xnnpack_integration',
'test_vulkan',
'test_sparse',
diff --git a/test/test_set_default_mobile_cpu_allocator.py b/test/test_set_default_mobile_cpu_allocator.py
new file mode 100644
index 0000000..cb2938a
--- /dev/null
+++ b/test/test_set_default_mobile_cpu_allocator.py
@@ -0,0 +1,27 @@
+import torch
+from torch.testing._internal.common_utils import TestCase, run_tests
+
+class TestSetDefaultMobileCPUAllocator(TestCase):
+ def test_no_exception(self):
+ torch._C._set_default_mobile_cpu_allocator()
+ torch._C._unset_default_mobile_cpu_allocator()
+
+ def test_exception(self):
+ with self.assertRaises(Exception):
+ torch._C._unset_default_mobile_cpu_allocator()
+
+ with self.assertRaises(Exception):
+ torch._C._set_default_mobile_cpu_allocator()
+ torch._C._set_default_mobile_cpu_allocator()
+
+ # Must reset to good state
+ # For next test.
+ torch._C._unset_default_mobile_cpu_allocator()
+
+ with self.assertRaises(Exception):
+ torch._C._set_default_mobile_cpu_allocator()
+ torch._C._unset_default_mobile_cpu_allocator()
+ torch._C._unset_default_mobile_cpu_allocator()
+
+if __name__ == '__main__':
+ run_tests()
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index 58bd77e..032aad2 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -589,6 +589,26 @@
else Py_RETURN_FALSE;
}
+PyObject *THPModule_setDefaultMobileCPUAllocator(PyObject *_unused, PyObject *noargs)
+{
+ try {
+ at::globalContext().setDefaultMobileCPUAllocator();
+ } catch (c10::Error& e) {
+ THPUtils_setError(e.what());
+ }
+ Py_RETURN_NONE;
+}
+
+PyObject *THPModule_unsetDefaultMobileCPUAllocator(PyObject *_unused, PyObject *noargs)
+{
+ try {
+ at::globalContext().unsetDefaultMobileCPUAllocator();
+ } catch (c10::Error& e) {
+ THPUtils_setError(e.what());
+ }
+ Py_RETURN_NONE;
+}
+
static PyObject * THPModule_vmapmode_increment_nesting(PyObject* _unused, PyObject *arg) {
HANDLE_TH_ERRORS
return THPUtils_packInt64(at::impl::VmapMode::increment_nesting());
@@ -673,6 +693,8 @@
{"_set_qengine", THPModule_setQEngine, METH_O, nullptr},
{"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr},
{"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
+ {"_set_default_mobile_cpu_allocator", THPModule_setDefaultMobileCPUAllocator, METH_NOARGS, nullptr},
+ {"_unset_default_mobile_cpu_allocator", THPModule_unsetDefaultMobileCPUAllocator, METH_NOARGS, nullptr},
{"_is_torch_function_enabled", THPModule_isEnabledTorchFunction, METH_NOARGS, nullptr},
{"_disabled_torch_function_impl", THPModule_disable_torch_function, METH_VARARGS, nullptr},
{"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr},