aaudio: use weak pointer to prevent UAF

Avoid using the mServiceEndpoint smart pointer
from multiple threads.

Bug: 74122779
Test: see bug for test instructions
Merged-In: Idaf9e32a163b25e51bde35d6f5ea10a372b5d916
Change-Id: Idaf9e32a163b25e51bde35d6f5ea10a372b5d916
(cherry picked from commit 92c3e26338637a7ddac54988187c89fb0fe49430)
diff --git a/services/oboeservice/AAudioServiceEndpoint.h b/services/oboeservice/AAudioServiceEndpoint.h
index 2ef6234..c19cc35 100644
--- a/services/oboeservice/AAudioServiceEndpoint.h
+++ b/services/oboeservice/AAudioServiceEndpoint.h
@@ -49,9 +49,9 @@
 
     virtual aaudio_result_t close() = 0;
 
-    virtual aaudio_result_t registerStream(android::sp<AAudioServiceStreamBase> stream);
+    aaudio_result_t registerStream(android::sp<AAudioServiceStreamBase> stream);
 
-    virtual aaudio_result_t unregisterStream(android::sp<AAudioServiceStreamBase> stream);
+    aaudio_result_t unregisterStream(android::sp<AAudioServiceStreamBase> stream);
 
     virtual aaudio_result_t startStream(android::sp<AAudioServiceStreamBase> stream,
                                         audio_port_handle_t *clientHandle) = 0;
diff --git a/services/oboeservice/AAudioServiceStreamBase.cpp b/services/oboeservice/AAudioServiceStreamBase.cpp
index ff0b037..d9b8766 100644
--- a/services/oboeservice/AAudioServiceStreamBase.cpp
+++ b/services/oboeservice/AAudioServiceStreamBase.cpp
@@ -104,6 +104,9 @@
             goto error;
         }
 
+        // This is not protected by a lock because the stream cannot be
+        // referenced until the service returns a handle to the client.
+        // So only one thread can open a stream.
         mServiceEndpoint = mEndpointManager.openEndpoint(mAudioService,
                                                          request,
                                                          sharingMode);
@@ -112,6 +115,9 @@
             result = AAUDIO_ERROR_UNAVAILABLE;
             goto error;
         }
+        // Save a weak pointer that we will use to access the endpoint.
+        mServiceEndpointWeak = mServiceEndpoint;
+
         mFramesPerBurst = mServiceEndpoint->getFramesPerBurst();
         copyFrom(*mServiceEndpoint);
     }
@@ -130,15 +136,19 @@
 
     stop();
 
-    if (mServiceEndpoint == nullptr) {
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
         result = AAUDIO_ERROR_INVALID_STATE;
     } else {
-        mServiceEndpoint->unregisterStream(this);
-        AAudioEndpointManager &mEndpointManager = AAudioEndpointManager::getInstance();
-        mEndpointManager.closeEndpoint(mServiceEndpoint);
-        mServiceEndpoint.clear();
+        endpoint->unregisterStream(this);
+        AAudioEndpointManager &endpointManager = AAudioEndpointManager::getInstance();
+        endpointManager.closeEndpoint(endpoint);
+
+        // AAudioService::closeStream() prevents two threads from closing at the same time.
+        mServiceEndpoint.clear(); // endpoint will hold the pointer until this method returns.
     }
 
+
     {
         std::lock_guard<std::mutex> lock(mUpMessageQueueLock);
         stopTimestampThread();
@@ -152,7 +162,12 @@
 
 aaudio_result_t AAudioServiceStreamBase::startDevice() {
     mClientHandle = AUDIO_PORT_HANDLE_NONE;
-    return mServiceEndpoint->startStream(this, &mClientHandle);
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
+        ALOGE("%s() has no endpoint", __func__);
+        return AAUDIO_ERROR_INVALID_STATE;
+    }
+    return endpoint->startStream(this, &mClientHandle);
 }
 
 /**
@@ -162,16 +177,11 @@
  */
 aaudio_result_t AAudioServiceStreamBase::start() {
     aaudio_result_t result = AAUDIO_OK;
+
     if (isRunning()) {
         return AAUDIO_OK;
     }
 
-    if (mServiceEndpoint == nullptr) {
-        ALOGE("AAudioServiceStreamBase::start() missing endpoint");
-        result = AAUDIO_ERROR_INVALID_STATE;
-        goto error;
-    }
-
     // Start with fresh presentation timestamps.
     mAtomicTimestamp.clear();
 
@@ -198,10 +208,6 @@
     if (!isRunning()) {
         return result;
     }
-    if (mServiceEndpoint == nullptr) {
-        ALOGE("AAudioServiceStreamShared::pause() missing endpoint");
-        return AAUDIO_ERROR_INVALID_STATE;
-    }
 
     // Send it now because the timestamp gets rounded up when stopStream() is called below.
     // Also we don't need the timestamps while we are shutting down.
@@ -213,7 +219,12 @@
         return result;
     }
 
