BpfMap: cache bpf map file descriptors

We switch back to int from ParcelFileDescriptor,
and eliminate all calls to close().  Bpf Map FDs
now live till process exit.

Bug: 230880517
Test: TreeHugger, atest com.android.networkstack.tethering.BpfMapTest
Signed-off-by: Maciej Żenczykowski <maze@google.com>
Change-Id: I89b6dc88ea56cb1e50695f8daf54ed79bce3fba2
(cherry picked from commit 8888c198daa8141988ee806adcf54b43e68b1076)
Merged-In: I89b6dc88ea56cb1e50695f8daf54ed79bce3fba2
diff --git a/staticlibs/device/com/android/net/module/util/BpfMap.java b/staticlibs/device/com/android/net/module/util/BpfMap.java
index a7fbab7..854b9fd 100644
--- a/staticlibs/device/com/android/net/module/util/BpfMap.java
+++ b/staticlibs/device/com/android/net/module/util/BpfMap.java
@@ -20,6 +20,7 @@
 
 import android.os.ParcelFileDescriptor;
 import android.system.ErrnoException;
+import android.util.Pair;
 
 import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
@@ -32,6 +33,7 @@
 import java.nio.ByteOrder;
 import java.util.NoSuchElementException;
 import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
 
 /**
  * BpfMap is a key -> value mapping structure that is designed to maintained the bpf map entries.
@@ -65,6 +67,27 @@
     private final int mKeySize;
     private final int mValueSize;
 
+    private static ConcurrentHashMap<Pair<String, Integer>, ParcelFileDescriptor> sFdCache =
+            new ConcurrentHashMap<>();
+
+    private static ParcelFileDescriptor cachedBpfFdGet(String path, int mode)
+            throws ErrnoException, NullPointerException {
+        Pair<String, Integer> key = Pair.create(path, mode);
+        // unlocked fetch is safe: map is concurrent read capable, and only inserted into
+        ParcelFileDescriptor fd = sFdCache.get(key);
+        if (fd != null) return fd;
+        // ok, no cached fd present, need to grab a lock
+        synchronized (BpfMap.class) {
+            // need to redo the check
+            fd = sFdCache.get(key);
+            if (fd != null) return fd;
+            // okay, we really haven't opened this before...
+            fd = ParcelFileDescriptor.adoptFd(nativeBpfFdGet(path, mode));
+            sFdCache.put(key, fd);
+            return fd;
+        }
+    }
+
     /**
      * Create a BpfMap map wrapper with "path" of filesystem.
      *
@@ -74,7 +97,7 @@
      */
     public BpfMap(@NonNull final String path, final int flag, final Class<K> key,
             final Class<V> value) throws ErrnoException, NullPointerException {
-        mMapFd = ParcelFileDescriptor.adoptFd(bpfFdGet(path, flag));
+        mMapFd = cachedBpfFdGet(path, flag);
         mKeyClass = key;
         mValueClass = value;
         mKeySize = Struct.getSize(key);
@@ -90,7 +113,7 @@
      */
     @VisibleForTesting
     protected BpfMap(final Class<K> key, final Class<V> value) {
-        mMapFd = ParcelFileDescriptor.adoptFd(-1 /*invalid*/);  // unused
+        mMapFd = null;  // unused
         mKeyClass = key;
         mValueClass = value;
         mKeySize = Struct.getSize(key);
@@ -103,7 +126,7 @@
      */
     @Override
     public void updateEntry(K key, V value) throws ErrnoException {
-        writeToMapEntry(mMapFd.getFd(), key.writeToBytes(), value.writeToBytes(), BPF_ANY);
+        nativeWriteToMapEntry(mMapFd.getFd(), key.writeToBytes(), value.writeToBytes(), BPF_ANY);
     }
 
     /**
@@ -114,7 +137,8 @@
     public void insertEntry(K key, V value)
             throws ErrnoException, IllegalStateException {
         try {
-            writeToMapEntry(mMapFd.getFd(), key.writeToBytes(), value.writeToBytes(), BPF_NOEXIST);
+            nativeWriteToMapEntry(mMapFd.getFd(), key.writeToBytes(), value.writeToBytes(),
+                    BPF_NOEXIST);
         } catch (ErrnoException e) {
             if (e.errno == EEXIST) throw new IllegalStateException(key + " already exists");
 
@@ -130,7 +154,8 @@
     public void replaceEntry(K key, V value)
             throws ErrnoException, NoSuchElementException {
         try {
-            writeToMapEntry(mMapFd.getFd(), key.writeToBytes(), value.writeToBytes(), BPF_EXIST);
+            nativeWriteToMapEntry(mMapFd.getFd(), key.writeToBytes(), value.writeToBytes(),
+                    BPF_EXIST);
         } catch (ErrnoException e) {
             if (e.errno == ENOENT) throw new NoSuchElementException(key + " not found");
 
@@ -148,13 +173,15 @@
     public boolean insertOrReplaceEntry(K key, V value)
             throws ErrnoException {
         try {
-            writeToMapEntry(mMapFd.getFd(), key.writeToBytes(), value.writeToBytes(), BPF_NOEXIST);
+            nativeWriteToMapEntry(mMapFd.getFd(), key.writeToBytes(), value.writeToBytes(),
+                    BPF_NOEXIST);
             return true;   /* insert succeeded */
         } catch (ErrnoException e) {
             if (e.errno != EEXIST) throw e;
         }
         try {
-            writeToMapEntry(mMapFd.getFd(), key.writeToBytes(), value.writeToBytes(), BPF_EXIST);
+            nativeWriteToMapEntry(mMapFd.getFd(), key.writeToBytes(), value.writeToBytes(),
+                    BPF_EXIST);
             return false;   /* replace succeeded */
         } catch (ErrnoException e) {
             if (e.errno != ENOENT) throw e;
@@ -171,7 +198,7 @@
     /** Remove existing key from eBpf map. Return false if map was not modified. */
     @Override
     public boolean deleteEntry(K key) throws ErrnoException {
-        return deleteMapEntry(mMapFd.getFd(), key.writeToBytes());
+        return nativeDeleteMapEntry(mMapFd.getFd(), key.writeToBytes());
     }
 
     /** Returns {@code true} if this map contains no elements. */
@@ -204,7 +231,7 @@
 
     private byte[] getNextRawKey(@Nullable final byte[] key) throws ErrnoException {
         byte[] nextKey = new byte[mKeySize];
-        if (getNextMapKey(mMapFd.getFd(), key, nextKey)) return nextKey;
+        if (nativeGetNextMapKey(mMapFd.getFd(), key, nextKey)) return nextKey;
 
         return null;
     }
@@ -239,7 +266,7 @@
 
     private byte[] getRawValue(final byte[] key) throws ErrnoException {
         byte[] value = new byte[mValueSize];
-        if (findMapEntry(mMapFd.getFd(), key, value)) return value;
+        if (nativeFindMapEntry(mMapFd.getFd(), key, value)) return value;
 
         return null;
     }
@@ -263,9 +290,13 @@
         }
     }
 
