Guard concurrent scratch_buffer and gemmlowp::GemmContext access
- scrach_buffer is a file scoped static intended to minimize the
allocation cost for Conv2D computations when possible.
- In order to make it thread safe for multi-threaded execution,
we need to make sure no concucrrent access to it.
- Similarly for gemmlowp::GemmContext used in Conv2D and
FullyConnected.
- The mutex lock is added to prevent concurrent executions that may
access the static scratch buffer and static gemmlowp::GemmContext.
Bug: 80430825
Bug: 80465406
Test: NeuralNetworksTest_mt_static
Test: NeuralNetworksApiBenchmark no visible performance impact
Merged-In: I6b0df63a03d1f16a1e43a0c1062a997bfbe8f3f2
Change-Id: I6b0df63a03d1f16a1e43a0c1062a997bfbe8f3f2
(cherry picked from commit 9c63a9c428e5489bc8d118f52687a12206967208)
diff --git a/common/operations/Conv2D.cpp b/common/operations/Conv2D.cpp
index 344ab8f..7ae35f1 100644
--- a/common/operations/Conv2D.cpp
+++ b/common/operations/Conv2D.cpp
@@ -26,6 +26,11 @@
static constexpr size_t kStaticBufferSize = 1605632;
static char static_scratch_buffer[kStaticBufferSize];
+// executionMutex is used to protect concurrent access of the static_scratch_buffer
+// and other non-threadsafe resources like gemmlowp::GemmContext.
+// std::mutex is safe for pthreads on Android.
+static std::mutex executionMutex;
+
#define ANDROID_NN_CONV_PARAMETERS(Type) \
uint32_t height = getSizeOfDimension(inputShape, 1); \
uint32_t width = getSizeOfDimension(inputShape, 2); \
@@ -86,6 +91,8 @@
CalculateActivationRangeFloat(activation, &output_activation_min,
&output_activation_max);
+ // Prevent concurrent executions that may access the scratch buffer.
+ std::unique_lock<std::mutex> lock(executionMutex);
tflite::optimized_ops::Conv(
inputData, convertShapeToDims(inputShape),
filterData, convertShapeToDims(filterShape),
@@ -129,9 +136,12 @@
&output_activation_max);
static gemmlowp::GemmContext gemm_context;
- // Alow gemmlowp automatcally decide how many threads to use.
- gemm_context.set_max_num_threads(0);
+ // Prevent concurrent executions that may access the scratch buffer and
+ // gemm_context.
+ std::unique_lock<std::mutex> lock(executionMutex);
+ // Alow gemmlowp automatically decide how many threads to use.
+ gemm_context.set_max_num_threads(0);
tflite::optimized_ops::Conv(
inputData, convertShapeToDims(inputShape), inputOffset,
filterData, convertShapeToDims(filterShape), filterOffset,
diff --git a/common/operations/FullyConnected.cpp b/common/operations/FullyConnected.cpp
index bc99a28..4d2008d 100644
--- a/common/operations/FullyConnected.cpp
+++ b/common/operations/FullyConnected.cpp
@@ -22,6 +22,11 @@
namespace android {
namespace nn {
+// executionMutex is used to protect concurrent access of non-threadsafe resources
+// like gemmlowp::GemmContext.
+// std::mutex is safe for pthreads on Android.
+static std::mutex executionMutex;
+
bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
const float* weightsData, const Shape& weightsShape,
const float* biasData, const Shape& biasShape,
@@ -67,7 +72,10 @@
&output_activation_max);
static gemmlowp::GemmContext gemm_context;
- // Alow gemmlowp automatcally decide how many threads to use.
+
+ // Prevent concurrent executions that access gemm_context.
+ std::unique_lock<std::mutex> lock(executionMutex);
+ // Alow gemmlowp automatically decide how many threads to use.
gemm_context.set_max_num_threads(0);
tflite::optimized_ops::FullyConnected(