-    result = mServiceEndpoint->stopStream(this, mClientHandle);
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
+        ALOGE("%s() has no endpoint", __func__);
+        return AAUDIO_ERROR_INVALID_STATE;
+    }
+    result = endpoint->stopStream(this, mClientHandle);
     if (result != AAUDIO_OK) {
         ALOGE("AAudioServiceStreamShared::pause() mServiceEndpoint returned %d", result);
         disconnect(); // TODO should we return or pause Base first?
@@ -230,11 +241,6 @@
         return result;
     }
 
-    if (mServiceEndpoint == nullptr) {
-        ALOGE("AAudioServiceStreamShared::stop() missing endpoint");
-        return AAUDIO_ERROR_INVALID_STATE;
-    }
-
     // Send it now because the timestamp gets rounded up when stopStream() is called below.
     // Also we don't need the timestamps while we are shutting down.
     sendCurrentTimestamp(); // warning - this calls a virtual function
@@ -244,8 +250,13 @@
         return result;
     }
 
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
+        ALOGE("%s() has no endpoint", __func__);
+        return AAUDIO_ERROR_INVALID_STATE;
+    }
     // TODO wait for data to be played out
-    result = mServiceEndpoint->stopStream(this, mClientHandle);
+    result = endpoint->stopStream(this, mClientHandle);
     if (result != AAUDIO_OK) {
         ALOGE("AAudioServiceStreamShared::stop() mServiceEndpoint returned %d", result);
         disconnect();
diff --git a/services/oboeservice/AAudioServiceStreamBase.h b/services/oboeservice/AAudioServiceStreamBase.h
index af435b4..815b1cc 100644
--- a/services/oboeservice/AAudioServiceStreamBase.h
+++ b/services/oboeservice/AAudioServiceStreamBase.h
@@ -233,7 +233,12 @@
     SimpleDoubleBuffer<Timestamp>  mAtomicTimestamp;
 
     android::AAudioService &mAudioService;
+
+    // The mServiceEndpoint variable can be accessed by multiple threads.
+    // So we access it by locally promoting a weak pointer to a smart pointer,
+    // which is thread-safe.
     android::sp<AAudioServiceEndpoint> mServiceEndpoint;
+    android::wp<AAudioServiceEndpoint> mServiceEndpointWeak;
 
 private:
     aaudio_handle_t         mHandle = -1;
diff --git a/services/oboeservice/AAudioServiceStreamMMAP.cpp b/services/oboeservice/AAudioServiceStreamMMAP.cpp
index 44ba1ca..2e7ff8b 100644
--- a/services/oboeservice/AAudioServiceStreamMMAP.cpp
+++ b/services/oboeservice/AAudioServiceStreamMMAP.cpp
@@ -70,14 +70,19 @@
         return result;
     }
 
-    result = mServiceEndpoint->registerStream(keep);
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
+        ALOGE("%s() has no endpoint", __func__);
+        return AAUDIO_ERROR_INVALID_STATE;
+    }
+
+    result = endpoint->registerStream(keep);
     if (result != AAUDIO_OK) {
-        goto error;
+        return result;
     }
 
     setState(AAUDIO_STREAM_STATE_OPEN);
 
-error:
     return AAUDIO_OK;
 }
 
@@ -122,21 +127,37 @@
 
 aaudio_result_t AAudioServiceStreamMMAP::startClient(const android::AudioClient& client,
                                                        audio_port_handle_t *clientHandle) {
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
+        ALOGE("%s() has no endpoint", __func__);
+        return AAUDIO_ERROR_INVALID_STATE;
+    }
     // Start the client on behalf of the application. Generate a new porthandle.