+    /* Empty implementation to implement AutoCloseable, so we can use BpfMaps
+     * with try with resources, but due to persistent FD cache, there is no actual
+     * need to close anything.  File descriptors will actually be closed when we
+     * unlock the BpfMap class and destroy the ParcelFileDescriptor objects.
+     */
     @Override
     public void close() throws IOException {
-        mMapFd.close();
     }
 
     /**
@@ -283,17 +314,25 @@
         }
     }
 
-    private native int bpfFdGet(String path, int mode) throws ErrnoException, NullPointerException;
+    private static native int nativeBpfFdGet(String path, int mode)
+            throws ErrnoException, NullPointerException;
 
-    private native void writeToMapEntry(int fd, byte[] key, byte[] value, int flags)
+    // Note: the following methods appear to not require the object by virtue of taking the
+    // fd as an int argument, but the hidden reference to this is actually what prevents
+    // the object from being garbage collected (and thus potentially maps closed) prior
+    // to the native code actually running (with a possibly already closed fd).
+
+    private native void nativeWriteToMapEntry(int fd, byte[] key, byte[] value, int flags)
             throws ErrnoException;
 
-    private native boolean deleteMapEntry(int fd, byte[] key) throws ErrnoException;
+    private native boolean nativeDeleteMapEntry(int fd, byte[] key) throws ErrnoException;
 
     // If key is found, the operation returns true and the nextKey would reference to the next
     // element.  If key is not found, the operation returns true and the nextKey would reference to
     // the first element.  If key is the last element, false is returned.
-    private native boolean getNextMapKey(int fd, byte[] key, byte[] nextKey) throws ErrnoException;
+    private native boolean nativeGetNextMapKey(int fd, byte[] key, byte[] nextKey)
+            throws ErrnoException;
 
-    private native boolean findMapEntry(int fd, byte[] key, byte[] value) throws ErrnoException;
+    private native boolean nativeFindMapEntry(int fd, byte[] key, byte[] value)
+            throws ErrnoException;
 }
diff --git a/staticlibs/native/bpfmapjni/com_android_net_module_util_BpfMap.cpp b/staticlibs/native/bpfmapjni/com_android_net_module_util_BpfMap.cpp
index e3f48e5..2e88fc8 100644
--- a/staticlibs/native/bpfmapjni/com_android_net_module_util_BpfMap.cpp
+++ b/staticlibs/native/bpfmapjni/com_android_net_module_util_BpfMap.cpp
@@ -27,18 +27,18 @@
 
 namespace android {
 
-static jint com_android_net_module_util_BpfMap_bpfFdGet(JNIEnv *env, jobject clazz,
+static jint com_android_net_module_util_BpfMap_nativeBpfFdGet(JNIEnv *env, jclass clazz,
         jstring path, jint mode) {
     ScopedUtfChars pathname(env, path);
 
     jint fd = bpf::bpfFdGet(pathname.c_str(), static_cast<unsigned>(mode));
 
-    if (fd < 0) jniThrowErrnoException(env, "bpfFdGet", errno);
+    if (fd < 0) jniThrowErrnoException(env, "nativeBpfFdGet", errno);
 
     return fd;
 }
 
-static void com_android_net_module_util_BpfMap_writeToMapEntry(JNIEnv *env, jobject clazz,
+static void com_android_net_module_util_BpfMap_nativeWriteToMapEntry(JNIEnv *env, jobject self,
         jint fd, jbyteArray key, jbyteArray value, jint flags) {
     ScopedByteArrayRO keyRO(env, key);
     ScopedByteArrayRO valueRO(env, value);
@@ -46,7 +46,7 @@
     int ret = bpf::writeToMapEntry(static_cast<int>(fd), keyRO.get(), valueRO.get(),
             static_cast<int>(flags));
 
-    if (ret) jniThrowErrnoException(env, "writeToMapEntry", errno);
+    if (ret) jniThrowErrnoException(env, "nativeWriteToMapEntry", errno);
 }
 
 static jboolean throwIfNotEnoent(JNIEnv *env, const char* functionName, int ret, int err) {
@@ -56,7 +56,7 @@
     return false;
 }
 
-static jboolean com_android_net_module_util_BpfMap_deleteMapEntry(JNIEnv *env, jobject clazz,
+static jboolean com_android_net_module_util_BpfMap_nativeDeleteMapEntry(JNIEnv *env, jobject self,
         jint fd, jbyteArray key) {
     ScopedByteArrayRO keyRO(env, key);
 
@@ -64,10 +64,10 @@
     // to ENOENT.
     int ret = bpf::deleteMapEntry(static_cast<int>(fd), keyRO.get());
 
-    return throwIfNotEnoent(env, "deleteMapEntry", ret, errno);
+    return throwIfNotEnoent(env, "nativeDeleteMapEntry", ret, errno);
 }
 
-static jboolean com_android_net_module_util_BpfMap_getNextMapKey(JNIEnv *env, jobject clazz,
+static jboolean com_android_net_module_util_BpfMap_nativeGetNextMapKey(JNIEnv *env, jobject self,
         jint fd, jbyteArray key, jbyteArray nextKey) {
     // If key is found, the operation returns zero and sets the next key pointer to the key of the
     // next element.  If key is not found, the operation returns zero and sets the next key pointer
@@ -83,10 +83,10 @@
         ret = bpf::getNextMapKey(static_cast<int>(fd), keyRO.get(), nextKeyRW.get());
     }
 
-    return throwIfNotEnoent(env, "getNextMapKey", ret, errno);
+    return throwIfNotEnoent(env, "nativeGetNextMapKey", ret, errno);
 }
 
-static jboolean com_android_net_module_util_BpfMap_findMapEntry(JNIEnv *env, jobject clazz,
+static jboolean com_android_net_module_util_BpfMap_nativeFindMapEntry(JNIEnv *env, jobject self,
         jint fd, jbyteArray key, jbyteArray value) {
     ScopedByteArrayRO keyRO(env, key);
     ScopedByteArrayRW valueRW(env, value);
@@ -95,7 +95,7 @@
     // "value".  If no element is found, the operation returns -1 and sets errno to ENOENT.
     int ret = bpf::findMapEntry(static_cast<int>(fd), keyRO.get(), valueRW.get());
 
-    return throwIfNotEnoent(env, "findMapEntry", ret, errno);
+    return throwIfNotEnoent(env, "nativeFindMapEntry", ret, errno);
 }
 
 /*
@@ -103,16 +103,16 @@
  */
 static const JNINativeMethod gMethods[] = {
     /* name, signature, funcPtr */
-    { "bpfFdGet", "(Ljava/lang/String;I)I",
-        (void*) com_android_net_module_util_BpfMap_bpfFdGet },
-    { "writeToMapEntry", "(I[B[BI)V",
-        (void*) com_android_net_module_util_BpfMap_writeToMapEntry },
-    { "deleteMapEntry", "(I[B)Z",
-        (void*) com_android_net_module_util_BpfMap_deleteMapEntry },
-    { "getNextMapKey", "(I[B[B)Z",
-        (void*) com_android_net_module_util_BpfMap_getNextMapKey },
-    { "findMapEntry", "(I[B[B)Z",
-        (void*) com_android_net_module_util_BpfMap_findMapEntry },
+    { "nativeBpfFdGet", "(Ljava/lang/String;I)I",
+        (void*) com_android_net_module_util_BpfMap_nativeBpfFdGet },
+    { "nativeWriteToMapEntry", "(I[B[BI)V",
+        (void*) com_android_net_module_util_BpfMap_nativeWriteToMapEntry },
+    { "nativeDeleteMapEntry", "(I[B)Z",
+        (void*) com_android_net_module_util_BpfMap_nativeDeleteMapEntry },
+    { "nativeGetNextMapKey", "(I[B[B)Z",
+        (void*) com_android_net_module_util_BpfMap_nativeGetNextMapKey },
+    { "nativeFindMapEntry", "(I[B[B)Z",
+        (void*) com_android_net_module_util_BpfMap_nativeFindMapEntry },
 
 };