Refactor get_mem_region and put_mem_region in gralloc

Bug: 128324105
Test: make
Change-Id: Id57fc7358e071e9775236c7fe596dc6c94923482
Signed-off-by: Roman Kiryanov <rkir@google.com>
diff --git a/system/gralloc/gralloc.cpp b/system/gralloc/gralloc.cpp
index fc6320b..812ee04 100644
--- a/system/gralloc/gralloc.cpp
+++ b/system/gralloc/gralloc.cpp
@@ -44,6 +44,7 @@
 #include <cutils/properties.h>
 
 #include <set>
+#include <map>
 #include <string>
 #include <sstream>
 
@@ -99,20 +100,6 @@
 
 static void fallback_init(void);  // forward
 
-struct MemRegionInfo {
-    void* ashmemBase;
-    mutable uint32_t refCount;
-};
-
-struct MemRegionInfoCmp {
-    bool operator()(const MemRegionInfo& a, const MemRegionInfo& b) const {
-        return a.ashmemBase < b.ashmemBase;
-    }
-};
-
-typedef std::set<MemRegionInfo, MemRegionInfoCmp> MemRegionSet;
-typedef MemRegionSet::iterator mem_region_handle_t;
-
 //
 // Our gralloc device structure (alloc interface)
 //
@@ -123,7 +110,10 @@
 };
 
 struct gralloc_memregions_t {
-    MemRegionSet ashmemRegions;
+    typedef std::map<void*, uint32_t> MemRegionMap;  // base -> refCount
+    typedef MemRegionMap::const_iterator mem_region_handle_t;
+
+    MemRegionMap ashmemRegions;
     pthread_mutex_t lock;
 };
 
@@ -140,10 +130,12 @@
 static gralloc_memregions_t* s_memregions = NULL;
 static gralloc_dmaregion_t* s_grdma = NULL;
 
-static void init_gralloc_memregions() {
-    if (s_memregions) return;
+static gralloc_memregions_t* init_gralloc_memregions() {
+    if (s_memregions) return s_memregions;
+
     s_memregions = new gralloc_memregions_t;
     pthread_mutex_init(&s_memregions->lock, NULL);
+    return s_memregions;
 }
 
 static void init_gralloc_dmaregion() {
@@ -224,51 +216,43 @@
 }
 
 static void get_mem_region(void* ashmemBase) {
-    init_gralloc_memregions();
-    D("%s: call for %p", __FUNCTION__, ashmemBase);
-    MemRegionInfo lookup;
-    lookup.ashmemBase = ashmemBase;
-    pthread_mutex_lock(&s_memregions->lock);
-    mem_region_handle_t handle = s_memregions->ashmemRegions.find(lookup);
-    if (handle == s_memregions->ashmemRegions.end()) {
-        MemRegionInfo newRegion;
-        newRegion.ashmemBase = ashmemBase;
-        newRegion.refCount = 1;
-        s_memregions->ashmemRegions.insert(newRegion);
-    } else {
-        handle->refCount++;
-    }
-    pthread_mutex_unlock(&s_memregions->lock);
+    D("%s: call for %p", __func__, ashmemBase);
+
+    gralloc_memregions_t* memregions = init_gralloc_memregions();
+
+    pthread_mutex_lock(&memregions->lock);
+    ++memregions->ashmemRegions[ashmemBase];
+    pthread_mutex_unlock(&memregions->lock);
 }
 
 static bool put_mem_region(void* ashmemBase) {
-    init_gralloc_memregions();
-    D("%s: call for %p", __FUNCTION__, ashmemBase);
-    MemRegionInfo lookup;
-    lookup.ashmemBase = ashmemBase;
-    pthread_mutex_lock(&s_memregions->lock);
-    mem_region_handle_t handle = s_memregions->ashmemRegions.find(lookup);
-    if (handle == s_memregions->ashmemRegions.end()) {
-        ALOGE("%s: error: tried to put nonexistent mem region!", __FUNCTION__);
-        pthread_mutex_unlock(&s_memregions->lock);
-        return true;
+    D("%s: call for %p", __func__, ashmemBase);
+
+    gralloc_memregions_t* memregions = init_gralloc_memregions();
+    bool shouldRemove;
+
+    pthread_mutex_lock(&memregions->lock);
+    gralloc_memregions_t::MemRegionMap::iterator i = memregions->ashmemRegions.find(ashmemBase);
+    if (i == memregions->ashmemRegions.end()) {
+        shouldRemove = true;
+        ALOGE("%s: error: tried to put a nonexistent mem region (%p)!", __func__, ashmemBase);
     } else {
-        handle->refCount--;
-        bool shouldRemove = !handle->refCount;
+        shouldRemove = --i->second == 0;
         if (shouldRemove) {
-            s_memregions->ashmemRegions.erase(lookup);
+            memregions->ashmemRegions.erase(i);
         }
-        pthread_mutex_unlock(&s_memregions->lock);
-        return shouldRemove;
     }
+    pthread_mutex_unlock(&memregions->lock);
+
+    return shouldRemove;
 }
 
 static void dump_regions() {
-    init_gralloc_memregions();
-    mem_region_handle_t curr = s_memregions->ashmemRegions.begin();
+    gralloc_memregions_t* memregions = init_gralloc_memregions();
+    gralloc_memregions_t::mem_region_handle_t curr = memregions->ashmemRegions.begin();
     std::stringstream res;
-    for (; curr != s_memregions->ashmemRegions.end(); curr++) {
-        res << "\tashmem base " << curr->ashmemBase << " refcount " << curr->refCount << "\n";
+    for (; curr != memregions->ashmemRegions.end(); ++curr) {
+        res << "\tashmem base " << curr->first << " refcount " << curr->second << "\n";
     }
     ALOGD("ashmem region dump [\n%s]", res.str().c_str());
 }