ART: mprotect() bottom page of ThreadPoolWorker stacks.
This will catch stack overflows instead of creating hard
to find issues later on.
Bug: 24133462
Change-Id: I3ae5d5da70b8167867936b2561830f3ff47f14fc
diff --git a/runtime/thread_pool.cc b/runtime/thread_pool.cc
index d8f80fa..0527d3a 100644
--- a/runtime/thread_pool.cc
+++ b/runtime/thread_pool.cc
@@ -16,7 +16,9 @@
#include "thread_pool.h"
+#include "base/bit_utils.h"
#include "base/casts.h"
+#include "base/logging.h"
#include "base/stl_util.h"
#include "base/time_utils.h"
#include "runtime.h"
@@ -30,10 +32,15 @@
size_t stack_size)
: thread_pool_(thread_pool),
name_(name) {
+ // Add an inaccessible page to catch stack overflow.
+ stack_size += kPageSize;
std::string error_msg;
stack_.reset(MemMap::MapAnonymous(name.c_str(), nullptr, stack_size, PROT_READ | PROT_WRITE,
false, false, &error_msg));
CHECK(stack_.get() != nullptr) << error_msg;
+ CHECK_ALIGNED(stack_->Begin(), kPageSize);
+ int mprotect_result = mprotect(stack_->Begin(), kPageSize, PROT_NONE);
+ CHECK_EQ(mprotect_result, 0) << "Failed to mprotect() bottom page of thread pool worker stack.";
const char* reason = "new thread pool worker thread";
pthread_attr_t attr;
CHECK_PTHREAD_CALL(pthread_attr_init, (&attr), reason);
@@ -92,7 +99,8 @@
while (GetThreadCount() < num_threads) {
const std::string worker_name = StringPrintf("%s worker thread %zu", name_.c_str(),
GetThreadCount());
- threads_.push_back(new ThreadPoolWorker(this, worker_name, ThreadPoolWorker::kDefaultStackSize));
+ threads_.push_back(
+ new ThreadPoolWorker(this, worker_name, ThreadPoolWorker::kDefaultStackSize));
}
// Wait for all of the threads to attach.
creation_barier_.Wait(self);