Fix race condition for cas sessions

Change the session to shared_ptr and use atomic_load/store.

Test: POC; CTS MediaCasTest; CTS MediaDrmClearkeyTest#
testClearKeyPlaybackMpeg2ts
bug: 113027383

Change-Id: I75f4cb33a022f28d45918442d64c5c46df2640ef
(cherry picked from commit 7934a8f7ee0988b181980b9e6f24ca7dd357f5e3)
diff --git a/drm/mediacas/plugins/clearkey/ClearKeyCasPlugin.cpp b/drm/mediacas/plugins/clearkey/ClearKeyCasPlugin.cpp
index 73ed8c3..1558e8b 100644
--- a/drm/mediacas/plugins/clearkey/ClearKeyCasPlugin.cpp
+++ b/drm/mediacas/plugins/clearkey/ClearKeyCasPlugin.cpp
@@ -118,9 +118,9 @@
 
 status_t ClearKeyCasPlugin::closeSession(const CasSessionId &sessionId) {
     ALOGV("closeSession: sessionId=%s", sessionIdToString(sessionId).string());
-    sp<ClearKeyCasSession> session =
+    std::shared_ptr<ClearKeyCasSession> session =
             ClearKeySessionLibrary::get()->findSession(sessionId);
-    if (session == NULL) {
+    if (session.get() == nullptr) {
         return ERROR_CAS_SESSION_NOT_OPENED;
     }
 
@@ -132,9 +132,9 @@
         const CasSessionId &sessionId, const CasData & /*data*/) {
     ALOGV("setSessionPrivateData: sessionId=%s",
             sessionIdToString(sessionId).string());
-    sp<ClearKeyCasSession> session =
+    std::shared_ptr<ClearKeyCasSession> session =
             ClearKeySessionLibrary::get()->findSession(sessionId);
-    if (session == NULL) {
+    if (session.get() == nullptr) {
         return ERROR_CAS_SESSION_NOT_OPENED;
     }
     return OK;
@@ -143,9 +143,9 @@
 status_t ClearKeyCasPlugin::processEcm(
         const CasSessionId &sessionId, const CasEcm& ecm) {
     ALOGV("processEcm: sessionId=%s", sessionIdToString(sessionId).string());
-    sp<ClearKeyCasSession> session =
+    std::shared_ptr<ClearKeyCasSession> session =
             ClearKeySessionLibrary::get()->findSession(sessionId);
-    if (session == NULL) {
+    if (session.get() == nullptr) {
         return ERROR_CAS_SESSION_NOT_OPENED;
     }
 
@@ -418,15 +418,15 @@
         const CasSessionId &sessionId) {
     ALOGV("setMediaCasSession: sessionId=%s", sessionIdToString(sessionId).string());
 
-    sp<ClearKeyCasSession> session =
+    std::shared_ptr<ClearKeyCasSession> session =
             ClearKeySessionLibrary::get()->findSession(sessionId);
 
-    if (session == NULL) {
+    if (session.get() == nullptr) {
         ALOGE("ClearKeyDescramblerPlugin: session not found");
         return ERROR_CAS_SESSION_NOT_OPENED;
     }
 
-    mCASSession = session;
+    std::atomic_store(&mCASSession, session);
     return OK;
 }
 
@@ -447,12 +447,14 @@
           subSamplesToString(subSamples, numSubSamples).string(),
           srcPtr, dstPtr, srcOffset, dstOffset);
 
-    if (mCASSession == NULL) {
+    std::shared_ptr<ClearKeyCasSession> session = std::atomic_load(&mCASSession);
+
+    if (session.get() == nullptr) {
         ALOGE("Uninitialized CAS session!");
         return ERROR_CAS_DECRYPT_UNIT_NOT_INITIALIZED;
     }
 
-    return mCASSession->decrypt(
+    return session->decrypt(
             secure, scramblingControl,
             numSubSamples, subSamples,
             (uint8_t*)srcPtr + srcOffset,
diff --git a/drm/mediacas/plugins/clearkey/ClearKeyCasPlugin.h b/drm/mediacas/plugins/clearkey/ClearKeyCasPlugin.h
index 42cfb8f..389e172 100644
--- a/drm/mediacas/plugins/clearkey/ClearKeyCasPlugin.h
+++ b/drm/mediacas/plugins/clearkey/ClearKeyCasPlugin.h
@@ -120,7 +120,7 @@
             AString *errorDetailMsg) override;
 
 private:
-    sp<ClearKeyCasSession> mCASSession;
+    std::shared_ptr<ClearKeyCasSession> mCASSession;
 
     String8 subSamplesToString(
             SubSample const *subSamples,
diff --git a/drm/mediacas/plugins/clearkey/ClearKeySessionLibrary.cpp b/drm/mediacas/plugins/clearkey/ClearKeySessionLibrary.cpp
index 4b4051d..3bb1176 100644
--- a/drm/mediacas/plugins/clearkey/ClearKeySessionLibrary.cpp
+++ b/drm/mediacas/plugins/clearkey/ClearKeySessionLibrary.cpp
@@ -56,7 +56,7 @@
 
     Mutex::Autolock lock(mSessionsLock);
 
-    sp<ClearKeyCasSession> session = new ClearKeyCasSession(plugin);
+    std::shared_ptr<ClearKeyCasSession> session(new ClearKeyCasSession(plugin));
 
     uint8_t *byteArray = (uint8_t *) &mNextSessionId;
     sessionId->push_back(byteArray[3]);
@@ -69,7 +69,7 @@
     return OK;
 }
 
-sp<ClearKeyCasSession> ClearKeySessionLibrary::findSession(
+std::shared_ptr<ClearKeyCasSession> ClearKeySessionLibrary::findSession(
         const CasSessionId& sessionId) {
     Mutex::Autolock lock(mSessionsLock);
 
@@ -88,7 +88,7 @@
         return;
     }
 
-    sp<ClearKeyCasSession> session = mIDToSessionMap.valueAt(index);
+    std::shared_ptr<ClearKeyCasSession> session = mIDToSessionMap.valueAt(index);
     mIDToSessionMap.removeItemsAt(index);
 }
 
@@ -96,7 +96,7 @@
     Mutex::Autolock lock(mSessionsLock);
 
     for (ssize_t index = (ssize_t)mIDToSessionMap.size() - 1; index >= 0; index--) {
-        sp<ClearKeyCasSession> session = mIDToSessionMap.valueAt(index);
+        std::shared_ptr<ClearKeyCasSession> session = mIDToSessionMap.valueAt(index);
         if (session->getPlugin() == plugin) {
             mIDToSessionMap.removeItemsAt(index);
         }
diff --git a/drm/mediacas/plugins/clearkey/ClearKeySessionLibrary.h b/drm/mediacas/plugins/clearkey/ClearKeySessionLibrary.h
index 01f5f47..a537e63 100644
--- a/drm/mediacas/plugins/clearkey/ClearKeySessionLibrary.h
+++ b/drm/mediacas/plugins/clearkey/ClearKeySessionLibrary.h
@@ -32,6 +32,10 @@
 
 class ClearKeyCasSession : public RefBase {
 public:
+    explicit ClearKeyCasSession(CasPlugin *plugin);
+
+    virtual ~ClearKeyCasSession();
+
     ssize_t decrypt(
             bool secure,
             DescramblerPlugin::ScramblingControl scramblingControl,
@@ -58,8 +62,6 @@
 
     friend class ClearKeySessionLibrary;
 
-    explicit ClearKeyCasSession(CasPlugin *plugin);
-    virtual ~ClearKeyCasSession();
     CasPlugin* getPlugin() const { return mPlugin; }
     status_t decryptPayload(
             const AES_KEY& key, size_t length, size_t offset, char* buffer) const;
@@ -73,7 +75,7 @@
 
     status_t addSession(CasPlugin *plugin, CasSessionId *sessionId);
 
-    sp<ClearKeyCasSession> findSession(const CasSessionId& sessionId);
+    std::shared_ptr<ClearKeyCasSession> findSession(const CasSessionId& sessionId);
 
     void destroySession(const CasSessionId& sessionId);
 
@@ -85,7 +87,7 @@
 
     Mutex mSessionsLock;
     uint32_t mNextSessionId;
-    KeyedVector<CasSessionId, sp<ClearKeyCasSession>> mIDToSessionMap;
+    KeyedVector<CasSessionId, std::shared_ptr<ClearKeyCasSession>> mIDToSessionMap;
 
     ClearKeySessionLibrary();
     DISALLOW_EVIL_CONSTRUCTORS(ClearKeySessionLibrary);