RESTRICT AUTOMERGE Validate TrackedBuffer in onBufferDestroyed Test: atest CtsMediaTestCases -- \ --module-arg CtsMediaTestCases:size:small Bug: 135140854 Change-Id: Ide95a619305a30b008b1e0bd5010ea2e359f4c99 (cherry picked from commit 101f53c592f5f6ddd8298ffead3c6533e49ab3c8)
diff --git a/media/codec2/hidl/1.0/utils/InputBufferManager.cpp b/media/codec2/hidl/1.0/utils/InputBufferManager.cpp index a023a05..8c0d0a4 100644 --- a/media/codec2/hidl/1.0/utils/InputBufferManager.cpp +++ b/media/codec2/hidl/1.0/utils/InputBufferManager.cpp
@@ -70,7 +70,7 @@ << "."; std::lock_guard<std::mutex> lock(mMutex); - std::set<TrackedBuffer> &bufferIds = + std::set<TrackedBuffer*> &bufferIds = mTrackedBuffersMap[listener][frameIndex]; for (size_t i = 0; i < input.buffers.size(); ++i) { @@ -79,13 +79,14 @@ << "Input buffer at index " << i << " is null."; continue; } - const TrackedBuffer &bufferId = - *bufferIds.emplace(listener, frameIndex, i, input.buffers[i]). - first; + TrackedBuffer *bufferId = + new TrackedBuffer(listener, frameIndex, i, input.buffers[i]); + mTrackedBufferCache.emplace(bufferId); + bufferIds.emplace(bufferId); c2_status_t status = input.buffers[i]->registerOnDestroyNotify( onBufferDestroyed, - const_cast<void*>(reinterpret_cast<const void*>(&bufferId))); + reinterpret_cast<void*>(bufferId)); if (status != C2_OK) { LOG(DEBUG) << "InputBufferManager::_registerFrameData -- " << "registerOnDestroyNotify() failed " @@ -119,31 +120,32 @@ auto findListener = mTrackedBuffersMap.find(listener); if (findListener != mTrackedBuffersMap.end()) { - std::map<uint64_t, std::set<TrackedBuffer>> &frameIndex2BufferIds + std::map<uint64_t, std::set<TrackedBuffer*>> &frameIndex2BufferIds = findListener->second; auto findFrameIndex = frameIndex2BufferIds.find(frameIndex); if (findFrameIndex != frameIndex2BufferIds.end()) { - std::set<TrackedBuffer> &bufferIds = findFrameIndex->second; - for (const TrackedBuffer& bufferId : bufferIds) { - std::shared_ptr<C2Buffer> buffer = bufferId.buffer.lock(); + std::set<TrackedBuffer*> &bufferIds = findFrameIndex->second; + for (TrackedBuffer* bufferId : bufferIds) { + std::shared_ptr<C2Buffer> buffer = bufferId->buffer.lock(); if (buffer) { c2_status_t status = buffer->unregisterOnDestroyNotify( onBufferDestroyed, - const_cast<void*>( - reinterpret_cast<const void*>(&bufferId))); + reinterpret_cast<void*>(bufferId)); if (status != C2_OK) { LOG(DEBUG) << "InputBufferManager::_unregisterFrameData " << "-- unregisterOnDestroyNotify() failed " << "(listener @ 0x" << std::hex - << bufferId.listener.unsafe_get() + << bufferId->listener.unsafe_get() << ", frameIndex = " - << std::dec << bufferId.frameIndex - << ", bufferIndex = " << bufferId.bufferIndex + << std::dec << bufferId->frameIndex + << ", bufferIndex = " << bufferId->bufferIndex << ") => status = " << status << "."; } } + mTrackedBufferCache.erase(bufferId); + delete bufferId; } frameIndex2BufferIds.erase(findFrameIndex); @@ -179,31 +181,32 @@ auto findListener = mTrackedBuffersMap.find(listener); if (findListener != mTrackedBuffersMap.end()) { - std::map<uint64_t, std::set<TrackedBuffer>> &frameIndex2BufferIds = + std::map<uint64_t, std::set<TrackedBuffer*>> &frameIndex2BufferIds = findListener->second; for (auto findFrameIndex = frameIndex2BufferIds.begin(); findFrameIndex != frameIndex2BufferIds.end(); ++findFrameIndex) { - std::set<TrackedBuffer> &bufferIds = findFrameIndex->second; - for (const TrackedBuffer& bufferId : bufferIds) { - std::shared_ptr<C2Buffer> buffer = bufferId.buffer.lock(); + std::set<TrackedBuffer*> &bufferIds = findFrameIndex->second; + for (TrackedBuffer* bufferId : bufferIds) { + std::shared_ptr<C2Buffer> buffer = bufferId->buffer.lock(); if (buffer) { c2_status_t status = buffer->unregisterOnDestroyNotify( onBufferDestroyed, - const_cast<void*>( - reinterpret_cast<const void*>(&bufferId))); + reinterpret_cast<void*>(bufferId)); if (status != C2_OK) { LOG(DEBUG) << "InputBufferManager::_unregisterFrameData " << "-- unregisterOnDestroyNotify() failed " << "(listener @ 0x" << std::hex - << bufferId.listener.unsafe_get() + << bufferId->listener.unsafe_get() << ", frameIndex = " - << std::dec << bufferId.frameIndex - << ", bufferIndex = " << bufferId.bufferIndex + << std::dec << bufferId->frameIndex + << ", bufferIndex = " << bufferId->bufferIndex << ") => status = " << status << "."; } + mTrackedBufferCache.erase(bufferId); + delete bufferId; } } } @@ -236,50 +239,59 @@ << std::dec << "."; return; } - TrackedBuffer id(*reinterpret_cast<TrackedBuffer*>(arg)); + + std::lock_guard<std::mutex> lock(mMutex); + TrackedBuffer *bufferId = reinterpret_cast<TrackedBuffer*>(arg); + + if (mTrackedBufferCache.find(bufferId) == mTrackedBufferCache.end()) { + LOG(VERBOSE) << "InputBufferManager::_onBufferDestroyed -- called with " + << "unregistered buffer: " + << "buf @ 0x" << std::hex << buf + << ", arg @ 0x" << std::hex << arg + << std::dec << "."; + return; + } + LOG(VERBOSE) << "InputBufferManager::_onBufferDestroyed -- called with " << "buf @ 0x" << std::hex << buf << ", arg @ 0x" << std::hex << arg << std::dec << " -- " - << "listener @ 0x" << std::hex << id.listener.unsafe_get() - << ", frameIndex = " << std::dec << id.frameIndex - << ", bufferIndex = " << id.bufferIndex + << "listener @ 0x" << std::hex << bufferId->listener.unsafe_get() + << ", frameIndex = " << std::dec << bufferId->frameIndex + << ", bufferIndex = " << bufferId->bufferIndex << "."; - - std::lock_guard<std::mutex> lock(mMutex); - - auto findListener = mTrackedBuffersMap.find(id.listener); + auto findListener = mTrackedBuffersMap.find(bufferId->listener); if (findListener == mTrackedBuffersMap.end()) { - LOG(DEBUG) << "InputBufferManager::_onBufferDestroyed -- " - << "received invalid listener: " - << "listener @ 0x" << std::hex << id.listener.unsafe_get() - << " (frameIndex = " << std::dec << id.frameIndex - << ", bufferIndex = " << id.bufferIndex - << ")."; + LOG(VERBOSE) << "InputBufferManager::_onBufferDestroyed -- " + << "received invalid listener: " + << "listener @ 0x" << std::hex << bufferId->listener.unsafe_get() + << " (frameIndex = " << std::dec << bufferId->frameIndex + << ", bufferIndex = " << bufferId->bufferIndex + << ")."; return; } - std::map<uint64_t, std::set<TrackedBuffer>> &frameIndex2BufferIds + std::map<uint64_t, std::set<TrackedBuffer*>> &frameIndex2BufferIds = findListener->second; - auto findFrameIndex = frameIndex2BufferIds.find(id.frameIndex); + auto findFrameIndex = frameIndex2BufferIds.find(bufferId->frameIndex); if (findFrameIndex == frameIndex2BufferIds.end()) { LOG(DEBUG) << "InputBufferManager::_onBufferDestroyed -- " << "received invalid frame index: " - << "frameIndex = " << id.frameIndex - << " (listener @ 0x" << std::hex << id.listener.unsafe_get() - << ", bufferIndex = " << std::dec << id.bufferIndex + << "frameIndex = " << bufferId->frameIndex + << " (listener @ 0x" << std::hex << bufferId->listener.unsafe_get() + << ", bufferIndex = " << std::dec << bufferId->bufferIndex << ")."; return; } - std::set<TrackedBuffer> &bufferIds = findFrameIndex->second; - auto findBufferId = bufferIds.find(id); + std::set<TrackedBuffer*> &bufferIds = findFrameIndex->second; + auto findBufferId = bufferIds.find(bufferId); if (findBufferId == bufferIds.end()) { LOG(DEBUG) << "InputBufferManager::_onBufferDestroyed -- " << "received invalid buffer index: " - << "bufferIndex = " << id.bufferIndex - << " (frameIndex = " << id.frameIndex - << ", listener @ 0x" << std::hex << id.listener.unsafe_get() + << "bufferIndex = " << bufferId->bufferIndex + << " (frameIndex = " << bufferId->frameIndex + << ", listener @ 0x" << std::hex << bufferId->listener.unsafe_get() << std::dec << ")."; return; } @@ -292,10 +304,13 @@ } } - DeathNotifications &deathNotifications = mDeathNotifications[id.listener]; - deathNotifications.indices[id.frameIndex].emplace_back(id.bufferIndex); + DeathNotifications &deathNotifications = mDeathNotifications[bufferId->listener]; + deathNotifications.indices[bufferId->frameIndex].emplace_back(bufferId->bufferIndex); ++deathNotifications.count; mOnBufferDestroyed.notify_one(); + + mTrackedBufferCache.erase(bufferId); + delete bufferId; } // Notify the clients about buffer destructions.
diff --git a/media/codec2/hidl/1.0/utils/include/codec2/hidl/1.0/InputBufferManager.h b/media/codec2/hidl/1.0/utils/include/codec2/hidl/1.0/InputBufferManager.h index b6857d5..42fa557 100644 --- a/media/codec2/hidl/1.0/utils/include/codec2/hidl/1.0/InputBufferManager.h +++ b/media/codec2/hidl/1.0/utils/include/codec2/hidl/1.0/InputBufferManager.h
@@ -196,13 +196,9 @@ frameIndex(frameIndex), bufferIndex(bufferIndex), buffer(buffer) {} - TrackedBuffer(const TrackedBuffer&) = default; - bool operator<(const TrackedBuffer& other) const { - return bufferIndex < other.bufferIndex; - } }; - // Map: listener -> frameIndex -> set<TrackedBuffer>. + // Map: listener -> frameIndex -> set<TrackedBuffer*>. // Essentially, this is used to store triples (listener, frameIndex, // bufferIndex) that's searchable by listener and (listener, frameIndex). // However, the value of the innermost map is TrackedBuffer, which also @@ -210,7 +206,7 @@ // because onBufferDestroyed() needs to know listener and frameIndex too. typedef std::map<wp<IComponentListener>, std::map<uint64_t, - std::set<TrackedBuffer>>> TrackedBuffersMap; + std::set<TrackedBuffer*>>> TrackedBuffersMap; // Storage for pending (unsent) death notifications for one listener. // Each pair in member named "indices" are (frameIndex, bufferIndex) from @@ -247,6 +243,16 @@ // Mutex for the management of all input buffers. std::mutex mMutex; + // Cache for all TrackedBuffers. + // + // Whenever registerOnDestroyNotify() is called, an argument of type + // TrackedBuffer is created and stored into this cache. + // Whenever unregisterOnDestroyNotify() or onBufferDestroyed() is called, + // the TrackedBuffer is removed from this cache. + // + // mTrackedBuffersMap stores references to TrackedBuffers inside this cache. + std::set<TrackedBuffer*> mTrackedBufferCache; + // Tracked input buffers. TrackedBuffersMap mTrackedBuffersMap;