RESTRICT AUTOMERGE Validate TrackedBuffer in onBufferDestroyed
Test: atest CtsMediaTestCases -- \
--module-arg CtsMediaTestCases:size:small
Bug: 135140854
Change-Id: Ide95a619305a30b008b1e0bd5010ea2e359f4c99
(cherry picked from commit 101f53c592f5f6ddd8298ffead3c6533e49ab3c8)
diff --git a/media/codec2/hidl/1.0/utils/InputBufferManager.cpp b/media/codec2/hidl/1.0/utils/InputBufferManager.cpp
index a023a05..8c0d0a4 100644
--- a/media/codec2/hidl/1.0/utils/InputBufferManager.cpp
+++ b/media/codec2/hidl/1.0/utils/InputBufferManager.cpp
@@ -70,7 +70,7 @@
<< ".";
std::lock_guard<std::mutex> lock(mMutex);
- std::set<TrackedBuffer> &bufferIds =
+ std::set<TrackedBuffer*> &bufferIds =
mTrackedBuffersMap[listener][frameIndex];
for (size_t i = 0; i < input.buffers.size(); ++i) {
@@ -79,13 +79,14 @@
<< "Input buffer at index " << i << " is null.";
continue;
}
- const TrackedBuffer &bufferId =
- *bufferIds.emplace(listener, frameIndex, i, input.buffers[i]).
- first;
+ TrackedBuffer *bufferId =
+ new TrackedBuffer(listener, frameIndex, i, input.buffers[i]);
+ mTrackedBufferCache.emplace(bufferId);
+ bufferIds.emplace(bufferId);
c2_status_t status = input.buffers[i]->registerOnDestroyNotify(
onBufferDestroyed,
- const_cast<void*>(reinterpret_cast<const void*>(&bufferId)));
+ reinterpret_cast<void*>(bufferId));
if (status != C2_OK) {
LOG(DEBUG) << "InputBufferManager::_registerFrameData -- "
<< "registerOnDestroyNotify() failed "
@@ -119,31 +120,32 @@
auto findListener = mTrackedBuffersMap.find(listener);
if (findListener != mTrackedBuffersMap.end()) {
- std::map<uint64_t, std::set<TrackedBuffer>> &frameIndex2BufferIds
+ std::map<uint64_t, std::set<TrackedBuffer*>> &frameIndex2BufferIds
= findListener->second;
auto findFrameIndex = frameIndex2BufferIds.find(frameIndex);
if (findFrameIndex != frameIndex2BufferIds.end()) {
- std::set<TrackedBuffer> &bufferIds = findFrameIndex->second;
- for (const TrackedBuffer& bufferId : bufferIds) {
- std::shared_ptr<C2Buffer> buffer = bufferId.buffer.lock();
+ std::set<TrackedBuffer*> &bufferIds = findFrameIndex->second;
+ for (TrackedBuffer* bufferId : bufferIds) {
+ std::shared_ptr<C2Buffer> buffer = bufferId->buffer.lock();
if (buffer) {
c2_status_t status = buffer->unregisterOnDestroyNotify(
onBufferDestroyed,
- const_cast<void*>(
- reinterpret_cast<const void*>(&bufferId)));
+ reinterpret_cast<void*>(bufferId));
if (status != C2_OK) {
LOG(DEBUG) << "InputBufferManager::_unregisterFrameData "
<< "-- unregisterOnDestroyNotify() failed "
<< "(listener @ 0x"
<< std::hex
- << bufferId.listener.unsafe_get()
+ << bufferId->listener.unsafe_get()
<< ", frameIndex = "
- << std::dec << bufferId.frameIndex
- << ", bufferIndex = " << bufferId.bufferIndex
+ << std::dec << bufferId->frameIndex
+ << ", bufferIndex = " << bufferId->bufferIndex
<< ") => status = " << status
<< ".";
}
}
+ mTrackedBufferCache.erase(bufferId);
+ delete bufferId;
}
frameIndex2BufferIds.erase(findFrameIndex);
@@ -179,31 +181,32 @@
auto findListener = mTrackedBuffersMap.find(listener);
if (findListener != mTrackedBuffersMap.end()) {
- std::map<uint64_t, std::set<TrackedBuffer>> &frameIndex2BufferIds =
+ std::map<uint64_t, std::set<TrackedBuffer*>> &frameIndex2BufferIds =
findListener->second;
for (auto findFrameIndex = frameIndex2BufferIds.begin();
findFrameIndex != frameIndex2BufferIds.end();
++findFrameIndex) {
- std::set<TrackedBuffer> &bufferIds = findFrameIndex->second;
- for (const TrackedBuffer& bufferId : bufferIds) {
- std::shared_ptr<C2Buffer> buffer = bufferId.buffer.lock();
+ std::set<TrackedBuffer*> &bufferIds = findFrameIndex->second;
+ for (TrackedBuffer* bufferId : bufferIds) {
+ std::shared_ptr<C2Buffer> buffer = bufferId->buffer.lock();
if (buffer) {
c2_status_t status = buffer->unregisterOnDestroyNotify(
onBufferDestroyed,
- const_cast<void*>(
- reinterpret_cast<const void*>(&bufferId)));
+ reinterpret_cast<void*>(bufferId));
if (status != C2_OK) {
LOG(DEBUG) << "InputBufferManager::_unregisterFrameData "
<< "-- unregisterOnDestroyNotify() failed "
<< "(listener @ 0x"
<< std::hex
- << bufferId.listener.unsafe_get()
+ << bufferId->listener.unsafe_get()
<< ", frameIndex = "
- << std::dec << bufferId.frameIndex
- << ", bufferIndex = " << bufferId.bufferIndex
+ << std::dec << bufferId->frameIndex
+ << ", bufferIndex = " << bufferId->bufferIndex
<< ") => status = " << status
<< ".";
}
+ mTrackedBufferCache.erase(bufferId);
+ delete bufferId;
}
}
}
@@ -236,50 +239,59 @@
<< std::dec << ".";
return;
}
- TrackedBuffer id(*reinterpret_cast<TrackedBuffer*>(arg));
+
+ std::lock_guard<std::mutex> lock(mMutex);
+ TrackedBuffer *bufferId = reinterpret_cast<TrackedBuffer*>(arg);
+
+ if (mTrackedBufferCache.find(bufferId) == mTrackedBufferCache.end()) {
+ LOG(VERBOSE) << "InputBufferManager::_onBufferDestroyed -- called with "
+ << "unregistered buffer: "
+ << "buf @ 0x" << std::hex << buf
+ << ", arg @ 0x" << std::hex << arg
+ << std::dec << ".";
+ return;
+ }
+
LOG(VERBOSE) << "InputBufferManager::_onBufferDestroyed -- called with "
<< "buf @ 0x" << std::hex << buf
<< ", arg @ 0x" << std::hex << arg
<< std::dec << " -- "
- << "listener @ 0x" << std::hex << id.listener.unsafe_get()
- << ", frameIndex = " << std::dec << id.frameIndex
- << ", bufferIndex = " << id.bufferIndex
+ << "listener @ 0x" << std::hex << bufferId->listener.unsafe_get()
+ << ", frameIndex = " << std::dec << bufferId->frameIndex
+ << ", bufferIndex = " << bufferId->bufferIndex
<< ".";
-
- std::lock_guard<std::mutex> lock(mMutex);
-
- auto findListener = mTrackedBuffersMap.find(id.listener);
+ auto findListener = mTrackedBuffersMap.find(bufferId->listener);
if (findListener == mTrackedBuffersMap.end()) {
- LOG(DEBUG) << "InputBufferManager::_onBufferDestroyed -- "
- << "received invalid listener: "
- << "listener @ 0x" << std::hex << id.listener.unsafe_get()
- << " (frameIndex = " << std::dec << id.frameIndex
- << ", bufferIndex = " << id.bufferIndex
- << ").";
+ LOG(VERBOSE) << "InputBufferManager::_onBufferDestroyed -- "
+ << "received invalid listener: "
+ << "listener @ 0x" << std::hex << bufferId->listener.unsafe_get()
+ << " (frameIndex = " << std::dec << bufferId->frameIndex
+ << ", bufferIndex = " << bufferId->bufferIndex
+ << ").";
return;
}
- std::map<uint64_t, std::set<TrackedBuffer>> &frameIndex2BufferIds
+ std::map<uint64_t, std::set<TrackedBuffer*>> &frameIndex2BufferIds
= findListener->second;
- auto findFrameIndex = frameIndex2BufferIds.find(id.frameIndex);
+ auto findFrameIndex = frameIndex2BufferIds.find(bufferId->frameIndex);
if (findFrameIndex == frameIndex2BufferIds.end()) {
LOG(DEBUG) << "InputBufferManager::_onBufferDestroyed -- "
<< "received invalid frame index: "
- << "frameIndex = " << id.frameIndex
- << " (listener @ 0x" << std::hex << id.listener.unsafe_get()
- << ", bufferIndex = " << std::dec << id.bufferIndex
+ << "frameIndex = " << bufferId->frameIndex
+ << " (listener @ 0x" << std::hex << bufferId->listener.unsafe_get()
+ << ", bufferIndex = " << std::dec << bufferId->bufferIndex
<< ").";
return;
}
- std::set<TrackedBuffer> &bufferIds = findFrameIndex->second;
- auto findBufferId = bufferIds.find(id);
+ std::set<TrackedBuffer*> &bufferIds = findFrameIndex->second;
+ auto findBufferId = bufferIds.find(bufferId);
if (findBufferId == bufferIds.end()) {
LOG(DEBUG) << "InputBufferManager::_onBufferDestroyed -- "
<< "received invalid buffer index: "
- << "bufferIndex = " << id.bufferIndex
- << " (frameIndex = " << id.frameIndex
- << ", listener @ 0x" << std::hex << id.listener.unsafe_get()
+ << "bufferIndex = " << bufferId->bufferIndex
+ << " (frameIndex = " << bufferId->frameIndex
+ << ", listener @ 0x" << std::hex << bufferId->listener.unsafe_get()
<< std::dec << ").";
return;
}
@@ -292,10 +304,13 @@
}
}
- DeathNotifications &deathNotifications = mDeathNotifications[id.listener];
- deathNotifications.indices[id.frameIndex].emplace_back(id.bufferIndex);
+ DeathNotifications &deathNotifications = mDeathNotifications[bufferId->listener];
+ deathNotifications.indices[bufferId->frameIndex].emplace_back(bufferId->bufferIndex);
++deathNotifications.count;
mOnBufferDestroyed.notify_one();
+
+ mTrackedBufferCache.erase(bufferId);
+ delete bufferId;
}
// Notify the clients about buffer destructions.
diff --git a/media/codec2/hidl/1.0/utils/include/codec2/hidl/1.0/InputBufferManager.h b/media/codec2/hidl/1.0/utils/include/codec2/hidl/1.0/InputBufferManager.h
index b6857d5..42fa557 100644
--- a/media/codec2/hidl/1.0/utils/include/codec2/hidl/1.0/InputBufferManager.h
+++ b/media/codec2/hidl/1.0/utils/include/codec2/hidl/1.0/InputBufferManager.h
@@ -196,13 +196,9 @@
frameIndex(frameIndex),
bufferIndex(bufferIndex),
buffer(buffer) {}
- TrackedBuffer(const TrackedBuffer&) = default;
- bool operator<(const TrackedBuffer& other) const {
- return bufferIndex < other.bufferIndex;
- }
};
- // Map: listener -> frameIndex -> set<TrackedBuffer>.
+ // Map: listener -> frameIndex -> set<TrackedBuffer*>.
// Essentially, this is used to store triples (listener, frameIndex,
// bufferIndex) that's searchable by listener and (listener, frameIndex).
// However, the value of the innermost map is TrackedBuffer, which also
@@ -210,7 +206,7 @@
// because onBufferDestroyed() needs to know listener and frameIndex too.
typedef std::map<wp<IComponentListener>,
std::map<uint64_t,
- std::set<TrackedBuffer>>> TrackedBuffersMap;
+ std::set<TrackedBuffer*>>> TrackedBuffersMap;
// Storage for pending (unsent) death notifications for one listener.
// Each pair in member named "indices" are (frameIndex, bufferIndex) from
@@ -247,6 +243,16 @@
// Mutex for the management of all input buffers.
std::mutex mMutex;
+ // Cache for all TrackedBuffers.
+ //
+ // Whenever registerOnDestroyNotify() is called, an argument of type
+ // TrackedBuffer is created and stored into this cache.
+ // Whenever unregisterOnDestroyNotify() or onBufferDestroyed() is called,
+ // the TrackedBuffer is removed from this cache.
+ //
+ // mTrackedBuffersMap stores references to TrackedBuffers inside this cache.
+ std::set<TrackedBuffer*> mTrackedBufferCache;
+
// Tracked input buffers.
TrackedBuffersMap mTrackedBuffersMap;