-    aaudio_result_t result = mServiceEndpoint->startClient(client, clientHandle);
+    aaudio_result_t result = endpoint->startClient(client, clientHandle);
     return result;
 }
 
 aaudio_result_t AAudioServiceStreamMMAP::stopClient(audio_port_handle_t clientHandle) {
-    aaudio_result_t result = mServiceEndpoint->stopClient(clientHandle);
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
+        ALOGE("%s() has no endpoint", __func__);
+        return AAUDIO_ERROR_INVALID_STATE;
+    }
+    aaudio_result_t result = endpoint->stopClient(clientHandle);
     return result;
 }
 
 // Get free-running DSP or DMA hardware position from the HAL.
 aaudio_result_t AAudioServiceStreamMMAP::getFreeRunningPosition(int64_t *positionFrames,
                                                                   int64_t *timeNanos) {
-    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP{
-            static_cast<AAudioServiceEndpointMMAP *>(mServiceEndpoint.get())};
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
+        ALOGE("%s() has no endpoint", __func__);
+        return AAUDIO_ERROR_INVALID_STATE;
+    }
+    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP =
+            static_cast<AAudioServiceEndpointMMAP *>(endpoint.get());
+
     aaudio_result_t result = serviceEndpointMMAP->getFreeRunningPosition(positionFrames, timeNanos);
     if (result == AAUDIO_OK) {
         Timestamp timestamp(*positionFrames, *timeNanos);
@@ -152,8 +173,15 @@
 // Get timestamp that was written by getFreeRunningPosition()
 aaudio_result_t AAudioServiceStreamMMAP::getHardwareTimestamp(int64_t *positionFrames,
                                                                 int64_t *timeNanos) {
-    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP{
-            static_cast<AAudioServiceEndpointMMAP *>(mServiceEndpoint.get())};
+
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
+        ALOGE("%s() has no endpoint", __func__);
+        return AAUDIO_ERROR_INVALID_STATE;
+    }
+    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP =
+            static_cast<AAudioServiceEndpointMMAP *>(endpoint.get());
+
     // TODO Get presentation timestamp from the HAL
     if (mAtomicTimestamp.isValid()) {
         Timestamp timestamp = mAtomicTimestamp.read();
@@ -171,7 +199,12 @@
 aaudio_result_t AAudioServiceStreamMMAP::getAudioDataDescription(
         AudioEndpointParcelable &parcelable)
 {
-    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP{
-            static_cast<AAudioServiceEndpointMMAP *>(mServiceEndpoint.get())};
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
+        ALOGE("%s() has no endpoint", __func__);
+        return AAUDIO_ERROR_INVALID_STATE;
+    }
+    sp<AAudioServiceEndpointMMAP> serviceEndpointMMAP =
+            static_cast<AAudioServiceEndpointMMAP *>(endpoint.get());
     return serviceEndpointMMAP->getDownDataDescription(parcelable);
 }
diff --git a/services/oboeservice/AAudioServiceStreamShared.cpp b/services/oboeservice/AAudioServiceStreamShared.cpp
index 084f996..f55c40d 100644
--- a/services/oboeservice/AAudioServiceStreamShared.cpp
+++ b/services/oboeservice/AAudioServiceStreamShared.cpp
@@ -128,6 +128,11 @@
 
     const AAudioStreamConfiguration &configurationInput = request.getConstantConfiguration();
 
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
+        result = AAUDIO_ERROR_INVALID_STATE;
+        goto error;
+    }
 
     // Is the request compatible with the shared endpoint?
     setFormat(configurationInput.getFormat());
@@ -141,20 +146,20 @@
 
     setSampleRate(configurationInput.getSampleRate());
     if (getSampleRate() == AAUDIO_UNSPECIFIED) {
-        setSampleRate(mServiceEndpoint->getSampleRate());
-    } else if (getSampleRate() != mServiceEndpoint->getSampleRate()) {
+        setSampleRate(endpoint->getSampleRate());
+    } else if (getSampleRate() != endpoint->getSampleRate()) {
         ALOGE("AAudioServiceStreamShared::open() mSampleRate = %d, need %d",
-              getSampleRate(), mServiceEndpoint->getSampleRate());
+              getSampleRate(), endpoint->getSampleRate());
         result = AAUDIO_ERROR_INVALID_RATE;
         goto error;
     }
 
     setSamplesPerFrame(configurationInput.getSamplesPerFrame());
     if (getSamplesPerFrame() == AAUDIO_UNSPECIFIED) {
-        setSamplesPerFrame(mServiceEndpoint->getSamplesPerFrame());
-    } else if (getSamplesPerFrame() != mServiceEndpoint->getSamplesPerFrame()) {
+        setSamplesPerFrame(endpoint->getSamplesPerFrame());
+    } else if (getSamplesPerFrame() != endpoint->getSamplesPerFrame()) {
         ALOGE("AAudioServiceStreamShared::open() mSamplesPerFrame = %d, need %d",
-              getSamplesPerFrame(), mServiceEndpoint->getSamplesPerFrame());
+              getSamplesPerFrame(), endpoint->getSamplesPerFrame());
         result = AAUDIO_ERROR_OUT_OF_RANGE;
         goto error;
     }
@@ -181,9 +186,9 @@
     }
 
     ALOGD("AAudioServiceStreamShared::open() actual rate = %d, channels = %d, deviceId = %d",
-          getSampleRate(), getSamplesPerFrame(), mServiceEndpoint->getDeviceId());
+          getSampleRate(), getSamplesPerFrame(), endpoint->getDeviceId());
 
-    result = mServiceEndpoint->registerStream(keep);
+    result = endpoint->registerStream(keep);
     if (result != AAUDIO_OK) {
         goto error;
     }
@@ -250,7 +255,13 @@
                                                                 int64_t *timeNanos) {
 
     int64_t position = 0;
-    aaudio_result_t result = mServiceEndpoint->getTimestamp(&position, timeNanos);
+    sp<AAudioServiceEndpoint> endpoint = mServiceEndpointWeak.promote();
+    if (endpoint == nullptr) {
+        ALOGE("%s() has no endpoint", __func__);
+        return AAUDIO_ERROR_INVALID_STATE;
+    }
+
+    aaudio_result_t result = endpoint->getTimestamp(&position, timeNanos);
     if (result == AAUDIO_OK) {
         int64_t offset = mTimestampPositionOffset.load();
         // TODO, do not go below starting value